├── README.md ├── data ├── 0000-AIT.png ├── 0000.jpg ├── DIM_list.txt ├── bg_list.txt ├── data_util.py ├── dataset.py └── gen_trimap.py ├── model ├── M_Net.py ├── T_Net_psp.py ├── extractors.py └── network.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_semantic_human_matting 2 | This is an unofficial implementation of the paper "Semantic human matting". 3 | 4 | # testing environment: 5 | Ubuntu 16.04 6 | 7 | Pytorch 0.4.1 8 | 9 | # data preparation 10 | 11 | To make our life easier, there are only two types of data: 12 | 1. RGBA-png format: which means the image has no background (removed already), 13 | you can generate such an image with the function 'get_rgba' from ./data/data_util.py. 14 | 2. Backround images: which can fetched from coco datasets or anywhere else (e.g. internet), and they will be used for randomly compositing new images on-the-fly with foreground images separated from 1 (RGBA-PNG images), as described in paper. 15 | 16 | For example: 17 | 18 | For 1: I used Adobe Deep Image Matting datasets, etc.; I composite alpha and foreground images togher to get my RGBA-png format images. 19 | 20 | For 2: I used coco datasets and some images crawled from internet. 21 | 22 | When having those above two types of data, then generate lists of training files containing the full path of training images, 23 | such as 'DIM_list.txt', 'bg_list.txt' in my case. Specifically, for flags: 24 | 25 | --fgLists: a list, contains list files in which all images share the same fg-bg ratio, e.g. ['DIM.txt','SHM.txt'] 26 | 27 | --bg_list: a txt file, contains all bg images for composition needs, e.g. 'bg_list.txt'. 28 | 29 | --dataRatio: a list, contains bg-fb ratio for each list file in fgLists. For example, similar to the paper, 30 | given [100, 1], we composite 100 images for each fore-ground image in 'DIM.txt' and 1 image for each fg in 'SHM.txt' 31 | 32 | # Implementation details 33 | The training model is completely implemented as described as in the paper, details are as follows: 34 | * T-net: PSP-50 is deployed for training trimap generation; input is image (3 channels) and output is trimap (one channel); 35 | 36 | * M-net: 13 convolutional layers and 4 max-pooling layers with the same hyper-parameters for VGG-16 are used as encoder, and 6 convolutional layers and 4 unpooling layers are used as decoder; input is image and trimap (6 channels) and output is alpha image (1 channel); 37 | 38 | * Fusion: the fusion loss functions are implemented as described in paper; 39 | 40 | * **_This model is flexible for inferencing any size of images when well trained._** 41 | 42 | # How to run the code 43 | ## pre_trian T_net 44 | python train.py --patch_size=400 --train_phase=pre_train_t_net 45 | 46 | optional: --continue_train 47 | 48 | ## pre_train M_net 49 | python train.py --patch_size=320 --train_phase=pre_train_m_net 50 | 51 | optional: --continue_train 52 | 53 | ## end to end training 54 | python train.py --patch_size=800 --pretrain --train_phase=end_to_end 55 | 56 | optional: --continue_train 57 | 58 | note: 59 | 1. the end to end train process is really time-consuming. 60 | 2. I tried to implement the crop-on-the-fly trick for m-net inputs as described in the original paper, 61 | but the training process seemed to be very slow and not stable. So the same input size is used for both 62 | nets through the end to end training. 63 | 64 | ## testing 65 | python test.py --train_phase=pre_train_t_net/pre_train_m_net/end_to_end 66 | 67 | # Results 68 | Note: the following result is produced by T-net & M-net together, as I haven't complete end to end phase training yet. 69 | 70 | Original image from the Internet: 71 | 72 | 73 | 74 | Output image produced by the SHM: 75 | 76 | 77 | 78 | # Observations 79 | 1. The performance of T-net is essential for the whole process 80 | 2. Training trimap data (generated from alpha image) is essential for T-net training, which means the trimap should 81 | be in high-quality (actually alpha should have high-quality) and clear enough 82 | 3. The end to end training is rather slow, and may be not stable, especially when T-net is not robust 83 | (not powerful and stable). So I haven't complted the end to end phase training. Even though, the result seemed satisfying. 84 | 85 | Leave your comments if you have any other observations and suggestions. 86 | -------------------------------------------------------------------------------- /data/0000-AIT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsing90/pytorch_semantic_human_matting/9cbb85b3ccf1c4a77fd4e2b237e5f658a3d6501b/data/0000-AIT.png -------------------------------------------------------------------------------- /data/0000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsing90/pytorch_semantic_human_matting/9cbb85b3ccf1c4a77fd4e2b237e5f658a3d6501b/data/0000.jpg -------------------------------------------------------------------------------- /data/DIM_list.txt: -------------------------------------------------------------------------------- 1 | /home/kevin/git/matting_data/DIM/img_png/1-1252426161dfXY.png 2 | /home/kevin/git/matting_data/DIM/img_png/1-1255621189mTnS.png 3 | /home/kevin/git/matting_data/DIM/img_png/1-1259162624NMFK.png 4 | /home/kevin/git/matting_data/DIM/img_png/1-1259245823Un3j.png 5 | /home/kevin/git/matting_data/DIM/img_png/10743257206_18e7f44f2e_b.png 6 | /home/kevin/git/matting_data/DIM/img_png/10845279884_d2d4c7b4d1_b.png 7 | /home/kevin/git/matting_data/DIM/img_png/11363165393_05d7a21d76_b.png 8 | /home/kevin/git/matting_data/DIM/img_png/13564741125_753939e9ce_o.png 9 | /home/kevin/git/matting_data/DIM/img_png/14731860273_5b40b19b51_o.png 10 | /home/kevin/git/matting_data/DIM/img_png/16087-a-young-woman-showing-a-bitten-green-apple-pv.png 11 | /home/kevin/git/matting_data/DIM/img_png/1609484818_b9bb12b.png 12 | /home/kevin/git/matting_data/DIM/img_png/17620-a-beautiful-woman-in-a-bikini-pv.png 13 | /home/kevin/git/matting_data/DIM/img_png/20672673163_20c8467827_b.png 14 | /home/kevin/git/matting_data/DIM/img_png/3262986095_2d5afe583c_b.png 15 | /home/kevin/git/matting_data/DIM/img_png/3588101233_f91aa5e3a3.png 16 | /home/kevin/git/matting_data/DIM/img_png/3858897226_cae5b75963_o.png 17 | /home/kevin/git/matting_data/DIM/img_png/474px-Jerry_Moran,_official_portrait,_112th_Congress_headshot.png 18 | /home/kevin/git/matting_data/DIM/img_png/4981835627_c4e6c4ffa8_o.png 19 | /home/kevin/git/matting_data/DIM/img_png/5025666458_576b974455_o.png 20 | /home/kevin/git/matting_data/DIM/img_png/5149410930_3a943dc43f_b.png 21 | /home/kevin/git/matting_data/DIM/img_png/5892503248_4b882863c7_o.png 22 | /home/kevin/git/matting_data/DIM/img_png/7669262460_e4be408343_b.png 23 | /home/kevin/git/matting_data/DIM/img_png/8244818049_dfa59a3eb8_b.png 24 | /home/kevin/git/matting_data/DIM/img_png/8688417335_01f3bafbe5_o.png 25 | /home/kevin/git/matting_data/DIM/img_png/9434599749_e7ccfc7812_b.png 26 | /home/kevin/git/matting_data/DIM/img_png/Aaron_Friedman_Headshot.png 27 | /home/kevin/git/matting_data/DIM/img_png/GT03.png 28 | /home/kevin/git/matting_data/DIM/img_png/GT08.png 29 | /home/kevin/git/matting_data/DIM/img_png/Girl_in_front_of_a_green_background.png 30 | /home/kevin/git/matting_data/DIM/img_png/MFettes-headshot.png 31 | /home/kevin/git/matting_data/DIM/img_png/Model_in_green_dress_3.png 32 | /home/kevin/git/matting_data/DIM/img_png/Modern_shingle_bob_haircut.png 33 | /home/kevin/git/matting_data/DIM/img_png/Motivate_(Fitness_model).png 34 | /home/kevin/git/matting_data/DIM/img_png/Official_portrait_of_Barack_Obama.png 35 | /home/kevin/git/matting_data/DIM/img_png/Professor_Steven_Chu_ForMemRS_headshot.png 36 | /home/kevin/git/matting_data/DIM/img_png/Wild_hair.png 37 | /home/kevin/git/matting_data/DIM/img_png/Woman_in_white_shirt_on_August_2009_02.png 38 | /home/kevin/git/matting_data/DIM/img_png/a-single-person-1084191_960_720.png 39 | /home/kevin/git/matting_data/DIM/img_png/apple-841169_960_720.png 40 | /home/kevin/git/matting_data/DIM/img_png/archeology_00000.png 41 | /home/kevin/git/matting_data/DIM/img_png/archeology_00005.png 42 | /home/kevin/git/matting_data/DIM/img_png/archeology_00010.png 43 | /home/kevin/git/matting_data/DIM/img_png/archeology_00015.png 44 | /home/kevin/git/matting_data/DIM/img_png/archeology_00020.png 45 | /home/kevin/git/matting_data/DIM/img_png/archeology_00025.png 46 | /home/kevin/git/matting_data/DIM/img_png/archeology_00030.png 47 | /home/kevin/git/matting_data/DIM/img_png/archeology_00035.png 48 | /home/kevin/git/matting_data/DIM/img_png/archeology_00040.png 49 | /home/kevin/git/matting_data/DIM/img_png/archeology_00045.png 50 | /home/kevin/git/matting_data/DIM/img_png/archeology_00050.png 51 | /home/kevin/git/matting_data/DIM/img_png/archeology_00055.png 52 | /home/kevin/git/matting_data/DIM/img_png/archeology_00060.png 53 | /home/kevin/git/matting_data/DIM/img_png/archeology_00065.png 54 | /home/kevin/git/matting_data/DIM/img_png/archeology_00070.png 55 | /home/kevin/git/matting_data/DIM/img_png/archeology_00075.png 56 | /home/kevin/git/matting_data/DIM/img_png/archeology_00080.png 57 | /home/kevin/git/matting_data/DIM/img_png/archeology_00085.png 58 | /home/kevin/git/matting_data/DIM/img_png/archeology_00090.png 59 | /home/kevin/git/matting_data/DIM/img_png/archeology_00095.png 60 | /home/kevin/git/matting_data/DIM/img_png/archeology_00100.png 61 | /home/kevin/git/matting_data/DIM/img_png/archeology_00105.png 62 | /home/kevin/git/matting_data/DIM/img_png/archeology_00110.png 63 | /home/kevin/git/matting_data/DIM/img_png/archeology_00115.png 64 | /home/kevin/git/matting_data/DIM/img_png/archeology_00120.png 65 | /home/kevin/git/matting_data/DIM/img_png/archeology_00125.png 66 | /home/kevin/git/matting_data/DIM/img_png/archeology_00130.png 67 | /home/kevin/git/matting_data/DIM/img_png/archeology_00135.png 68 | /home/kevin/git/matting_data/DIM/img_png/archeology_00140.png 69 | /home/kevin/git/matting_data/DIM/img_png/archeology_00145.png 70 | /home/kevin/git/matting_data/DIM/img_png/archeology_00150.png 71 | /home/kevin/git/matting_data/DIM/img_png/archeology_00155.png 72 | /home/kevin/git/matting_data/DIM/img_png/archeology_00160.png 73 | /home/kevin/git/matting_data/DIM/img_png/archeology_00165.png 74 | /home/kevin/git/matting_data/DIM/img_png/archeology_00170.png 75 | /home/kevin/git/matting_data/DIM/img_png/archeology_00175.png 76 | /home/kevin/git/matting_data/DIM/img_png/archeology_00180.png 77 | /home/kevin/git/matting_data/DIM/img_png/archeology_00185.png 78 | /home/kevin/git/matting_data/DIM/img_png/archeology_00190.png 79 | /home/kevin/git/matting_data/DIM/img_png/archeology_00195.png 80 | /home/kevin/git/matting_data/DIM/img_png/arrgh___r___28_by_mjranum_stock.png 81 | /home/kevin/git/matting_data/DIM/img_png/arrgh___r___29_by_mjranum_stock.png 82 | /home/kevin/git/matting_data/DIM/img_png/arrgh___r___30_by_mjranum_stock.png 83 | /home/kevin/git/matting_data/DIM/img_png/ballerina-855652_1920.png 84 | /home/kevin/git/matting_data/DIM/img_png/beautiful-19075_960_720.png 85 | /home/kevin/git/matting_data/DIM/img_png/boy-1518482_1920.png 86 | /home/kevin/git/matting_data/DIM/img_png/boy-454633_1920.png 87 | /home/kevin/git/matting_data/DIM/img_png/face-1223346_960_720.png 88 | /home/kevin/git/matting_data/DIM/img_png/fashion-model-portrait.png 89 | /home/kevin/git/matting_data/DIM/img_png/fashion-model-pose.png 90 | /home/kevin/git/matting_data/DIM/img_png/girl-1219339_1920.png 91 | /home/kevin/git/matting_data/DIM/img_png/girl-1467820_1280.png 92 | /home/kevin/git/matting_data/DIM/img_png/girl-1535859_1920.png 93 | /home/kevin/git/matting_data/DIM/img_png/girl-beautiful-young-face-53000.png 94 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00000.png 95 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00005.png 96 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00010.png 97 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00015.png 98 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00020.png 99 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00025.png 100 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00030.png 101 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00035.png 102 | /home/kevin/git/matting_data/DIM/img_png/godiva_close_00040.png 103 | /home/kevin/git/matting_data/DIM/img_png/goth_by_bugidifino-d4w7zms.png 104 | /home/kevin/git/matting_data/DIM/img_png/hair-flying-142210_1920.png 105 | /home/kevin/git/matting_data/DIM/img_png/headshotid_by_bokogreat_stock-d355xf3.png 106 | /home/kevin/git/matting_data/DIM/img_png/lil_white_goth_grl___23_by_mjranum_stock.png 107 | /home/kevin/git/matting_data/DIM/img_png/lil_white_goth_grl___26_by_mjranum_stock.png 108 | /home/kevin/git/matting_data/DIM/img_png/locked_00000.png 109 | /home/kevin/git/matting_data/DIM/img_png/locked_00005.png 110 | /home/kevin/git/matting_data/DIM/img_png/locked_00010.png 111 | /home/kevin/git/matting_data/DIM/img_png/locked_00015.png 112 | /home/kevin/git/matting_data/DIM/img_png/locked_00020.png 113 | /home/kevin/git/matting_data/DIM/img_png/locked_00025.png 114 | /home/kevin/git/matting_data/DIM/img_png/locked_00030.png 115 | /home/kevin/git/matting_data/DIM/img_png/locked_00035.png 116 | /home/kevin/git/matting_data/DIM/img_png/locked_00040.png 117 | /home/kevin/git/matting_data/DIM/img_png/locked_00045.png 118 | /home/kevin/git/matting_data/DIM/img_png/locked_00050.png 119 | /home/kevin/git/matting_data/DIM/img_png/locked_00055.png 120 | /home/kevin/git/matting_data/DIM/img_png/locked_00060.png 121 | /home/kevin/git/matting_data/DIM/img_png/locked_00065.png 122 | /home/kevin/git/matting_data/DIM/img_png/locked_00070.png 123 | /home/kevin/git/matting_data/DIM/img_png/locked_00075.png 124 | /home/kevin/git/matting_data/DIM/img_png/long-1245787_1920.png 125 | /home/kevin/git/matting_data/DIM/img_png/man-388104_960_720.png 126 | /home/kevin/git/matting_data/DIM/img_png/man_headshot.png 127 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00000.png 128 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00005.png 129 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00010.png 130 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00015.png 131 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00020.png 132 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00025.png 133 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00030.png 134 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00035.png 135 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00040.png 136 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00045.png 137 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00050.png 138 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00055.png 139 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00060.png 140 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00065.png 141 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00070.png 142 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00075.png 143 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00080.png 144 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00085.png 145 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00090.png 146 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00095.png 147 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00100.png 148 | /home/kevin/git/matting_data/DIM/img_png/mmtest_00105.png 149 | /home/kevin/git/matting_data/DIM/img_png/model-429733_960_720.png 150 | /home/kevin/git/matting_data/DIM/img_png/model-600238_1920.png 151 | /home/kevin/git/matting_data/DIM/img_png/model-610352_960_720.png 152 | /home/kevin/git/matting_data/DIM/img_png/model-858753_960_720.png 153 | /home/kevin/git/matting_data/DIM/img_png/model-858755_960_720.png 154 | /home/kevin/git/matting_data/DIM/img_png/model-873675_960_720.png 155 | /home/kevin/git/matting_data/DIM/img_png/model-873678_960_720.png 156 | /home/kevin/git/matting_data/DIM/img_png/model-873690_960_720.png 157 | /home/kevin/git/matting_data/DIM/img_png/model-881425_960_720.png 158 | /home/kevin/git/matting_data/DIM/img_png/model-881431_960_720.png 159 | /home/kevin/git/matting_data/DIM/img_png/model-female-girl-beautiful-51969.png 160 | /home/kevin/git/matting_data/DIM/img_png/person-woman-eyes-face.png 161 | /home/kevin/git/matting_data/DIM/img_png/pexels-photo-58463.png 162 | /home/kevin/git/matting_data/DIM/img_png/pink-hair-855660_960_720.png 163 | /home/kevin/git/matting_data/DIM/img_png/portrait-750774_1920.png 164 | /home/kevin/git/matting_data/DIM/img_png/sailor_flying_4_by_senshistock-d4k2wmr.png 165 | /home/kevin/git/matting_data/DIM/img_png/sea-sunny-person-beach.png 166 | /home/kevin/git/matting_data/DIM/img_png/skin-care-937667_960_720.png 167 | /home/kevin/git/matting_data/DIM/img_png/sorcery___8_by_mjranum_stock.png 168 | /home/kevin/git/matting_data/DIM/img_png/wedding-846926_1920.png 169 | /home/kevin/git/matting_data/DIM/img_png/wedding-dresses-1486260_1280.png 170 | /home/kevin/git/matting_data/DIM/img_png/with_wings___pose_reference_by_senshistock-d6by42n.png 171 | /home/kevin/git/matting_data/DIM/img_png/with_wings___pose_reference_by_senshistock-d6by42n_2.png 172 | /home/kevin/git/matting_data/DIM/img_png/woman-1138435_960_720.png 173 | /home/kevin/git/matting_data/DIM/img_png/woman-659354_960_720.png 174 | /home/kevin/git/matting_data/DIM/img_png/woman-804072_960_720.png 175 | /home/kevin/git/matting_data/DIM/img_png/woman-868519_960_720.png 176 | /home/kevin/git/matting_data/DIM/img_png/woman-952506_1920 (1).png 177 | /home/kevin/git/matting_data/DIM/img_png/woman-morning-bathrobe-bathroom.png 178 | /home/kevin/git/matting_data/DIM/img_png/woman1.png 179 | /home/kevin/git/matting_data/DIM/img_png/woman2.png 180 | /home/kevin/git/matting_data/DIM/img_png/women-878869_1920.png 181 | -------------------------------------------------------------------------------- /data/data_util.py: -------------------------------------------------------------------------------- 1 | # code for temporary usage or previously used 2 | # author: L.Q. Chen 3 | 4 | import cv2 5 | import os 6 | import random 7 | import numpy as np 8 | import math 9 | 10 | # for RGBA format image generation: inputs are the paths of normal image & alpha image (mask) 11 | def get_rgba(img_path, msk_path): 12 | img = cv2.imread(img_path) 13 | msk = cv2.imread(msk_path, 0) 14 | 15 | image = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA) 16 | image[:,:,3] = msk 17 | return image 18 | 19 | # for convenience, we make fake fg and bg image 20 | def fake_fg_bg(img,alpha): 21 | color_fg = [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] 22 | color_bg = [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] 23 | fg = np.full(img.shape, color_fg) 24 | bg = np.full(img.shape, color_bg) 25 | a = np.expand_dims(alpha/255., axis=2) 26 | 27 | fg = img * a + fg * (1-a) 28 | bg = bg * a + img * (1-a) 29 | return fg.astype(np.uint8), bg.astype(np.uint8) 30 | 31 | # composite fg and bg with arbitrary size 32 | def composite(image, a, bg): 33 | h, w = a.shape 34 | bh, bw = bg.shape[:2] 35 | wratio, hratio = w/bw, h/bh 36 | ratio = wratio if wratio > hratio else hratio 37 | bg = cv2.resize(bg, (math.ceil(bw * ratio), math.ceil(bh * ratio)), cv2.INTER_CUBIC) 38 | bh, bw = bg.shape[:2] 39 | assert bh>=h and bw>=w 40 | bg = np.array(bg[(bh - h) // 2: (bh - h) // 2 + h, (bw - w) // 2: (bw - w) // 2 + w], np.float32) 41 | assert h, w == bg.shape[:2] 42 | 43 | fg = np.array(image, np.float32) 44 | alpha = np.expand_dims(a/255., axis=2).astype(np.float32) 45 | comp = alpha * fg + (1 - alpha) * bg 46 | comp = comp.astype(np.uint8) 47 | 48 | # we need bg for calculating compositional loss in M_net 49 | return comp, bg 50 | 51 | def read_files(name, return_a_bg=True): 52 | if name[2] == 'bg': 53 | img_path, bg_path = name[0].strip(), name[1].strip() 54 | assert os.path.isfile(img_path) and os.path.isfile(bg_path), (img_path, bg_path) 55 | 56 | image = cv2.imread(img_path, -1) # it's RGBA image 57 | fg = image[:,:,:3] 58 | a = image[:,:,3] 59 | bg = cv2.imread(bg_path) 60 | 61 | image, bg = composite(fg, a, bg) 62 | trimap = gen_trimap.rand_trimap(a) 63 | 64 | elif name[2] == 'msk': 65 | img_path, a_path = name[0].strip(), name[1].strip() 66 | assert os.path.isfile(img_path) and os.path.isfile(a_path) 67 | 68 | image = cv2.imread(img_path) # it's composited image 69 | a = cv2.imread(a_path, 0) # it's grayscale image 70 | 71 | a[a>0] = 255 72 | trimap = gen_trimap.rand_trimap(a) 73 | 74 | # for M-net and fusion training, we need bg for compositional loss 75 | # here to simplify, we convert original image 76 | if return_a_bg: 77 | fg, bg = fake_fg_bg(image, a) 78 | 79 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure 80 | trimap[trimap == 0] = 0 81 | trimap[trimap == 128] = 1 82 | trimap[trimap == 255] = 2 83 | 84 | assert image.shape[:2] == trimap.shape[:2] == a.shape[:2] 85 | 86 | if return_a_bg: 87 | return image, trimap, a, bg, fg 88 | else: 89 | return image, trimap 90 | 91 | # crop image/alpha/trimap with random size of [max_size//2, max_size] then resize to patch_size 92 | def random_patch(image, trimap, patch_size, alpha=None, bg=None, fg=None): 93 | h, w = image.shape[:2] 94 | max_size = max(h, w) 95 | min_size = max_size//2 96 | if isinstance(alpha, np.ndarray) and isinstance(bg, np.ndarray): 97 | patch_a_bg= True 98 | else: 99 | patch_a_bg = False 100 | 101 | count = 0 # debug usage 102 | while True: 103 | sqr_tri = np.zeros((max_size, max_size), np.uint8) 104 | sqr_img = np.zeros((max_size, max_size, 3), np.uint8) 105 | if patch_a_bg: 106 | sqr_alp = np.zeros((max_size, max_size), np.uint8) 107 | sqr_bg = np.zeros((max_size, max_size, 3), np.uint8) 108 | sqr_fg = np.zeros((max_size, max_size, 3), np.uint8) 109 | if h>=w: 110 | sqr_tri[:, (h-w)//2 : (h-w)//2+w] = trimap 111 | sqr_img[:, (h-w)//2 : (h-w)//2+w] = image 112 | if patch_a_bg: 113 | sqr_alp[:, (h-w)//2 : (h-w)//2+w] = alpha 114 | sqr_bg[:, (h-w)//2 : (h-w)//2+w] = bg 115 | sqr_fg[:, (h-w)//2 : (h-w)//2+w] = fg 116 | else: 117 | sqr_tri[(w-h)//2 : (w-h)//2+h, :] = trimap 118 | sqr_img[(w-h)//2 : (w-h)//2+h, :] = image 119 | if patch_a_bg: 120 | sqr_alp[(w-h)//2 : (w-h)//2+h, :] = alpha 121 | sqr_bg[(w-h)//2 : (w-h)//2+h, :] = bg 122 | sqr_fg[(w-h)//2 : (w-h)//2+h, :] = fg 123 | 124 | crop_size = random.randint(min_size, max_size) # both value are inclusive 125 | x = random.randint(0, max_size-crop_size) # 0 is inclusive 126 | y = random.randint(0, max_size-crop_size) 127 | trimap_temp = sqr_tri[y: y+crop_size, x: x+crop_size] 128 | if len(np.where(trimap_temp == 1)[0])>0: # check if unknown area is included 129 | image = sqr_img[y: y+crop_size, x: x+crop_size] 130 | trimap = trimap_temp 131 | if patch_a_bg: 132 | alpha = sqr_alp[y: y+crop_size, x: x+crop_size] 133 | bg = sqr_bg[y: y+crop_size, x: x+crop_size] 134 | fg = sqr_fg[y: y+crop_size, x: x+crop_size] 135 | break 136 | elif len(np.where(trimap==1)[0]) == 0: 137 | print('Warning & Error: No unknown area in current trimap! Refer to saved trimap in folder.') 138 | image = sqr_img 139 | trimap = sqr_tri 140 | if patch_a_bg: 141 | alpha = sqr_alp 142 | bg = sqr_bg 143 | fg = sqr_fg 144 | os.makedirs('ckpt/exceptions', exist_ok=True) 145 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image) 146 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap) 147 | break 148 | elif count>5: 149 | print('Warning: cannot find right patch randomly, use max_size instead! Refer to saved files in folder.') 150 | image = sqr_img 151 | trimap = sqr_tri 152 | if patch_a_bg: 153 | alpha = sqr_alp 154 | bg = sqr_bg 155 | fg = sqr_fg 156 | os.makedirs('ckpt/exceptions', exist_ok=True) 157 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image) 158 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap) 159 | break 160 | 161 | count += 1 # debug usage 162 | 163 | image = cv2.resize(image, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) 164 | trimap = cv2.resize(trimap, (patch_size, patch_size), interpolation=cv2.INTER_NEAREST) 165 | if patch_a_bg: 166 | alpha = cv2.resize(alpha, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) 167 | bg = cv2.resize(bg, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) 168 | fg = cv2.resize(fg, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) 169 | return image, trimap, alpha, bg, fg 170 | else: 171 | return image, trimap -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import os 4 | import random 5 | import numpy as np 6 | import math 7 | 8 | from data import gen_trimap 9 | 10 | import torch 11 | import torch.utils.data as data 12 | 13 | def composite(image, a, bg): 14 | fg = np.array(image, np.float32) 15 | alpha = np.expand_dims(a/255., axis=2).astype(np.float32) 16 | comp = alpha * fg + (1 - alpha) * bg 17 | comp = comp.astype(np.uint8) 18 | 19 | # we need bg for calculating compositional loss in M_net 20 | return comp 21 | 22 | ## crop & resize bg so that img and bg have same size 23 | def resize_bg(bg, h, w): 24 | bh, bw = bg.shape[:2] 25 | wratio, hratio = w/bw, h/bh 26 | ratio = wratio if wratio > hratio else hratio 27 | bg = cv2.resize(bg, (math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv2.INTER_CUBIC) 28 | bh, bw = bg.shape[:2] 29 | assert bh>=h and bw>=w 30 | bg = np.array(bg[(bh - h) // 2: (bh - h) // 2 + h, (bw - w) // 2: (bw - w) // 2 + w], np.float32) 31 | assert h, w == bg.shape[:2] 32 | 33 | return bg # attention: bg is in float32 34 | 35 | # randomly crop images, then resize to patch_size 36 | def random_patch(image, trimap, patch_size, bg=None, a=None): 37 | h, w = image.shape[:2] 38 | max_size = max(h, w) 39 | min_size = max_size // 2 40 | 41 | count=0 42 | while True: 43 | sqr_tri = np.zeros((max_size, max_size), np.uint8) 44 | if isinstance(bg, np.ndarray): 45 | sqr_img = np.zeros((max_size, max_size, 4), np.uint8) 46 | sqr_bga = np.zeros((max_size, max_size, 3), np.uint8) 47 | bga = resize_bg(bg, h, w) 48 | elif isinstance(a, np.ndarray): 49 | sqr_img = np.zeros((max_size, max_size, 3), np.uint8) 50 | sqr_bga = np.zeros((max_size, max_size), np.uint8) 51 | bga = a 52 | else: 53 | raise ValueError('no bg or alpha given from input!') 54 | 55 | if h >= w: 56 | sqr_tri[:, (h - w) // 2: (h - w) // 2 + w] = trimap 57 | sqr_img[:, (h - w) // 2: (h - w) // 2 + w] = image 58 | sqr_bga[:, (h - w) // 2: (h - w) // 2 + w] = bga 59 | else: 60 | sqr_tri[(w - h) // 2: (w - h) // 2 + h, :] = trimap 61 | sqr_img[(w - h) // 2: (w - h) // 2 + h, :] = image 62 | sqr_bga[(w - h) // 2: (w - h) // 2 + h, :] = bga 63 | 64 | crop_size = random.randint(min_size, max_size) # both value are inclusive 65 | x = random.randint(0, max_size - crop_size) # 0 is inclusive 66 | y = random.randint(0, max_size - crop_size) 67 | trimap_temp = sqr_tri[y: y + crop_size, x: x + crop_size] 68 | if len(np.where(trimap_temp == 128)[0]) > 0: # check if unknown area is included 69 | image = sqr_img[y: y + crop_size, x: x + crop_size] 70 | bga = sqr_bga[y: y + crop_size, x: x + crop_size] 71 | break 72 | elif len(np.where(trimap == 128)[0]) == 0: 73 | print('Warning: No unknown area in current trimap! Refer to folder.') 74 | os.makedirs('ckpt/exceptions', exist_ok=True) 75 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image) 76 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap) 77 | elif count > 3: 78 | print('Warning & Error: cannot find right patch randomly, use max_size instead! Refer to folder.') 79 | os.makedirs('ckpt/exceptions', exist_ok=True) 80 | cv2.imwrite('ckpt/exceptions/img_{}_{}.png'.format(str(h), str(w)), image) 81 | cv2.imwrite('ckpt/exceptions/tri_{}_{}.png'.format(str(h), str(w)), trimap) 82 | image = sqr_img 83 | bga = sqr_bga 84 | break 85 | 86 | count += 1 # debug usage 87 | 88 | image = cv2.resize(image, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) 89 | bga = cv2.resize(bga, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) 90 | 91 | return image, bga 92 | 93 | # read/ crop/ resize inputs including fb, bg, alpha, trimap, composited image 94 | # attention: all resize operation should be done before composition 95 | def read_crop_resize(name, patch_size, stage): 96 | if name[2] == 'bg': 97 | img_path, bg_path = name[0].strip(), name[1].strip() 98 | assert os.path.isfile(img_path) and os.path.isfile(bg_path), (img_path, bg_path) 99 | 100 | image = cv2.imread(img_path, -1) # it's RGBA image 101 | fg = image[:,:,:3] 102 | a = image[:,:,3] 103 | bg = cv2.imread(bg_path) 104 | trimap = gen_trimap.rand_trimap(a) 105 | 106 | if stage == 't_net': 107 | image, bg = random_patch(image, trimap, patch_size, bg=bg) 108 | fg = image[:,:,:3] 109 | a = image[:,:,3] 110 | # composite fg and bg, generate trimap 111 | img = composite(fg, a, bg) 112 | trimap = gen_trimap.rand_trimap(a) 113 | 114 | return img, trimap 115 | 116 | elif stage == 'm_net': 117 | fg, a, bg = mask_center_crop(fg, a, bg, trimap, patch_size) 118 | # generate trimap again to avoid alpha resize side-effect on trimap 119 | trimap = gen_trimap.rand_trimap(a) 120 | # composite fg and bg 121 | img = composite(fg, a, bg) 122 | 123 | return img, trimap, a, bg, fg 124 | 125 | elif stage == 'end2end': 126 | # for t_net 127 | t_patch, m_patch = patch_size 128 | image_m, bg_m = random_patch(image, trimap, t_patch, bg=bg) 129 | fg_m = image_m[:,:,:3] 130 | a_m = image_m[:,:,3] 131 | img_m = composite(fg_m, a_m, bg_m) 132 | trimap_m = gen_trimap.rand_trimap(a_m) 133 | 134 | # random flip and rotation before going to m_net 135 | bg_m = bg_m.astype(np.uint8) 136 | img_m, trimap_m, a_m, bg_m, fg_m = random_flip_rotation(img_m, trimap_m, a_m, bg_m, fg_m) 137 | 138 | # for m_net 139 | fg, a, bg = mask_center_crop(fg_m, a_m, bg_m, trimap_m, m_patch) 140 | # generate trimap again to avoid alpha resize side-effect on trimap 141 | trimap = gen_trimap.rand_trimap(a) 142 | # composite fg and bg 143 | img = composite(fg, a, bg) 144 | 145 | return (img_m, trimap_m), (img, trimap, a, bg, fg) 146 | 147 | # note: this type of data is only used for t_net training 148 | elif name[2] == 'msk': 149 | img_path, a_path = name[0].strip(), name[1].strip() 150 | assert os.path.isfile(img_path) and os.path.isfile(a_path) 151 | 152 | image = cv2.imread(img_path) # it's composited image 153 | a = cv2.imread(a_path, 0) # it's grayscale image 154 | a[a > 0] = 255 155 | trimap = gen_trimap.rand_trimap(a) 156 | img, a = random_patch(image, trimap, patch_size, a=a) 157 | 158 | # generate trimap again to avoid alpha resize side-effect on trimap 159 | trimap = gen_trimap.rand_trimap(a) 160 | 161 | return img, trimap 162 | 163 | # crop image into crop_size 164 | def safe_crop(img, x, y, crop_size, resize_patch=None): 165 | if len(img.shape) == 2: 166 | new = np.zeros((crop_size, crop_size), np.uint8) 167 | else: 168 | new = np.zeros((crop_size, crop_size, 3), np.uint8) 169 | cropped = img[y:y+crop_size, x:x+crop_size] # if y+crop_size bigger than len_y, return len_y 170 | h, w = cropped.shape[:2] # h or w falls into the range of [crop_size/2, crop_size] 171 | new[0:h, 0:w] = cropped 172 | 173 | if resize_patch: 174 | new = cv2.resize(new, (resize_patch, resize_patch), interpolation=cv2.INTER_CUBIC) 175 | 176 | return new 177 | 178 | # crop image around trimap unknown area with random size of [patch_size, 2*patch_size] then resize to patch_size 179 | def mask_center_crop(fg, a, bg, trimap, patch_size): 180 | max_size = patch_size * 2 181 | min_size = patch_size 182 | crop_size = random.randint(min_size, max_size) # both value are inclusive 183 | 184 | # get crop center around trimap unknown area 185 | y_idx, x_idx = np.where(trimap == 128) 186 | num_unknowns = len(y_idx) 187 | if num_unknowns > 0: 188 | idx = np.random.choice(range(num_unknowns)) 189 | cx = x_idx[idx] 190 | cy = y_idx[idx] 191 | x = max(0, cx - int(crop_size / 2)) 192 | y = max(0, cy - int(crop_size /2 )) 193 | else: 194 | raise ValueError('no unknown area in trimap!') 195 | 196 | # crop image/trimap/alpha/fg/bg 197 | fg = safe_crop(fg, x, y, crop_size, resize_patch=patch_size) 198 | a = safe_crop(a, x, y, crop_size, resize_patch=patch_size) 199 | bg = safe_crop(bg, x, y, crop_size, resize_patch=patch_size) 200 | 201 | return fg, a, bg 202 | 203 | # randomly flip and rotate images 204 | def random_flip_rotation(image, trimap, alpha=None, bg=None, fg=None): 205 | if isinstance(alpha, np.ndarray) and isinstance(bg, np.ndarray): 206 | a_bg = True 207 | else: 208 | a_bg = False 209 | # horizontal flipping 210 | if random.random() < 0.5: 211 | image = cv2.flip(image, 1) 212 | trimap = cv2.flip(trimap, 1) 213 | if a_bg: 214 | alpha = cv2.flip(alpha, 1) 215 | bg = cv2.flip(bg, 1) 216 | fg = cv2.flip(fg, 1) 217 | 218 | # rotation 219 | if random.random() < 0.05: 220 | degree = random.randint(1,3) # rotate 1/2/3 * 90 degrees 221 | image = np.rot90(image, degree) 222 | trimap = np.rot90(trimap, degree) 223 | if a_bg: 224 | alpha = np.rot90(alpha, degree) 225 | bg = np.rot90(bg, degree) 226 | fg = np.rot90(fg, degree) 227 | 228 | if a_bg: 229 | return image, trimap, alpha, bg, fg 230 | else: 231 | return image, trimap 232 | 233 | def np2Tensor(array): 234 | if len(array.shape)>2: 235 | ts = (2, 0, 1) 236 | tensor = torch.FloatTensor(array.transpose(ts).astype(float)) 237 | else: 238 | tensor = torch.FloatTensor(array.astype(float)) 239 | return tensor 240 | 241 | class human_matting_data(data.Dataset): 242 | """ 243 | human_matting 244 | """ 245 | def __init__(self, args): 246 | super().__init__() 247 | self.data_root = args.dataDir 248 | self.patch_size = args.patch_size 249 | self.phase = args.train_phase 250 | self.dataRatio = args.dataRatio 251 | 252 | self.fg_paths = [] 253 | for file in args.fgLists: 254 | fg_path = os.path.join(self.data_root, file) 255 | assert os.path.isfile(fg_path), "missing file at {}".format(fg_path) 256 | with open(fg_path, 'r') as f: 257 | self.fg_paths.append(f.readlines()) 258 | 259 | bg_path = os.path.join(self.data_root, args.bg_list) 260 | assert os.path.isfile(bg_path), "missing bg file at: ".format(bg_path) 261 | with open(bg_path, 'r') as f: 262 | self.path_bg = f.readlines() 263 | 264 | assert len(self.path_bg) == sum([self.dataRatio[i]*len(self.fg_paths[i]) for i in range(len(self.fg_paths))]), \ 265 | 'the total num of bg is not equal to fg: bg-{}, fg-{}'\ 266 | .format(len(self.path_bg), [self.dataRatio[i]*len(self.fg_paths[i]) for i in range(len(self.fg_paths))]) 267 | self.num = len(self.path_bg) 268 | 269 | #self.shuffle_count = 0 270 | self.shuffle_data() 271 | 272 | print("Dataset : total training images:{}".format(self.num)) 273 | 274 | def __getitem__(self, index): 275 | # data structure returned :: dict {} 276 | # image: c, h, w / range[0-1] , float 277 | # trimap: 1, h, w / [0,1,2] , float 278 | # alpha: 1, h, w / range[0-1] , float 279 | 280 | if self.phase == 'pre_train_t_net': 281 | # read files, random crop and resize 282 | image, trimap = read_crop_resize(self.names[index], self.patch_size, stage='t_net') 283 | # augmentation 284 | image, trimap = random_flip_rotation(image, trimap) 285 | 286 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure 287 | trimap[trimap == 0] = 0 288 | trimap[trimap == 128] = 1 289 | trimap[trimap == 255] = 2 290 | assert image.shape[:2] == trimap.shape[:2] 291 | 292 | # normalize 293 | image = image.astype(np.float32) / 255.0 294 | 295 | # to tensor 296 | image = np2Tensor(image) 297 | trimap = np2Tensor(trimap) 298 | 299 | trimap = trimap.unsqueeze_(0) # shape: 1, h, w 300 | 301 | sample = {'image':image, 'trimap':trimap} 302 | 303 | elif self.phase == 'pre_train_m_net': 304 | # read files 305 | image, trimap, alpha, bg, fg = read_crop_resize(self.names[index], self.patch_size, stage='m_net') 306 | # augmentation 307 | image, trimap, alpha, bg, fg = random_flip_rotation(image, trimap, alpha, bg, fg) 308 | 309 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure 310 | trimap[trimap == 0] = 0 311 | trimap[trimap == 128] = 1 312 | trimap[trimap == 255] = 2 313 | 314 | assert image.shape[:2] == trimap.shape[:2] == alpha.shape[:2] 315 | 316 | # normalize 317 | image = image.astype(np.float32) / 255.0 318 | alpha = alpha.astype(np.float32) / 255.0 319 | bg = bg.astype(np.float32) / 255.0 320 | fg = fg.astype(np.float32) / 255.0 321 | 322 | # trimap one-hot encoding: when pre-train M_net, trimap should have 3 channels 323 | trimap = np.eye(3)[trimap.reshape(-1)].reshape(list(trimap.shape)+[3]) 324 | 325 | # to tensor 326 | image = np2Tensor(image) 327 | trimap = np2Tensor(trimap) 328 | alpha = np2Tensor(alpha) 329 | bg = np2Tensor(bg) 330 | fg = np2Tensor(fg) 331 | 332 | alpha = alpha.unsqueeze_(0) # shape: 1, h, w 333 | 334 | sample = {'image': image, 'trimap': trimap, 'alpha': alpha, 'bg':bg, 'fg':fg} 335 | 336 | elif self.phase == 'end_to_end': 337 | # read files 338 | assert len(self.patch_size) == 2, 'patch_size should have two values for end2end training !' 339 | input_t, input_m = read_crop_resize(self.names[index], self.patch_size, stage='end2end') 340 | img_t, tri_t = input_t 341 | img_m, tri_m, a_m, bg_m, fg_m = input_m 342 | 343 | # NOTE ! ! ! trimap should be 3 classes for classification : fg, bg. unsure 344 | tri_t[tri_t == 0] = 0 345 | tri_t[tri_t == 128] = 1 346 | tri_t[tri_t == 255] = 2 347 | tri_m[tri_m == 0] = 0 348 | tri_m[tri_m == 128] = 1 349 | tri_m[tri_m == 255] = 2 350 | 351 | assert img_t.shape[:2] == tri_t.shape[:2] 352 | assert img_m.shape[:2] == tri_m.shape[:2] == a_m.shape[:2] 353 | 354 | # t_net processing 355 | img_t = img_t.astype(np.float32) / 255.0 356 | img_t = np2Tensor(img_t) 357 | tri_t = np2Tensor(tri_t) 358 | tri_t = tri_t.unsqueeze_(0) # shape: 1, h, w 359 | 360 | # m_net processing 361 | img_m = img_m.astype(np.float32) / 255.0 362 | a_m = a_m.astype(np.float32) / 255.0 363 | bg_m = bg_m.astype(np.float32) / 255.0 364 | fg_m = fg_m.astype(np.float32) / 255.0 365 | 366 | # trimap one-hot encoding: when pre-train M_net, trimap should have 3 channels 367 | tri_m = np.eye(3)[tri_m.reshape(-1)].reshape(list(tri_m.shape) + [3]) 368 | # to tensor 369 | img_m = np2Tensor(img_m) 370 | tri_m = np2Tensor(tri_m) 371 | a_m = np2Tensor(a_m) 372 | bg_m = np2Tensor(bg_m) 373 | fg_m = np2Tensor(fg_m) 374 | 375 | tri_m = tri_m.unsqueeze_(0) # shape: 1, h, w 376 | a_m = a_m.unsqueeze_(0) 377 | 378 | sample = [{'image':img_t, 'trimap':tri_t}, 379 | {'image':img_m, 'trimap':tri_m, 'alpha': a_m, 'bg':bg_m, 'fg':fg_m}] 380 | 381 | return sample 382 | 383 | def shuffle_data(self): 384 | # data structure of self.names:: list 385 | # (.png_img, .bg, 'bg') or (.composite, .mask, 'msk) :: tuple 386 | self.names = [] 387 | 388 | random.shuffle(self.path_bg) 389 | 390 | count = 0 391 | for idx, path_list in enumerate(self.fg_paths): 392 | bg_per_fg = self.dataRatio[idx] 393 | for path in path_list: 394 | for i in range(bg_per_fg): 395 | self.names.append((path, self.path_bg[count], 'bg')) # 'bg' means we need to composite fg & bg 396 | count += 1 397 | 398 | assert count == len(self.path_bg) 399 | 400 | """# debug usage: check shuffled data after each call 401 | with open('shuffled_data_{}.txt'.format(self.shuffle_count),'w') as f: 402 | for name in self.names: 403 | f.write(name[0].strip()+' || '+name[1].strip()+'\n') 404 | self.shuffle_count += 1 405 | """ 406 | 407 | def __len__(self): 408 | return self.num 409 | 410 | -------------------------------------------------------------------------------- /data/gen_trimap.py: -------------------------------------------------------------------------------- 1 | import cv2, os 2 | import numpy as np 3 | import argparse 4 | import tqdm 5 | 6 | """ 7 | def get_args(): 8 | parser = argparse.ArgumentParser(description='Trimap') 9 | parser.add_argument('--mskDir', type=str, required=True, help="masks directory") 10 | parser.add_argument('--saveDir', type=str, required=True, help="where trimap result save to") 11 | parser.add_argument('--list', type=str, required=True, help="list of images id") 12 | parser.add_argument('--size', type=int, required=True, help="kernel size") 13 | args = parser.parse_args() 14 | print(args) 15 | return args 16 | """ 17 | # a simple trimap generation code for fixed kernel size 18 | def erode_dilate(msk, size=(10, 10), smooth=True): 19 | if smooth: 20 | size = (size[0]-4, size[1]-4) 21 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, size) 22 | 23 | dilated = cv2.dilate(msk, kernel, iterations=1) 24 | if smooth: # if it is .jpg, prevent output to be jagged 25 | dilated[(dilated>5)] = 255 26 | dilated[(dilated <= 5)] = 0 27 | else: 28 | dilated[(dilated>0)] = 255 29 | 30 | eroded = cv2.erode(msk, kernel, iterations=1) 31 | if smooth: 32 | eroded[(eroded<250)] = 0 33 | eroded[(eroded >= 250)] = 255 34 | else: 35 | eroded[(eroded < 255)] = 0 36 | 37 | res = dilated.copy() 38 | res[((dilated == 255) & (eroded == 0))] = 128 39 | 40 | """# make sure there are only 3 values in trimap 41 | cnt0 = len(np.where(res >= 0)[0]) 42 | cnt1 = len(np.where(res == 0)[0]) 43 | cnt2 = len(np.where(res == 128)[0]) 44 | cnt3 = len(np.where(res == 255)[0]) 45 | assert cnt0 == cnt1 + cnt2 + cnt3 46 | """ 47 | 48 | return res 49 | 50 | # trimap generation with different/random kernel size 51 | def rand_trimap(msk, smooth=False): 52 | h, w = msk.shape 53 | scale_up, scale_down = 0.022, 0.006 # hyper parameter 54 | dmin = 0 # hyper parameter 55 | emax = 255 - dmin # hyper parameter 56 | 57 | # .jpg (or low quality .png) tend to be jagged, smoothing tricks need to be applied 58 | if smooth: 59 | # give thrshold for dilation and erode results 60 | scale_up, scale_down = 0.02, 0.006 61 | dmin = 5 62 | emax = 255 - dmin 63 | 64 | # apply gussian smooth 65 | if h<1000: 66 | gau_ker = round(h*0.01) # we restrict the kernel size to 5-9 67 | gau_ker = gau_ker if gau_ker % 2 ==1 else gau_ker-1 # make sure it's odd 68 | if h<500: 69 | gau_ker = max(3, gau_ker) 70 | msk = cv2.GaussianBlur(msk, (gau_ker, gau_ker), 0) 71 | 72 | kernel_size_high = max(10, round((h + w) / 2 * scale_up)) 73 | kernel_size_low = max(1, round((h + w) /2 * scale_down)) 74 | erode_kernel_size = np.random.randint(kernel_size_low, kernel_size_high) 75 | dilate_kernel_size = np.random.randint(kernel_size_low, kernel_size_high) 76 | 77 | erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (erode_kernel_size, erode_kernel_size)) 78 | dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (dilate_kernel_size, dilate_kernel_size)) 79 | eroded_alpha = cv2.erode(msk, erode_kernel) 80 | dilated_alpha = cv2.dilate(msk, dilate_kernel) 81 | 82 | dilated_alpha = np.where(dilated_alpha > dmin, 255, 0) 83 | eroded_alpha = np.where(eroded_alpha < emax, 0, 255) 84 | 85 | res = dilated_alpha.copy() 86 | res[((dilated_alpha == 255) & (eroded_alpha == 0))] = 128 87 | 88 | return res 89 | 90 | 91 | def get_trimap(msk, smooth=True): 92 | h, w = msk.shape[:2] 93 | scale_up, scale_down = 0.022, 0.008 # hyper parameter 94 | dmin = 0 # hyper parameter 95 | emax = 255 - dmin # hyper parameter 96 | 97 | # .jpg (or low quality .png) tend to be jagged, smoothing tricks need to be applied 98 | if smooth: 99 | scale_up, scale_down = 0.02, 0.006 100 | dmin = 5 101 | emax = 255 - dmin 102 | 103 | kernel_size_high = max(10, round(h * scale_up)) 104 | kernel_size_low = max(1, round(h * scale_down)) 105 | kernel_size = (kernel_size_high + kernel_size_low)//2 106 | 107 | print('kernel size:', kernel_size) 108 | 109 | erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)) 110 | dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)) 111 | eroded_alpha = cv2.erode(msk, erode_kernel) 112 | dilated_alpha = cv2.dilate(msk, dilate_kernel) 113 | 114 | dilated_alpha = np.where(dilated_alpha > dmin, 255, 0) 115 | eroded_alpha = np.where(eroded_alpha < emax, 0, 255) 116 | 117 | res = dilated_alpha.copy() 118 | res[((dilated_alpha == 255) & (eroded_alpha == 0))] = 128 119 | 120 | return res 121 | 122 | 123 | def main(): 124 | kernel_size = 18 125 | alphaDir = 'alpha' # where alpha matting images are 126 | trimapDir = 'trimap' 127 | names = os.listdir('image') 128 | if names == []: 129 | raise ValueError('No images are in the dir: ./image') 130 | print("Images Count: {}".format(len(names))) 131 | 132 | for name in tqdm.tqdm(names): 133 | alpha_path = alphaDir + "/" + name 134 | trimap_path = trimapDir + "/" + name.strip()[:-4] + ".png" # output must be .png format 135 | alpha = cv2.imread(alpha_path, 0) 136 | if name[-3:] != 'png': 137 | trimap = erode_dilate(alpha, size=(kernel_size, kernel_size), smooth=True) 138 | else: 139 | trimap = erode_dilate(alpha, size=(kernel_size,kernel_size)) 140 | 141 | #print("Write to {}".format(trimap_name)) 142 | cv2.imwrite(trimap_path, trimap) 143 | 144 | if __name__ == "__main__": 145 | main() 146 | 147 | 148 | -------------------------------------------------------------------------------- /model/M_Net.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class M_net(nn.Module): 7 | ''' 8 | encoder + decoder 9 | ''' 10 | 11 | def __init__(self, classes=2): 12 | 13 | super(M_net, self).__init__() 14 | # ----------------------------------------------------------------- 15 | # encoder 16 | # --------------------- 17 | # stage-1 18 | self.conv_1_1 = nn.Sequential(nn.Conv2d(6, 64, 3, 1, 1, bias=True), nn.BatchNorm2d(64), nn.ReLU()) 19 | self.conv_1_2 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1, bias=True), nn.BatchNorm2d(64), nn.ReLU()) 20 | self.max_pooling_1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 21 | 22 | # stage-2 23 | self.conv_2_1 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1, bias=True), nn.BatchNorm2d(128), nn.ReLU()) 24 | self.conv_2_2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1, bias=True), nn.BatchNorm2d(128), nn.ReLU()) 25 | self.max_pooling_2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 26 | 27 | # stage-3 28 | self.conv_3_1 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU()) 29 | self.conv_3_2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU()) 30 | self.conv_3_3 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1, bias=True), nn.BatchNorm2d(256), nn.ReLU()) 31 | self.max_pooling_3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 32 | 33 | # stage-4 34 | self.conv_4_1 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 35 | self.conv_4_2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 36 | self.conv_4_3 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 37 | self.max_pooling_4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 38 | 39 | # stage-5 40 | self.conv_5_1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 41 | self.conv_5_2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 42 | self.conv_5_3 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 43 | 44 | # ----------------------------------------------------------------- 45 | # decoder 46 | # --------------------- 47 | # stage-5 48 | self.deconv_5 = nn.Sequential(nn.Conv2d(512, 512, 5, 1, 2, bias=True), nn.BatchNorm2d(512), nn.ReLU()) 49 | 50 | # stage-4 51 | self.up_pool_4 = nn.MaxUnpool2d(2, stride=2) 52 | self.deconv_4 = nn.Sequential(nn.Conv2d(512, 256, 5, 1, 2, bias=True), nn.BatchNorm2d(256), nn.ReLU()) 53 | 54 | # stage-3 55 | self.up_pool_3 = nn.MaxUnpool2d(2, stride=2) 56 | self.deconv_3 = nn.Sequential(nn.Conv2d(256, 128, 5, 1, 2, bias=True), nn.BatchNorm2d(128), nn.ReLU()) 57 | 58 | # stage-2 59 | self.up_pool_2 = nn.MaxUnpool2d(2, stride=2) 60 | self.deconv_2 = nn.Sequential(nn.Conv2d(128, 64, 5, 1, 2, bias=True), nn.BatchNorm2d(64), nn.ReLU()) 61 | 62 | # stage-1 63 | self.up_pool_1 = nn.MaxUnpool2d(2, stride=2) 64 | self.deconv_1 = nn.Sequential(nn.Conv2d(64, 64, 5, 1, 2, bias=True), nn.BatchNorm2d(64), nn.ReLU()) 65 | 66 | # stage-0 67 | self.conv_0 = nn.Conv2d(64, 1, 5, 1, 2, bias=True) 68 | 69 | 70 | def forward(self, input): 71 | 72 | # ---------------- 73 | # encoder 74 | # -------- 75 | x11 = self.conv_1_1(input) 76 | x12 = self.conv_1_2(x11) 77 | x1p, id1 = self.max_pooling_1(x12) 78 | 79 | x21 = self.conv_2_1(x1p) 80 | x22 = self.conv_2_2(x21) 81 | x2p, id2 = self.max_pooling_2(x22) 82 | 83 | x31 = self.conv_3_1(x2p) 84 | x32 = self.conv_3_2(x31) 85 | x33 = self.conv_3_3(x32) 86 | x3p, id3 = self.max_pooling_3(x33) 87 | 88 | x41 = self.conv_4_1(x3p) 89 | x42 = self.conv_4_2(x41) 90 | x43 = self.conv_4_3(x42) 91 | x4p, id4 = self.max_pooling_4(x43) 92 | 93 | x51 = self.conv_5_1(x4p) 94 | x52 = self.conv_5_2(x51) 95 | x53 = self.conv_5_3(x52) 96 | # ---------------- 97 | # decoder 98 | # -------- 99 | x5d = self.deconv_5(x53) 100 | 101 | x4u = self.up_pool_4(x5d, id4) 102 | x4d = self.deconv_4(x4u) 103 | 104 | x3u = self.up_pool_3(x4d, id3) 105 | x3d = self.deconv_3(x3u) 106 | 107 | x2u = self.up_pool_2(x3d, id2) 108 | x2d = self.deconv_2(x2u) 109 | 110 | x1u = self.up_pool_1(x2d, id1) 111 | x1d = self.deconv_1(x1u) 112 | 113 | # raw alpha pred 114 | raw_alpha = self.conv_0(x1d) 115 | 116 | return raw_alpha 117 | 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /model/T_Net_psp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from model import extractors 6 | 7 | 8 | class PSPModule(nn.Module): 9 | def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): 10 | super().__init__() 11 | self.stages = [] 12 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 13 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 14 | self.relu = nn.ReLU() 15 | 16 | def _make_stage(self, features, size): 17 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 18 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 19 | return nn.Sequential(prior, conv) 20 | 21 | def forward(self, feats): 22 | h, w = feats.size(2), feats.size(3) 23 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] 24 | bottle = self.bottleneck(torch.cat(priors, 1)) 25 | return self.relu(bottle) 26 | 27 | 28 | class PSPUpsample(nn.Module): 29 | def __init__(self, in_channels, out_channels): 30 | super().__init__() 31 | self.conv = nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 33 | nn.BatchNorm2d(out_channels), 34 | nn.PReLU() 35 | ) 36 | 37 | def forward(self, x): 38 | h, w = 2 * x.size(2), 2 * x.size(3) 39 | p = F.interpolate(input=x, size=(h, w), mode='bilinear', align_corners=True) 40 | return self.conv(p) 41 | 42 | class PSPNet(nn.Module): 43 | def __init__(self, n_classes=3, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50', 44 | pretrained=True): 45 | super().__init__() 46 | self.feats = getattr(extractors, backend)(pretrained) 47 | self.psp = PSPModule(psp_size, 1024, sizes) 48 | self.drop_1 = nn.Dropout2d(p=0.3) 49 | 50 | self.up_1 = PSPUpsample(1024, 256) 51 | self.up_2 = PSPUpsample(256, 64) 52 | self.up_3 = PSPUpsample(64, 64) 53 | 54 | self.drop_2 = nn.Dropout2d(p=0.15) 55 | 56 | """ 57 | self.final = nn.Sequential( 58 | nn.Conv2d(64, n_classes, kernel_size=1), 59 | nn.LogSoftmax() 60 | ) 61 | """ 62 | self.final = nn.Sequential(nn.Conv2d(64, n_classes, kernel_size=1)) 63 | 64 | self.classifier = nn.Sequential( 65 | nn.Linear(deep_features_size, 256), 66 | nn.ReLU(), 67 | nn.Linear(256, n_classes) 68 | ) 69 | 70 | def forward(self, x): 71 | f, class_f = self.feats(x) 72 | p = self.psp(f) 73 | p = self.drop_1(p) 74 | 75 | p = self.up_1(p) 76 | p = self.drop_2(p) 77 | 78 | p = self.up_2(p) 79 | p = self.drop_2(p) 80 | 81 | p = self.up_3(p) 82 | p = self.drop_2(p) 83 | 84 | #auxiliary = F.adaptive_max_pool2d(input=class_f, output_size=(1, 1)).view(-1, class_f.size(1)) 85 | 86 | return self.final(p) 87 | 88 | -------------------------------------------------------------------------------- /model/extractors.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils import model_zoo 8 | from torchvision.models.densenet import densenet121, densenet161 9 | from torchvision.models.squeezenet import squeezenet1_1 10 | 11 | 12 | def load_weights_sequential(target, source_state): 13 | """ 14 | new_dict = OrderedDict() 15 | for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()): 16 | new_dict[k1] = v2 17 | """ 18 | model_to_load = {k: v for k, v in source_state.items() if k in target.state_dict().keys()} 19 | target.load_state_dict(model_to_load) 20 | 21 | ''' 22 | Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x 23 | ''' 24 | model_urls = { 25 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 26 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 27 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 28 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 29 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 35 | padding=dilation, dilation=dilation, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(planes) 77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 78 | padding=dilation, bias=False) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 81 | self.bn3 = nn.BatchNorm2d(planes * 4) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | def __init__(self, block, layers=(3, 4, 23, 3)): 111 | self.inplanes = 64 112 | super(ResNet, self).__init__() 113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 114 | bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(block, 64, layers[0]) 119 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 120 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 121 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.inplanes, planes * block.expansion, 136 | kernel_size=1, stride=stride, bias=False), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [block(self.inplanes, planes, stride, downsample)] 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, dilation=dilation)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x_3 = self.layer3(x) 156 | x = self.layer4(x_3) 157 | 158 | return x, x_3 159 | 160 | 161 | ''' 162 | Implementation of DenseNet with deep supervision. Downsampling is changed to 8x 163 | ''' 164 | 165 | 166 | class _DenseLayer(nn.Sequential): 167 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 168 | super(_DenseLayer, self).__init__() 169 | self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), 170 | self.add_module('relu.1', nn.ReLU(inplace=True)), 171 | self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * 172 | growth_rate, kernel_size=1, stride=1, bias=False)), 173 | self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), 174 | self.add_module('relu.2', nn.ReLU(inplace=True)), 175 | self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, 176 | kernel_size=3, stride=1, padding=1, bias=False)), 177 | self.drop_rate = drop_rate 178 | 179 | def forward(self, x): 180 | new_features = super(_DenseLayer, self).forward(x) 181 | if self.drop_rate > 0: 182 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 183 | return torch.cat([x, new_features], 1) 184 | 185 | 186 | class _DenseBlock(nn.Sequential): 187 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 188 | super(_DenseBlock, self).__init__() 189 | for i in range(num_layers): 190 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 191 | self.add_module('denselayer%d' % (i + 1), layer) 192 | 193 | 194 | class _Transition(nn.Sequential): 195 | def __init__(self, num_input_features, num_output_features, downsample=True): 196 | super(_Transition, self).__init__() 197 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 198 | self.add_module('relu', nn.ReLU(inplace=True)) 199 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 200 | kernel_size=1, stride=1, bias=False)) 201 | if downsample: 202 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 203 | else: 204 | self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1)) # compatibility hack 205 | 206 | 207 | class DenseNet(nn.Module): 208 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 209 | num_init_features=64, bn_size=4, drop_rate=0, pretrained=True): 210 | 211 | super(DenseNet, self).__init__() 212 | 213 | # First convolution 214 | self.start_features = nn.Sequential(OrderedDict([ 215 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 216 | ('norm0', nn.BatchNorm2d(num_init_features)), 217 | ('relu0', nn.ReLU(inplace=True)), 218 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 219 | ])) 220 | 221 | # Each denseblock 222 | num_features = num_init_features 223 | 224 | init_weights = list(densenet121(pretrained=True).features.children()) 225 | start = 0 226 | for i, c in enumerate(self.start_features.children()): 227 | if pretrained: 228 | c.load_state_dict(init_weights[i].state_dict()) 229 | start += 1 230 | self.blocks = nn.ModuleList() 231 | for i, num_layers in enumerate(block_config): 232 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 233 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 234 | if pretrained: 235 | block.load_state_dict(init_weights[start].state_dict()) 236 | start += 1 237 | self.blocks.append(block) 238 | setattr(self, 'denseblock%d' % (i + 1), block) 239 | 240 | num_features = num_features + num_layers * growth_rate 241 | if i != len(block_config) - 1: 242 | downsample = i < 1 243 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, 244 | downsample=downsample) 245 | if pretrained: 246 | trans.load_state_dict(init_weights[start].state_dict()) 247 | start += 1 248 | self.blocks.append(trans) 249 | setattr(self, 'transition%d' % (i + 1), trans) 250 | num_features = num_features // 2 251 | 252 | def forward(self, x): 253 | out = self.start_features(x) 254 | deep_features = None 255 | for i, block in enumerate(self.blocks): 256 | out = block(out) 257 | if i == 5: 258 | deep_features = out 259 | 260 | return out, deep_features 261 | 262 | 263 | class Fire(nn.Module): 264 | 265 | def __init__(self, inplanes, squeeze_planes, 266 | expand1x1_planes, expand3x3_planes, dilation=1): 267 | super(Fire, self).__init__() 268 | self.inplanes = inplanes 269 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 270 | self.squeeze_activation = nn.ReLU(inplace=True) 271 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 272 | kernel_size=1) 273 | self.expand1x1_activation = nn.ReLU(inplace=True) 274 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 275 | kernel_size=3, padding=dilation, dilation=dilation) 276 | self.expand3x3_activation = nn.ReLU(inplace=True) 277 | 278 | def forward(self, x): 279 | x = self.squeeze_activation(self.squeeze(x)) 280 | return torch.cat([ 281 | self.expand1x1_activation(self.expand1x1(x)), 282 | self.expand3x3_activation(self.expand3x3(x)) 283 | ], 1) 284 | 285 | 286 | class SqueezeNet(nn.Module): 287 | 288 | def __init__(self, pretrained=False): 289 | super(SqueezeNet, self).__init__() 290 | 291 | self.feat_1 = nn.Sequential( 292 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 293 | nn.ReLU(inplace=True) 294 | ) 295 | self.feat_2 = nn.Sequential( 296 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 297 | Fire(64, 16, 64, 64), 298 | Fire(128, 16, 64, 64) 299 | ) 300 | self.feat_3 = nn.Sequential( 301 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 302 | Fire(128, 32, 128, 128, 2), 303 | Fire(256, 32, 128, 128, 2) 304 | ) 305 | self.feat_4 = nn.Sequential( 306 | Fire(256, 48, 192, 192, 4), 307 | Fire(384, 48, 192, 192, 4), 308 | Fire(384, 64, 256, 256, 4), 309 | Fire(512, 64, 256, 256, 4) 310 | ) 311 | if pretrained: 312 | weights = squeezenet1_1(pretrained=True).features.state_dict() 313 | load_weights_sequential(self, weights) 314 | 315 | def forward(self, x): 316 | f1 = self.feat_1(x) 317 | f2 = self.feat_2(f1) 318 | f3 = self.feat_3(f2) 319 | f4 = self.feat_4(f3) 320 | return f4, f3 321 | 322 | 323 | ''' 324 | Handy methods for construction 325 | ''' 326 | 327 | 328 | def squeezenet(pretrained=True): 329 | return SqueezeNet(pretrained) 330 | 331 | 332 | def densenet(pretrained=True): 333 | return DenseNet(pretrained=pretrained) 334 | 335 | 336 | def resnet18(pretrained=True): 337 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 338 | if pretrained: 339 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18'])) 340 | return model 341 | 342 | 343 | def resnet34(pretrained=True): 344 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 345 | if pretrained: 346 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34'])) 347 | return model 348 | 349 | 350 | def resnet50(pretrained=True): 351 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 352 | if pretrained: 353 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50'])) 354 | return model 355 | 356 | 357 | def resnet101(pretrained=True): 358 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 359 | if pretrained: 360 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101'])) 361 | return model 362 | 363 | 364 | def resnet152(pretrained=True): 365 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 366 | if pretrained: 367 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152'])) 368 | return model 369 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from model.M_Net import M_net 7 | from model.T_Net_psp import PSPNet 8 | 9 | class net_T(nn.Module): 10 | # Train T_net 11 | def __init__(self): 12 | 13 | super(net_T, self).__init__() 14 | 15 | self.t_net = PSPNet() 16 | 17 | def forward(self, input): 18 | 19 | # trimap 20 | trimap = self.t_net(input) 21 | return trimap 22 | 23 | class net_M(nn.Module): 24 | ''' 25 | train M_net 26 | ''' 27 | 28 | def __init__(self): 29 | 30 | super(net_M, self).__init__() 31 | self.m_net = M_net() 32 | 33 | def forward(self, input, trimap): 34 | 35 | # paper: bs, fs, us 36 | bg, fg, unsure = torch.split(trimap, 1, dim=1) 37 | 38 | # concat input and trimap 39 | m_net_input = torch.cat((input, trimap), 1) 40 | 41 | # matting 42 | alpha_r = self.m_net(m_net_input) 43 | # fusion module 44 | # paper : alpha_p = fs + us * alpha_r 45 | alpha_p = fg + unsure * alpha_r 46 | 47 | return alpha_p 48 | 49 | class net_F(nn.Module): 50 | ''' 51 | end to end net 52 | ''' 53 | 54 | def __init__(self): 55 | 56 | super(net_F, self).__init__() 57 | 58 | self.t_net = PSPNet() 59 | self.m_net = M_net() 60 | 61 | 62 | 63 | def forward(self, input): 64 | 65 | # trimap 66 | trimap = self.t_net(input) 67 | trimap_softmax = F.softmax(trimap, dim=1) 68 | 69 | # paper: bs, fs, us 70 | bg, unsure, fg = torch.split(trimap_softmax, 1, dim=1) 71 | 72 | # concat input and trimap 73 | m_net_input = torch.cat((input, trimap_softmax), 1) 74 | 75 | # matting 76 | alpha_r = self.m_net(m_net_input) 77 | # fusion module 78 | # paper : alpha_p = fs + us * alpha_r 79 | alpha_p = fg + unsure * alpha_r 80 | 81 | return trimap, alpha_p 82 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import cv2 4 | import torch 5 | import argparse 6 | import numpy as np 7 | import os 8 | import torch.nn.functional as F 9 | 10 | parser = argparse.ArgumentParser(description='human matting') 11 | parser.add_argument('--model', default='./ckpt', help='root dir of preTrained model') 12 | parser.add_argument('--size', type=int, default=400, help='input size') 13 | parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu') 14 | parser.add_argument('--train_phase', default='end_to_end',help='which phase of the model') 15 | 16 | args = parser.parse_args() 17 | 18 | torch.set_grad_enabled(False) 19 | 20 | 21 | ################################# 22 | #---------------- 23 | if args.without_gpu: 24 | print("use CPU !") 25 | device = torch.device('cpu') 26 | else: 27 | if torch.cuda.is_available(): 28 | n_gpu = torch.cuda.device_count() 29 | print("use GPU") 30 | device = torch.device('cuda') 31 | 32 | ################################# 33 | #--------------- 34 | def load_model(args): 35 | model_path = os.path.join(args.model, args.train_phase, 'model/model_obj.pth') 36 | assert os.path.isfile(model_path), 'Wrong model path: {}'.format(model_path) 37 | print('Loading model from {}...'.format(model_path)) 38 | 39 | if args.without_gpu: 40 | myModel = torch.load(model_path, map_location=lambda storage, loc: storage) 41 | else: 42 | myModel = torch.load(model_path) 43 | 44 | myModel.eval() 45 | myModel.to(device) 46 | 47 | return myModel 48 | 49 | def seg_process(args, inputs, net): 50 | 51 | if args.train_phase == 'pre_train_t_net': 52 | trimap = net(inputs) 53 | trimap = torch.argmax(trimap[0], dim=0) 54 | trimap[trimap == 0] = 0 55 | trimap[trimap == 1] = 128 56 | trimap[trimap == 2] = 255 57 | 58 | if args.without_gpu: 59 | trimap_np = trimap.data.numpy() 60 | else: 61 | trimap_np = trimap.cpu().data.numpy() 62 | 63 | trimap_np = trimap_np.astype(np.uint8) 64 | return trimap_np 65 | 66 | elif args.train_phase == 'pre_train_m_net': 67 | alpha = net(inputs[0], inputs[1]) 68 | 69 | if args.without_gpu: 70 | alpha = alpha.data.numpy() 71 | else: 72 | alpha = alpha.cpu().data.numpy() 73 | 74 | alpha = alpha[0][0] * 255.0 75 | alpha = alpha.astype(np.uint8) 76 | 77 | return alpha 78 | 79 | else: 80 | alpha = net(inputs) 81 | 82 | if args.without_gpu: 83 | alpha = alpha.data.numpy() 84 | else: 85 | alpha = alpha.cpu().data.numpy() 86 | 87 | alpha = alpha[0][0] * 255.0 88 | alpha = alpha.astype(np.uint8) 89 | return alpha 90 | 91 | def test(args, net): 92 | 93 | t0 = time.time() 94 | out_dir = 'result/' + args.train_phase + '/' 95 | os.makedirs(out_dir, exist_ok=True) 96 | 97 | # get a frame 98 | imgList = os.listdir('result/test') 99 | if imgList==[]: 100 | raise ValueError('Empty dir at: ./result/test') 101 | for imgname in imgList: 102 | img = cv2.imread('result/test/'+imgname) 103 | 104 | if args.train_phase == 'pre_train_t_net': 105 | img = img / 255.0 106 | 107 | tensor_img = torch.from_numpy(img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2) 108 | 109 | tensor_img = tensor_img.to(device) 110 | 111 | frame_seg = seg_process(args, tensor_img, net) 112 | 113 | elif args.train_phase == 'pre_train_m_net': 114 | # No resize 115 | h, w, _ = img.shape 116 | img = img / 255.0 117 | 118 | tri_path = 'result/trimap/'+imgname 119 | assert os.path.isfile(tri_path), 'wrong trimap path: {}'.format(tri_path) 120 | 121 | trimap_src = cv2.imread(tri_path, 0) 122 | 123 | trimap = trimap_src.copy() 124 | trimap[trimap == 0] = 0 125 | trimap[trimap == 128] = 1 126 | trimap[trimap == 255] = 2 127 | trimap = np.eye(3)[trimap.reshape(-1)].reshape(list(trimap.shape) + [3]) 128 | 129 | tensor_img = torch.from_numpy(img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2) 130 | tensor_tri = torch.from_numpy(trimap.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2) 131 | tensor_img = tensor_img.to(device) 132 | tensor_tri = tensor_tri.to(device) 133 | 134 | frame_seg = seg_process(args, (tensor_img, tensor_tri), net, trimap=trimap_src) 135 | 136 | else: 137 | img = img / 255.0 138 | tensor_img = torch.from_numpy(img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2) 139 | tensor_img = tensor_img.to(device) 140 | frame_seg = seg_process(args, tensor_img, net) 141 | 142 | # show a frame 143 | cv2.imwrite(out_dir+imgname, frame_seg) 144 | 145 | print('Average time cost: {:.0f} s/image'.format((time.time() - t0) / len(imgList))) 146 | print('output images were saved at: ', out_dir) 147 | 148 | def main(args): 149 | 150 | myModel = load_model(args) 151 | test(args, myModel) 152 | 153 | 154 | if __name__ == "__main__": 155 | main(args) 156 | 157 | 158 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | import time 8 | import os 9 | import cv2 10 | import numpy as np 11 | from data import dataset 12 | from model import network 13 | import torch.nn.functional as F 14 | 15 | 16 | def get_args(): 17 | # Training settings 18 | parser = argparse.ArgumentParser(description='Semantic Human Matting !') 19 | parser.add_argument('--dataDir', default='./data/', help='dataset directory') 20 | parser.add_argument('--fgLists', type=list, default=[], required=True, help="training fore-ground images lists") 21 | parser.add_argument('--bg_list', type=str, required=True, help='train back-ground images list, one file') 22 | parser.add_argument('--dataRatio', type=list, default=[], required=True, help="train bg:fg raio, eg. [100]") 23 | parser.add_argument('--saveDir', default='./ckpt', help='model save dir') 24 | parser.add_argument('--trainData', default='human_matting_data', help='train dataset name') 25 | 26 | parser.add_argument('--continue_train', action='store_true', default=False, help='continue training the training') 27 | parser.add_argument('--pretrain', action='store_true', help='load pretrained model from t_net & m_net ') 28 | parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu') 29 | 30 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads for data loading') 31 | parser.add_argument('--train_batch', type=int, default=8, help='input batch size for train') 32 | parser.add_argument('--patch_size', type=int, default=400, help='patch size for train') 33 | 34 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate') 35 | parser.add_argument('--lrDecay', type=int, default=100) 36 | parser.add_argument('--lrdecayType', default='keep') 37 | parser.add_argument('--nEpochs', type=int, default=300, help='number of epochs to train') 38 | parser.add_argument('--save_epoch', type=int, default=5, help='number of epochs to save model') 39 | parser.add_argument('--print_iter', type=int, default=1000, help='pring loss and save image') 40 | 41 | parser.add_argument('--train_phase', default= 'end_to_end', help='train phase') 42 | parser.add_argument('--debug', action='store_true', default=False, help='debug mode') 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | def torch2numpy(tensor, without_gpu=False): 48 | if without_gpu: 49 | array = tensor.data.numpy() 50 | else: 51 | array = tensor.cpu().data.numpy() 52 | 53 | array = array * 255.0 54 | return array.astype(np.uint8) 55 | 56 | def save_img(args, all_img, epoch, i=0): 57 | img_dir = os.path.join(args.saveDir, args.train_phase, 'save_img') 58 | if not os.path.isdir(img_dir): 59 | os.mkdir(img_dir) 60 | 61 | if args.train_phase == 'pre_train_t_net': 62 | img, trimap_pre, trimap_gt = all_img 63 | trimap_fake = torch.argmax(trimap_pre[0], dim=0) 64 | trimap_cat = torch.cat((trimap_gt[0][0], trimap_fake.float()), dim=-1) # horizontal concate 65 | trimap_cat = torch.stack((trimap_cat,) * 3, dim=0) 66 | img_cat = torch.cat((img[0] * 255.0, trimap_cat * 127.5), dim=-1) 67 | if args.without_gpu: 68 | img_cat = img_cat.data.numpy() 69 | else: 70 | img_cat = img_cat.cpu().data.numpy() 71 | cv2.imwrite(img_dir + '/trimap_{}_{}.png'.format(str(epoch), str(i)), 72 | img_cat.transpose((1, 2, 0)).astype(np.uint8)) 73 | 74 | if args.train_phase == 'pre_train_m_net': 75 | img, alpha_pre, alpha_gt = all_img 76 | img = img[0] * 255.0 77 | alpha_pre = torch.stack((alpha_pre[0][0] * 255.0,) * 3, dim=0) 78 | alpha_gt = torch.stack((alpha_gt[0][0] * 255.0,) * 3, dim=0) 79 | img_cat = torch.cat((img, alpha_gt, alpha_pre), dim=-1) 80 | 81 | if args.without_gpu: 82 | img_cat = img_cat.data.numpy() 83 | else: 84 | img_cat = img_cat.cpu().data.numpy() 85 | cv2.imwrite(img_dir + '/alpha_{}_{}.png'.format(str(epoch), str(i)), 86 | img_cat.transpose((1, 2, 0)).astype(np.uint8)) 87 | 88 | if args.train_phase == 'end_to_end': 89 | img, trimap_gt, alpha_gt, trimap_pre, alpha_pre = all_img 90 | img = img[0] * 255.0 91 | trimap_pre = torch.argmax(trimap_pre[0], dim=0) 92 | trimap_pre = torch.stack((trimap_pre.float(),)*3, dim=0) * 127.5 93 | trimap_gt = torch.stack((trimap_gt[0][0],)*3, dim=0) * 127.5 94 | alpha_pre = torch.stack((alpha_pre[0][0] * 255.0,) * 3, dim=0) 95 | alpha_gt = torch.stack((alpha_gt[0][0] * 255.0,) * 3, dim=0) 96 | 97 | img_cat = torch.cat((img, trimap_gt, trimap_pre, alpha_gt, alpha_pre), dim=-1) 98 | #alpha_cat = torch.cat((alpha_gt, alpha_pre), dim=-1) 99 | if args.without_gpu: 100 | img_cat = img_cat.data.numpy() 101 | else: 102 | img_cat = img_cat.cpu().data.numpy() 103 | 104 | img_cat = img_cat.transpose((1,2,0)).astype(np.uint8) 105 | 106 | # save image 107 | cv2.imwrite(img_dir + '/e2e_{}_{}.png'.format(str(epoch), str(i)),img_cat) 108 | 109 | def set_lr(args, epoch, optimizer): 110 | 111 | lrDecay = args.lrDecay 112 | decayType = args.lrdecayType 113 | if decayType == 'keep': 114 | lr = args.lr 115 | elif decayType == 'step': 116 | epoch_iter = (epoch + 1) // lrDecay 117 | lr = args.lr / 2**epoch_iter 118 | elif decayType == 'exp': 119 | k = math.log(2) / lrDecay 120 | lr = args.lr * math.exp(-k * epoch) 121 | elif decayType == 'poly': 122 | lr = args.lr * math.pow((1 - epoch / args.nEpochs), 0.9) 123 | 124 | for param_group in optimizer.param_groups: 125 | param_group['lr'] = lr 126 | 127 | return lr 128 | 129 | 130 | 131 | class Train_Log(): 132 | def __init__(self, args): 133 | self.args = args 134 | 135 | self.save_dir = os.path.join(args.saveDir, args.train_phase) 136 | if not os.path.exists(self.save_dir): 137 | os.makedirs(self.save_dir) 138 | 139 | self.save_dir_model = os.path.join(self.save_dir, 'model') 140 | if not os.path.exists(self.save_dir_model): 141 | os.makedirs(self.save_dir_model) 142 | 143 | if os.path.exists(self.save_dir + '/log.txt'): 144 | self.logFile = open(self.save_dir + '/log.txt', 'a') 145 | else: 146 | self.logFile = open(self.save_dir + '/log.txt', 'w') 147 | 148 | # in case pretrained weights need to be loaded 149 | if self.args.pretrain: 150 | self.t_path = os.path.join(args.saveDir, 'pre_train_t_net', 'model', 'ckpt_lastest.pth') 151 | self.m_path = os.path.join(args.saveDir, 'pre_train_m_net', 'model', 'ckpt_lastest.pth') 152 | assert os.path.isfile(self.t_path) and os.path.isfile(self.m_path), \ 153 | 'Wrong dir for pretrained models:\n{},{}'.format(self.t_path, self.m_path) 154 | 155 | 156 | def save_model(self, model, epoch, save_as=False): 157 | if save_as: # for args.save_epoch 158 | lastest_out_path = "{}/ckpt_{}.pth".format(self.save_dir_model, epoch) 159 | model_out_path = "{}/model_obj.pth".format(self.save_dir_model) 160 | torch.save( 161 | model, 162 | model_out_path) 163 | else: # for regular save 164 | lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model) 165 | 166 | torch.save({ 167 | 'epoch': epoch, 168 | 'state_dict': model.state_dict(), 169 | }, lastest_out_path) 170 | 171 | def load_pretrain(self, model): 172 | t_ckpt = torch.load(self.t_path) 173 | model.load_state_dict(t_ckpt['state_dict'], strict=False) 174 | m_ckpt = torch.load(self.m_path) 175 | model.load_state_dict(m_ckpt['state_dict'], strict=False) 176 | print('=> loaded pretrained t_net & m_net pretrained models !') 177 | 178 | return model 179 | 180 | def load_model(self, model): 181 | lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model) 182 | ckpt = torch.load(lastest_out_path) 183 | start_epoch = ckpt['epoch'] + 1 184 | model.load_state_dict(ckpt['state_dict']) 185 | print("=> loaded checkpoint '{}' (epoch {})".format(lastest_out_path, ckpt['epoch'])) 186 | 187 | return start_epoch, model 188 | 189 | def save_log(self, log): 190 | self.logFile.write(log + '\n') 191 | 192 | # initialise conv2d weights 193 | def weight_init(m): 194 | if isinstance(m, nn.Conv2d): 195 | torch.nn.init.xavier_normal_(m.weight.data) 196 | #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 197 | #m.weight.data.normal_(0, math.sqrt(2. / n)) 198 | elif isinstance(m, nn.BatchNorm2d): 199 | m.weight.data.fill_(1) 200 | m.bias.data.zero_() 201 | 202 | def loss_f_T(trimap_pre, trimap_gt): 203 | criterion = nn.CrossEntropyLoss() 204 | L_t = criterion(trimap_pre, trimap_gt[:, 0, :, :].long()) 205 | 206 | return L_t 207 | 208 | 209 | def loss_f_M(img, alpha_pre, alpha_gt, bg, fg, trimap): 210 | # ------------------------------------- 211 | # prediction loss L_p 212 | # ------------------------ 213 | eps = 1e-6 214 | # l_alpha 215 | L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean() 216 | 217 | comp_pred = alpha_pre * fg + (1 - alpha_pre) * bg 218 | 219 | # be careful about here: if img's range is [0,1] then eps should divede 255 220 | L_composition = torch.sqrt(torch.pow(img - comp_pred, 2.) + eps).mean() 221 | 222 | L_p = 0.5 * L_alpha + 0.5 * L_composition 223 | 224 | return L_p, L_alpha, L_composition 225 | 226 | def loss_function(img, trimap_pre, trimap_gt, alpha_pre, alpha_gt, bg, fg): 227 | 228 | # ------------------------------------- 229 | # classification loss L_t 230 | # ------------------------ 231 | criterion = nn.CrossEntropyLoss() 232 | L_t = criterion(trimap_pre, trimap_gt[:,0,:,:].long()) 233 | 234 | # ------------------------------------- 235 | # prediction loss L_p 236 | # ------------------------ 237 | eps = 1e-6 238 | # l_alpha 239 | L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean() 240 | 241 | comp_pred = alpha_pre * fg + (1 - alpha_pre) * bg 242 | 243 | # be careful about here: if img's range is [0,1] then eps should divede 255 244 | L_composition = torch.sqrt(torch.pow(img - comp_pred, 2.) + eps).mean() 245 | L_p = 0.5 * L_alpha + 0.5 * L_composition 246 | 247 | # train_phase 248 | loss = L_p + 0.01*L_t 249 | 250 | return loss, L_alpha, L_composition, L_t 251 | 252 | 253 | def main(): 254 | args = get_args() 255 | 256 | if args.without_gpu: 257 | print("use CPU !") 258 | device = torch.device('cpu') 259 | else: 260 | if torch.cuda.is_available(): 261 | device = torch.device('cuda') 262 | else: 263 | print("No GPU is is available !") 264 | 265 | print("============> Building model ...") 266 | if args.train_phase == 'pre_train_t_net': 267 | model = network.net_T() 268 | elif args.train_phase == 'pre_train_m_net': 269 | model = network.net_M() 270 | model.apply(weight_init) 271 | elif args.train_phase == 'end_to_end': 272 | model = network.net_F() 273 | if args.pretrain: 274 | model = Train_Log.load_pretrain(model) 275 | else: 276 | raise ValueError('Wrong train phase request!') 277 | train_data = dataset.human_matting_data(args) 278 | model.to(device) 279 | 280 | # debug setting 281 | save_latest_freq = int(len(train_data)//args.train_batch*0.55) 282 | if args.debug: 283 | args.save_epoch = 1 284 | args.train_batch = 1 # defualt debug: 1 285 | args.nEpochs = 1 286 | args.print_iter = 1 287 | save_latest_freq = 10 288 | 289 | print(args) 290 | print("============> Loading datasets ...") 291 | 292 | trainloader = DataLoader(train_data, 293 | batch_size=args.train_batch, 294 | drop_last=True, 295 | shuffle=True, 296 | num_workers=args.nThreads, 297 | pin_memory=True) 298 | 299 | print("============> Set optimizer ...") 300 | lr = args.lr 301 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), \ 302 | lr=lr, betas=(0.9, 0.999), 303 | weight_decay=0.0005) 304 | 305 | print("============> Start Train ! ...") 306 | start_epoch = 1 307 | trainlog = Train_Log(args) 308 | if args.continue_train: 309 | start_epoch, model = trainlog.load_model(model) 310 | 311 | model.train() 312 | for epoch in range(start_epoch, args.nEpochs+1): 313 | 314 | loss_ = 0 315 | L_alpha_ = 0 316 | L_composition_ = 0 317 | L_cross_ = 0 318 | if args.lrdecayType != 'keep': 319 | lr = set_lr(args, epoch, optimizer) 320 | 321 | t0 = time.time() 322 | for i, sample_batched in enumerate(trainloader): 323 | 324 | optimizer.zero_grad() 325 | 326 | if args.train_phase == 'pre_train_t_net': 327 | img, trimap_gt = sample_batched['image'], sample_batched['trimap'] 328 | img, trimap_gt = img.to(device), trimap_gt.to(device) 329 | 330 | trimap_pre = model(img) 331 | if args.debug: #debug only 332 | assert tuple(trimap_pre.shape) == (args.train_batch, 3, args.patch_size, args.patch_size) 333 | assert tuple(trimap_gt.shape) == (args.train_batch, 1, args.patch_size, args.patch_size) 334 | 335 | loss = loss_f_T(trimap_pre, trimap_gt) 336 | 337 | loss_ += loss.item() 338 | 339 | if i!=0 and i % args.print_iter == 0: 340 | save_img(args, (img, trimap_pre, trimap_gt), epoch, i) 341 | print("[epoch:{} iter:{}] \tloss: {:.5f}".format(epoch, i, loss)) 342 | if i!=0 and i % save_latest_freq == 0: 343 | print("average loss: {:.5f}\nsaving model ....".format(loss_ / (i+1))) 344 | trainlog.save_model(model, epoch) 345 | 346 | elif args.train_phase == 'pre_train_m_net': 347 | img, trimap_gt, alpha_gt, bg, fg = sample_batched['image'], sample_batched['trimap'], sample_batched['alpha'], sample_batched['bg'], sample_batched['fg'] 348 | img, trimap_gt, alpha_gt, bg, fg = img.to(device), trimap_gt.to(device), alpha_gt.to(device), bg.to(device), fg.to(device) 349 | 350 | alpha_pre = model(img, trimap_gt) 351 | if args.debug: 352 | assert tuple(alpha_pre.shape) == (args.train_batch, 1, args.patch_size, args.patch_size) 353 | img_dir = os.path.join(args.saveDir, args.train_phase, 'save_img') 354 | img_fg_bg = np.concatenate((torch2numpy(trimap_gt[0]), torch2numpy(fg[0]), torch2numpy(bg[0])), axis=-1) 355 | img_fg_bg = np.transpose(img_fg_bg, (1,2,0)) 356 | cv2.imwrite(img_dir + '/fgbg_{}_{}.png'.format(str(epoch), str(i)), img_fg_bg) 357 | 358 | loss, L_alpha, L_composition = loss_f_M(img, alpha_pre, alpha_gt, bg, fg, trimap_gt) 359 | 360 | loss_ += loss.item() 361 | L_alpha_ += L_alpha.item() 362 | L_composition_ += L_composition.item() 363 | 364 | if i!=0 and i % args.print_iter == 0: 365 | save_img(args, (img, alpha_pre, alpha_gt), epoch, i) 366 | print("[epoch:{} iter:{}] loss: {:.5f} loss_a: {:.5f} loss_c: {:.5f}"\ 367 | .format(epoch, i, loss, L_alpha, L_composition)) 368 | if i!=0 and i % save_latest_freq == 0: 369 | print("average loss: {:.5f}\nsaving model ....".format(loss_ / (i + 1))) 370 | trainlog.save_model(model, epoch) 371 | 372 | elif args.train_phase == 'end_to_end': 373 | img, trimap_gt, alpha_gt, bg, fg = sample_batched['image'], sample_batched['trimap'], sample_batched['alpha'], sample_batched['bg'], sample_batched['fg'] 374 | img, trimap_gt, alpha_gt, bg, fg = img.to(device), trimap_gt.to(device), alpha_gt.to(device), bg.to(device), fg.to(device) 375 | 376 | trimap_pre, alpha_pre = model(img) 377 | loss, L_alpha, L_composition, L_cross = loss_function(img, trimap_pre, trimap_gt, alpha_pre, alpha_gt, bg) 378 | 379 | loss_ += loss.item() 380 | L_alpha_ += L_alpha.item() 381 | L_composition_ += L_composition.item() 382 | L_cross_ += L_cross.item() 383 | 384 | loss.backward() 385 | optimizer.step() 386 | 387 | 388 | # shuffle data after each epoch to recreate the dataset 389 | print('epoch end, shuffle datasets again ...') 390 | train_data.shuffle_data() 391 | #trainloader.dataset.shuffle_data() 392 | 393 | t1 = time.time() 394 | 395 | if args.train_phase == 'pre_train_t_net': 396 | loss_ = loss_ / (i+1) 397 | log = "[{} / {}] \tloss: {:.5f}\ttime: {:.0f}".format(epoch, args.nEpochs, loss_, t1-t0) 398 | 399 | elif args.train_phase == 'pre_train_m_net': 400 | loss_ = loss_ / (i + 1) 401 | L_alpha_ = L_alpha_ / (i + 1) 402 | L_composition_ = L_composition_ / (i + 1) 403 | log = "[{} / {}] loss: {:.5f} loss_a: {:.5f} loss_c: {:.5f} time: {:.0f}"\ 404 | .format(epoch, args.nEpochs, 405 | loss_, 406 | L_alpha_, 407 | L_composition_, 408 | t1 - t0) 409 | 410 | elif args.train_phase == 'end_to_end': 411 | loss_ = loss_ / (i+1) 412 | L_alpha_ = L_alpha_ / (i+1) 413 | L_composition_ = L_composition_ / (i+1) 414 | L_cross_ = L_cross_ / (i+1) 415 | 416 | log = "[{} / {}] loss: {:.5f} loss_a: {:.5f} loss_c: {:.5f} loss_t: {:.5f} time: {:.0f}"\ 417 | .format(epoch, args.nEpochs, 418 | loss_, 419 | L_alpha_, 420 | L_composition_, 421 | L_cross_, 422 | t1 - t0) 423 | print(log) 424 | trainlog.save_log(log) 425 | trainlog.save_model(model, epoch) 426 | 427 | if epoch % args.save_epoch == 0: 428 | trainlog.save_model(model, epoch, save_as=True) 429 | 430 | 431 | 432 | if __name__ == "__main__": 433 | main() 434 | --------------------------------------------------------------------------------