├── README.md
├── data
├── 0000-AIT.png
├── 0000.jpg
├── DIM_list.txt
├── bg_list.txt
├── data_util.py
├── dataset.py
└── gen_trimap.py
├── model
├── M_Net.py
├── T_Net_psp.py
├── extractors.py
└── network.py
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch_semantic_human_matting
2 | This is an unofficial implementation of the paper "Semantic human matting".
3 |
4 | # testing environment:
5 | Ubuntu 16.04
6 |
7 | Pytorch 0.4.1
8 |
9 | # data preparation
10 |
11 | To make our life easier, there are only two types of data:
12 | 1. RGBA-png format: which means the image has no background (removed already),
13 | you can generate such an image with the function 'get_rgba' from ./data/data_util.py.
14 | 2. Backround images: which can fetched from coco datasets or anywhere else (e.g. internet), and they will be used for randomly compositing new images on-the-fly with foreground images separated from 1 (RGBA-PNG images), as described in paper.
15 |
16 | For example:
17 |
18 | For 1: I used Adobe Deep Image Matting datasets, etc.; I composite alpha and foreground images togher to get my RGBA-png format images.
19 |
20 | For 2: I used coco datasets and some images crawled from internet.
21 |
22 | When having those above two types of data, then generate lists of training files containing the full path of training images,
23 | such as 'DIM_list.txt', 'bg_list.txt' in my case. Specifically, for flags:
24 |
25 | --fgLists: a list, contains list files in which all images share the same fg-bg ratio, e.g. ['DIM.txt','SHM.txt']
26 |
27 | --bg_list: a txt file, contains all bg images for composition needs, e.g. 'bg_list.txt'.
28 |
29 | --dataRatio: a list, contains bg-fb ratio for each list file in fgLists. For example, similar to the paper,
30 | given [100, 1], we composite 100 images for each fore-ground image in 'DIM.txt' and 1 image for each fg in 'SHM.txt'
31 |
32 | # Implementation details
33 | The training model is completely implemented as described as in the paper, details are as follows:
34 | * T-net: PSP-50 is deployed for training trimap generation; input is image (3 channels) and output is trimap (one channel);
35 |
36 | * M-net: 13 convolutional layers and 4 max-pooling layers with the same hyper-parameters for VGG-16 are used as encoder, and 6 convolutional layers and 4 unpooling layers are used as decoder; input is image and trimap (6 channels) and output is alpha image (1 channel);
37 |
38 | * Fusion: the fusion loss functions are implemented as described in paper;
39 |
40 | * **_This model is flexible for inferencing any size of images when well trained._**
41 |
42 | # How to run the code
43 | ## pre_trian T_net
44 | python train.py --patch_size=400 --train_phase=pre_train_t_net
45 |
46 | optional: --continue_train
47 |
48 | ## pre_train M_net
49 | python train.py --patch_size=320 --train_phase=pre_train_m_net
50 |
51 | optional: --continue_train
52 |
53 | ## end to end training
54 | python train.py --patch_size=800 --pretrain --train_phase=end_to_end
55 |
56 | optional: --continue_train
57 |
58 | note:
59 | 1. the end to end train process is really time-consuming.
60 | 2. I tried to implement the crop-on-the-fly trick for m-net inputs as described in the original paper,
61 | but the training process seemed to be very slow and not stable. So the same input size is used for both
62 | nets through the end to end training.
63 |
64 | ## testing
65 | python test.py --train_phase=pre_train_t_net/pre_train_m_net/end_to_end
66 |
67 | # Results
68 | Note: the following result is produced by T-net & M-net together, as I haven't complete end to end phase training yet.
69 |
70 | Original image from the Internet:
71 |
72 |
73 |
74 | Output image produced by the SHM:
75 |
76 |
77 |
78 | # Observations
79 | 1. The performance of T-net is essential for the whole process
80 | 2. Training trimap data (generated from alpha image) is essential for T-net training, which means the trimap should
81 | be in high-quality (actually alpha should have high-quality) and clear enough
82 | 3. The end to end training is rather slow, and may be not stable, especially when T-net is not robust
83 | (not powerful and stable). So I haven't complted the end to end phase training. Even though, the result seemed satisfying.
84 |
85 | Leave your comments if you have any other observations and suggestions.
86 |
--------------------------------------------------------------------------------
/data/0000-AIT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsing90/pytorch_semantic_human_matting/9cbb85b3ccf1c4a77fd4e2b237e5f658a3d6501b/data/0000-AIT.png
--------------------------------------------------------------------------------
/data/0000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsing90/pytorch_semantic_human_matting/9cbb85b3ccf1c4a77fd4e2b237e5f658a3d6501b/data/0000.jpg
--------------------------------------------------------------------------------
/data/DIM_list.txt:
--------------------------------------------------------------------------------
1 | /home/kevin/git/matting_data/DIM/img_png/1-1252426161dfXY.png
2 | /home/kevin/git/matting_data/DIM/img_png/1-1255621189mTnS.png
3 | /home/kevin/git/matting_data/DIM/img_png/1-1259162624NMFK.png
4 | /home/kevin/git/matting_data/DIM/img_png/1-1259245823Un3j.png
5 | /home/kevin/git/matting_data/DIM/img_png/10743257206_18e7f44f2e_b.png
6 | /home/kevin/git/matting_data/DIM/img_png/10845279884_d2d4c7b4d1_b.png
7 | /home/kevin/git/matting_data/DIM/img_png/11363165393_05d7a21d76_b.png
8 | /home/kevin/git/matting_data/DIM/img_png/13564741125_753939e9ce_o.png
9 | /home/kevin/git/matting_data/DIM/img_png/14731860273_5b40b19b51_o.png
10 | /home/kevin/git/matting_data/DIM/img_png/16087-a-young-woman-showing-a-bitten-green-apple-pv.png
11 | /home/kevin/git/matting_data/DIM/img_png/1609484818_b9bb12b.png
12 | /home/kevin/git/matting_data/DIM/img_png/17620-a-beautiful-woman-in-a-bikini-pv.png
13 | /home/kevin/git/matting_data/DIM/img_png/20672673163_20c8467827_b.png
14 | /home/kevin/git/matting_data/DIM/img_png/3262986095_2d5afe583c_b.png
15 | /home/kevin/git/matting_data/DIM/img_png/3588101233_f91aa5e3a3.png
16 | /home/kevin/git/matting_data/DIM/img_png/3858897226_cae5b75963_o.png
17 | /home/kevin/git/matting_data/DIM/img_png/474px-Jerry_Moran,_official_portrait,_112th_Congress_headshot.png
18 | /home/kevin/git/matting_data/DIM/img_png/4981835627_c4e6c4ffa8_o.png
19 | /home/kevin/git/matting_data/DIM/img_png/5025666458_576b974455_o.png
20 | /home/kevin/git/matting_data/DIM/img_png/5149410930_3a943dc43f_b.png
21 | /home/kevin/git/matting_data/DIM/img_png/5892503248_4b882863c7_o.png
22 | /home/kevin/git/matting_data/DIM/img_png/7669262460_e4be408343_b.png
23 | /home/kevin/git/matting_data/DIM/img_png/8244818049_dfa59a3eb8_b.png
24 | /home/kevin/git/matting_data/DIM/img_png/8688417335_01f3bafbe5_o.png
25 | /home/kevin/git/matting_data/DIM/img_png/9434599749_e7ccfc7812_b.png
26 | /home/kevin/git/matting_data/DIM/img_png/Aaron_Friedman_Headshot.png
27 | /home/kevin/git/matting_data/DIM/img_png/GT03.png
28 | /home/kevin/git/matting_data/DIM/img_png/GT08.png
29 | /home/kevin/git/matting_data/DIM/img_png/Girl_in_front_of_a_green_background.png
30 | /home/kevin/git/matting_data/DIM/img_png/MFettes-headshot.png
31 | /home/kevin/git/matting_data/DIM/img_png/Model_in_green_dress_3.png
32 | /home/kevin/git/matting_data/DIM/img_png/Modern_shingle_bob_haircut.png
33 | /home/kevin/git/matting_data/DIM/img_png/Motivate_(Fitness_model).png
34 | /home/kevin/git/matting_data/DIM/img_png/Official_portrait_of_Barack_Obama.png
35 | /home/kevin/git/matting_data/DIM/img_png/Professor_Steven_Chu_ForMemRS_headshot.png
36 | /home/kevin/git/matting_data/DIM/img_png/Wild_hair.png
37 | /home/kevin/git/matting_data/DIM/img_png/Woman_in_white_shirt_on_August_2009_02.png
38 | /home/kevin/git/matting_data/DIM/img_png/a-single-person-1084191_960_720.png
39 | /home/kevin/git/matting_data/DIM/img_png/apple-841169_960_720.png
40 | /home/kevin/git/matting_data/DIM/img_png/archeology_00000.png
41 | /home/kevin/git/matting_data/DIM/img_png/archeology_00005.png
42 | /home/kevin/git/matting_data/DIM/img_png/archeology_00010.png
43 | /home/kevin/git/matting_data/DIM/img_png/archeology_00015.png
44 | /home/kevin/git/matting_data/DIM/img_png/archeology_00020.png
45 | /home/kevin/git/matting_data/DIM/img_png/archeology_00025.png
46 | /home/kevin/git/matting_data/DIM/img_png/archeology_00030.png
47 | /home/kevin/git/matting_data/DIM/img_png/archeology_00035.png
48 | /home/kevin/git/matting_data/DIM/img_png/archeology_00040.png
49 | /home/kevin/git/matting_data/DIM/img_png/archeology_00045.png
50 | /home/kevin/git/matting_data/DIM/img_png/archeology_00050.png
51 | /home/kevin/git/matting_data/DIM/img_png/archeology_00055.png
52 | /home/kevin/git/matting_data/DIM/img_png/archeology_00060.png
53 | /home/kevin/git/matting_data/DIM/img_png/archeology_00065.png
54 | /home/kevin/git/matting_data/DIM/img_png/archeology_00070.png
55 | /home/kevin/git/matting_data/DIM/img_png/archeology_00075.png
56 | /home/kevin/git/matting_data/DIM/img_png/archeology_00080.png
57 | /home/kevin/git/matting_data/DIM/img_png/archeology_00085.png
58 | /home/kevin/git/matting_data/DIM/img_png/archeology_00090.png
59 | /home/kevin/git/matting_data/DIM/img_png/archeology_00095.png
60 | /home/kevin/git/matting_data/DIM/img_png/archeology_00100.png
61 | /home/kevin/git/matting_data/DIM/img_png/archeology_00105.png
62 | /home/kevin/git/matting_data/DIM/img_png/archeology_00110.png
63 | /home/kevin/git/matting_data/DIM/img_png/archeology_00115.png
64 | /home/kevin/git/matting_data/DIM/img_png/archeology_00120.png
65 | /home/kevin/git/matting_data/DIM/img_png/archeology_00125.png
66 | /home/kevin/git/matting_data/DIM/img_png/archeology_00130.png
67 | /home/kevin/git/matting_data/DIM/img_png/archeology_00135.png
68 | /home/kevin/git/matting_data/DIM/img_png/archeology_00140.png
69 | /home/kevin/git/matting_data/DIM/img_png/archeology_00145.png
70 | /home/kevin/git/matting_data/DIM/img_png/archeology_00150.png
71 | /home/kevin/git/matting_data/DIM/img_png/archeology_00155.png
72 | /home/kevin/git/matting_data/DIM/img_png/archeology_00160.png
73 | /home/kevin/git/matting_data/DIM/img_png/archeology_00165.png
74 | /home/kevin/git/matting_data/DIM/img_png/archeology_00170.png
75 | /home/kevin/git/matting_data/DIM/img_png/archeology_00175.png
76 | /home/kevin/git/matting_data/DIM/img_png/archeology_00180.png
77 | /home/kevin/git/matting_data/DIM/img_png/archeology_00185.png
78 | /home/kevin/git/matting_data/DIM/img_png/archeology_00190.png
79 | /home/kevin/git/matting_data/DIM/img_png/archeology_00195.png
80 | /home/kevin/git/matting_data/DIM/img_png/arrgh___r___28_by_mjranum_stock.png
81 | /home/kevin/git/matting_data/DIM/img_png/arrgh___r___29_by_mjranum_stock.png
82 | /home/kevin/git/matting_data/DIM/img_png/arrgh___r___30_by_mjranum_stock.png
83 | /home/kevin/git/matting_data/DIM/img_png/ballerina-855652_1920.png
84 | /home/kevin/git/matting_data/DIM/img_png/beautiful-19075_960_720.png
85 | /home/kevin/git/matting_data/DIM/img_png/boy-1518482_1920.png
86 | /home/kevin/git/matting_data/DIM/img_png/boy-454633_1920.png
87 | /home/kevin/git/matting_data/DIM/img_png/face-1223346_960_720.png
88 | /home/kevin/git/matting_data/DIM/img_png/fashion-model-portrait.png
89 | /home/kevin/git/matting_data/DIM/img_png/fashion-model-pose.png
90 | /home/kevin/git/matting_data/DIM/img_png/girl-1219339_1920.png
91 | /home/kevin/git/matting_data/DIM/img_png/girl-1467820_1280.png
92 | /home/kevin/git/matting_data/DIM/img_png/girl-1535859_1920.png
93 | /home/kevin/git/matting_data/DIM/img_png/girl-beautiful-young-face-53000.png
94 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00000.png
95 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00005.png
96 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00010.png
97 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00015.png
98 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00020.png
99 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00025.png
100 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00030.png
101 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00035.png
102 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00040.png
103 | /home/kevin/git/matting_data/DIM/img_png/goth_by_bugidifino-d4w7zms.png
104 | /home/kevin/git/matting_data/DIM/img_png/hair-flying-142210_1920.png
105 | /home/kevin/git/matting_data/DIM/img_png/headshotid_by_bokogreat_stock-d355xf3.png
106 | /home/kevin/git/matting_data/DIM/img_png/lil_white_goth_grl___23_by_mjranum_stock.png
107 | /home/kevin/git/matting_data/DIM/img_png/lil_white_goth_grl___26_by_mjranum_stock.png
108 | /home/kevin/git/matting_data/DIM/img_png/locked_00000.png
109 | /home/kevin/git/matting_data/DIM/img_png/locked_00005.png
110 | /home/kevin/git/matting_data/DIM/img_png/locked_00010.png
111 | /home/kevin/git/matting_data/DIM/img_png/locked_00015.png
112 | /home/kevin/git/matting_data/DIM/img_png/locked_00020.png
113 | /home/kevin/git/matting_data/DIM/img_png/locked_00025.png
114 | /home/kevin/git/matting_data/DIM/img_png/locked_00030.png
115 | /home/kevin/git/matting_data/DIM/img_png/locked_00035.png
116 | /home/kevin/git/matting_data/DIM/img_png/locked_00040.png
117 | /home/kevin/git/matting_data/DIM/img_png/locked_00045.png
118 | /home/kevin/git/matting_data/DIM/img_png/locked_00050.png
119 | /home/kevin/git/matting_data/DIM/img_png/locked_00055.png
120 | /home/kevin/git/matting_data/DIM/img_png/locked_00060.png
121 | /home/kevin/git/matting_data/DIM/img_png/locked_00065.png
122 | /home/kevin/git/matting_data/DIM/img_png/locked_00070.png
123 | /home/kevin/git/matting_data/DIM/img_png/locked_00075.png
124 | /home/kevin/git/matting_data/DIM/img_png/long-1245787_1920.png
125 | /home/kevin/git/matting_data/DIM/img_png/man-388104_960_720.png
126 | /home/kevin/git/matting_data/DIM/img_png/man_headshot.png
127 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00000.png
128 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00005.png
129 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00010.png
130 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00015.png
131 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00020.png
132 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00025.png
133 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00030.png
134 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00035.png
135 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00040.png
136 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00045.png
137 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00050.png
138 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00055.png
139 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00060.png
140 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00065.png
141 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00070.png
142 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00075.png
143 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00080.png
144 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00085.png
145 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00090.png
146 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00095.png
147 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00100.png
148 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00105.png
149 | /home/kevin/git/matting_data/DIM/img_png/model-429733_960_720.png
150 | /home/kevin/git/matting_data/DIM/img_png/model-600238_1920.png
151 | /home/kevin/git/matting_data/DIM/img_png/model-610352_960_720.png
152 | /home/kevin/git/matting_data/DIM/img_png/model-858753_960_720.png
153 | /home/kevin/git/matting_data/DIM/img_png/model-858755_960_720.png
154 | /home/kevin/git/matting_data/DIM/img_png/model-873675_960_720.png
155 | /home/kevin/git/matting_data/DIM/img_png/model-873678_960_720.png
156 | /home/kevin/git/matting_data/DIM/img_png/model-873690_960_720.png
157 | /home/kevin/git/matting_data/DIM/img_png/model-881425_960_720.png
158 | /home/kevin/git/matting_data/DIM/img_png/model-881431_960_720.png
159 | /home/kevin/git/matting_data/DIM/img_png/model-female-girl-beautiful-51969.png
160 | /home/kevin/git/matting_data/DIM/img_png/person-woman-eyes-face.png
161 | /home/kevin/git/matting_data/DIM/img_png/pexels-photo-58463.png
162 | /home/kevin/git/matting_data/DIM/img_png/pink-hair-855660_960_720.png
163 | /home/kevin/git/matting_data/DIM/img_png/portrait-750774_1920.png
164 | /home/kevin/git/matting_data/DIM/img_png/sailor_flying_4_by_senshistock-d4k2wmr.png
165 | /home/kevin/git/matting_data/DIM/img_png/sea-sunny-person-beach.png
166 | /home/kevin/git/matting_data/DIM/img_png/skin-care-937667_960_720.png
167 | /home/kevin/git/matting_data/DIM/img_png/sorcery___8_by_mjranum_stock.png
168 | /home/kevin/git/matting_data/DIM/img_png/wedding-846926_1920.png
169 | /home/kevin/git/matting_data/DIM/img_png/wedding-dresses-1486260_1280.png
170 | /home/kevin/git/matting_data/DIM/img_png/with_wings___pose_reference_by_senshistock-d6by42n.png
171 | /home/kevin/git/matting_data/DIM/img_png/with_wings___pose_reference_by_senshistock-d6by42n_2.png
172 | /home/kevin/git/matting_data/DIM/img_png/woman-1138435_960_720.png
173 | /home/kevin/git/matting_data/DIM/img_png/woman-659354_960_720.png
174 | /home/kevin/git/matting_data/DIM/img_png/woman-804072_960_720.png
175 | /home/kevin/git/matting_data/DIM/img_png/woman-868519_960_720.png
176 | /home/kevin/git/matting_data/DIM/img_png/woman-952506_1920 (1).png
177 | /home/kevin/git/matting_data/DIM/img_png/woman-morning-bathrobe-bathroom.png
178 | /home/kevin/git/matting_data/DIM/img_png/woman1.png
179 | /home/kevin/git/matting_data/DIM/img_png/woman2.png
180 | /home/kevin/git/matting_data/DIM/img_png/women-878869_1920.png
181 |
--------------------------------------------------------------------------------
/data/data_util.py:
--------------------------------------------------------------------------------
1 | # code for temporary usage or previously used
2 | # author: L.Q. Chen
3 |
4 | import cv2
5 | import os
6 | import random
7 | import numpy as np
8 | import math
9 |
10 | # for RGBA format image generation: inputs are the paths of normal image & alpha image (mask)
11 | def get_rgba(img_path, msk_path):
12 | img = cv2.imread(img_path)
13 | msk = cv2.imread(msk_path, 0)
14 |
15 | image = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
16 | image[:,:,3] = msk
17 | return image
18 |
19 | # for convenience, we make fake fg and bg image
20 | def fake_fg_bg(img,alpha):
21 | color_fg = [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)]
22 | color_bg = [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)]
23 | fg = np.full(img.shape, color_fg)
24 | bg = np.full(img.shape, color_bg)
25 | a = np.expand_dims(alpha/255., axis=2)
26 |
27 | fg = img * a + fg * (1-a)
28 | bg = bg * a + img * (1-a)
29 | return fg.astype(np.uint8), bg.astype(np.uint8)
30 |
31 | # composite fg and bg with arbitrary size
32 | def composite(image, a, bg):
33 | h, w = a.shape
34 | bh, bw = bg.shape[:2]
35 | wratio, hratio = w/bw, h/bh
36 | ratio = wratio if wratio > hratio else hratio
37 | bg = cv2.resize(bg, (math.ceil(bw * ratio), math.ceil(bh * ratio)), cv2.INTER_CUBIC)
38 | bh, bw = bg.shape[:2]
39 | assert bh>=h and bw>=w
40 | bg = np.array(bg[(bh - h) // 2: (bh - h) // 2 + h, (bw - w) // 2: (bw - w) // 2 + w], np.float32)
41 | assert h, w == bg.shape[:2]
42 |
43 | fg = np.array(image, np.float32)
44 | alpha = np.expand_dims(a/255., axis=2).astype(np.float32)
45 | comp = alpha * fg + (1 - alpha) * bg
46 | comp = comp.astype(np.uint8)
47 |
48 | # we need bg for calculating compositional loss in M_net
49 | return comp, bg
50 |
51 | def read_files(name, return_a_bg=True):
52 | if name[2] == 'bg':
53 | img_path, bg_path = name[0].strip(), name[1].strip()
54 | assert os.path.isfile(img_path) and os.path.isfile(bg_path), (img_path, bg_path)
55 |
56 | image = cv2.imread(img_path, -1) # it's RGBA image
57 | fg = image[:,:,:3]
58 | a = image[:,:,3]
59 | bg = cv2.imread(bg_path)
60 |
61 | image, bg = composite(fg, a, bg)
62 | trimap = gen_trimap.rand_trimap(a)
63 |
64 | elif name[2] == 'msk':
65 | img_path, a_path = name[0].strip(), name[1].strip()
66 | assert os.path.isfile(img_path) and os.path.isfile(a_path)
67 |
68 | image = cv2.imread(img_path) # it's composited image
69 | a = cv2.imread(a_path, 0) # it's grayscale image
70 |
71 | a[a>0] = 255
72 | trimap = gen_trimap.rand_trimap(a)
73 |
74 | # for M-net and fusion training, we need bg for compositional loss
75 | # here to simplify, we convert original image
76 | if return_a_bg:
77 | fg, bg = fake_fg_bg(image, a)
78 |
79 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure
80 | trimap[trimap == 0] = 0
81 | trimap[trimap == 128] = 1
82 | trimap[trimap == 255] = 2
83 |
84 | assert image.shape[:2] == trimap.shape[:2] == a.shape[:2]
85 |
86 | if return_a_bg:
87 | return image, trimap, a, bg, fg
88 | else:
89 | return image, trimap
90 |
91 | # crop image/alpha/trimap with random size of [max_size//2, max_size] then resize to patch_size
92 | def random_patch(image, trimap, patch_size, alpha=None, bg=None, fg=None):
93 | h, w = image.shape[:2]
94 | max_size = max(h, w)
95 | min_size = max_size//2
96 | if isinstance(alpha, np.ndarray) and isinstance(bg, np.ndarray):
97 | patch_a_bg= True
98 | else:
99 | patch_a_bg = False
100 |
101 | count = 0 # debug usage
102 | while True:
103 | sqr_tri = np.zeros((max_size, max_size), np.uint8)
104 | sqr_img = np.zeros((max_size, max_size, 3), np.uint8)
105 | if patch_a_bg:
106 | sqr_alp = np.zeros((max_size, max_size), np.uint8)
107 | sqr_bg = np.zeros((max_size, max_size, 3), np.uint8)
108 | sqr_fg = np.zeros((max_size, max_size, 3), np.uint8)
109 | if h>=w:
110 | sqr_tri[:, (h-w)//2 : (h-w)//2+w] = trimap
111 | sqr_img[:, (h-w)//2 : (h-w)//2+w] = image
112 | if patch_a_bg:
113 | sqr_alp[:, (h-w)//2 : (h-w)//2+w] = alpha
114 | sqr_bg[:, (h-w)//2 : (h-w)//2+w] = bg
115 | sqr_fg[:, (h-w)//2 : (h-w)//2+w] = fg
116 | else:
117 | sqr_tri[(w-h)//2 : (w-h)//2+h, :] = trimap
118 | sqr_img[(w-h)//2 : (w-h)//2+h, :] = image
119 | if patch_a_bg:
120 | sqr_alp[(w-h)//2 : (w-h)//2+h, :] = alpha
121 | sqr_bg[(w-h)//2 : (w-h)//2+h, :] = bg
122 | sqr_fg[(w-h)//2 : (w-h)//2+h, :] = fg
123 |
124 | crop_size = random.randint(min_size, max_size) # both value are inclusive
125 | x = random.randint(0, max_size-crop_size) # 0 is inclusive
126 | y = random.randint(0, max_size-crop_size)
127 | trimap_temp = sqr_tri[y: y+crop_size, x: x+crop_size]
128 | if len(np.where(trimap_temp == 1)[0])>0: # check if unknown area is included
129 | image = sqr_img[y: y+crop_size, x: x+crop_size]
130 | trimap = trimap_temp
131 | if patch_a_bg:
132 | alpha = sqr_alp[y: y+crop_size, x: x+crop_size]
133 | bg = sqr_bg[y: y+crop_size, x: x+crop_size]
134 | fg = sqr_fg[y: y+crop_size, x: x+crop_size]
135 | break
136 | elif len(np.where(trimap==1)[0]) == 0:
137 | print('Warning & Error: No unknown area in current trimap! Refer to saved trimap in folder.')
138 | image = sqr_img
139 | trimap = sqr_tri
140 | if patch_a_bg:
141 | alpha = sqr_alp
142 | bg = sqr_bg
143 | fg = sqr_fg
144 | os.makedirs('ckpt/exceptions', exist_ok=True)
145 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image)
146 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap)
147 | break
148 | elif count>5:
149 | print('Warning: cannot find right patch randomly, use max_size instead! Refer to saved files in folder.')
150 | image = sqr_img
151 | trimap = sqr_tri
152 | if patch_a_bg:
153 | alpha = sqr_alp
154 | bg = sqr_bg
155 | fg = sqr_fg
156 | os.makedirs('ckpt/exceptions', exist_ok=True)
157 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image)
158 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap)
159 | break
160 |
161 | count += 1 # debug usage
162 |
163 | image = cv2.resize(image, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
164 | trimap = cv2.resize(trimap, (patch_size, patch_size), interpolation=cv2.INTER_NEAREST)
165 | if patch_a_bg:
166 | alpha = cv2.resize(alpha, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
167 | bg = cv2.resize(bg, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
168 | fg = cv2.resize(fg, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
169 | return image, trimap, alpha, bg, fg
170 | else:
171 | return image, trimap
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import os
4 | import random
5 | import numpy as np
6 | import math
7 |
8 | from data import gen_trimap
9 |
10 | import torch
11 | import torch.utils.data as data
12 |
13 | def composite(image, a, bg):
14 | fg = np.array(image, np.float32)
15 | alpha = np.expand_dims(a/255., axis=2).astype(np.float32)
16 | comp = alpha * fg + (1 - alpha) * bg
17 | comp = comp.astype(np.uint8)
18 |
19 | # we need bg for calculating compositional loss in M_net
20 | return comp
21 |
22 | ## crop & resize bg so that img and bg have same size
23 | def resize_bg(bg, h, w):
24 | bh, bw = bg.shape[:2]
25 | wratio, hratio = w/bw, h/bh
26 | ratio = wratio if wratio > hratio else hratio
27 | bg = cv2.resize(bg, (math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv2.INTER_CUBIC)
28 | bh, bw = bg.shape[:2]
29 | assert bh>=h and bw>=w
30 | bg = np.array(bg[(bh - h) // 2: (bh - h) // 2 + h, (bw - w) // 2: (bw - w) // 2 + w], np.float32)
31 | assert h, w == bg.shape[:2]
32 |
33 | return bg # attention: bg is in float32
34 |
35 | # randomly crop images, then resize to patch_size
36 | def random_patch(image, trimap, patch_size, bg=None, a=None):
37 | h, w = image.shape[:2]
38 | max_size = max(h, w)
39 | min_size = max_size // 2
40 |
41 | count=0
42 | while True:
43 | sqr_tri = np.zeros((max_size, max_size), np.uint8)
44 | if isinstance(bg, np.ndarray):
45 | sqr_img = np.zeros((max_size, max_size, 4), np.uint8)
46 | sqr_bga = np.zeros((max_size, max_size, 3), np.uint8)
47 | bga = resize_bg(bg, h, w)
48 | elif isinstance(a, np.ndarray):
49 | sqr_img = np.zeros((max_size, max_size, 3), np.uint8)
50 | sqr_bga = np.zeros((max_size, max_size), np.uint8)
51 | bga = a
52 | else:
53 | raise ValueError('no bg or alpha given from input!')
54 |
55 | if h >= w:
56 | sqr_tri[:, (h - w) // 2: (h - w) // 2 + w] = trimap
57 | sqr_img[:, (h - w) // 2: (h - w) // 2 + w] = image
58 | sqr_bga[:, (h - w) // 2: (h - w) // 2 + w] = bga
59 | else:
60 | sqr_tri[(w - h) // 2: (w - h) // 2 + h, :] = trimap
61 | sqr_img[(w - h) // 2: (w - h) // 2 + h, :] = image
62 | sqr_bga[(w - h) // 2: (w - h) // 2 + h, :] = bga
63 |
64 | crop_size = random.randint(min_size, max_size) # both value are inclusive
65 | x = random.randint(0, max_size - crop_size) # 0 is inclusive
66 | y = random.randint(0, max_size - crop_size)
67 | trimap_temp = sqr_tri[y: y + crop_size, x: x + crop_size]
68 | if len(np.where(trimap_temp == 128)[0]) > 0: # check if unknown area is included
69 | image = sqr_img[y: y + crop_size, x: x + crop_size]
70 | bga = sqr_bga[y: y + crop_size, x: x + crop_size]
71 | break
72 | elif len(np.where(trimap == 128)[0]) == 0:
73 | print('Warning: No unknown area in current trimap! Refer to folder.')
74 | os.makedirs('ckpt/exceptions', exist_ok=True)
75 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image)
76 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap)
77 | elif count > 3:
78 | print('Warning & Error: cannot find right patch randomly, use max_size instead! Refer to folder.')
79 | os.makedirs('ckpt/exceptions', exist_ok=True)
80 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image)
81 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap)
82 | image = sqr_img
83 | bga = sqr_bga
84 | break
85 |
86 | count += 1 # debug usage
87 |
88 | image = cv2.resize(image, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
89 | bga = cv2.resize(bga, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC)
90 |
91 | return image, bga
92 |
93 | # read/ crop/ resize inputs including fb, bg, alpha, trimap, composited image
94 | # attention: all resize operation should be done before composition
95 | def read_crop_resize(name, patch_size, stage):
96 | if name[2] == 'bg':
97 | img_path, bg_path = name[0].strip(), name[1].strip()
98 | assert os.path.isfile(img_path) and os.path.isfile(bg_path), (img_path, bg_path)
99 |
100 | image = cv2.imread(img_path, -1) # it's RGBA image
101 | fg = image[:,:,:3]
102 | a = image[:,:,3]
103 | bg = cv2.imread(bg_path)
104 | trimap = gen_trimap.rand_trimap(a)
105 |
106 | if stage == 't_net':
107 | image, bg = random_patch(image, trimap, patch_size, bg=bg)
108 | fg = image[:,:,:3]
109 | a = image[:,:,3]
110 | # composite fg and bg, generate trimap
111 | img = composite(fg, a, bg)
112 | trimap = gen_trimap.rand_trimap(a)
113 |
114 | return img, trimap
115 |
116 | elif stage == 'm_net':
117 | fg, a, bg = mask_center_crop(fg, a, bg, trimap, patch_size)
118 | # generate trimap again to avoid alpha resize side-effect on trimap
119 | trimap = gen_trimap.rand_trimap(a)
120 | # composite fg and bg
121 | img = composite(fg, a, bg)
122 |
123 | return img, trimap, a, bg, fg
124 |
125 | elif stage == 'end2end':
126 | # for t_net
127 | t_patch, m_patch = patch_size
128 | image_m, bg_m = random_patch(image, trimap, t_patch, bg=bg)
129 | fg_m = image_m[:,:,:3]
130 | a_m = image_m[:,:,3]
131 | img_m = composite(fg_m, a_m, bg_m)
132 | trimap_m = gen_trimap.rand_trimap(a_m)
133 |
134 | # random flip and rotation before going to m_net
135 | bg_m = bg_m.astype(np.uint8)
136 | img_m, trimap_m, a_m, bg_m, fg_m = random_flip_rotation(img_m, trimap_m, a_m, bg_m, fg_m)
137 |
138 | # for m_net
139 | fg, a, bg = mask_center_crop(fg_m, a_m, bg_m, trimap_m, m_patch)
140 | # generate trimap again to avoid alpha resize side-effect on trimap
141 | trimap = gen_trimap.rand_trimap(a)
142 | # composite fg and bg
143 | img = composite(fg, a, bg)
144 |
145 | return (img_m, trimap_m), (img, trimap, a, bg, fg)
146 |
147 | # note: this type of data is only used for t_net training
148 | elif name[2] == 'msk':
149 | img_path, a_path = name[0].strip(), name[1].strip()
150 | assert os.path.isfile(img_path) and os.path.isfile(a_path)
151 |
152 | image = cv2.imread(img_path) # it's composited image
153 | a = cv2.imread(a_path, 0) # it's grayscale image
154 | a[a > 0] = 255
155 | trimap = gen_trimap.rand_trimap(a)
156 | img, a = random_patch(image, trimap, patch_size, a=a)
157 |
158 | # generate trimap again to avoid alpha resize side-effect on trimap
159 | trimap = gen_trimap.rand_trimap(a)
160 |
161 | return img, trimap
162 |
163 | # crop image into crop_size
164 | def safe_crop(img, x, y, crop_size, resize_patch=None):
165 | if len(img.shape) == 2:
166 | new = np.zeros((crop_size, crop_size), np.uint8)
167 | else:
168 | new = np.zeros((crop_size, crop_size, 3), np.uint8)
169 | cropped = img[y:y+crop_size, x:x+crop_size] # if y+crop_size bigger than len_y, return len_y
170 | h, w = cropped.shape[:2] # h or w falls into the range of [crop_size/2, crop_size]
171 | new[0:h, 0:w] = cropped
172 |
173 | if resize_patch:
174 | new = cv2.resize(new, (resize_patch, resize_patch), interpolation=cv2.INTER_CUBIC)
175 |
176 | return new
177 |
178 | # crop image around trimap unknown area with random size of [patch_size, 2*patch_size] then resize to patch_size
179 | def mask_center_crop(fg, a, bg, trimap, patch_size):
180 | max_size = patch_size * 2
181 | min_size = patch_size
182 | crop_size = random.randint(min_size, max_size) # both value are inclusive
183 |
184 | # get crop center around trimap unknown area
185 | y_idx, x_idx = np.where(trimap == 128)
186 | num_unknowns = len(y_idx)
187 | if num_unknowns > 0:
188 | idx = np.random.choice(range(num_unknowns))
189 | cx = x_idx[idx]
190 | cy = y_idx[idx]
191 | x = max(0, cx - int(crop_size / 2))
192 | y = max(0, cy - int(crop_size /2 ))
193 | else:
194 | raise ValueError('no unknown area in trimap!')
195 |
196 | # crop image/trimap/alpha/fg/bg
197 | fg = safe_crop(fg, x, y, crop_size, resize_patch=patch_size)
198 | a = safe_crop(a, x, y, crop_size, resize_patch=patch_size)
199 | bg = safe_crop(bg, x, y, crop_size, resize_patch=patch_size)
200 |
201 | return fg, a, bg
202 |
203 | # randomly flip and rotate images
204 | def random_flip_rotation(image, trimap, alpha=None, bg=None, fg=None):
205 | if isinstance(alpha, np.ndarray) and isinstance(bg, np.ndarray):
206 | a_bg = True
207 | else:
208 | a_bg = False
209 | # horizontal flipping
210 | if random.random() < 0.5:
211 | image = cv2.flip(image, 1)
212 | trimap = cv2.flip(trimap, 1)
213 | if a_bg:
214 | alpha = cv2.flip(alpha, 1)
215 | bg = cv2.flip(bg, 1)
216 | fg = cv2.flip(fg, 1)
217 |
218 | # rotation
219 | if random.random() < 0.05:
220 | degree = random.randint(1,3) # rotate 1/2/3 * 90 degrees
221 | image = np.rot90(image, degree)
222 | trimap = np.rot90(trimap, degree)
223 | if a_bg:
224 | alpha = np.rot90(alpha, degree)
225 | bg = np.rot90(bg, degree)
226 | fg = np.rot90(fg, degree)
227 |
228 | if a_bg:
229 | return image, trimap, alpha, bg, fg
230 | else:
231 | return image, trimap
232 |
233 | def np2Tensor(array):
234 | if len(array.shape)>2:
235 | ts = (2, 0, 1)
236 | tensor = torch.FloatTensor(array.transpose(ts).astype(float))
237 | else:
238 | tensor = torch.FloatTensor(array.astype(float))
239 | return tensor
240 |
241 | class human_matting_data(data.Dataset):
242 | """
243 | human_matting
244 | """
245 | def __init__(self, args):
246 | super().__init__()
247 | self.data_root = args.dataDir
248 | self.patch_size = args.patch_size
249 | self.phase = args.train_phase
250 | self.dataRatio = args.dataRatio
251 |
252 | self.fg_paths = []
253 | for file in args.fgLists:
254 | fg_path = os.path.join(self.data_root, file)
255 | assert os.path.isfile(fg_path), "missing file at {}".format(fg_path)
256 | with open(fg_path, 'r') as f:
257 | self.fg_paths.append(f.readlines())
258 |
259 | bg_path = os.path.join(self.data_root, args.bg_list)
260 | assert os.path.isfile(bg_path), "missing bg file at: ".format(bg_path)
261 | with open(bg_path, 'r') as f:
262 | self.path_bg = f.readlines()
263 |
264 | assert len(self.path_bg) == sum([self.dataRatio[i]*len(self.fg_paths[i]) for i in range(len(self.fg_paths))]), \
265 | 'the total num of bg is not equal to fg: bg-{}, fg-{}'\
266 | .format(len(self.path_bg), [self.dataRatio[i]*len(self.fg_paths[i]) for i in range(len(self.fg_paths))])
267 | self.num = len(self.path_bg)
268 |
269 | #self.shuffle_count = 0
270 | self.shuffle_data()
271 |
272 | print("Dataset : total training images:{}".format(self.num))
273 |
274 | def __getitem__(self, index):
275 | # data structure returned :: dict {}
276 | # image: c, h, w / range[0-1] , float
277 | # trimap: 1, h, w / [0,1,2] , float
278 | # alpha: 1, h, w / range[0-1] , float
279 |
280 | if self.phase == 'pre_train_t_net':
281 | # read files, random crop and resize
282 | image, trimap = read_crop_resize(self.names[index], self.patch_size, stage='t_net')
283 | # augmentation
284 | image, trimap = random_flip_rotation(image, trimap)
285 |
286 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure
287 | trimap[trimap == 0] = 0
288 | trimap[trimap == 128] = 1
289 | trimap[trimap == 255] = 2
290 | assert image.shape[:2] == trimap.shape[:2]
291 |
292 | # normalize
293 | image = image.astype(np.float32) / 255.0
294 |
295 | # to tensor
296 | image = np2Tensor(image)
297 | trimap = np2Tensor(trimap)
298 |
299 | trimap = trimap.unsqueeze_(0) # shape: 1, h, w
300 |
301 | sample = {'image':image, 'trimap':trimap}
302 |
303 | elif self.phase == 'pre_train_m_net':
304 | # read files
305 | image, trimap, alpha, bg, fg = read_crop_resize(self.names[index], self.patch_size, stage='m_net')
306 | # augmentation
307 | image, trimap, alpha, bg, fg = random_flip_rotation(image, trimap, alpha, bg, fg)
308 |
309 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure
310 | trimap[trimap == 0] = 0
311 | trimap[trimap == 128] = 1
312 | trimap[trimap == 255] = 2
313 |
314 | assert image.shape[:2] == trimap.shape[:2] == alpha.shape[:2]
315 |
316 | # normalize
317 | image = image.astype(np.float32) / 255.0
318 | alpha = alpha.astype(np.float32) / 255.0
319 | bg = bg.astype(np.float32) / 255.0
320 | fg = fg.astype(np.float32) / 255.0
321 |
322 | # trimap one-hot encoding: when pre-train M_net, trimap should have 3 channels
323 | trimap = np.eye(3)[trimap.reshape(-1)].reshape(list(trimap.shape)+[3])
324 |
325 | # to tensor
326 | image = np2Tensor(image)
327 | trimap = np2Tensor(trimap)
328 | alpha = np2Tensor(alpha)
329 | bg = np2Tensor(bg)
330 | fg = np2Tensor(fg)
331 |
332 | alpha = alpha.unsqueeze_(0) # shape: 1, h, w
333 |
334 | sample = {'image': image, 'trimap': trimap, 'alpha': alpha, 'bg':bg, 'fg':fg}
335 |
336 | elif self.phase == 'end_to_end':
337 | # read files
338 | assert len(self.patch_size) == 2, 'patch_size should have two values for end2end training !'
339 | input_t, input_m = read_crop_resize(self.names[index], self.patch_size, stage='end2end')
340 | img_t, tri_t = input_t
341 | img_m, tri_m, a_m, bg_m, fg_m = input_m
342 |
343 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure
344 | tri_t[tri_t == 0] = 0
345 | tri_t[tri_t == 128] = 1
346 | tri_t[tri_t == 255] = 2
347 | tri_m[tri_m == 0] = 0
348 | tri_m[tri_m == 128] = 1
349 | tri_m[tri_m == 255] = 2
350 |
351 | assert img_t.shape[:2] == tri_t.shape[:2]
352 | assert img_m.shape[:2] == tri_m.shape[:2] == a_m.shape[:2]
353 |
354 | # t_net processing
355 | img_t = img_t.astype(np.float32) / 255.0
356 | img_t = np2Tensor(img_t)
357 | tri_t = np2Tensor(tri_t)
358 | tri_t = tri_t.unsqueeze_(0) # shape: 1, h, w
359 |
360 | # m_net processing
361 | img_m = img_m.astype(np.float32) / 255.0
362 | a_m = a_m.astype(np.float32) / 255.0
363 | bg_m = bg_m.astype(np.float32) / 255.0
364 | fg_m = fg_m.astype(np.float32) / 255.0
365 |
366 | # trimap one-hot encoding: when pre-train M_net, trimap should have 3 channels
367 | tri_m = np.eye(3)[tri_m.reshape(-1)].reshape(list(tri_m.shape) + [3])
368 | # to tensor
369 | img_m = np2Tensor(img_m)
370 | tri_m = np2Tensor(tri_m)
371 | a_m = np2Tensor(a_m)
372 | bg_m = np2Tensor(bg_m)
373 | fg_m = np2Tensor(fg_m)
374 |
375 | tri_m = tri_m.unsqueeze_(0) # shape: 1, h, w
376 | a_m = a_m.unsqueeze_(0)
377 |
378 | sample = [{'image':img_t, 'trimap':tri_t},
379 | {'image':img_m, 'trimap':tri_m, 'alpha': a_m, 'bg':bg_m, 'fg':fg_m}]
380 |
381 | return sample
382 |
383 | def shuffle_data(self):
384 | # data structure of self.names:: list
385 | # (.png_img, .bg, 'bg') or (.composite, .mask, 'msk) :: tuple
386 | self.names = []
387 |
388 | random.shuffle(self.path_bg)
389 |
390 | count = 0
391 | for idx, path_list in enumerate(self.fg_paths):
392 | bg_per_fg = self.dataRatio[idx]
393 | for path in path_list:
394 | for i in range(bg_per_fg):
395 | self.names.append((path, self.path_bg[count], 'bg')) # 'bg' means we need to composite fg & bg
396 | count += 1
397 |
398 | assert count == len(self.path_bg)
399 |
400 | """# debug usage: check shuffled data after each call
401 | with open('shuffled_data_{}.txt'.format(self.shuffle_count),'w') as f:
402 | for name in self.names:
403 | f.write(name[0].strip()+' || '+name[1].strip()+'\n')
404 | self.shuffle_count += 1
405 | """
406 |
407 | def __len__(self):
408 | return self.num
409 |
410 |
--------------------------------------------------------------------------------
/data/gen_trimap.py:
--------------------------------------------------------------------------------
1 | import cv2, os
2 | import numpy as np
3 | import argparse
4 | import tqdm
5 |
6 | """
7 | def get_args():
8 | parser = argparse.ArgumentParser(description='Trimap')
9 | parser.add_argument('--mskDir', type=str, required=True, help="masks directory")
10 | parser.add_argument('--saveDir', type=str, required=True, help="where trimap result save to")
11 | parser.add_argument('--list', type=str, required=True, help="list of images id")
12 | parser.add_argument('--size', type=int, required=True, help="kernel size")
13 | args = parser.parse_args()
14 | print(args)
15 | return args
16 | """
17 | # a simple trimap generation code for fixed kernel size
18 | def erode_dilate(msk, size=(10, 10), smooth=True):
19 | if smooth:
20 | size = (size[0]-4, size[1]-4)
21 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, size)
22 |
23 | dilated = cv2.dilate(msk, kernel, iterations=1)
24 | if smooth: # if it is .jpg, prevent output to be jagged
25 | dilated[(dilated>5)] = 255
26 | dilated[(dilated <= 5)] = 0
27 | else:
28 | dilated[(dilated>0)] = 255
29 |
30 | eroded = cv2.erode(msk, kernel, iterations=1)
31 | if smooth:
32 | eroded[(eroded<250)] = 0
33 | eroded[(eroded >= 250)] = 255
34 | else:
35 | eroded[(eroded < 255)] = 0
36 |
37 | res = dilated.copy()
38 | res[((dilated == 255) & (eroded == 0))] = 128
39 |
40 | """# make sure there are only 3 values in trimap
41 | cnt0 = len(np.where(res >= 0)[0])
42 | cnt1 = len(np.where(res == 0)[0])
43 | cnt2 = len(np.where(res == 128)[0])
44 | cnt3 = len(np.where(res == 255)[0])
45 | assert cnt0 == cnt1 + cnt2 + cnt3
46 | """
47 |
48 | return res
49 |
50 | # trimap generation with different/random kernel size
51 | def rand_trimap(msk, smooth=False):
52 | h, w = msk.shape
53 | scale_up, scale_down = 0.022, 0.006 # hyper parameter
54 | dmin = 0 # hyper parameter
55 | emax = 255 - dmin # hyper parameter
56 |
57 | # .jpg (or low quality .png) tend to be jagged, smoothing tricks need to be applied
58 | if smooth:
59 | # give thrshold for dilation and erode results
60 | scale_up, scale_down = 0.02, 0.006
61 | dmin = 5
62 | emax = 255 - dmin
63 |
64 | # apply gussian smooth
65 | if h<1000:
66 | gau_ker = round(h*0.01) # we restrict the kernel size to 5-9
67 | gau_ker = gau_ker if gau_ker % 2 ==1 else gau_ker-1 # make sure it's odd
68 | if h<500:
69 | gau_ker = max(3, gau_ker)
70 | msk = cv2.GaussianBlur(msk, (gau_ker, gau_ker), 0)
71 |
72 | kernel_size_high = max(10, round((h + w) / 2 * scale_up))
73 | kernel_size_low = max(1, round((h + w) /2 * scale_down))
74 | erode_kernel_size = np.random.randint(kernel_size_low, kernel_size_high)
75 | dilate_kernel_size = np.random.randint(kernel_size_low, kernel_size_high)
76 |
77 | erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (erode_kernel_size, erode_kernel_size))
78 | dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dilate_kernel_size, dilate_kernel_size))
79 | eroded_alpha = cv2.erode(msk, erode_kernel)
80 | dilated_alpha = cv2.dilate(msk, dilate_kernel)
81 |
82 | dilated_alpha = np.where(dilated_alpha > dmin, 255, 0)
83 | eroded_alpha = np.where(eroded_alpha < emax, 0, 255)
84 |
85 | res = dilated_alpha.copy()
86 | res[((dilated_alpha == 255) & (eroded_alpha == 0))] = 128
87 |
88 | return res
89 |
90 |
91 | def get_trimap(msk, smooth=True):
92 | h, w = msk.shape[:2]
93 | scale_up, scale_down = 0.022, 0.008 # hyper parameter
94 | dmin = 0 # hyper parameter
95 | emax = 255 - dmin # hyper parameter
96 |
97 | # .jpg (or low quality .png) tend to be jagged, smoothing tricks need to be applied
98 | if smooth:
99 | scale_up, scale_down = 0.02, 0.006
100 | dmin = 5
101 | emax = 255 - dmin
102 |
103 | kernel_size_high = max(10, round(h * scale_up))
104 | kernel_size_low = max(1, round(h * scale_down))
105 | kernel_size = (kernel_size_high + kernel_size_low)//2
106 |
107 | print('kernel size:', kernel_size)
108 |
109 | erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
110 | dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
111 | eroded_alpha = cv2.erode(msk, erode_kernel)
112 | dilated_alpha = cv2.dilate(msk, dilate_kernel)
113 |
114 | dilated_alpha = np.where(dilated_alpha > dmin, 255, 0)
115 | eroded_alpha = np.where(eroded_alpha < emax, 0, 255)
116 |
117 | res = dilated_alpha.copy()
118 | res[((dilated_alpha == 255) & (eroded_alpha == 0))] = 128
119 |
120 | return res
121 |
122 |
123 | def main():
124 | kernel_size = 18
125 | alphaDir = 'alpha' # where alpha matting images are
126 | trimapDir = 'trimap'
127 | names = os.listdir('image')
128 | if names == []:
129 | raise ValueError('No images are in the dir: ./image')
130 | print("Images Count: {}".format(len(names)))
131 |
132 | for name in tqdm.tqdm(names):
133 | alpha_path = alphaDir + "/" + name
134 | trimap_path = trimapDir + "/" + name.strip()[:-4] + ".png" # output must be .png format
135 | alpha = cv2.imread(alpha_path, 0)
136 | if name[-3:] != 'png':
137 | trimap = erode_dilate(alpha, size=(kernel_size, kernel_size), smooth=True)
138 | else:
139 | trimap = erode_dilate(alpha, size=(kernel_size,kernel_size))
140 |
141 | #print("Write to {}".format(trimap_name))
142 | cv2.imwrite(trimap_path, trimap)
143 |
144 | if __name__ == "__main__":
145 | main()
146 |
147 |
148 |
--------------------------------------------------------------------------------
/model/M_Net.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class M_net(nn.Module):
7 | '''
8 | encoder + decoder
9 | '''
10 |
11 | def __init__(self, classes=2):
12 |
13 | super(M_net, self).__init__()
14 | # -----------------------------------------------------------------
15 | # encoder
16 | # ---------------------
17 | # stage-1
18 | self.conv_1_1 = nn.Sequential(nn.Conv2d(6, 64, 3, 1, 1, bias=True), nn.BatchNorm2d(64), nn.ReLU())
19 | self.conv_1_2 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1, bias=True), nn.BatchNorm2d(64), nn.ReLU())
20 | self.max_pooling_1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
21 |
22 | # stage-2
23 | self.conv_2_1 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1, bias=True), nn.BatchNorm2d(128), nn.ReLU())
24 | self.conv_2_2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1, bias=True), nn.BatchNorm2d(128), nn.ReLU())
25 | self.max_pooling_2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
26 |
27 | # stage-3
28 | self.conv_3_1 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU())
29 | self.conv_3_2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU())
30 | self.conv_3_3 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU())
31 | self.max_pooling_3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
32 |
33 | # stage-4
34 | self.conv_4_1 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
35 | self.conv_4_2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
36 | self.conv_4_3 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
37 | self.max_pooling_4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
38 |
39 | # stage-5
40 | self.conv_5_1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
41 | self.conv_5_2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
42 | self.conv_5_3 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU())
43 |
44 | # -----------------------------------------------------------------
45 | # decoder
46 | # ---------------------
47 | # stage-5
48 | self.deconv_5 = nn.Sequential(nn.Conv2d(512, 512, 5, 1, 2, bias=True), nn.BatchNorm2d(512), nn.ReLU())
49 |
50 | # stage-4
51 | self.up_pool_4 = nn.MaxUnpool2d(2, stride=2)
52 | self.deconv_4 = nn.Sequential(nn.Conv2d(512, 256, 5, 1, 2, bias=True), nn.BatchNorm2d(256), nn.ReLU())
53 |
54 | # stage-3
55 | self.up_pool_3 = nn.MaxUnpool2d(2, stride=2)
56 | self.deconv_3 = nn.Sequential(nn.Conv2d(256, 128, 5, 1, 2, bias=True), nn.BatchNorm2d(128), nn.ReLU())
57 |
58 | # stage-2
59 | self.up_pool_2 = nn.MaxUnpool2d(2, stride=2)
60 | self.deconv_2 = nn.Sequential(nn.Conv2d(128, 64, 5, 1, 2, bias=True), nn.BatchNorm2d(64), nn.ReLU())
61 |
62 | # stage-1
63 | self.up_pool_1 = nn.MaxUnpool2d(2, stride=2)
64 | self.deconv_1 = nn.Sequential(nn.Conv2d(64, 64, 5, 1, 2, bias=True), nn.BatchNorm2d(64), nn.ReLU())
65 |
66 | # stage-0
67 | self.conv_0 = nn.Conv2d(64, 1, 5, 1, 2, bias=True)
68 |
69 |
70 | def forward(self, input):
71 |
72 | # ----------------
73 | # encoder
74 | # --------
75 | x11 = self.conv_1_1(input)
76 | x12 = self.conv_1_2(x11)
77 | x1p, id1 = self.max_pooling_1(x12)
78 |
79 | x21 = self.conv_2_1(x1p)
80 | x22 = self.conv_2_2(x21)
81 | x2p, id2 = self.max_pooling_2(x22)
82 |
83 | x31 = self.conv_3_1(x2p)
84 | x32 = self.conv_3_2(x31)
85 | x33 = self.conv_3_3(x32)
86 | x3p, id3 = self.max_pooling_3(x33)
87 |
88 | x41 = self.conv_4_1(x3p)
89 | x42 = self.conv_4_2(x41)
90 | x43 = self.conv_4_3(x42)
91 | x4p, id4 = self.max_pooling_4(x43)
92 |
93 | x51 = self.conv_5_1(x4p)
94 | x52 = self.conv_5_2(x51)
95 | x53 = self.conv_5_3(x52)
96 | # ----------------
97 | # decoder
98 | # --------
99 | x5d = self.deconv_5(x53)
100 |
101 | x4u = self.up_pool_4(x5d, id4)
102 | x4d = self.deconv_4(x4u)
103 |
104 | x3u = self.up_pool_3(x4d, id3)
105 | x3d = self.deconv_3(x3u)
106 |
107 | x2u = self.up_pool_2(x3d, id2)
108 | x2d = self.deconv_2(x2u)
109 |
110 | x1u = self.up_pool_1(x2d, id1)
111 | x1d = self.deconv_1(x1u)
112 |
113 | # raw alpha pred
114 | raw_alpha = self.conv_0(x1d)
115 |
116 | return raw_alpha
117 |
118 |
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/model/T_Net_psp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from model import extractors
6 |
7 |
8 | class PSPModule(nn.Module):
9 | def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
10 | super().__init__()
11 | self.stages = []
12 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
13 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
14 | self.relu = nn.ReLU()
15 |
16 | def _make_stage(self, features, size):
17 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
18 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
19 | return nn.Sequential(prior, conv)
20 |
21 | def forward(self, feats):
22 | h, w = feats.size(2), feats.size(3)
23 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats]
24 | bottle = self.bottleneck(torch.cat(priors, 1))
25 | return self.relu(bottle)
26 |
27 |
28 | class PSPUpsample(nn.Module):
29 | def __init__(self, in_channels, out_channels):
30 | super().__init__()
31 | self.conv = nn.Sequential(
32 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
33 | nn.BatchNorm2d(out_channels),
34 | nn.PReLU()
35 | )
36 |
37 | def forward(self, x):
38 | h, w = 2 * x.size(2), 2 * x.size(3)
39 | p = F.interpolate(input=x, size=(h, w), mode='bilinear', align_corners=True)
40 | return self.conv(p)
41 |
42 | class PSPNet(nn.Module):
43 | def __init__(self, n_classes=3, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50',
44 | pretrained=True):
45 | super().__init__()
46 | self.feats = getattr(extractors, backend)(pretrained)
47 | self.psp = PSPModule(psp_size, 1024, sizes)
48 | self.drop_1 = nn.Dropout2d(p=0.3)
49 |
50 | self.up_1 = PSPUpsample(1024, 256)
51 | self.up_2 = PSPUpsample(256, 64)
52 | self.up_3 = PSPUpsample(64, 64)
53 |
54 | self.drop_2 = nn.Dropout2d(p=0.15)
55 |
56 | """
57 | self.final = nn.Sequential(
58 | nn.Conv2d(64, n_classes, kernel_size=1),
59 | nn.LogSoftmax()
60 | )
61 | """
62 | self.final = nn.Sequential(nn.Conv2d(64, n_classes, kernel_size=1))
63 |
64 | self.classifier = nn.Sequential(
65 | nn.Linear(deep_features_size, 256),
66 | nn.ReLU(),
67 | nn.Linear(256, n_classes)
68 | )
69 |
70 | def forward(self, x):
71 | f, class_f = self.feats(x)
72 | p = self.psp(f)
73 | p = self.drop_1(p)
74 |
75 | p = self.up_1(p)
76 | p = self.drop_2(p)
77 |
78 | p = self.up_2(p)
79 | p = self.drop_2(p)
80 |
81 | p = self.up_3(p)
82 | p = self.drop_2(p)
83 |
84 | #auxiliary = F.adaptive_max_pool2d(input=class_f, output_size=(1, 1)).view(-1, class_f.size(1))
85 |
86 | return self.final(p)
87 |
88 |
--------------------------------------------------------------------------------
/model/extractors.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils import model_zoo
8 | from torchvision.models.densenet import densenet121, densenet161
9 | from torchvision.models.squeezenet import squeezenet1_1
10 |
11 |
12 | def load_weights_sequential(target, source_state):
13 | """
14 | new_dict = OrderedDict()
15 | for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()):
16 | new_dict[k1] = v2
17 | """
18 | model_to_load = {k: v for k, v in source_state.items() if k in target.state_dict().keys()}
19 | target.load_state_dict(model_to_load)
20 |
21 | '''
22 | Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x
23 | '''
24 | model_urls = {
25 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
26 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
27 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
28 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
29 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
30 | }
31 |
32 |
33 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35 | padding=dilation, dilation=dilation, bias=False)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 |
41 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
42 | super(BasicBlock, self).__init__()
43 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
44 | self.bn1 = nn.BatchNorm2d(planes)
45 | self.relu = nn.ReLU(inplace=True)
46 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.downsample = downsample
49 | self.stride = stride
50 |
51 | def forward(self, x):
52 | residual = x
53 |
54 | out = self.conv1(x)
55 | out = self.bn1(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class Bottleneck(nn.Module):
71 | expansion = 4
72 |
73 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
74 | super(Bottleneck, self).__init__()
75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
76 | self.bn1 = nn.BatchNorm2d(planes)
77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
78 | padding=dilation, bias=False)
79 | self.bn2 = nn.BatchNorm2d(planes)
80 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
81 | self.bn3 = nn.BatchNorm2d(planes * 4)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.downsample = downsample
84 | self.stride = stride
85 |
86 | def forward(self, x):
87 | residual = x
88 |
89 | out = self.conv1(x)
90 | out = self.bn1(out)
91 | out = self.relu(out)
92 |
93 | out = self.conv2(out)
94 | out = self.bn2(out)
95 | out = self.relu(out)
96 |
97 | out = self.conv3(out)
98 | out = self.bn3(out)
99 |
100 | if self.downsample is not None:
101 | residual = self.downsample(x)
102 |
103 | out += residual
104 | out = self.relu(out)
105 |
106 | return out
107 |
108 |
109 | class ResNet(nn.Module):
110 | def __init__(self, block, layers=(3, 4, 23, 3)):
111 | self.inplanes = 64
112 | super(ResNet, self).__init__()
113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
114 | bias=False)
115 | self.bn1 = nn.BatchNorm2d(64)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
118 | self.layer1 = self._make_layer(block, 64, layers[0])
119 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
120 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
122 |
123 | for m in self.modules():
124 | if isinstance(m, nn.Conv2d):
125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
126 | m.weight.data.normal_(0, math.sqrt(2. / n))
127 | elif isinstance(m, nn.BatchNorm2d):
128 | m.weight.data.fill_(1)
129 | m.bias.data.zero_()
130 |
131 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
132 | downsample = None
133 | if stride != 1 or self.inplanes != planes * block.expansion:
134 | downsample = nn.Sequential(
135 | nn.Conv2d(self.inplanes, planes * block.expansion,
136 | kernel_size=1, stride=stride, bias=False),
137 | nn.BatchNorm2d(planes * block.expansion),
138 | )
139 |
140 | layers = [block(self.inplanes, planes, stride, downsample)]
141 | self.inplanes = planes * block.expansion
142 | for i in range(1, blocks):
143 | layers.append(block(self.inplanes, planes, dilation=dilation))
144 |
145 | return nn.Sequential(*layers)
146 |
147 | def forward(self, x):
148 | x = self.conv1(x)
149 | x = self.bn1(x)
150 | x = self.relu(x)
151 | x = self.maxpool(x)
152 |
153 | x = self.layer1(x)
154 | x = self.layer2(x)
155 | x_3 = self.layer3(x)
156 | x = self.layer4(x_3)
157 |
158 | return x, x_3
159 |
160 |
161 | '''
162 | Implementation of DenseNet with deep supervision. Downsampling is changed to 8x
163 | '''
164 |
165 |
166 | class _DenseLayer(nn.Sequential):
167 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
168 | super(_DenseLayer, self).__init__()
169 | self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
170 | self.add_module('relu.1', nn.ReLU(inplace=True)),
171 | self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
172 | growth_rate, kernel_size=1, stride=1, bias=False)),
173 | self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
174 | self.add_module('relu.2', nn.ReLU(inplace=True)),
175 | self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
176 | kernel_size=3, stride=1, padding=1, bias=False)),
177 | self.drop_rate = drop_rate
178 |
179 | def forward(self, x):
180 | new_features = super(_DenseLayer, self).forward(x)
181 | if self.drop_rate > 0:
182 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
183 | return torch.cat([x, new_features], 1)
184 |
185 |
186 | class _DenseBlock(nn.Sequential):
187 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
188 | super(_DenseBlock, self).__init__()
189 | for i in range(num_layers):
190 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
191 | self.add_module('denselayer%d' % (i + 1), layer)
192 |
193 |
194 | class _Transition(nn.Sequential):
195 | def __init__(self, num_input_features, num_output_features, downsample=True):
196 | super(_Transition, self).__init__()
197 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
198 | self.add_module('relu', nn.ReLU(inplace=True))
199 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
200 | kernel_size=1, stride=1, bias=False))
201 | if downsample:
202 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
203 | else:
204 | self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1)) # compatibility hack
205 |
206 |
207 | class DenseNet(nn.Module):
208 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
209 | num_init_features=64, bn_size=4, drop_rate=0, pretrained=True):
210 |
211 | super(DenseNet, self).__init__()
212 |
213 | # First convolution
214 | self.start_features = nn.Sequential(OrderedDict([
215 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
216 | ('norm0', nn.BatchNorm2d(num_init_features)),
217 | ('relu0', nn.ReLU(inplace=True)),
218 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
219 | ]))
220 |
221 | # Each denseblock
222 | num_features = num_init_features
223 |
224 | init_weights = list(densenet121(pretrained=True).features.children())
225 | start = 0
226 | for i, c in enumerate(self.start_features.children()):
227 | if pretrained:
228 | c.load_state_dict(init_weights[i].state_dict())
229 | start += 1
230 | self.blocks = nn.ModuleList()
231 | for i, num_layers in enumerate(block_config):
232 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
233 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
234 | if pretrained:
235 | block.load_state_dict(init_weights[start].state_dict())
236 | start += 1
237 | self.blocks.append(block)
238 | setattr(self, 'denseblock%d' % (i + 1), block)
239 |
240 | num_features = num_features + num_layers * growth_rate
241 | if i != len(block_config) - 1:
242 | downsample = i < 1
243 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
244 | downsample=downsample)
245 | if pretrained:
246 | trans.load_state_dict(init_weights[start].state_dict())
247 | start += 1
248 | self.blocks.append(trans)
249 | setattr(self, 'transition%d' % (i + 1), trans)
250 | num_features = num_features // 2
251 |
252 | def forward(self, x):
253 | out = self.start_features(x)
254 | deep_features = None
255 | for i, block in enumerate(self.blocks):
256 | out = block(out)
257 | if i == 5:
258 | deep_features = out
259 |
260 | return out, deep_features
261 |
262 |
263 | class Fire(nn.Module):
264 |
265 | def __init__(self, inplanes, squeeze_planes,
266 | expand1x1_planes, expand3x3_planes, dilation=1):
267 | super(Fire, self).__init__()
268 | self.inplanes = inplanes
269 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
270 | self.squeeze_activation = nn.ReLU(inplace=True)
271 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
272 | kernel_size=1)
273 | self.expand1x1_activation = nn.ReLU(inplace=True)
274 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
275 | kernel_size=3, padding=dilation, dilation=dilation)
276 | self.expand3x3_activation = nn.ReLU(inplace=True)
277 |
278 | def forward(self, x):
279 | x = self.squeeze_activation(self.squeeze(x))
280 | return torch.cat([
281 | self.expand1x1_activation(self.expand1x1(x)),
282 | self.expand3x3_activation(self.expand3x3(x))
283 | ], 1)
284 |
285 |
286 | class SqueezeNet(nn.Module):
287 |
288 | def __init__(self, pretrained=False):
289 | super(SqueezeNet, self).__init__()
290 |
291 | self.feat_1 = nn.Sequential(
292 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
293 | nn.ReLU(inplace=True)
294 | )
295 | self.feat_2 = nn.Sequential(
296 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
297 | Fire(64, 16, 64, 64),
298 | Fire(128, 16, 64, 64)
299 | )
300 | self.feat_3 = nn.Sequential(
301 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
302 | Fire(128, 32, 128, 128, 2),
303 | Fire(256, 32, 128, 128, 2)
304 | )
305 | self.feat_4 = nn.Sequential(
306 | Fire(256, 48, 192, 192, 4),
307 | Fire(384, 48, 192, 192, 4),
308 | Fire(384, 64, 256, 256, 4),
309 | Fire(512, 64, 256, 256, 4)
310 | )
311 | if pretrained:
312 | weights = squeezenet1_1(pretrained=True).features.state_dict()
313 | load_weights_sequential(self, weights)
314 |
315 | def forward(self, x):
316 | f1 = self.feat_1(x)
317 | f2 = self.feat_2(f1)
318 | f3 = self.feat_3(f2)
319 | f4 = self.feat_4(f3)
320 | return f4, f3
321 |
322 |
323 | '''
324 | Handy methods for construction
325 | '''
326 |
327 |
328 | def squeezenet(pretrained=True):
329 | return SqueezeNet(pretrained)
330 |
331 |
332 | def densenet(pretrained=True):
333 | return DenseNet(pretrained=pretrained)
334 |
335 |
336 | def resnet18(pretrained=True):
337 | model = ResNet(BasicBlock, [2, 2, 2, 2])
338 | if pretrained:
339 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']))
340 | return model
341 |
342 |
343 | def resnet34(pretrained=True):
344 | model = ResNet(BasicBlock, [3, 4, 6, 3])
345 | if pretrained:
346 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34']))
347 | return model
348 |
349 |
350 | def resnet50(pretrained=True):
351 | model = ResNet(Bottleneck, [3, 4, 6, 3])
352 | if pretrained:
353 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']))
354 | return model
355 |
356 |
357 | def resnet101(pretrained=True):
358 | model = ResNet(Bottleneck, [3, 4, 23, 3])
359 | if pretrained:
360 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101']))
361 | return model
362 |
363 |
364 | def resnet152(pretrained=True):
365 | model = ResNet(Bottleneck, [3, 8, 36, 3])
366 | if pretrained:
367 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152']))
368 | return model
369 |
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from model.M_Net import M_net
7 | from model.T_Net_psp import PSPNet
8 |
9 | class net_T(nn.Module):
10 | # Train T_net
11 | def __init__(self):
12 |
13 | super(net_T, self).__init__()
14 |
15 | self.t_net = PSPNet()
16 |
17 | def forward(self, input):
18 |
19 | # trimap
20 | trimap = self.t_net(input)
21 | return trimap
22 |
23 | class net_M(nn.Module):
24 | '''
25 | train M_net
26 | '''
27 |
28 | def __init__(self):
29 |
30 | super(net_M, self).__init__()
31 | self.m_net = M_net()
32 |
33 | def forward(self, input, trimap):
34 |
35 | # paper: bs, fs, us
36 | bg, fg, unsure = torch.split(trimap, 1, dim=1)
37 |
38 | # concat input and trimap
39 | m_net_input = torch.cat((input, trimap), 1)
40 |
41 | # matting
42 | alpha_r = self.m_net(m_net_input)
43 | # fusion module
44 | # paper : alpha_p = fs + us * alpha_r
45 | alpha_p = fg + unsure * alpha_r
46 |
47 | return alpha_p
48 |
49 | class net_F(nn.Module):
50 | '''
51 | end to end net
52 | '''
53 |
54 | def __init__(self):
55 |
56 | super(net_F, self).__init__()
57 |
58 | self.t_net = PSPNet()
59 | self.m_net = M_net()
60 |
61 |
62 |
63 | def forward(self, input):
64 |
65 | # trimap
66 | trimap = self.t_net(input)
67 | trimap_softmax = F.softmax(trimap, dim=1)
68 |
69 | # paper: bs, fs, us
70 | bg, unsure, fg = torch.split(trimap_softmax, 1, dim=1)
71 |
72 | # concat input and trimap
73 | m_net_input = torch.cat((input, trimap_softmax), 1)
74 |
75 | # matting
76 | alpha_r = self.m_net(m_net_input)
77 | # fusion module
78 | # paper : alpha_p = fs + us * alpha_r
79 | alpha_p = fg + unsure * alpha_r
80 |
81 | return trimap, alpha_p
82 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | import cv2
4 | import torch
5 | import argparse
6 | import numpy as np
7 | import os
8 | import torch.nn.functional as F
9 |
10 | parser = argparse.ArgumentParser(description='human matting')
11 | parser.add_argument('--model', default='./ckpt', help='root dir of preTrained model')
12 | parser.add_argument('--size', type=int, default=400, help='input size')
13 | parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu')
14 | parser.add_argument('--train_phase', default='end_to_end',help='which phase of the model')
15 |
16 | args = parser.parse_args()
17 |
18 | torch.set_grad_enabled(False)
19 |
20 |
21 | #################################
22 | #----------------
23 | if args.without_gpu:
24 | print("use CPU !")
25 | device = torch.device('cpu')
26 | else:
27 | if torch.cuda.is_available():
28 | n_gpu = torch.cuda.device_count()
29 | print("use GPU")
30 | device = torch.device('cuda')
31 |
32 | #################################
33 | #---------------
34 | def load_model(args):
35 | model_path = os.path.join(args.model, args.train_phase, 'model/model_obj.pth')
36 | assert os.path.isfile(model_path), 'Wrong model path: {}'.format(model_path)
37 | print('Loading model from {}...'.format(model_path))
38 |
39 | if args.without_gpu:
40 | myModel = torch.load(model_path, map_location=lambda storage, loc: storage)
41 | else:
42 | myModel = torch.load(model_path)
43 |
44 | myModel.eval()
45 | myModel.to(device)
46 |
47 | return myModel
48 |
49 | def seg_process(args, inputs, net):
50 |
51 | if args.train_phase == 'pre_train_t_net':
52 | trimap = net(inputs)
53 | trimap = torch.argmax(trimap[0], dim=0)
54 | trimap[trimap == 0] = 0
55 | trimap[trimap == 1] = 128
56 | trimap[trimap == 2] = 255
57 |
58 | if args.without_gpu:
59 | trimap_np = trimap.data.numpy()
60 | else:
61 | trimap_np = trimap.cpu().data.numpy()
62 |
63 | trimap_np = trimap_np.astype(np.uint8)
64 | return trimap_np
65 |
66 | elif args.train_phase == 'pre_train_m_net':
67 | alpha = net(inputs[0], inputs[1])
68 |
69 | if args.without_gpu:
70 | alpha = alpha.data.numpy()
71 | else:
72 | alpha = alpha.cpu().data.numpy()
73 |
74 | alpha = alpha[0][0] * 255.0
75 | alpha = alpha.astype(np.uint8)
76 |
77 | return alpha
78 |
79 | else:
80 | alpha = net(inputs)
81 |
82 | if args.without_gpu:
83 | alpha = alpha.data.numpy()
84 | else:
85 | alpha = alpha.cpu().data.numpy()
86 |
87 | alpha = alpha[0][0] * 255.0
88 | alpha = alpha.astype(np.uint8)
89 | return alpha
90 |
91 | def test(args, net):
92 |
93 | t0 = time.time()
94 | out_dir = 'result/' + args.train_phase + '/'
95 | os.makedirs(out_dir, exist_ok=True)
96 |
97 | # get a frame
98 | imgList = os.listdir('result/test')
99 | if imgList==[]:
100 | raise ValueError('Empty dir at: ./result/test')
101 | for imgname in imgList:
102 | img = cv2.imread('result/test/'+imgname)
103 |
104 | if args.train_phase == 'pre_train_t_net':
105 | img = img / 255.0
106 |
107 | tensor_img = torch.from_numpy(img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
108 |
109 | tensor_img = tensor_img.to(device)
110 |
111 | frame_seg = seg_process(args, tensor_img, net)
112 |
113 | elif args.train_phase == 'pre_train_m_net':
114 | # No resize
115 | h, w, _ = img.shape
116 | img = img / 255.0
117 |
118 | tri_path = 'result/trimap/'+imgname
119 | assert os.path.isfile(tri_path), 'wrong trimap path: {}'.format(tri_path)
120 |
121 | trimap_src = cv2.imread(tri_path, 0)
122 |
123 | trimap = trimap_src.copy()
124 | trimap[trimap == 0] = 0
125 | trimap[trimap == 128] = 1
126 | trimap[trimap == 255] = 2
127 | trimap = np.eye(3)[trimap.reshape(-1)].reshape(list(trimap.shape) + [3])
128 |
129 | tensor_img = torch.from_numpy(img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
130 | tensor_tri = torch.from_numpy(trimap.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
131 | tensor_img = tensor_img.to(device)
132 | tensor_tri = tensor_tri.to(device)
133 |
134 | frame_seg = seg_process(args, (tensor_img, tensor_tri), net, trimap=trimap_src)
135 |
136 | else:
137 | img = img / 255.0
138 | tensor_img = torch.from_numpy(img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
139 | tensor_img = tensor_img.to(device)
140 | frame_seg = seg_process(args, tensor_img, net)
141 |
142 | # show a frame
143 | cv2.imwrite(out_dir+imgname, frame_seg)
144 |
145 | print('Average time cost: {:.0f} s/image'.format((time.time() - t0) / len(imgList)))
146 | print('output images were saved at: ', out_dir)
147 |
148 | def main(args):
149 |
150 | myModel = load_model(args)
151 | test(args, myModel)
152 |
153 |
154 | if __name__ == "__main__":
155 | main(args)
156 |
157 |
158 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | import time
8 | import os
9 | import cv2
10 | import numpy as np
11 | from data import dataset
12 | from model import network
13 | import torch.nn.functional as F
14 |
15 |
16 | def get_args():
17 | # Training settings
18 | parser = argparse.ArgumentParser(description='Semantic Human Matting !')
19 | parser.add_argument('--dataDir', default='./data/', help='dataset directory')
20 | parser.add_argument('--fgLists', type=list, default=[], required=True, help="training fore-ground images lists")
21 | parser.add_argument('--bg_list', type=str, required=True, help='train back-ground images list, one file')
22 | parser.add_argument('--dataRatio', type=list, default=[], required=True, help="train bg:fg raio, eg. [100]")
23 | parser.add_argument('--saveDir', default='./ckpt', help='model save dir')
24 | parser.add_argument('--trainData', default='human_matting_data', help='train dataset name')
25 |
26 | parser.add_argument('--continue_train', action='store_true', default=False, help='continue training the training')
27 | parser.add_argument('--pretrain', action='store_true', help='load pretrained model from t_net & m_net ')
28 | parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu')
29 |
30 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads for data loading')
31 | parser.add_argument('--train_batch', type=int, default=8, help='input batch size for train')
32 | parser.add_argument('--patch_size', type=int, default=400, help='patch size for train')
33 |
34 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
35 | parser.add_argument('--lrDecay', type=int, default=100)
36 | parser.add_argument('--lrdecayType', default='keep')
37 | parser.add_argument('--nEpochs', type=int, default=300, help='number of epochs to train')
38 | parser.add_argument('--save_epoch', type=int, default=5, help='number of epochs to save model')
39 | parser.add_argument('--print_iter', type=int, default=1000, help='pring loss and save image')
40 |
41 | parser.add_argument('--train_phase', default= 'end_to_end', help='train phase')
42 | parser.add_argument('--debug', action='store_true', default=False, help='debug mode')
43 |
44 | args = parser.parse_args()
45 | return args
46 |
47 | def torch2numpy(tensor, without_gpu=False):
48 | if without_gpu:
49 | array = tensor.data.numpy()
50 | else:
51 | array = tensor.cpu().data.numpy()
52 |
53 | array = array * 255.0
54 | return array.astype(np.uint8)
55 |
56 | def save_img(args, all_img, epoch, i=0):
57 | img_dir = os.path.join(args.saveDir, args.train_phase, 'save_img')
58 | if not os.path.isdir(img_dir):
59 | os.mkdir(img_dir)
60 |
61 | if args.train_phase == 'pre_train_t_net':
62 | img, trimap_pre, trimap_gt = all_img
63 | trimap_fake = torch.argmax(trimap_pre[0], dim=0)
64 | trimap_cat = torch.cat((trimap_gt[0][0], trimap_fake.float()), dim=-1) # horizontal concate
65 | trimap_cat = torch.stack((trimap_cat,) * 3, dim=0)
66 | img_cat = torch.cat((img[0] * 255.0, trimap_cat * 127.5), dim=-1)
67 | if args.without_gpu:
68 | img_cat = img_cat.data.numpy()
69 | else:
70 | img_cat = img_cat.cpu().data.numpy()
71 | cv2.imwrite(img_dir + '/trimap_{}_{}.png'.format(str(epoch), str(i)),
72 | img_cat.transpose((1, 2, 0)).astype(np.uint8))
73 |
74 | if args.train_phase == 'pre_train_m_net':
75 | img, alpha_pre, alpha_gt = all_img
76 | img = img[0] * 255.0
77 | alpha_pre = torch.stack((alpha_pre[0][0] * 255.0,) * 3, dim=0)
78 | alpha_gt = torch.stack((alpha_gt[0][0] * 255.0,) * 3, dim=0)
79 | img_cat = torch.cat((img, alpha_gt, alpha_pre), dim=-1)
80 |
81 | if args.without_gpu:
82 | img_cat = img_cat.data.numpy()
83 | else:
84 | img_cat = img_cat.cpu().data.numpy()
85 | cv2.imwrite(img_dir + '/alpha_{}_{}.png'.format(str(epoch), str(i)),
86 | img_cat.transpose((1, 2, 0)).astype(np.uint8))
87 |
88 | if args.train_phase == 'end_to_end':
89 | img, trimap_gt, alpha_gt, trimap_pre, alpha_pre = all_img
90 | img = img[0] * 255.0
91 | trimap_pre = torch.argmax(trimap_pre[0], dim=0)
92 | trimap_pre = torch.stack((trimap_pre.float(),)*3, dim=0) * 127.5
93 | trimap_gt = torch.stack((trimap_gt[0][0],)*3, dim=0) * 127.5
94 | alpha_pre = torch.stack((alpha_pre[0][0] * 255.0,) * 3, dim=0)
95 | alpha_gt = torch.stack((alpha_gt[0][0] * 255.0,) * 3, dim=0)
96 |
97 | img_cat = torch.cat((img, trimap_gt, trimap_pre, alpha_gt, alpha_pre), dim=-1)
98 | #alpha_cat = torch.cat((alpha_gt, alpha_pre), dim=-1)
99 | if args.without_gpu:
100 | img_cat = img_cat.data.numpy()
101 | else:
102 | img_cat = img_cat.cpu().data.numpy()
103 |
104 | img_cat = img_cat.transpose((1,2,0)).astype(np.uint8)
105 |
106 | # save image
107 | cv2.imwrite(img_dir + '/e2e_{}_{}.png'.format(str(epoch), str(i)),img_cat)
108 |
109 | def set_lr(args, epoch, optimizer):
110 |
111 | lrDecay = args.lrDecay
112 | decayType = args.lrdecayType
113 | if decayType == 'keep':
114 | lr = args.lr
115 | elif decayType == 'step':
116 | epoch_iter = (epoch + 1) // lrDecay
117 | lr = args.lr / 2**epoch_iter
118 | elif decayType == 'exp':
119 | k = math.log(2) / lrDecay
120 | lr = args.lr * math.exp(-k * epoch)
121 | elif decayType == 'poly':
122 | lr = args.lr * math.pow((1 - epoch / args.nEpochs), 0.9)
123 |
124 | for param_group in optimizer.param_groups:
125 | param_group['lr'] = lr
126 |
127 | return lr
128 |
129 |
130 |
131 | class Train_Log():
132 | def __init__(self, args):
133 | self.args = args
134 |
135 | self.save_dir = os.path.join(args.saveDir, args.train_phase)
136 | if not os.path.exists(self.save_dir):
137 | os.makedirs(self.save_dir)
138 |
139 | self.save_dir_model = os.path.join(self.save_dir, 'model')
140 | if not os.path.exists(self.save_dir_model):
141 | os.makedirs(self.save_dir_model)
142 |
143 | if os.path.exists(self.save_dir + '/log.txt'):
144 | self.logFile = open(self.save_dir + '/log.txt', 'a')
145 | else:
146 | self.logFile = open(self.save_dir + '/log.txt', 'w')
147 |
148 | # in case pretrained weights need to be loaded
149 | if self.args.pretrain:
150 | self.t_path = os.path.join(args.saveDir, 'pre_train_t_net', 'model', 'ckpt_lastest.pth')
151 | self.m_path = os.path.join(args.saveDir, 'pre_train_m_net', 'model', 'ckpt_lastest.pth')
152 | assert os.path.isfile(self.t_path) and os.path.isfile(self.m_path), \
153 | 'Wrong dir for pretrained models:\n{},{}'.format(self.t_path, self.m_path)
154 |
155 |
156 | def save_model(self, model, epoch, save_as=False):
157 | if save_as: # for args.save_epoch
158 | lastest_out_path = "{}/ckpt_{}.pth".format(self.save_dir_model, epoch)
159 | model_out_path = "{}/model_obj.pth".format(self.save_dir_model)
160 | torch.save(
161 | model,
162 | model_out_path)
163 | else: # for regular save
164 | lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model)
165 |
166 | torch.save({
167 | 'epoch': epoch,
168 | 'state_dict': model.state_dict(),
169 | }, lastest_out_path)
170 |
171 | def load_pretrain(self, model):
172 | t_ckpt = torch.load(self.t_path)
173 | model.load_state_dict(t_ckpt['state_dict'], strict=False)
174 | m_ckpt = torch.load(self.m_path)
175 | model.load_state_dict(m_ckpt['state_dict'], strict=False)
176 | print('=> loaded pretrained t_net & m_net pretrained models !')
177 |
178 | return model
179 |
180 | def load_model(self, model):
181 | lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model)
182 | ckpt = torch.load(lastest_out_path)
183 | start_epoch = ckpt['epoch'] + 1
184 | model.load_state_dict(ckpt['state_dict'])
185 | print("=> loaded checkpoint '{}' (epoch {})".format(lastest_out_path, ckpt['epoch']))
186 |
187 | return start_epoch, model
188 |
189 | def save_log(self, log):
190 | self.logFile.write(log + '\n')
191 |
192 | # initialise conv2d weights
193 | def weight_init(m):
194 | if isinstance(m, nn.Conv2d):
195 | torch.nn.init.xavier_normal_(m.weight.data)
196 | #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
197 | #m.weight.data.normal_(0, math.sqrt(2. / n))
198 | elif isinstance(m, nn.BatchNorm2d):
199 | m.weight.data.fill_(1)
200 | m.bias.data.zero_()
201 |
202 | def loss_f_T(trimap_pre, trimap_gt):
203 | criterion = nn.CrossEntropyLoss()
204 | L_t = criterion(trimap_pre, trimap_gt[:, 0, :, :].long())
205 |
206 | return L_t
207 |
208 |
209 | def loss_f_M(img, alpha_pre, alpha_gt, bg, fg, trimap):
210 | # -------------------------------------
211 | # prediction loss L_p
212 | # ------------------------
213 | eps = 1e-6
214 | # l_alpha
215 | L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean()
216 |
217 | comp_pred = alpha_pre * fg + (1 - alpha_pre) * bg
218 |
219 | # be careful about here: if img's range is [0,1] then eps should divede 255
220 | L_composition = torch.sqrt(torch.pow(img - comp_pred, 2.) + eps).mean()
221 |
222 | L_p = 0.5 * L_alpha + 0.5 * L_composition
223 |
224 | return L_p, L_alpha, L_composition
225 |
226 | def loss_function(img, trimap_pre, trimap_gt, alpha_pre, alpha_gt, bg, fg):
227 |
228 | # -------------------------------------
229 | # classification loss L_t
230 | # ------------------------
231 | criterion = nn.CrossEntropyLoss()
232 | L_t = criterion(trimap_pre, trimap_gt[:,0,:,:].long())
233 |
234 | # -------------------------------------
235 | # prediction loss L_p
236 | # ------------------------
237 | eps = 1e-6
238 | # l_alpha
239 | L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean()
240 |
241 | comp_pred = alpha_pre * fg + (1 - alpha_pre) * bg
242 |
243 | # be careful about here: if img's range is [0,1] then eps should divede 255
244 | L_composition = torch.sqrt(torch.pow(img - comp_pred, 2.) + eps).mean()
245 | L_p = 0.5 * L_alpha + 0.5 * L_composition
246 |
247 | # train_phase
248 | loss = L_p + 0.01*L_t
249 |
250 | return loss, L_alpha, L_composition, L_t
251 |
252 |
253 | def main():
254 | args = get_args()
255 |
256 | if args.without_gpu:
257 | print("use CPU !")
258 | device = torch.device('cpu')
259 | else:
260 | if torch.cuda.is_available():
261 | device = torch.device('cuda')
262 | else:
263 | print("No GPU is is available !")
264 |
265 | print("============> Building model ...")
266 | if args.train_phase == 'pre_train_t_net':
267 | model = network.net_T()
268 | elif args.train_phase == 'pre_train_m_net':
269 | model = network.net_M()
270 | model.apply(weight_init)
271 | elif args.train_phase == 'end_to_end':
272 | model = network.net_F()
273 | if args.pretrain:
274 | model = Train_Log.load_pretrain(model)
275 | else:
276 | raise ValueError('Wrong train phase request!')
277 | train_data = dataset.human_matting_data(args)
278 | model.to(device)
279 |
280 | # debug setting
281 | save_latest_freq = int(len(train_data)//args.train_batch*0.55)
282 | if args.debug:
283 | args.save_epoch = 1
284 | args.train_batch = 1 # defualt debug: 1
285 | args.nEpochs = 1
286 | args.print_iter = 1
287 | save_latest_freq = 10
288 |
289 | print(args)
290 | print("============> Loading datasets ...")
291 |
292 | trainloader = DataLoader(train_data,
293 | batch_size=args.train_batch,
294 | drop_last=True,
295 | shuffle=True,
296 | num_workers=args.nThreads,
297 | pin_memory=True)
298 |
299 | print("============> Set optimizer ...")
300 | lr = args.lr
301 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), \
302 | lr=lr, betas=(0.9, 0.999),
303 | weight_decay=0.0005)
304 |
305 | print("============> Start Train ! ...")
306 | start_epoch = 1
307 | trainlog = Train_Log(args)
308 | if args.continue_train:
309 | start_epoch, model = trainlog.load_model(model)
310 |
311 | model.train()
312 | for epoch in range(start_epoch, args.nEpochs+1):
313 |
314 | loss_ = 0
315 | L_alpha_ = 0
316 | L_composition_ = 0
317 | L_cross_ = 0
318 | if args.lrdecayType != 'keep':
319 | lr = set_lr(args, epoch, optimizer)
320 |
321 | t0 = time.time()
322 | for i, sample_batched in enumerate(trainloader):
323 |
324 | optimizer.zero_grad()
325 |
326 | if args.train_phase == 'pre_train_t_net':
327 | img, trimap_gt = sample_batched['image'], sample_batched['trimap']
328 | img, trimap_gt = img.to(device), trimap_gt.to(device)
329 |
330 | trimap_pre = model(img)
331 | if args.debug: #debug only
332 | assert tuple(trimap_pre.shape) == (args.train_batch, 3, args.patch_size, args.patch_size)
333 | assert tuple(trimap_gt.shape) == (args.train_batch, 1, args.patch_size, args.patch_size)
334 |
335 | loss = loss_f_T(trimap_pre, trimap_gt)
336 |
337 | loss_ += loss.item()
338 |
339 | if i!=0 and i % args.print_iter == 0:
340 | save_img(args, (img, trimap_pre, trimap_gt), epoch, i)
341 | print("[epoch:{} iter:{}] \tloss: {:.5f}".format(epoch, i, loss))
342 | if i!=0 and i % save_latest_freq == 0:
343 | print("average loss: {:.5f}\nsaving model ....".format(loss_ / (i+1)))
344 | trainlog.save_model(model, epoch)
345 |
346 | elif args.train_phase == 'pre_train_m_net':
347 | img, trimap_gt, alpha_gt, bg, fg = sample_batched['image'], sample_batched['trimap'], sample_batched['alpha'], sample_batched['bg'], sample_batched['fg']
348 | img, trimap_gt, alpha_gt, bg, fg = img.to(device), trimap_gt.to(device), alpha_gt.to(device), bg.to(device), fg.to(device)
349 |
350 | alpha_pre = model(img, trimap_gt)
351 | if args.debug:
352 | assert tuple(alpha_pre.shape) == (args.train_batch, 1, args.patch_size, args.patch_size)
353 | img_dir = os.path.join(args.saveDir, args.train_phase, 'save_img')
354 | img_fg_bg = np.concatenate((torch2numpy(trimap_gt[0]), torch2numpy(fg[0]), torch2numpy(bg[0])), axis=-1)
355 | img_fg_bg = np.transpose(img_fg_bg, (1,2,0))
356 | cv2.imwrite(img_dir + '/fgbg_{}_{}.png'.format(str(epoch), str(i)), img_fg_bg)
357 |
358 | loss, L_alpha, L_composition = loss_f_M(img, alpha_pre, alpha_gt, bg, fg, trimap_gt)
359 |
360 | loss_ += loss.item()
361 | L_alpha_ += L_alpha.item()
362 | L_composition_ += L_composition.item()
363 |
364 | if i!=0 and i % args.print_iter == 0:
365 | save_img(args, (img, alpha_pre, alpha_gt), epoch, i)
366 | print("[epoch:{} iter:{}] loss: {:.5f} loss_a: {:.5f} loss_c: {:.5f}"\
367 | .format(epoch, i, loss, L_alpha, L_composition))
368 | if i!=0 and i % save_latest_freq == 0:
369 | print("average loss: {:.5f}\nsaving model ....".format(loss_ / (i + 1)))
370 | trainlog.save_model(model, epoch)
371 |
372 | elif args.train_phase == 'end_to_end':
373 | img, trimap_gt, alpha_gt, bg, fg = sample_batched['image'], sample_batched['trimap'], sample_batched['alpha'], sample_batched['bg'], sample_batched['fg']
374 | img, trimap_gt, alpha_gt, bg, fg = img.to(device), trimap_gt.to(device), alpha_gt.to(device), bg.to(device), fg.to(device)
375 |
376 | trimap_pre, alpha_pre = model(img)
377 | loss, L_alpha, L_composition, L_cross = loss_function(img, trimap_pre, trimap_gt, alpha_pre, alpha_gt, bg)
378 |
379 | loss_ += loss.item()
380 | L_alpha_ += L_alpha.item()
381 | L_composition_ += L_composition.item()
382 | L_cross_ += L_cross.item()
383 |
384 | loss.backward()
385 | optimizer.step()
386 |
387 |
388 | # shuffle data after each epoch to recreate the dataset
389 | print('epoch end, shuffle datasets again ...')
390 | train_data.shuffle_data()
391 | #trainloader.dataset.shuffle_data()
392 |
393 | t1 = time.time()
394 |
395 | if args.train_phase == 'pre_train_t_net':
396 | loss_ = loss_ / (i+1)
397 | log = "[{} / {}] \tloss: {:.5f}\ttime: {:.0f}".format(epoch, args.nEpochs, loss_, t1-t0)
398 |
399 | elif args.train_phase == 'pre_train_m_net':
400 | loss_ = loss_ / (i + 1)
401 | L_alpha_ = L_alpha_ / (i + 1)
402 | L_composition_ = L_composition_ / (i + 1)
403 | log = "[{} / {}] loss: {:.5f} loss_a: {:.5f} loss_c: {:.5f} time: {:.0f}"\
404 | .format(epoch, args.nEpochs,
405 | loss_,
406 | L_alpha_,
407 | L_composition_,
408 | t1 - t0)
409 |
410 | elif args.train_phase == 'end_to_end':
411 | loss_ = loss_ / (i+1)
412 | L_alpha_ = L_alpha_ / (i+1)
413 | L_composition_ = L_composition_ / (i+1)
414 | L_cross_ = L_cross_ / (i+1)
415 |
416 | log = "[{} / {}] loss: {:.5f} loss_a: {:.5f} loss_c: {:.5f} loss_t: {:.5f} time: {:.0f}"\
417 | .format(epoch, args.nEpochs,
418 | loss_,
419 | L_alpha_,
420 | L_composition_,
421 | L_cross_,
422 | t1 - t0)
423 | print(log)
424 | trainlog.save_log(log)
425 | trainlog.save_model(model, epoch)
426 |
427 | if epoch % args.save_epoch == 0:
428 | trainlog.save_model(model, epoch, save_as=True)
429 |
430 |
431 |
432 | if __name__ == "__main__":
433 | main()
434 |
--------------------------------------------------------------------------------