├── .gitignore
├── 1703.03872.pdf
├── Combined_Dataset
├── Test_set
│ └── Composition_code_revised.py
└── Training_set
│ └── Composition_code_revised.py
├── LICENSE
├── README.md
├── clean.sh
├── config.py
├── custom_layers
├── __init__.py
├── scale_layer.py
└── unpooling_layer.py
├── data_generator.py
├── demo.py
├── encoder_decoder.svg
├── history
├── 2018-05-17 07-52-29.png
└── 2018-05-21 21-44-41.png
├── images
├── 0_alpha.png
├── 0_compose.png
├── 0_image.png
├── 0_new_bg.png
├── 0_out.png
├── 0_trimap.png
├── 1_alpha.png
├── 1_compose.png
├── 1_image.png
├── 1_new_bg.png
├── 1_out.png
├── 1_trimap.png
├── 2_alpha.png
├── 2_compose.png
├── 2_image.png
├── 2_new_bg.png
├── 2_out.png
├── 2_trimap.png
├── 3_alpha.png
├── 3_compose.png
├── 3_image.png
├── 3_new_bg.png
├── 3_out.png
├── 3_trimap.png
├── 4_alpha.png
├── 4_compose.png
├── 4_image.png
├── 4_new_bg.png
├── 4_out.png
├── 4_trimap.png
├── 5_alpha.png
├── 5_compose.png
├── 5_image.png
├── 5_new_bg.png
├── 5_out.png
├── 5_trimap.png
├── 6_alpha.png
├── 6_compose.png
├── 6_image.png
├── 6_new_bg.png
├── 6_out.png
├── 6_trimap.png
├── 7_alpha.png
├── 7_compose.png
├── 7_image.png
├── 7_new_bg.png
├── 7_out.png
├── 7_trimap.png
├── 8_alpha.png
├── 8_compose.png
├── 8_image.png
├── 8_new_bg.png
├── 8_out.png
├── 8_trimap.png
├── 9_alpha.png
├── 9_compose.png
├── 9_image.png
├── 9_new_bg.png
├── 9_out.png
└── 9_trimap.png
├── migrate.py
├── model.py
├── model.svg
├── parallel_model.svg
├── plot_model.py
├── pre_process.py
├── predit_single.py
├── refinement.svg
├── segnet.py
├── test.py
├── test_alphamatting.py
├── train.py
├── train_encoder_decoder.py
├── train_final.py
├── train_names.txt
├── train_refinement.py
├── unit_tests.py
├── utils.py
├── valid_names.txt
└── vgg16.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | Adobe_Deep_Matting_Dataset.zip
3 | __pycache__/
4 | custom_layers/__pycache__/
5 | models/
6 | train2014/
7 | logs/
8 | bg/
9 | fg/
10 | mask/
11 | merged/
12 | bg_test/
13 | fg_test/
14 | mask_test/
15 | merged_test/
16 | temp/
17 | .cache/
18 | VOC2008test.tar
19 | VOCdevkit/
20 | VOCtest_06-Nov-2007.tar
21 | VOCtrainval_06-Nov-2007.tar
22 | VOCtrainval_14-Jul-2008.tar
23 | train2014.zip
24 | Deep-Image-Matting.7z
25 | matting_evaluation.zip
26 | training.txt
27 | Combined_Dataset/Adobe Deep Image Mattng Dataset License Agreement.pdf
28 | Combined_Dataset/README.txt
29 | Combined_Dataset/Test_set/Adobe-licensed images/
30 | Combined_Dataset/Test_set/Composition_code.py
31 | Combined_Dataset/Test_set/test_bg_names.txt
32 | Combined_Dataset/Test_set/test_fg_names.txt
33 | Combined_Dataset/Training_set/Adobe-licensed images/
34 | Combined_Dataset/Training_set/Composition_code.py
35 | Combined_Dataset/Training_set/Other/
36 | Combined_Dataset/Training_set/training_bg_names.txt
37 | Combined_Dataset/Training_set/training_fg_names.txt
--------------------------------------------------------------------------------
/1703.03872.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/1703.03872.pdf
--------------------------------------------------------------------------------
/Combined_Dataset/Test_set/Composition_code_revised.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import cv2 as cv
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | ##############################################################
8 | # Set your paths here
9 |
10 | # path to provided foreground images
11 | fg_path = 'data/fg_test/'
12 |
13 | # path to provided alpha mattes
14 | a_path = 'data/mask_test/'
15 |
16 | # Path to background images (MSCOCO)
17 | bg_path = 'data/bg_test/'
18 |
19 | # Path to folder where you want the composited images to go
20 | out_path = 'data/merged_test/'
21 |
22 |
23 | ##############################################################
24 |
25 | def composite4(fg, bg, a, w, h):
26 | fg = np.array(fg, np.float32)
27 | bg = np.array(bg[0:h, 0:w], np.float32)
28 | alpha = np.zeros((h, w, 1), np.float32)
29 | alpha[:, :, 0] = a / 255.
30 | comp = alpha * fg + (1 - alpha) * bg
31 | comp = comp.astype(np.uint8)
32 | return comp
33 |
34 |
35 | def process(im_name, bg_name, fcount, bcount):
36 | im = cv.imread(fg_path + im_name)
37 | a = cv.imread(a_path + im_name, 0)
38 | h, w = im.shape[:2]
39 | bg = cv.imread(bg_path + bg_name)
40 | bh, bw = bg.shape[:2]
41 | wratio = w / bw
42 | hratio = h / bh
43 | ratio = wratio if wratio > hratio else hratio
44 | if ratio > 1:
45 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
46 |
47 | out = composite4(im, bg, a, w, h)
48 | filename = out_path + str(fcount) + '_' + str(bcount) + '.png'
49 | cv.imwrite(filename, out)
50 |
51 |
52 | def do_composite_test():
53 | num_bgs = 20
54 |
55 | with open('data/Combined_Dataset/Test_set/test_bg_names.txt') as f:
56 | bg_files = f.read().splitlines()
57 | with open('data/Combined_Dataset/Test_set/test_fg_names.txt') as f:
58 | fg_files = f.read().splitlines()
59 |
60 | # a_files = os.listdir(a_path)
61 | num_samples = len(fg_files) * num_bgs
62 |
63 | # pb = ProgressBar(total=100, prefix='Compose test images', suffix='', decimals=3, length=50, fill='=')
64 | start = time.time()
65 | bcount = 0
66 | for fcount in tqdm(range(len(fg_files))):
67 | im_name = fg_files[fcount]
68 |
69 | for i in range(num_bgs):
70 | bg_name = bg_files[bcount]
71 | process(im_name, bg_name, fcount, bcount)
72 | bcount += 1
73 |
74 | # pb.print_progress_bar(bcount * 100.0 / num_samples)
75 |
76 | end = time.time()
77 | elapsed = end - start
78 | print('elapsed: {} seconds'.format(elapsed))
79 |
--------------------------------------------------------------------------------
/Combined_Dataset/Training_set/Composition_code_revised.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import cv2 as cv
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | ##############################################################
8 | # Set your paths here
9 |
10 | # path to provided foreground images
11 | fg_path = 'data/fg/'
12 |
13 | # path to provided alpha mattes
14 | a_path = 'data/mask/'
15 |
16 | # Path to background images (MSCOCO)
17 | bg_path = 'data/bg/'
18 |
19 | # Path to folder where you want the composited images to go
20 | out_path = 'data/merged/'
21 |
22 |
23 | ##############################################################
24 |
25 | def composite4(fg, bg, a, w, h):
26 | fg = np.array(fg, np.float32)
27 | bg = np.array(bg[0:h, 0:w], np.float32)
28 | alpha = np.zeros((h, w, 1), np.float32)
29 | alpha[:, :, 0] = a / 255.
30 | comp = alpha * fg + (1 - alpha) * bg
31 | comp = comp.astype(np.uint8)
32 | return comp
33 |
34 |
35 | def process(im_name, bg_name, fcount, bcount):
36 | im = cv.imread(fg_path + im_name)
37 | a = cv.imread(a_path + im_name, 0)
38 | h, w = im.shape[:2]
39 | bg = cv.imread(bg_path + bg_name)
40 | bh, bw = bg.shape[:2]
41 | wratio = w / bw
42 | hratio = h / bh
43 | ratio = wratio if wratio > hratio else hratio
44 | if ratio > 1:
45 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
46 |
47 | out = composite4(im, bg, a, w, h)
48 | filename = out_path + str(fcount) + '_' + str(bcount) + '.png'
49 | cv.imwrite(filename, out)
50 |
51 |
52 | def do_composite():
53 | num_bgs = 100
54 |
55 | with open('data/Combined_Dataset/Training_set/training_bg_names.txt') as f:
56 | bg_files = f.read().splitlines()
57 | with open('data/Combined_Dataset/Training_set/training_fg_names.txt') as f:
58 | fg_files = f.read().splitlines()
59 |
60 | # a_files = os.listdir(a_path)
61 | num_samples = len(fg_files) * num_bgs
62 |
63 | # pb = ProgressBar(total=100, prefix='Compose train images', suffix='', decimals=3, length=50, fill='=')
64 | start = time.time()
65 | bcount = 0
66 | for fcount in tqdm(range(len(fg_files))):
67 | im_name = fg_files[fcount]
68 |
69 | for i in range(num_bgs):
70 | bg_name = bg_files[bcount]
71 | process(im_name, bg_name, fcount, bcount)
72 | bcount += 1
73 |
74 | # pb.print_progress_bar(bcount * 100.0 / num_samples)
75 |
76 | end = time.time()
77 | elapsed = end - start
78 | print('elapsed: {} seconds'.format(elapsed))
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Yang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Just in case you are interested, [Deep Image Matting v2](https://github.com/foamliu/Deep-Image-Matting-v2) is an upgraded version of this.
2 |
3 | # Deep Image Matting
4 | This repository is to reproduce Deep Image Matting.
5 |
6 | ## Dependencies
7 | - [NumPy](http://docs.scipy.org/doc/numpy-1.10.1/user/install.html)
8 | - [Tensorflow 1.9.0](https://www.tensorflow.org/)
9 | - [Keras 2.1.6](https://keras.io/#installation)
10 | - [OpenCV](https://opencv-python-tutroals.readthedocs.io/en/latest/)
11 |
12 | ## Dataset
13 | ### Adobe Deep Image Matting Dataset
14 | Follow the [instruction](https://sites.google.com/view/deepimagematting) to contact author for the dataset.
15 |
16 | ### MSCOCO
17 | Go to [MSCOCO](http://cocodataset.org/#download) to download:
18 | * [2014 Train images](http://images.cocodataset.org/zips/train2014.zip)
19 |
20 |
21 | ### PASCAL VOC
22 | Go to [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) to download:
23 | * VOC challenge 2008 [training/validation data](http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar)
24 | * The test data for the [VOC2008 challenge](http://host.robots.ox.ac.uk/pascal/VOC/voc2008/index.html#testdata)
25 |
26 | ## ImageNet Pretrained Models
27 | Download [VGG16](https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5) into "models" folder.
28 |
29 |
30 | ## Usage
31 | ### Data Pre-processing
32 | Extract training images:
33 | ```bash
34 | $ python pre_process.py
35 | ```
36 |
37 | ### Train
38 | ```bash
39 | $ python train.py
40 | ```
41 |
42 | If you want to visualize during training, run in your terminal:
43 | ```bash
44 | $ tensorboard --logdir path_to_current_dir/logs
45 | ```
46 |
47 | ### Demo
48 | Download pre-trained Deep Image Matting [Model](https://github.com/foamliu/Deep-Image-Matting/releases/download/v1.0/final.42-0.0398.hdf5) to "models" folder then run:
49 | ```bash
50 | $ python demo.py
51 | ```
52 |
53 | Image/Trimap | Output/GT | New BG/Compose |
54 | |---|---|---|
55 | | |  |  |
56 | | |  | |
57 | | |  |  |
58 | | |  | |
59 | | |  |  |
60 | | |  | |
61 | | |  |  |
62 | | |  | |
63 | | |  |  |
64 | | |  | |
65 | | |  |  |
66 | | |  | |
67 | | |  |  |
68 | | |  | |
69 | | |  |  |
70 | | |  | |
71 | | |  |  |
72 | | |  | |
73 | | |  |  |
74 | | |  | |
75 |
76 |
--------------------------------------------------------------------------------
/clean.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | rm models/model.*
3 | rm logs -r
4 | rm training.txt
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | img_rows, img_cols = 320, 320
2 | # img_rows_half, img_cols_half = 160, 160
3 | channel = 4
4 | batch_size = 16
5 | epochs = 1000
6 | patience = 50
7 | num_samples = 43100
8 | num_train_samples = 34480
9 | # num_samples - num_train_samples
10 | num_valid_samples = 8620
11 | unknown_code = 128
12 | epsilon = 1e-6
13 | epsilon_sqr = epsilon ** 2
14 |
15 | ##############################################################
16 | # Set your paths here
17 |
18 | # path to provided foreground images
19 | fg_path = 'data/fg/'
20 |
21 | # path to provided alpha mattes
22 | a_path = 'data/mask/'
23 |
24 | # Path to background images (MSCOCO)
25 | bg_path = 'data/bg/'
26 |
27 | # Path to folder where you want the composited images to go
28 | out_path = 'data/merged/'
29 |
30 | ##############################################################
31 |
--------------------------------------------------------------------------------
/custom_layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Python Package
2 |
--------------------------------------------------------------------------------
/custom_layers/scale_layer.py:
--------------------------------------------------------------------------------
1 | from keras.layers.core import Layer
2 | from keras.engine import InputSpec
3 | from keras import backend as K
4 | try:
5 | from keras import initializations
6 | except ImportError:
7 | from keras import initializers as initializations
8 |
9 | class Scale(Layer):
10 | '''Learns a set of weights and biases used for scaling the input data.
11 | the output consists simply in an element-wise multiplication of the input
12 | and a sum of a set of constants:
13 |
14 | out = in * gamma + beta,
15 |
16 | where 'gamma' and 'beta' are the weights and biases larned.
17 |
18 | # Arguments
19 | axis: integer, axis along which to normalize in mode 0. For instance,
20 | if your input tensor has shape (samples, channels, rows, cols),
21 | set axis to 1 to normalize per feature map (channels axis).
22 | momentum: momentum in the computation of the
23 | exponential average of the mean and standard deviation
24 | of the data, for feature-wise normalization.
25 | weights: Initialization weights.
26 | List of 2 Numpy arrays, with shapes:
27 | `[(input_shape,), (input_shape,)]`
28 | beta_init: name of initialization function for shift parameter
29 | (see [initializations](../initializations.md)), or alternatively,
30 | Theano/TensorFlow function to use for weights initialization.
31 | This parameter is only relevant if you don't pass a `weights` argument.
32 | gamma_init: name of initialization function for scale parameter (see
33 | [initializations](../initializations.md)), or alternatively,
34 | Theano/TensorFlow function to use for weights initialization.
35 | This parameter is only relevant if you don't pass a `weights` argument.
36 | '''
37 | def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs):
38 | self.momentum = momentum
39 | self.axis = axis
40 | self.beta_init = initializations.get(beta_init)
41 | self.gamma_init = initializations.get(gamma_init)
42 | self.initial_weights = weights
43 | super(Scale, self).__init__(**kwargs)
44 |
45 | def build(self, input_shape):
46 | self.input_spec = [InputSpec(shape=input_shape)]
47 | shape = (int(input_shape[self.axis]),)
48 |
49 | # Compatibility with TensorFlow >= 1.0.0
50 | self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name))
51 | self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name))
52 | #self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
53 | #self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
54 | self.trainable_weights = [self.gamma, self.beta]
55 |
56 | if self.initial_weights is not None:
57 | self.set_weights(self.initial_weights)
58 | del self.initial_weights
59 |
60 | def call(self, x, mask=None):
61 | input_shape = self.input_spec[0].shape
62 | broadcast_shape = [1] * len(input_shape)
63 | broadcast_shape[self.axis] = input_shape[self.axis]
64 |
65 | out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape)
66 | return out
67 |
68 | def get_config(self):
69 | config = {"momentum": self.momentum, "axis": self.axis}
70 | base_config = super(Scale, self).get_config()
71 | return dict(list(base_config.items()) + list(config.items()))
72 |
--------------------------------------------------------------------------------
/custom_layers/unpooling_layer.py:
--------------------------------------------------------------------------------
1 | from keras import backend as K
2 | from keras.engine.topology import Layer
3 | from keras.layers import Lambda, Multiply
4 |
5 |
6 | class Unpooling(Layer):
7 |
8 | def __init__(self, **kwargs):
9 | super(Unpooling, self).__init__(**kwargs)
10 |
11 | def build(self, input_shape):
12 | super(Unpooling, self).build(input_shape)
13 |
14 | def call(self, inputs, **kwargs):
15 | x = inputs[:, 1]
16 | # print('x.shape: ' + str(K.int_shape(x)))
17 | bool_mask = Lambda(lambda t: K.greater_equal(t[:, 0], t[:, 1]),
18 | output_shape=K.int_shape(x)[1:])(inputs)
19 | # print('bool_mask.shape: ' + str(K.int_shape(bool_mask)))
20 | mask = Lambda(lambda t: K.cast(t, dtype='float32'))(bool_mask)
21 | # print('mask.shape: ' + str(K.int_shape(mask)))
22 | x = Multiply()([mask, x])
23 | # print('x.shape: ' + str(K.int_shape(x)))
24 | return x
25 |
26 | def compute_output_shape(self, input_shape):
27 | return input_shape[0], input_shape[2], input_shape[3], input_shape[4]
28 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | from random import shuffle
5 |
6 | import cv2 as cv
7 | import numpy as np
8 | from keras.utils import Sequence
9 |
10 | from config import batch_size
11 | from config import fg_path, bg_path, a_path
12 | from config import img_cols, img_rows
13 | from config import unknown_code
14 | from utils import safe_crop
15 |
16 | kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (3, 3))
17 | with open('data/Combined_Dataset/Training_set/training_fg_names.txt') as f:
18 | fg_files = f.read().splitlines()
19 | with open('data/Combined_Dataset/Test_set/test_fg_names.txt') as f:
20 | fg_test_files = f.read().splitlines()
21 | with open('data/Combined_Dataset/Training_set/training_bg_names.txt') as f:
22 | bg_files = f.read().splitlines()
23 | with open('data/Combined_Dataset/Test_set/test_bg_names.txt') as f:
24 | bg_test_files = f.read().splitlines()
25 |
26 |
27 | def get_alpha(name):
28 | fg_i = int(name.split("_")[0])
29 | name = fg_files[fg_i]
30 | filename = os.path.join('data/mask', name)
31 | alpha = cv.imread(filename, 0)
32 | return alpha
33 |
34 |
35 | def get_alpha_test(name):
36 | fg_i = int(name.split("_")[0])
37 | name = fg_test_files[fg_i]
38 | filename = os.path.join('data/mask_test', name)
39 | alpha = cv.imread(filename, 0)
40 | return alpha
41 |
42 |
43 | def composite4(fg, bg, a, w, h):
44 | fg = np.array(fg, np.float32)
45 | bg_h, bg_w = bg.shape[:2]
46 | x = 0
47 | if bg_w > w:
48 | x = np.random.randint(0, bg_w - w)
49 | y = 0
50 | if bg_h > h:
51 | y = np.random.randint(0, bg_h - h)
52 | bg = np.array(bg[y:y + h, x:x + w], np.float32)
53 | alpha = np.zeros((h, w, 1), np.float32)
54 | alpha[:, :, 0] = a / 255.
55 | im = alpha * fg + (1 - alpha) * bg
56 | im = im.astype(np.uint8)
57 | return im, a, fg, bg
58 |
59 |
60 | def process(im_name, bg_name):
61 | im = cv.imread(fg_path + im_name)
62 | a = cv.imread(a_path + im_name, 0)
63 | h, w = im.shape[:2]
64 | bg = cv.imread(bg_path + bg_name)
65 | bh, bw = bg.shape[:2]
66 | wratio = w / bw
67 | hratio = h / bh
68 | ratio = wratio if wratio > hratio else hratio
69 | if ratio > 1:
70 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
71 |
72 | return composite4(im, bg, a, w, h)
73 |
74 |
75 | def generate_trimap(alpha):
76 | fg = np.array(np.equal(alpha, 255).astype(np.float32))
77 | # fg = cv.erode(fg, kernel, iterations=np.random.randint(1, 3))
78 | unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
79 | unknown = cv.dilate(unknown, kernel, iterations=np.random.randint(1, 20))
80 | trimap = fg * 255 + (unknown - fg) * 128
81 | return trimap.astype(np.uint8)
82 |
83 |
84 | # Randomly crop (image, trimap) pairs centered on pixels in the unknown regions.
85 | def random_choice(trimap, crop_size=(320, 320)):
86 | crop_height, crop_width = crop_size
87 | y_indices, x_indices = np.where(trimap == unknown_code)
88 | num_unknowns = len(y_indices)
89 | x, y = 0, 0
90 | if num_unknowns > 0:
91 | ix = np.random.choice(range(num_unknowns))
92 | center_x = x_indices[ix]
93 | center_y = y_indices[ix]
94 | x = max(0, center_x - int(crop_width / 2))
95 | y = max(0, center_y - int(crop_height / 2))
96 | return x, y
97 |
98 |
99 | class DataGenSequence(Sequence):
100 | def __init__(self, usage):
101 | self.usage = usage
102 |
103 | filename = '{}_names.txt'.format(usage)
104 | with open(filename, 'r') as f:
105 | self.names = f.read().splitlines()
106 |
107 | np.random.shuffle(self.names)
108 |
109 | def __len__(self):
110 | return int(np.ceil(len(self.names) / float(batch_size)))
111 |
112 | def __getitem__(self, idx):
113 | i = idx * batch_size
114 |
115 | length = min(batch_size, (len(self.names) - i))
116 | batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32)
117 | batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32)
118 |
119 | for i_batch in range(length):
120 | name = self.names[i]
121 | fcount = int(name.split('.')[0].split('_')[0])
122 | bcount = int(name.split('.')[0].split('_')[1])
123 | im_name = fg_files[fcount]
124 | bg_name = bg_files[bcount]
125 | image, alpha, fg, bg = process(im_name, bg_name)
126 |
127 | # crop size 320:640:480 = 1:1:1
128 | different_sizes = [(320, 320), (480, 480), (640, 640)]
129 | crop_size = random.choice(different_sizes)
130 |
131 | trimap = generate_trimap(alpha)
132 | x, y = random_choice(trimap, crop_size)
133 | image = safe_crop(image, x, y, crop_size)
134 | alpha = safe_crop(alpha, x, y, crop_size)
135 |
136 | trimap = generate_trimap(alpha)
137 |
138 | # Flip array left to right randomly (prob=1:1)
139 | if np.random.random_sample() > 0.5:
140 | image = np.fliplr(image)
141 | trimap = np.fliplr(trimap)
142 | alpha = np.fliplr(alpha)
143 |
144 | batch_x[i_batch, :, :, 0:3] = image / 255.
145 | batch_x[i_batch, :, :, 3] = trimap / 255.
146 |
147 | mask = np.equal(trimap, 128).astype(np.float32)
148 | batch_y[i_batch, :, :, 0] = alpha / 255.
149 | batch_y[i_batch, :, :, 1] = mask
150 |
151 | i += 1
152 |
153 | return batch_x, batch_y
154 |
155 | def on_epoch_end(self):
156 | np.random.shuffle(self.names)
157 |
158 |
159 | def train_gen():
160 | return DataGenSequence('train')
161 |
162 |
163 | def valid_gen():
164 | return DataGenSequence('valid')
165 |
166 |
167 | def shuffle_data():
168 | num_fgs = 431
169 | num_bgs = 43100
170 | num_bgs_per_fg = 100
171 | num_valid_samples = 8620
172 | names = []
173 | bcount = 0
174 | for fcount in range(num_fgs):
175 | for i in range(num_bgs_per_fg):
176 | names.append(str(fcount) + '_' + str(bcount) + '.png')
177 | bcount += 1
178 |
179 | from config import num_valid_samples
180 | valid_names = random.sample(names, num_valid_samples)
181 | train_names = [n for n in names if n not in valid_names]
182 | shuffle(valid_names)
183 | shuffle(train_names)
184 |
185 | with open('valid_names.txt', 'w') as file:
186 | file.write('\n'.join(valid_names))
187 |
188 | with open('train_names.txt', 'w') as file:
189 | file.write('\n'.join(train_names))
190 |
191 |
192 | if __name__ == '__main__':
193 | filename = 'merged/357_35748.png'
194 | bgr_img = cv.imread(filename)
195 | bg_h, bg_w = bgr_img.shape[:2]
196 | print(bg_w, bg_h)
197 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 |
5 | import cv2 as cv
6 | import keras.backend as K
7 | import numpy as np
8 |
9 | from data_generator import generate_trimap, random_choice, get_alpha_test
10 | from model import build_encoder_decoder, build_refinement
11 | from utils import compute_mse_loss, compute_sad_loss
12 | from utils import get_final_output, safe_crop, draw_str
13 |
14 |
15 | def composite4(fg, bg, a, w, h):
16 | fg = np.array(fg, np.float32)
17 | bg_h, bg_w = bg.shape[:2]
18 | x = 0
19 | if bg_w > w:
20 | x = np.random.randint(0, bg_w - w)
21 | y = 0
22 | if bg_h > h:
23 | y = np.random.randint(0, bg_h - h)
24 | bg = np.array(bg[y:y + h, x:x + w], np.float32)
25 | alpha = np.zeros((h, w, 1), np.float32)
26 | alpha[:, :, 0] = a / 255.
27 | im = alpha * fg + (1 - alpha) * bg
28 | im = im.astype(np.uint8)
29 | return im, bg
30 |
31 |
32 | if __name__ == '__main__':
33 | img_rows, img_cols = 320, 320
34 | channel = 4
35 |
36 | pretrained_path = 'models/final.42-0.0398.hdf5'
37 | encoder_decoder = build_encoder_decoder()
38 | final = build_refinement(encoder_decoder)
39 | final.load_weights(pretrained_path)
40 | print(final.summary())
41 |
42 | out_test_path = 'data/merged_test/'
43 | test_images = [f for f in os.listdir(out_test_path) if
44 | os.path.isfile(os.path.join(out_test_path, f)) and f.endswith('.png')]
45 | samples = random.sample(test_images, 10)
46 |
47 | bg_test = 'data/bg_test/'
48 | test_bgs = [f for f in os.listdir(bg_test) if
49 | os.path.isfile(os.path.join(bg_test, f)) and f.endswith('.jpg')]
50 | sample_bgs = random.sample(test_bgs, 10)
51 |
52 | total_loss = 0.0
53 | for i in range(len(samples)):
54 | filename = samples[i]
55 | image_name = filename.split('.')[0]
56 |
57 | print('\nStart processing image: {}'.format(filename))
58 |
59 | bgr_img = cv.imread(os.path.join(out_test_path, filename))
60 | bg_h, bg_w = bgr_img.shape[:2]
61 | print('bg_h, bg_w: ' + str((bg_h, bg_w)))
62 |
63 | a = get_alpha_test(image_name)
64 | a_h, a_w = a.shape[:2]
65 | print('a_h, a_w: ' + str((a_h, a_w)))
66 |
67 | alpha = np.zeros((bg_h, bg_w), np.float32)
68 | alpha[0:a_h, 0:a_w] = a
69 | trimap = generate_trimap(alpha)
70 | different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
71 | crop_size = random.choice(different_sizes)
72 | x, y = random_choice(trimap, crop_size)
73 | print('x, y: ' + str((x, y)))
74 |
75 | bgr_img = safe_crop(bgr_img, x, y, crop_size)
76 | alpha = safe_crop(alpha, x, y, crop_size)
77 | trimap = safe_crop(trimap, x, y, crop_size)
78 | cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8))
79 | cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
80 | cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))
81 |
82 | x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
83 | x_test[0, :, :, 0:3] = bgr_img / 255.
84 | x_test[0, :, :, 3] = trimap / 255.
85 |
86 | y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
87 | y_true[0, :, :, 0] = alpha / 255.
88 | y_true[0, :, :, 1] = trimap / 255.
89 |
90 | y_pred = final.predict(x_test)
91 | # print('y_pred.shape: ' + str(y_pred.shape))
92 |
93 | y_pred = np.reshape(y_pred, (img_rows, img_cols))
94 | print(y_pred.shape)
95 | y_pred = y_pred * 255.0
96 | y_pred = get_final_output(y_pred, trimap)
97 | y_pred = y_pred.astype(np.uint8)
98 |
99 | sad_loss = compute_sad_loss(y_pred, alpha, trimap)
100 | mse_loss = compute_mse_loss(y_pred, alpha, trimap)
101 | str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (sad_loss, mse_loss, str(crop_size))
102 | print(str_msg)
103 |
104 | out = y_pred.copy()
105 | draw_str(out, (10, 20), str_msg)
106 | cv.imwrite('images/{}_out.png'.format(i), out)
107 |
108 | sample_bg = sample_bgs[i]
109 | bg = cv.imread(os.path.join(bg_test, sample_bg))
110 | bh, bw = bg.shape[:2]
111 | wratio = img_cols / bw
112 | hratio = img_rows / bh
113 | ratio = wratio if wratio > hratio else hratio
114 | if ratio > 1:
115 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
116 | im, bg = composite4(bgr_img, bg, y_pred, img_cols, img_rows)
117 | cv.imwrite('images/{}_compose.png'.format(i), im)
118 | cv.imwrite('images/{}_new_bg.png'.format(i), bg)
119 |
120 | K.clear_session()
121 |
--------------------------------------------------------------------------------
/history/2018-05-17 07-52-29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/history/2018-05-17 07-52-29.png
--------------------------------------------------------------------------------
/history/2018-05-21 21-44-41.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/history/2018-05-21 21-44-41.png
--------------------------------------------------------------------------------
/images/0_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_alpha.png
--------------------------------------------------------------------------------
/images/0_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_compose.png
--------------------------------------------------------------------------------
/images/0_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_image.png
--------------------------------------------------------------------------------
/images/0_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_new_bg.png
--------------------------------------------------------------------------------
/images/0_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_out.png
--------------------------------------------------------------------------------
/images/0_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/0_trimap.png
--------------------------------------------------------------------------------
/images/1_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_alpha.png
--------------------------------------------------------------------------------
/images/1_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_compose.png
--------------------------------------------------------------------------------
/images/1_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_image.png
--------------------------------------------------------------------------------
/images/1_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_new_bg.png
--------------------------------------------------------------------------------
/images/1_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_out.png
--------------------------------------------------------------------------------
/images/1_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/1_trimap.png
--------------------------------------------------------------------------------
/images/2_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_alpha.png
--------------------------------------------------------------------------------
/images/2_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_compose.png
--------------------------------------------------------------------------------
/images/2_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_image.png
--------------------------------------------------------------------------------
/images/2_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_new_bg.png
--------------------------------------------------------------------------------
/images/2_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_out.png
--------------------------------------------------------------------------------
/images/2_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/2_trimap.png
--------------------------------------------------------------------------------
/images/3_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_alpha.png
--------------------------------------------------------------------------------
/images/3_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_compose.png
--------------------------------------------------------------------------------
/images/3_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_image.png
--------------------------------------------------------------------------------
/images/3_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_new_bg.png
--------------------------------------------------------------------------------
/images/3_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_out.png
--------------------------------------------------------------------------------
/images/3_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/3_trimap.png
--------------------------------------------------------------------------------
/images/4_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_alpha.png
--------------------------------------------------------------------------------
/images/4_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_compose.png
--------------------------------------------------------------------------------
/images/4_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_image.png
--------------------------------------------------------------------------------
/images/4_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_new_bg.png
--------------------------------------------------------------------------------
/images/4_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_out.png
--------------------------------------------------------------------------------
/images/4_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/4_trimap.png
--------------------------------------------------------------------------------
/images/5_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_alpha.png
--------------------------------------------------------------------------------
/images/5_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_compose.png
--------------------------------------------------------------------------------
/images/5_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_image.png
--------------------------------------------------------------------------------
/images/5_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_new_bg.png
--------------------------------------------------------------------------------
/images/5_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_out.png
--------------------------------------------------------------------------------
/images/5_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/5_trimap.png
--------------------------------------------------------------------------------
/images/6_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_alpha.png
--------------------------------------------------------------------------------
/images/6_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_compose.png
--------------------------------------------------------------------------------
/images/6_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_image.png
--------------------------------------------------------------------------------
/images/6_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_new_bg.png
--------------------------------------------------------------------------------
/images/6_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_out.png
--------------------------------------------------------------------------------
/images/6_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/6_trimap.png
--------------------------------------------------------------------------------
/images/7_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_alpha.png
--------------------------------------------------------------------------------
/images/7_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_compose.png
--------------------------------------------------------------------------------
/images/7_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_image.png
--------------------------------------------------------------------------------
/images/7_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_new_bg.png
--------------------------------------------------------------------------------
/images/7_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_out.png
--------------------------------------------------------------------------------
/images/7_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/7_trimap.png
--------------------------------------------------------------------------------
/images/8_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_alpha.png
--------------------------------------------------------------------------------
/images/8_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_compose.png
--------------------------------------------------------------------------------
/images/8_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_image.png
--------------------------------------------------------------------------------
/images/8_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_new_bg.png
--------------------------------------------------------------------------------
/images/8_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_out.png
--------------------------------------------------------------------------------
/images/8_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/8_trimap.png
--------------------------------------------------------------------------------
/images/9_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_alpha.png
--------------------------------------------------------------------------------
/images/9_compose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_compose.png
--------------------------------------------------------------------------------
/images/9_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_image.png
--------------------------------------------------------------------------------
/images/9_new_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_new_bg.png
--------------------------------------------------------------------------------
/images/9_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_out.png
--------------------------------------------------------------------------------
/images/9_trimap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Deep-Image-Matting/bd4c18bd3574069fdda99ed70c9519c65c503f8c/images/9_trimap.png
--------------------------------------------------------------------------------
/migrate.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | import numpy as np
3 |
4 | from config import channel
5 | from model import build_encoder_decoder
6 | from vgg16 import vgg16_model
7 |
8 |
9 | def migrate_model(new_model):
10 | old_model = vgg16_model(224, 224, 3)
11 | # print(old_model.summary())
12 | old_layers = [l for l in old_model.layers]
13 | new_layers = [l for l in new_model.layers]
14 |
15 | old_conv1_1 = old_model.get_layer('conv1_1')
16 | old_weights = old_conv1_1.get_weights()[0]
17 | old_biases = old_conv1_1.get_weights()[1]
18 | new_weights = np.zeros((3, 3, channel, 64), dtype=np.float32)
19 | new_weights[:, :, 0:3, :] = old_weights
20 | new_weights[:, :, 3:channel, :] = 0.0
21 | new_conv1_1 = new_model.get_layer('conv1_1')
22 | new_conv1_1.set_weights([new_weights, old_biases])
23 |
24 | for i in range(2, 31):
25 | old_layer = old_layers[i]
26 | new_layer = new_layers[i + 1]
27 | new_layer.set_weights(old_layer.get_weights())
28 |
29 | # flatten = old_model.get_layer('flatten')
30 | # f_dim = flatten.input_shape
31 | # print('f_dim: ' + str(f_dim))
32 | # old_dense1 = old_model.get_layer('dense1')
33 | # input_shape = old_dense1.input_shape
34 | # output_dim = old_dense1.get_weights()[1].shape[0]
35 | # print('output_dim: ' + str(output_dim))
36 | # W, b = old_dense1.get_weights()
37 | # shape = (7, 7, 512, output_dim)
38 | # new_W = W.reshape(shape)
39 | # new_conv6 = new_model.get_layer('conv6')
40 | # new_conv6.set_weights([new_W, b])
41 |
42 | del old_model
43 |
44 |
45 | if __name__ == '__main__':
46 | model = build_encoder_decoder()
47 | migrate_model(model)
48 | print(model.summary())
49 | model.save_weights('models/model_weights.h5')
50 |
51 | K.clear_session()
52 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | import tensorflow as tf
3 | from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Concatenate, \
4 | Reshape, Lambda
5 | from keras.models import Model
6 | from keras.utils import multi_gpu_model
7 | from keras.utils import plot_model
8 |
9 | from custom_layers.unpooling_layer import Unpooling
10 |
11 |
12 | def build_encoder_decoder():
13 | # Encoder
14 | input_tensor = Input(shape=(320, 320, 4))
15 | x = ZeroPadding2D((1, 1))(input_tensor)
16 | x = Conv2D(64, (3, 3), activation='relu', name='conv1_1')(x)
17 | x = ZeroPadding2D((1, 1))(x)
18 | x = Conv2D(64, (3, 3), activation='relu', name='conv1_2')(x)
19 | orig_1 = x
20 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
21 |
22 | x = ZeroPadding2D((1, 1))(x)
23 | x = Conv2D(128, (3, 3), activation='relu', name='conv2_1')(x)
24 | x = ZeroPadding2D((1, 1))(x)
25 | x = Conv2D(128, (3, 3), activation='relu', name='conv2_2')(x)
26 | orig_2 = x
27 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
28 |
29 | x = ZeroPadding2D((1, 1))(x)
30 | x = Conv2D(256, (3, 3), activation='relu', name='conv3_1')(x)
31 | x = ZeroPadding2D((1, 1))(x)
32 | x = Conv2D(256, (3, 3), activation='relu', name='conv3_2')(x)
33 | x = ZeroPadding2D((1, 1))(x)
34 | x = Conv2D(256, (3, 3), activation='relu', name='conv3_3')(x)
35 | orig_3 = x
36 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
37 |
38 | x = ZeroPadding2D((1, 1))(x)
39 | x = Conv2D(512, (3, 3), activation='relu', name='conv4_1')(x)
40 | x = ZeroPadding2D((1, 1))(x)
41 | x = Conv2D(512, (3, 3), activation='relu', name='conv4_2')(x)
42 | x = ZeroPadding2D((1, 1))(x)
43 | x = Conv2D(512, (3, 3), activation='relu', name='conv4_3')(x)
44 | orig_4 = x
45 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
46 |
47 | x = ZeroPadding2D((1, 1))(x)
48 | x = Conv2D(512, (3, 3), activation='relu', name='conv5_1')(x)
49 | x = ZeroPadding2D((1, 1))(x)
50 | x = Conv2D(512, (3, 3), activation='relu', name='conv5_2')(x)
51 | x = ZeroPadding2D((1, 1))(x)
52 | x = Conv2D(512, (3, 3), activation='relu', name='conv5_3')(x)
53 | orig_5 = x
54 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
55 |
56 | # Decoder
57 | # x = Conv2D(4096, (7, 7), activation='relu', padding='valid', name='conv6')(x)
58 | # x = BatchNormalization()(x)
59 | # x = UpSampling2D(size=(7, 7))(x)
60 |
61 | x = Conv2D(512, (1, 1), activation='relu', padding='same', name='deconv6', kernel_initializer='he_normal',
62 | bias_initializer='zeros')(x)
63 | x = BatchNormalization()(x)
64 | x = UpSampling2D(size=(2, 2))(x)
65 | the_shape = K.int_shape(orig_5)
66 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
67 | origReshaped = Reshape(shape)(orig_5)
68 | # print('origReshaped.shape: ' + str(K.int_shape(origReshaped)))
69 | xReshaped = Reshape(shape)(x)
70 | # print('xReshaped.shape: ' + str(K.int_shape(xReshaped)))
71 | together = Concatenate(axis=1)([origReshaped, xReshaped])
72 | # print('together.shape: ' + str(K.int_shape(together)))
73 | x = Unpooling()(together)
74 |
75 | x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5', kernel_initializer='he_normal',
76 | bias_initializer='zeros')(x)
77 | x = BatchNormalization()(x)
78 | x = UpSampling2D(size=(2, 2))(x)
79 | the_shape = K.int_shape(orig_4)
80 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
81 | origReshaped = Reshape(shape)(orig_4)
82 | xReshaped = Reshape(shape)(x)
83 | together = Concatenate(axis=1)([origReshaped, xReshaped])
84 | x = Unpooling()(together)
85 |
86 | x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4', kernel_initializer='he_normal',
87 | bias_initializer='zeros')(x)
88 | x = BatchNormalization()(x)
89 | x = UpSampling2D(size=(2, 2))(x)
90 | the_shape = K.int_shape(orig_3)
91 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
92 | origReshaped = Reshape(shape)(orig_3)
93 | xReshaped = Reshape(shape)(x)
94 | together = Concatenate(axis=1)([origReshaped, xReshaped])
95 | x = Unpooling()(together)
96 |
97 | x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3', kernel_initializer='he_normal',
98 | bias_initializer='zeros')(x)
99 | x = BatchNormalization()(x)
100 | x = UpSampling2D(size=(2, 2))(x)
101 | the_shape = K.int_shape(orig_2)
102 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
103 | origReshaped = Reshape(shape)(orig_2)
104 | xReshaped = Reshape(shape)(x)
105 | together = Concatenate(axis=1)([origReshaped, xReshaped])
106 | x = Unpooling()(together)
107 |
108 | x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2', kernel_initializer='he_normal',
109 | bias_initializer='zeros')(x)
110 | x = BatchNormalization()(x)
111 | x = UpSampling2D(size=(2, 2))(x)
112 | the_shape = K.int_shape(orig_1)
113 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
114 | origReshaped = Reshape(shape)(orig_1)
115 | xReshaped = Reshape(shape)(x)
116 | together = Concatenate(axis=1)([origReshaped, xReshaped])
117 | x = Unpooling()(together)
118 | x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1', kernel_initializer='he_normal',
119 | bias_initializer='zeros')(x)
120 | x = BatchNormalization()(x)
121 |
122 | x = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal',
123 | bias_initializer='zeros')(x)
124 |
125 | model = Model(inputs=input_tensor, outputs=x)
126 | return model
127 |
128 |
129 | def build_refinement(encoder_decoder):
130 | input_tensor = encoder_decoder.input
131 |
132 | input = Lambda(lambda i: i[:, :, :, 0:3])(input_tensor)
133 |
134 | x = Concatenate(axis=3)([input, encoder_decoder.output])
135 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
136 | bias_initializer='zeros')(x)
137 | x = BatchNormalization()(x)
138 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
139 | bias_initializer='zeros')(x)
140 | x = BatchNormalization()(x)
141 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
142 | bias_initializer='zeros')(x)
143 | x = BatchNormalization()(x)
144 | x = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='refinement_pred', kernel_initializer='he_normal',
145 | bias_initializer='zeros')(x)
146 |
147 | model = Model(inputs=input_tensor, outputs=x)
148 | return model
149 |
150 |
151 | if __name__ == '__main__':
152 | with tf.device("/cpu:0"):
153 | encoder_decoder = build_encoder_decoder()
154 | print(encoder_decoder.summary())
155 | plot_model(encoder_decoder, to_file='encoder_decoder.svg', show_layer_names=True, show_shapes=True)
156 |
157 | with tf.device("/cpu:0"):
158 | refinement = build_refinement(encoder_decoder)
159 | print(refinement.summary())
160 | plot_model(refinement, to_file='refinement.svg', show_layer_names=True, show_shapes=True)
161 |
162 | parallel_model = multi_gpu_model(refinement, gpus=None)
163 | print(parallel_model.summary())
164 | plot_model(parallel_model, to_file='parallel_model.svg', show_layer_names=True, show_shapes=True)
165 |
166 | K.clear_session()
167 |
--------------------------------------------------------------------------------
/model.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
999 |
--------------------------------------------------------------------------------
/parallel_model.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
104 |
--------------------------------------------------------------------------------
/plot_model.py:
--------------------------------------------------------------------------------
1 | # dependency: pip install pydot & brew install graphviz
2 | from model import build_encoder_decoder
3 | from keras.utils import plot_model
4 |
5 | if __name__ == '__main__':
6 | img_rows, img_cols = 320, 320
7 | channel = 3
8 | model = build_encoder_decoder(img_rows, img_cols, channel)
9 | plot_model(model, to_file='model.svg', show_layer_names=True, show_shapes=True)
10 |
--------------------------------------------------------------------------------
/pre_process.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import shutil
5 | import zipfile
6 | import tarfile
7 |
8 | from Combined_Dataset.Training_set.Composition_code_revised import do_composite
9 | from Combined_Dataset.Test_set.Composition_code_revised import do_composite_test
10 |
11 | if __name__ == '__main__':
12 | # path to provided foreground images
13 | fg_path = 'data/fg/'
14 | # path to provided alpha mattes
15 | a_path = 'data/mask/'
16 | # Path to background images (MSCOCO)
17 | bg_path = 'data/bg/'
18 | # Path to folder where you want the composited images to go
19 | out_path = 'data/merged/'
20 |
21 | train_folder = 'data/Combined_Dataset/Training_set/'
22 |
23 | # if not os.path.exists('Combined_Dataset'):
24 | zip_file = 'data/Adobe_Deep_Matting_Dataset.zip'
25 | print('Extracting {}...'.format(zip_file))
26 |
27 | zip_ref = zipfile.ZipFile(zip_file, 'r')
28 | zip_ref.extractall('data')
29 | zip_ref.close()
30 |
31 | if not os.path.exists(bg_path):
32 | zip_file = 'data/train2014.zip'
33 | print('Extracting {}...'.format(zip_file))
34 |
35 | zip_ref = zipfile.ZipFile(zip_file, 'r')
36 | zip_ref.extractall('data')
37 | zip_ref.close()
38 |
39 | with open(os.path.join(train_folder, 'training_bg_names.txt')) as f:
40 | training_bg_names = f.read().splitlines()
41 |
42 | os.makedirs(bg_path)
43 | for bg_name in training_bg_names:
44 | src_path = os.path.join('data/train2014', bg_name)
45 | dest_path = os.path.join(bg_path, bg_name)
46 | shutil.move(src_path, dest_path)
47 |
48 | if not os.path.exists(fg_path):
49 | os.makedirs(fg_path)
50 |
51 | for old_folder in [train_folder + 'Adobe-licensed images/fg', train_folder + 'Other/fg']:
52 | fg_files = os.listdir(old_folder)
53 | for fg_file in fg_files:
54 | src_path = os.path.join(old_folder, fg_file)
55 | dest_path = os.path.join(fg_path, fg_file)
56 | shutil.move(src_path, dest_path)
57 |
58 | if not os.path.exists(a_path):
59 | os.makedirs(a_path)
60 |
61 | for old_folder in [train_folder + 'Adobe-licensed images/alpha', train_folder + 'Other/alpha']:
62 | a_files = os.listdir(old_folder)
63 | for a_file in a_files:
64 | src_path = os.path.join(old_folder, a_file)
65 | dest_path = os.path.join(a_path, a_file)
66 | shutil.move(src_path, dest_path)
67 |
68 | if not os.path.exists(out_path):
69 | os.makedirs(out_path)
70 | # do_composite()
71 |
72 | # path to provided foreground images
73 | fg_test_path = 'data/fg_test/'
74 | # path to provided alpha mattes
75 | a_test_path = 'data/mask_test/'
76 | # Path to background images (PASCAL VOC)
77 | bg_test_path = 'data/bg_test/'
78 | # Path to folder where you want the composited images to go
79 | out_test_path = 'data/merged_test/'
80 |
81 | # test data gen
82 | test_folder = 'data/Combined_Dataset/Test_set/'
83 |
84 | if not os.path.exists(bg_test_path):
85 | os.makedirs(bg_test_path)
86 |
87 | tar_file = 'data/VOCtrainval_14-Jul-2008.tar'
88 | print('Extracting {}...'.format(tar_file))
89 |
90 | tar = tarfile.open(tar_file)
91 | tar.extractall('data')
92 | tar.close()
93 |
94 | tar_file = 'data/VOC2008test.tar'
95 | print('Extracting {}...'.format(tar_file))
96 |
97 | tar = tarfile.open(tar_file)
98 | tar.extractall('data')
99 | tar.close()
100 |
101 | with open(os.path.join(test_folder, 'test_bg_names.txt')) as f:
102 | test_bg_names = f.read().splitlines()
103 |
104 | for bg_name in test_bg_names:
105 | tokens = bg_name.split('_')
106 | src_path = os.path.join('data/VOCdevkit/VOC2008/JPEGImages', bg_name)
107 | dest_path = os.path.join(bg_test_path, bg_name)
108 | shutil.move(src_path, dest_path)
109 |
110 | if not os.path.exists(fg_test_path):
111 | os.makedirs(fg_test_path)
112 |
113 | for old_folder in [test_folder + 'Adobe-licensed images/fg']:
114 | fg_files = os.listdir(old_folder)
115 | for fg_file in fg_files:
116 | src_path = os.path.join(old_folder, fg_file)
117 | dest_path = os.path.join(fg_test_path, fg_file)
118 | shutil.move(src_path, dest_path)
119 |
120 | if not os.path.exists(a_test_path):
121 | os.makedirs(a_test_path)
122 |
123 | for old_folder in [test_folder + 'Adobe-licensed images/alpha']:
124 | a_files = os.listdir(old_folder)
125 | for a_file in a_files:
126 | src_path = os.path.join(old_folder, a_file)
127 | dest_path = os.path.join(a_test_path, a_file)
128 | shutil.move(src_path, dest_path)
129 |
130 | if not os.path.exists(out_test_path):
131 | os.makedirs(out_test_path)
132 |
133 | do_composite_test()
134 |
--------------------------------------------------------------------------------
/predit_single.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import cv2 as cv
3 | import keras.backend as K
4 | import numpy as np
5 |
6 | from model import build_encoder_decoder, build_refinement
7 | from utils import get_final_output, create_patches, patch_dims, assemble_patches
8 | import tensorflow as tf
9 | import time
10 |
11 | config = tf.ConfigProto(device_count = {"GPU": 1, "CPU": 1})
12 | sess = tf.Session(config=config)
13 | K.set_session(sess)
14 |
15 | if __name__ == '__main__':
16 | # load network
17 | PATCH_SIZE = 320
18 | PRETRAINED_PATH = 'models/final.42-0.0398.hdf5'
19 | TRIMAP_PATH = "images/trimap2.png"
20 | IMG_PATH = "images/frame2.png"
21 |
22 | encoder_decoder = build_encoder_decoder()
23 | final = build_refinement(encoder_decoder)
24 | final.load_weights(PRETRAINED_PATH)
25 | print(final.summary())
26 |
27 | # loading input files
28 | trimap = cv.imread(TRIMAP_PATH, cv.IMREAD_GRAYSCALE)
29 | img = cv.imread(IMG_PATH)
30 | result = np.zeros(trimap.shape, dtype=np.uint8)
31 |
32 | img_size = np.array(trimap.shape)
33 |
34 | # create patches
35 | x = np.dstack((img, np.expand_dims(trimap, axis=2))) / 255.
36 | patches = create_patches(x, PATCH_SIZE)
37 |
38 | # create mat for patches predictions
39 | patches_count = np.product(
40 | patch_dims(mat_size=trimap.shape, patch_size=PATCH_SIZE)
41 | )
42 | patches_predictions = np.zeros(shape=(patches_count, PATCH_SIZE, PATCH_SIZE))
43 |
44 | # predicting
45 | for i in range(patches.shape[0]):
46 | print("Predicting patches {}/{}".format(i + 1, patches_count))
47 |
48 | patch_prediction = final.predict(np.expand_dims(patches[i, :, :, :], axis=0))
49 | patches_predictions[i] = np.reshape(patch_prediction, (PATCH_SIZE, PATCH_SIZE)) * 255.
50 |
51 | # assemble
52 | result = assemble_patches(patches_predictions, trimap.shape, PATCH_SIZE)
53 | result = result[:img_size[0], :img_size[1]]
54 |
55 | prediction = get_final_output(result, trimap).astype(np.uint8)
56 |
57 | # save into files
58 | cv.imshow("result", prediction)
59 | cv.imshow("image", img)
60 | cv.waitKey(0)
61 |
62 | K.clear_session()
63 |
64 |
--------------------------------------------------------------------------------
/segnet.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | import tensorflow as tf
3 | from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Reshape, \
4 | Concatenate, Lambda
5 | from keras.models import Model
6 | from keras.utils import multi_gpu_model
7 | from keras.utils import plot_model
8 |
9 | from custom_layers.unpooling_layer import Unpooling
10 |
11 |
12 | def build_encoder_decoder():
13 | kernel = 3
14 |
15 | # Encoder
16 | #
17 | input_tensor = Input(shape=(320, 320, 4))
18 | x = ZeroPadding2D((1, 1))(input_tensor)
19 | x = Conv2D(64, (kernel, kernel), activation='relu', name='conv1_1')(x)
20 | x = ZeroPadding2D((1, 1))(x)
21 | x = Conv2D(64, (kernel, kernel), activation='relu', name='conv1_2')(x)
22 | orig_1 = x
23 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
24 |
25 | x = ZeroPadding2D((1, 1))(x)
26 | x = Conv2D(128, (kernel, kernel), activation='relu', name='conv2_1')(x)
27 | x = ZeroPadding2D((1, 1))(x)
28 | x = Conv2D(128, (kernel, kernel), activation='relu', name='conv2_2')(x)
29 | orig_2 = x
30 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
31 |
32 | x = ZeroPadding2D((1, 1))(x)
33 | x = Conv2D(256, (kernel, kernel), activation='relu', name='conv3_1')(x)
34 | x = ZeroPadding2D((1, 1))(x)
35 | x = Conv2D(256, (kernel, kernel), activation='relu', name='conv3_2')(x)
36 | x = ZeroPadding2D((1, 1))(x)
37 | x = Conv2D(256, (kernel, kernel), activation='relu', name='conv3_3')(x)
38 | orig_3 = x
39 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
40 |
41 | x = ZeroPadding2D((1, 1))(x)
42 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv4_1')(x)
43 | x = ZeroPadding2D((1, 1))(x)
44 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv4_2')(x)
45 | x = ZeroPadding2D((1, 1))(x)
46 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv4_3')(x)
47 | orig_4 = x
48 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
49 |
50 | x = ZeroPadding2D((1, 1))(x)
51 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv5_1')(x)
52 | x = ZeroPadding2D((1, 1))(x)
53 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv5_2')(x)
54 | x = ZeroPadding2D((1, 1))(x)
55 | x = Conv2D(512, (kernel, kernel), activation='relu', name='conv5_3')(x)
56 | orig_5 = x
57 | x = MaxPooling2D((2, 2), strides=(2, 2))(x)
58 |
59 | # Decoder
60 | #
61 | x = UpSampling2D(size=(2, 2))(x)
62 | the_shape = K.int_shape(orig_5)
63 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
64 | origReshaped = Reshape(shape)(orig_5)
65 | xReshaped = Reshape(shape)(x)
66 | together = Concatenate(axis=1)([origReshaped, xReshaped])
67 | x = Unpooling()(together)
68 | x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_1',
69 | kernel_initializer='he_normal',
70 | bias_initializer='zeros')(x)
71 | x = BatchNormalization()(x)
72 | x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_2',
73 | kernel_initializer='he_normal',
74 | bias_initializer='zeros')(x)
75 | x = BatchNormalization()(x)
76 | x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_3',
77 | kernel_initializer='he_normal',
78 | bias_initializer='zeros')(x)
79 | x = BatchNormalization()(x)
80 |
81 | x = UpSampling2D(size=(2, 2))(x)
82 | the_shape = K.int_shape(orig_4)
83 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
84 | origReshaped = Reshape(shape)(orig_4)
85 | xReshaped = Reshape(shape)(x)
86 | together = Concatenate(axis=1)([origReshaped, xReshaped])
87 | x = Unpooling()(together)
88 | x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_1',
89 | kernel_initializer='he_normal',
90 | bias_initializer='zeros')(x)
91 | x = BatchNormalization()(x)
92 | x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_2',
93 | kernel_initializer='he_normal',
94 | bias_initializer='zeros')(x)
95 | x = BatchNormalization()(x)
96 | x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_3',
97 | kernel_initializer='he_normal',
98 | bias_initializer='zeros')(x)
99 | x = BatchNormalization()(x)
100 |
101 | x = UpSampling2D(size=(2, 2))(x)
102 | the_shape = K.int_shape(orig_3)
103 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
104 | origReshaped = Reshape(shape)(orig_3)
105 | xReshaped = Reshape(shape)(x)
106 | together = Concatenate(axis=1)([origReshaped, xReshaped])
107 | x = Unpooling()(together)
108 | x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_1',
109 | kernel_initializer='he_normal',
110 | bias_initializer='zeros')(x)
111 | x = BatchNormalization()(x)
112 | x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_2',
113 | kernel_initializer='he_normal',
114 | bias_initializer='zeros')(x)
115 | x = BatchNormalization()(x)
116 | x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_3',
117 | kernel_initializer='he_normal',
118 | bias_initializer='zeros')(x)
119 | x = BatchNormalization()(x)
120 |
121 | x = UpSampling2D(size=(2, 2))(x)
122 | the_shape = K.int_shape(orig_2)
123 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
124 | origReshaped = Reshape(shape)(orig_2)
125 | xReshaped = Reshape(shape)(x)
126 | together = Concatenate(axis=1)([origReshaped, xReshaped])
127 | x = Unpooling()(together)
128 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_1',
129 | kernel_initializer='he_normal',
130 | bias_initializer='zeros')(x)
131 | x = BatchNormalization()(x)
132 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_2',
133 | kernel_initializer='he_normal',
134 | bias_initializer='zeros')(x)
135 | x = BatchNormalization()(x)
136 |
137 | x = UpSampling2D(size=(2, 2))(x)
138 | the_shape = K.int_shape(orig_1)
139 | shape = (1, the_shape[1], the_shape[2], the_shape[3])
140 | origReshaped = Reshape(shape)(orig_1)
141 | xReshaped = Reshape(shape)(x)
142 | together = Concatenate(axis=1)([origReshaped, xReshaped])
143 | x = Unpooling()(together)
144 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_1',
145 | kernel_initializer='he_normal',
146 | bias_initializer='zeros')(x)
147 | x = BatchNormalization()(x)
148 | x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_2',
149 | kernel_initializer='he_normal',
150 | bias_initializer='zeros')(x)
151 | x = BatchNormalization()(x)
152 |
153 | x = Conv2D(1, (kernel, kernel), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal',
154 | bias_initializer='zeros')(x)
155 |
156 | model = Model(inputs=input_tensor, outputs=x)
157 | return model
158 |
159 |
160 | def build_refinement(encoder_decoder):
161 | input_tensor = encoder_decoder.input
162 |
163 | input = Lambda(lambda i: i[:, :, :, 0:3])(input_tensor)
164 |
165 | x = Concatenate(axis=3)([input, encoder_decoder.output])
166 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
167 | bias_initializer='zeros')(x)
168 | x = BatchNormalization()(x)
169 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
170 | bias_initializer='zeros')(x)
171 | x = BatchNormalization()(x)
172 | x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
173 | bias_initializer='zeros')(x)
174 | x = BatchNormalization()(x)
175 | x = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='refinement_pred', kernel_initializer='he_normal',
176 | bias_initializer='zeros')(x)
177 |
178 | model = Model(inputs=input_tensor, outputs=x)
179 | return model
180 |
181 |
182 | if __name__ == '__main__':
183 | with tf.device("/cpu:0"):
184 | encoder_decoder = build_encoder_decoder()
185 | print(encoder_decoder.summary())
186 | plot_model(encoder_decoder, to_file='encoder_decoder.svg', show_layer_names=True, show_shapes=True)
187 |
188 | with tf.device("/cpu:0"):
189 | refinement = build_refinement(encoder_decoder)
190 | print(refinement.summary())
191 | plot_model(refinement, to_file='refinement.svg', show_layer_names=True, show_shapes=True)
192 |
193 | parallel_model = multi_gpu_model(refinement, gpus=None)
194 | print(parallel_model.summary())
195 | plot_model(parallel_model, to_file='parallel_model.svg', show_layer_names=True, show_shapes=True)
196 |
197 | K.clear_session()
198 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import cv2 as cv
4 | import numpy as np
5 |
6 | from model import build_encoder_decoder, build_refinement
7 | from utils import get_final_output
8 |
9 | # python test.py -i "images/image.png" -t "images/trimap.png"
10 | if __name__ == '__main__':
11 | img_rows, img_cols = 320, 320
12 | channel = 4
13 |
14 | model_weights_path = 'models/final.42-0.0398.hdf5'
15 | encoder_decoder = build_encoder_decoder()
16 | final = build_refinement(encoder_decoder)
17 | final.load_weights(model_weights_path)
18 | print(final.summary())
19 |
20 | ap = argparse.ArgumentParser()
21 | ap.add_argument("-i", "--image", help="path to the image file")
22 | ap.add_argument("-t", "--trimap", help="path to the trimap file")
23 | args = vars(ap.parse_args())
24 | image_path = args["image"]
25 | trimap_path = args["trimap"]
26 |
27 | if image_path is None:
28 | image_path = 'images/image.jpg'
29 | if trimap_path is None:
30 | trimap_path = 'images/trimap.jpg'
31 |
32 | print('Start processing image: {}'.format(image_path))
33 |
34 | x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
35 | bgr_img = cv.imread(image_path)
36 | trimap = cv.imread(trimap_path, 0)
37 |
38 | x_test[0, :, :, 0:3] = bgr_img / 255.
39 | x_test[0, :, :, 3] = trimap / 255.
40 |
41 | out = final.predict(x_test)
42 | out = np.reshape(out, (img_rows, img_cols))
43 | print(out.shape)
44 | out = out * 255.0
45 | out = get_final_output(out, trimap)
46 | out = out.astype(np.uint8)
47 | cv.imshow('out', out)
48 | cv.imwrite('images/out.png', out)
49 | cv.waitKey(0)
50 | cv.destroyAllWindows()
51 |
--------------------------------------------------------------------------------
/test_alphamatting.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2 as cv
4 | import numpy as np
5 |
6 | from model import build_encoder_decoder, build_refinement
7 |
8 | if __name__ == '__main__':
9 | pretrained_path = 'models/final.42-0.0398.hdf5'
10 | encoder_decoder = build_encoder_decoder()
11 | final = build_refinement(encoder_decoder)
12 | final.load_weights(pretrained_path)
13 | print(final.summary())
14 |
15 | images = [f for f in os.listdir('alphamatting/input_lowres') if f.endswith('.png')]
16 |
17 | for image_name in images:
18 | filename = os.path.join('alphamatting/input_lowres', image_name)
19 | im = cv.imread(filename)
20 | im_h, im_w = im.shape[:2]
21 |
22 | for id in [1, 2, 3]:
23 | trimap_name = os.path.join('alphamatting/trimap_lowres/Trimap{}'.format(id), image_name)
24 | trimap = cv.imread(trimap_name, 0)
25 |
26 | for i in range(0, np.ceil(im_h / 320)):
27 | for j in range(0, np.ceil(im_w / 320)):
28 | x = j * 320
29 | y = i * 320
30 | w = min(320, im_w - x)
31 | h = min(320, im_h - y)
32 | im_crop = im[y:y + h, x:x + w]
33 | tri_crop = trimap[y:y + h, x:x + w]
34 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import keras
4 | import tensorflow as tf
5 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
6 | from keras.utils import multi_gpu_model
7 |
8 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
9 | from data_generator import train_gen, valid_gen
10 | from migrate import migrate_model
11 | from segnet import build_encoder_decoder, build_refinement
12 | from utils import overall_loss, get_available_cpus, get_available_gpus
13 |
14 | if __name__ == '__main__':
15 | checkpoint_models_path = 'models/'
16 | # Parse arguments
17 | ap = argparse.ArgumentParser()
18 | ap.add_argument("-p", "--pretrained", help="path to save pretrained model files")
19 | args = vars(ap.parse_args())
20 | pretrained_path = args["pretrained"]
21 |
22 | # Callbacks
23 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
24 | model_names = checkpoint_models_path + 'final.{epoch:02d}-{val_loss:.4f}.hdf5'
25 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
26 | early_stop = EarlyStopping('val_loss', patience=patience)
27 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)
28 |
29 |
30 | class MyCbk(keras.callbacks.Callback):
31 | def __init__(self, model):
32 | keras.callbacks.Callback.__init__(self)
33 | self.model_to_save = model
34 |
35 | def on_epoch_end(self, epoch, logs=None):
36 | fmt = checkpoint_models_path + 'final.%02d-%.4f.hdf5'
37 | self.model_to_save.save(fmt % (epoch, logs['val_loss']))
38 |
39 |
40 | # Load our model, added support for Multi-GPUs
41 | num_gpu = len(get_available_gpus())
42 | if num_gpu >= 2:
43 | with tf.device("/cpu:0"):
44 | model = build_encoder_decoder()
45 | model = build_refinement(model)
46 | if pretrained_path is not None:
47 | model.load_weights(pretrained_path)
48 | else:
49 | migrate_model(model)
50 |
51 | final = multi_gpu_model(model, gpus=num_gpu)
52 | # rewrite the callback: saving through the original model and not the multi-gpu model.
53 | model_checkpoint = MyCbk(model)
54 | else:
55 | model = build_encoder_decoder()
56 | final = build_refinement(model)
57 | if pretrained_path is not None:
58 | final.load_weights(pretrained_path)
59 | else:
60 | migrate_model(final)
61 |
62 | decoder_target = tf.placeholder(dtype='float32', shape=(None, None, None, None))
63 | final.compile(optimizer='nadam', loss=overall_loss, target_tensors=[decoder_target])
64 |
65 | print(final.summary())
66 |
67 | # Final callbacks
68 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr]
69 |
70 | # Start Fine-tuning
71 | final.fit_generator(train_gen(),
72 | steps_per_epoch=num_train_samples // batch_size,
73 | validation_data=valid_gen(),
74 | validation_steps=num_valid_samples // batch_size,
75 | epochs=epochs,
76 | verbose=1,
77 | callbacks=callbacks,
78 | use_multiprocessing=True,
79 | workers=2
80 | )
81 |
--------------------------------------------------------------------------------
/train_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import keras
4 | import tensorflow as tf
5 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
6 | from keras.utils import multi_gpu_model
7 |
8 | import migrate
9 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
10 | from data_generator import train_gen, valid_gen
11 | from model import build_encoder_decoder
12 | from utils import overall_loss, get_available_cpus, get_available_gpus
13 |
14 | if __name__ == '__main__':
15 | # Parse arguments
16 | ap = argparse.ArgumentParser()
17 | ap.add_argument("-c", "--checkpoint", help="path to save checkpoint model files")
18 | ap.add_argument("-p", "--pretrained", help="path to save pretrained model files")
19 | args = vars(ap.parse_args())
20 | checkpoint_path = args["checkpoint"]
21 | pretrained_path = args["pretrained"]
22 | if checkpoint_path is None:
23 | checkpoint_models_path = 'models/'
24 | else:
25 | # python train_encoder_decoder.py -c /mnt/Deep-Image-Matting/models/
26 | checkpoint_models_path = '{}/'.format(checkpoint_path)
27 |
28 | # Callbacks
29 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
30 | model_names = checkpoint_models_path + 'model.{epoch:02d}-{val_loss:.4f}.hdf5'
31 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
32 | early_stop = EarlyStopping('val_loss', patience=patience)
33 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)
34 |
35 | class MyCbk(keras.callbacks.Callback):
36 | def __init__(self, model):
37 | keras.callbacks.Callback.__init__(self)
38 | self.model_to_save = model
39 |
40 | def on_epoch_end(self, epoch, logs=None):
41 | fmt = checkpoint_models_path + 'model.%02d-%.4f.hdf5'
42 | self.model_to_save.save(fmt % (epoch, logs['val_loss']))
43 |
44 | # Load our model, added support for Multi-GPUs
45 | num_gpu = len(get_available_gpus())
46 | if num_gpu >= 2:
47 | with tf.device("/cpu:0"):
48 | if pretrained_path is not None:
49 | model = build_encoder_decoder()
50 | model.load_weights(pretrained_path)
51 | else:
52 | model = build_encoder_decoder()
53 | migrate.migrate_model(model)
54 |
55 | new_model = multi_gpu_model(model, gpus=num_gpu)
56 | # rewrite the callback: saving through the original model and not the multi-gpu model.
57 | model_checkpoint = MyCbk(model)
58 | else:
59 | if pretrained_path is not None:
60 | new_model = build_encoder_decoder()
61 | new_model.load_weights(pretrained_path)
62 | else:
63 | new_model = build_encoder_decoder()
64 | migrate.migrate_model(new_model)
65 |
66 | # sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)
67 | new_model.compile(optimizer='nadam', loss=overall_loss)
68 |
69 | print(new_model.summary())
70 |
71 | # Summarize then go!
72 | num_cpu = get_available_cpus()
73 | workers = int(round(num_cpu / 2))
74 | print('num_gpu={}\nnum_cpu={}\nworkers={}\ntrained_models_path={}.'.format(num_gpu, num_cpu, workers,
75 | checkpoint_models_path))
76 |
77 | # Final callbacks
78 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr]
79 |
80 | # Start Fine-tuning
81 | new_model.fit_generator(train_gen(),
82 | steps_per_epoch=num_train_samples // batch_size,
83 | validation_data=valid_gen(),
84 | validation_steps=num_valid_samples // batch_size,
85 | epochs=epochs,
86 | verbose=1,
87 | callbacks=callbacks,
88 | use_multiprocessing=True,
89 | workers=workers
90 | )
91 |
--------------------------------------------------------------------------------
/train_final.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import keras
4 | import tensorflow as tf
5 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
6 | from keras.utils import multi_gpu_model
7 |
8 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
9 | from data_generator import train_gen, valid_gen
10 | from model import build_encoder_decoder, build_refinement
11 | from utils import alpha_prediction_loss, get_available_cpus, get_available_gpus
12 |
13 | if __name__ == '__main__':
14 | checkpoint_models_path = 'models/'
15 | # Parse arguments
16 | ap = argparse.ArgumentParser()
17 | ap.add_argument("-p", "--pretrained", help="path to save pretrained model files")
18 | args = vars(ap.parse_args())
19 | pretrained_path = args["pretrained"]
20 |
21 | # Callbacks
22 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
23 | model_names = checkpoint_models_path + 'final.{epoch:02d}-{val_loss:.4f}.hdf5'
24 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
25 | early_stop = EarlyStopping('val_loss', patience=patience)
26 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)
27 |
28 |
29 | class MyCbk(keras.callbacks.Callback):
30 | def __init__(self, model):
31 | keras.callbacks.Callback.__init__(self)
32 | self.model_to_save = model
33 |
34 | def on_epoch_end(self, epoch, logs=None):
35 | fmt = checkpoint_models_path + 'model.%02d-%.4f.hdf5'
36 | self.model_to_save.save(fmt % (epoch, logs['val_loss']))
37 |
38 | num_gpu = len(get_available_gpus())
39 | if num_gpu >= 2:
40 | with tf.device("/cpu:0"):
41 | # Load our model, added support for Multi-GPUs
42 | model = build_encoder_decoder()
43 | model = build_refinement(model)
44 | model.load_weights(pretrained_path)
45 |
46 | final = multi_gpu_model(model, gpus=num_gpu)
47 | # rewrite the callback: saving through the original model and not the multi-gpu model.
48 | model_checkpoint = MyCbk(model)
49 | else:
50 | model = build_encoder_decoder()
51 | final = build_refinement(model)
52 | final.load_weights(pretrained_path)
53 |
54 | # finetune the whole network together.
55 | for layer in final.layers:
56 | layer.trainable = True
57 |
58 | sgd = keras.optimizers.SGD(lr=1e-5, decay=1e-6, momentum=0.9, nesterov=True)
59 | # nadam = keras.optimizers.Nadam(lr=2e-5)
60 | decoder_target = tf.placeholder(dtype='float32', shape=(None, None, None, None))
61 | final.compile(optimizer=sgd, loss=alpha_prediction_loss, target_tensors=[decoder_target])
62 |
63 | print(final.summary())
64 |
65 | # Summarize then go!
66 | num_cpu = get_available_cpus()
67 | workers = int(round(num_cpu / 2))
68 |
69 | # Final callbacks
70 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr]
71 |
72 | # Start Fine-tuning
73 | final.fit_generator(train_gen(),
74 | steps_per_epoch=num_train_samples // batch_size,
75 | validation_data=valid_gen(),
76 | validation_steps=num_valid_samples // batch_size,
77 | epochs=epochs,
78 | verbose=1,
79 | callbacks=callbacks,
80 | use_multiprocessing=True,
81 | workers=workers
82 | )
83 |
--------------------------------------------------------------------------------
/train_refinement.py:
--------------------------------------------------------------------------------
1 | import keras
2 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
3 | from keras.optimizers import SGD
4 |
5 | from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
6 | from data_generator import train_gen, valid_gen
7 | from model import build_encoder_decoder, build_refinement
8 | from utils import custom_loss_wrapper, get_available_cpus
9 |
10 | if __name__ == '__main__':
11 | checkpoint_models_path = 'models/'
12 |
13 | # Callbacks
14 | tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
15 | model_names = checkpoint_models_path + 'refinement.{epoch:02d}-{val_loss:.4f}.hdf5'
16 | model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
17 | early_stop = EarlyStopping('val_loss', patience=patience)
18 | reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)
19 |
20 | pretrained_path = 'models/model.98-0.0459.hdf5'
21 | encoder_decoder = build_encoder_decoder()
22 | encoder_decoder.load_weights(pretrained_path)
23 | # fix encoder-decoder part parameters and then update the refinement part.
24 | for layer in encoder_decoder.layers:
25 | layer.trainable = False
26 |
27 | refinement = build_refinement(encoder_decoder)
28 |
29 | # sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)
30 | refinement.compile(optimizer='nadam', loss=custom_loss_wrapper(refinement.input))
31 |
32 | print(refinement.summary())
33 |
34 | # Summarize then go!
35 | num_cpu = get_available_cpus()
36 | workers = int(round(num_cpu / 2))
37 |
38 | # Final callbacks
39 | callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr]
40 |
41 | # Start Fine-tuning
42 | refinement.fit_generator(train_gen(),
43 | steps_per_epoch=num_train_samples // batch_size,
44 | validation_data=valid_gen(),
45 | validation_steps=num_valid_samples // batch_size,
46 | epochs=epochs,
47 | verbose=1,
48 | callbacks=callbacks,
49 | use_multiprocessing=True,
50 | workers=workers
51 | )
52 |
--------------------------------------------------------------------------------
/unit_tests.py:
--------------------------------------------------------------------------------
1 | import random
2 | import unittest
3 |
4 | import cv2 as cv
5 | import numpy as np
6 | import os
7 | from config import unknown_code
8 | from data_generator import generate_trimap
9 | from data_generator import get_alpha
10 | from data_generator import random_choice
11 | from utils import safe_crop
12 |
13 |
14 | class TestStringMethods(unittest.TestCase):
15 |
16 | def test_generate_trimap(self):
17 | image = cv.imread('fg/1-1252426161dfXY.jpg')
18 | alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
19 | trimap = generate_trimap(alpha)
20 | self.assertEqual(trimap.shape, (615, 410))
21 |
22 | # ensure np.where works as expected.
23 | count = 0
24 | h, w = trimap.shape[:2]
25 | for i in range(h):
26 | for j in range(w):
27 | if trimap[i, j] == unknown_code:
28 | count += 1
29 | x_indices, y_indices = np.where(trimap == unknown_code)
30 | num_unknowns = len(x_indices)
31 | self.assertEqual(count, num_unknowns)
32 |
33 | # ensure an unknown pixel is chosen
34 | ix = random.choice(range(num_unknowns))
35 | center_x = x_indices[ix]
36 | center_y = y_indices[ix]
37 |
38 | self.assertEqual(trimap[center_x, center_y], unknown_code)
39 |
40 | x, y = random_choice(trimap)
41 | # print(x, y)
42 | image = safe_crop(image, x, y)
43 | trimap = safe_crop(trimap, x, y)
44 | alpha = safe_crop(alpha, x, y)
45 | cv.imwrite('temp/test_generate_trimap_image.png', image)
46 | cv.imwrite('temp/test_generate_trimap_trimap.png', trimap)
47 | cv.imwrite('temp/test_generate_trimap_alpha.png', alpha)
48 |
49 | def test_flip(self):
50 | image = cv.imread('fg/1-1252426161dfXY.jpg')
51 | # print(image.shape)
52 | alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
53 | trimap = generate_trimap(alpha)
54 | x, y = random_choice(trimap)
55 | image = safe_crop(image, x, y)
56 | trimap = safe_crop(trimap, x, y)
57 | alpha = safe_crop(alpha, x, y)
58 | image = np.fliplr(image)
59 | trimap = np.fliplr(trimap)
60 | alpha = np.fliplr(alpha)
61 | cv.imwrite('temp/test_flip_image.png', image)
62 | cv.imwrite('temp/test_flip_trimap.png', trimap)
63 | cv.imwrite('temp/test_flip_alpha.png', alpha)
64 |
65 | def test_different_sizes(self):
66 | different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
67 | crop_size = random.choice(different_sizes)
68 | # print('crop_size=' + str(crop_size))
69 |
70 | def test_resize(self):
71 | name = '0_0.png'
72 | filename = os.path.join('merged', name)
73 | image = cv.imread(filename)
74 | bg_h, bg_w = image.shape[:2]
75 | a = get_alpha(name)
76 | a_h, a_w = a.shape[:2]
77 | alpha = np.zeros((bg_h, bg_w), np.float32)
78 | alpha[0:a_h, 0:a_w] = a
79 | trimap = generate_trimap(alpha)
80 | # 剪切尺寸 320:640:480 = 3:1:1
81 | crop_size = (480, 480)
82 | x, y = random_choice(trimap, crop_size)
83 | image = safe_crop(image, x, y, crop_size)
84 | trimap = safe_crop(trimap, x, y, crop_size)
85 | alpha = safe_crop(alpha, x, y, crop_size)
86 | cv.imwrite('temp/test_resize_image.png', image)
87 | cv.imwrite('temp/test_resize_trimap.png', trimap)
88 | cv.imwrite('temp/test_resize_alpha.png', alpha)
89 |
90 |
91 | if __name__ == '__main__':
92 | unittest.main()
93 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import math
3 | import cv2 as cv
4 | import keras.backend as K
5 | import numpy as np
6 | from tensorflow.python.client import device_lib
7 |
8 | from config import epsilon, epsilon_sqr
9 | from config import img_cols
10 | from config import img_rows
11 | from config import unknown_code
12 |
13 |
14 | # overall loss: weighted summation of the two individual losses.
15 | #
16 | def overall_loss(y_true, y_pred):
17 | w_l = 0.5
18 | return w_l * alpha_prediction_loss(y_true, y_pred) + (1 - w_l) * compositional_loss(y_true, y_pred)
19 |
20 |
21 | # alpha prediction loss: the abosolute difference between the ground truth alpha values and the
22 | # predicted alpha values at each pixel. However, due to the non-differentiable property of
23 | # absolute values, we use the following loss function to approximate it.
24 | def alpha_prediction_loss(y_true, y_pred):
25 | mask = y_true[:, :, :, 1]
26 | diff = y_pred[:, :, :, 0] - y_true[:, :, :, 0]
27 | diff = diff * mask
28 | num_pixels = K.sum(mask)
29 | return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon)
30 |
31 |
32 | # compositional loss: the aboslute difference between the ground truth RGB colors and the predicted
33 | # RGB colors composited by the ground truth foreground, the ground truth background and the predicted
34 | # alpha mattes.
35 | def compositional_loss(y_true, y_pred):
36 | mask = y_true[:, :, :, 1]
37 | mask = K.reshape(mask, (-1, img_rows, img_cols, 1))
38 | image = y_true[:, :, :, 2:5]
39 | fg = y_true[:, :, :, 5:8]
40 | bg = y_true[:, :, :, 8:11]
41 | c_g = image
42 | c_p = y_pred * fg + (1.0 - y_pred) * bg
43 | diff = c_p - c_g
44 | diff = diff * mask
45 | num_pixels = K.sum(mask)
46 | return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon)
47 |
48 |
49 | # compute the MSE error given a prediction, a ground truth and a trimap.
50 | # pred: the predicted alpha matte
51 | # target: the ground truth alpha matte
52 | # trimap: the given trimap
53 | #
54 | def compute_mse_loss(pred, target, trimap):
55 | error_map = (pred - target) / 255.
56 | mask = np.equal(trimap, unknown_code).astype(np.float32)
57 | # print('unknown: ' + str(unknown))
58 | loss = np.sum(np.square(error_map) * mask) / np.sum(mask)
59 | # print('mse_loss: ' + str(loss))
60 | return loss
61 |
62 |
63 | # compute the SAD error given a prediction, a ground truth and a trimap.
64 | #
65 | def compute_sad_loss(pred, target, trimap):
66 | error_map = np.abs(pred - target) / 255.
67 | mask = np.equal(trimap, unknown_code).astype(np.float32)
68 | loss = np.sum(error_map * mask)
69 |
70 | # the loss is scaled by 1000 due to the large images used in our experiment.
71 | loss = loss / 1000
72 | # print('sad_loss: ' + str(loss))
73 | return loss
74 |
75 |
76 | # getting the number of GPUs
77 | def get_available_gpus():
78 | local_device_protos = device_lib.list_local_devices()
79 | return [x.name for x in local_device_protos if x.device_type == 'GPU']
80 |
81 |
82 | # getting the number of CPUs
83 | def get_available_cpus():
84 | return multiprocessing.cpu_count()
85 |
86 |
87 | def get_final_output(out, trimap):
88 | mask = np.equal(trimap, unknown_code).astype(np.float32)
89 | return (1 - mask) * trimap + mask * out
90 |
91 | def patch_dims(mat_size, patch_size):
92 | return np.ceil(np.array(mat_size) / patch_size).astype(int)
93 |
94 | def create_patches(mat, patch_size):
95 | mat_size = mat.shape
96 | assert len(mat_size) == 3, "Input mat need to have 4 channels (R, G, B, trimap)"
97 | assert mat_size[-1] == 4 , "Input mat need to have 4 channels (R, G, B, trimap)"
98 |
99 | patches_dim = patch_dims(mat_size=mat_size[:2], patch_size=patch_size)
100 | patches_count = np.product(patches_dim)
101 |
102 | patches = np.zeros(shape=(patches_count, patch_size, patch_size, 4), dtype=np.float32)
103 | for y in range(patches_dim[0]):
104 | y_start = y * patch_size
105 | for x in range(patches_dim[1]):
106 | x_start = x * patch_size
107 |
108 | # extract patch from input mat
109 | single_patch = mat[y_start: y_start + patch_size, x_start: x_start + patch_size, :]
110 |
111 | # zero pad patch in bottom and right side if real patch size is smaller than patch size
112 | real_patch_h, real_patch_w = single_patch.shape[:2]
113 | patch_id = y + x * patches_dim[0]
114 | patches[patch_id, :real_patch_h, :real_patch_w, :] = single_patch
115 |
116 | return patches
117 |
118 | def assemble_patches(pred_patches, mat_size, patch_size):
119 | patch_dim_h, patch_dim_w = patch_dims(mat_size=mat_size, patch_size=patch_size)
120 | result = np.zeros(shape=(patch_size * patch_dim_h, patch_size * patch_dim_w), dtype=np.uint8)
121 | patches_count = pred_patches.shape[0]
122 |
123 | for i in range(patches_count):
124 | y = (i % patch_dim_h) * patch_size
125 | x = int(math.floor(i / patch_dim_h)) * patch_size
126 |
127 | result[y:y+patch_size, x:x+patch_size] = pred_patches[i]
128 |
129 | return result
130 |
131 | def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)):
132 | crop_height, crop_width = crop_size
133 | if len(mat.shape) == 2:
134 | ret = np.zeros((crop_height, crop_width), np.float32)
135 | else:
136 | ret = np.zeros((crop_height, crop_width, 3), np.float32)
137 | crop = mat[y:y + crop_height, x:x + crop_width]
138 | h, w = crop.shape[:2]
139 | ret[0:h, 0:w] = crop
140 | if crop_size != (img_rows, img_cols):
141 | ret = cv.resize(ret, dsize=(img_rows, img_cols), interpolation=cv.INTER_NEAREST)
142 | return ret
143 |
144 |
145 | def draw_str(dst, target, s):
146 | x, y = target
147 | cv.putText(dst, s, (x + 1, y + 1), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness=2, lineType=cv.LINE_AA)
148 | cv.putText(dst, s, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv.LINE_AA)
149 |
--------------------------------------------------------------------------------
/vgg16.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import keras.backend as K
4 | from keras.layers import Conv2D, ZeroPadding2D, MaxPooling2D
5 | from keras.layers import Dense, Dropout, Flatten
6 | from keras.models import Sequential
7 |
8 |
9 | def vgg16_model(img_rows, img_cols, channel=3):
10 | model = Sequential()
11 | # Encoder
12 | model.add(ZeroPadding2D((1, 1), input_shape=(img_rows, img_cols, channel), name='input'))
13 | model.add(Conv2D(64, (3, 3), activation='relu', name='conv1_1'))
14 | model.add(ZeroPadding2D((1, 1)))
15 | model.add(Conv2D(64, (3, 3), activation='relu', name='conv1_2'))
16 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
17 |
18 | model.add(ZeroPadding2D((1, 1)))
19 | model.add(Conv2D(128, (3, 3), activation='relu', name='conv2_1'))
20 | model.add(ZeroPadding2D((1, 1)))
21 | model.add(Conv2D(128, (3, 3), activation='relu', name='conv2_2'))
22 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
23 |
24 | model.add(ZeroPadding2D((1, 1)))
25 | model.add(Conv2D(256, (3, 3), activation='relu', name='conv3_1'))
26 | model.add(ZeroPadding2D((1, 1)))
27 | model.add(Conv2D(256, (3, 3), activation='relu', name='conv3_2'))
28 | model.add(ZeroPadding2D((1, 1)))
29 | model.add(Conv2D(256, (3, 3), activation='relu', name='conv3_3'))
30 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
31 |
32 | model.add(ZeroPadding2D((1, 1)))
33 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv4_1'))
34 | model.add(ZeroPadding2D((1, 1)))
35 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv4_2'))
36 | model.add(ZeroPadding2D((1, 1)))
37 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv4_3'))
38 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
39 |
40 | model.add(ZeroPadding2D((1, 1)))
41 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv5_1'))
42 | model.add(ZeroPadding2D((1, 1)))
43 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv5_2'))
44 | model.add(ZeroPadding2D((1, 1)))
45 | model.add(Conv2D(512, (3, 3), activation='relu', name='conv5_3'))
46 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
47 |
48 | # Add Fully Connected Layer
49 | model.add(Flatten(name='flatten'))
50 | model.add(Dense(4096, activation='relu', name='dense1'))
51 | model.add(Dropout(0.5))
52 | model.add(Dense(4096, activation='relu', name='dense2'))
53 | model.add(Dropout(0.5))
54 | model.add(Dense(1000, activation='softmax', name='softmax'))
55 |
56 | # Loads ImageNet pre-trained data
57 | weights_path = 'models/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
58 | model.load_weights(weights_path)
59 |
60 | return model
61 |
62 |
63 | if __name__ == '__main__':
64 | model = vgg16_model(224, 224, 3)
65 | # input_layer = model.get_layer('input')
66 | print(model.summary())
67 |
68 | K.clear_session()
69 |
--------------------------------------------------------------------------------