├── .gitignore ├── README.md ├── data ├── AdobeData.py ├── __init__.py └── py_adobe_data.py ├── docs ├── exampl1.jpg ├── example2.jpg └── single_alpah_prediction_loss.jpg ├── models ├── __init__.py ├── encoder_decoder.py ├── matting_loss.py ├── pretrain_keys_pair.py ├── py_encoder_decoder.py └── py_loss.py ├── train_encoder_decoder.py └── utils ├── __init__.py ├── learning_vis_test.py └── visulization.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.log 3 | *.txt 4 | *.png 5 | data/adobe_data 6 | checkpoints 7 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #### This is an implementation for paper [Deep Image Matting](http://arxiv.org/abs/1703.03872) 2 | - Deep image matting is a learning method to estimate the alpha matting params for fg,bg,merged and trimap 3 | - 该项目基于pytorch实现,主要的数据,模型,损失函数,训练策略介绍如下: 4 | -- [data/py_adobe_data.py](https://github.com/hudengjunai/DeepImageMatting/blob/master/data/py_adobe_data.py) .the online fg/bg alpha merge data,compose COCO 2014 train and Matting Datasets 5 | -- [models/py_encoder_decoder.py](https://github.com/hudengjunai/DeepImageMatting/blob/master/models/py_encoder_decoder.py). The model define ,vgg encoder and unpooling/conv decoder. 6 | -- [train_encoder_decoder.py](https://github.com/hudengjunai/DeepImageMatting/blob/master/train_encoder_decoder.py). The train stage define, encoder-decoder/refine-head/over-all,totally three stages. 7 | -- [utils/visulization.py](https://github.com/hudengjunai/DeepImageMatting/blob/master/utils/visulization.py). The loss and image vis module 8 | 9 | 10 | #### 项目简介 11 | -- 数据集,数据集使用在线合成的方法。具体存放路径如下所示: 12 | 需要修改 [data/py_adobe_data.py](https://github.com/hudengjunai/DeepImageMatting/blob/master/data/py_adobe_data.py) 中数据位置 13 | 1.将CombineDataset的前景数据和背景数据文件夹拷贝到一起。 14 | ```buildoutcfg 15 | self.a_path = './data/adobe_data/trainval/alpha' #alpha 存放路径,将Train数据的alpha和Others数据的alpha收集一起存放这里 16 | self.fg_path = './data/adobe_data/trainval/fg/' #同上, 存储前景数据,共439张 17 | self.bg_path = '/data/jh/notebooks/hehao/datasets/coco/train2014/' #the coco path 指向coco数据集地址 18 | 19 | ``` 20 | ### 单模块测试 21 | ##### 数据模块 22 | 数据模块的功能性测试,可单独测试 23 | ```bash 24 | DeepImageMatting$ python data/py_adobe_data.py 25 | ``` 26 | ```angularjs 27 | data[:,:,0:3] = image # firset three rgb channel ,经过归一化的数据sub-mean,div-std 28 | data[:,:,3] = torch.tensor(trimap) # last channel is trimap,value(0,128,255) 255为前景 29 | 30 | label[:,:,0:3] = torch.tensor(bg) #前景,(0,255)之间取值 31 | label[:,:,3:6] = torch.tensor(fg) #背景,(0,255)之间 32 | label[:,:,6:9] = torch.tensor(merged) #合成图(0,255),该值和data[0:3]一样,只不过没经过归一化,为便于计算loss缘故 33 | label[:,:,9:10] = torch.tensor(alpha.reshape(self.size,self.size,1)) #ground_truth alpha matting值 34 | label[:,:,10] = torch.tensor(mask) #unknown region区域的掩码 35 | 36 | ``` 37 | ##### 模型模块 38 | 模型部分包括encoder-decoder,encoder-decoder-refinehead 残差结构,都可以使用以下语句测试 39 | ```bash 40 | DeepImageMatting$ python models/py_encoder_decoder.py #测试模型计算 41 | 论文种对于Unpooling 和Deconv介绍不是很清晰,针对SegNet网络种的介绍,可以在vgg maxpooling时保留Max的位置索引,在上采样时进行赋值。 42 | 上采样有若干种方法:双线性插值(可学习),双线性插值(不可学习),转置卷积(反卷积,stride大于1),reverse maxpooling等方式,当前实现种采用反向maxpooling的方式 43 | 44 | ``` 45 | ##### 模型与训练参数加载 46 | matting网络的encoder部分采用的vgg_bn16的参数,需要将对应参数灌入当前模型,具体参见models/py_encoder_decoder.py 中load_vggbn 函数 47 | 48 | ##### 损失函数 49 | 损失函数分为两种,参考论文中提出的alpha-prediction loss 和 Compositional loss 计算方式可参见内部loss类型 50 | ```bash 51 | DeepImageMatting$ python models/py_loss.py #测试损失函数 52 | 对于论文中提出的损失函数有两个: 53 | - alpha 损失,预测的matting参数和groud_truth alpha之间的差值平方均值 mse。 具体参见 AlphaPredLoss 54 | - compose损失,使用matting预测参数合成的图片和原本真实合成图片的mse差值 。具体参见 ComposeLoss 55 | ``` 56 | 57 | ##### 训练模块 58 | 所有的三个阶段的训练模块都集中在自定义类Trainer中,分为初始化:设置数据模型损失函数训练策略,训练,验证主要的模块 59 | ```bash 60 | DeepImageMatting$ python train_encoder_decoder.py --params 61 | 其中stage为阶段参数: 62 | - 第一阶段:训练encoder-decoder 类似于SegNet结构 63 | - 第二阶段:训练refine_head,一个alpha细化模块 。 64 | - 第三阶段:整体进行训练,整体结构类似一个残差块。 65 | ``` 66 | 67 | ### 训练过程记录: 68 | #### Doing list 需要实验的任务 69 | - 在第一阶段encoder-decoder训练时,因为alpha prediction loss和compositional loss 的数量级不一样,采用文中作者提出的两种loss加起来, 70 | 会导致无法收敛,一直震荡的状态,通过观察梯度发现梯度很小,根据论文中实验部分描述,采用alpha-prediction loss来train,会收敛。 71 | ![avatar](./docs/single_alpah_prediction_loss.jpg) 72 | ![avatar](./docs/exampl1.jpg) 73 | ![avatar](./docs/example2.jpg) 74 | 75 | - 训练过程中可视化 前景背景,合成图和alpha数值 76 | 77 | - 需要尝试双线性插值和unpooling对训练的影响 78 | [medium 博客,unpooling和deconv](https://towardsdatascience.com/review-deconvnet-unpooling-layer-semantic-segmentation-55cf8a6e380e) 79 | [unpooling介绍](https://jinzequn.github.io/2018/01/28/deconv-and-unpool/) 80 | ```python 81 | unpooling deconv 的使用在SegNet等网络中 82 | ``` 83 | - 需要训练encoder-decoder中间衔接的trans模块,卷积核大小,当前采用一个3x3和1x1(具体参见py_encoder_decoder.py的trans部分) 84 | ```python 85 | class transMap(nn.Module): 86 | def __init__(self): 87 | super(transMap,self).__init__() 88 | self.conv1 = conv2DBatchNormRelu(512,512,3,1,1) 89 | self.conv2 = conv2DBatchNormRelu(512,512,1,1,0) 90 | 91 | def forward(self, x): 92 | return self.conv2(self.conv1(x)) 93 | ``` 94 | - encoder的输入为合成图RGB和trimap通道,RGB通道按照之前totensor和normalize操作, 95 | trimap在当前数据集为(0,128,255)三个取值,其数值量级不一致,需要对trimap采用归一化的方法,当前采用 除以255,减去0.5均值,除以0.5的方式 96 | - 需要尝试使用skip-connection之后的区别 97 | 结合U-Net和SegNet等分割网络中跳层连接的形式,结合encoder浅层的语义,提升matting效果 98 | - 需要结合最新的BiSeNet等分割网络的技巧来提升matting效果 99 | 100 | #### Note 101 | - 如果有疑问的话,可以联系 hudengjunai@gmail.com 102 | 103 | -------------------------------------------------------------------------------- /data/AdobeData.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet.gluon.data import Dataset,DataLoader 3 | from mxnet.image import imread 4 | from PIL import Image 5 | import os 6 | import numpy as np 7 | import cv2 8 | import math 9 | from mxnet import nd 10 | import mxnet.gluon.data.vision.transforms as T 11 | 12 | default_transform = T.Compose(T.ToTensor(),T.Normalize(mean=(),std=())) 13 | 14 | class AdobeDataset(Dataset): 15 | """the adobe dataset to for load the train and test dataset 16 | get item will return the bg_img,trimap,alpha,""" 17 | 18 | 19 | def __init__(self,usage,size=320,transform = default_transform): 20 | self.kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (3, 3)) 21 | self.transform = transform 22 | self.size = size 23 | self.usage = usage 24 | filename = '{}_names.txt'.format(usage) 25 | with open(filename,'r') as f: 26 | self.names = f.read().splitlines() 27 | np.random.shuffle(self.names) 28 | if self.usage in ['train','valid']: 29 | self.a_path = './data/adobe_train/alpha/' 30 | self.fg_path = './data/adobe_train/fg' 31 | self.bg_path = './data/adboe_train/bg' 32 | fg_names = 'training_fg_names.txt' 33 | bg_names = 'training_bg_names.txt' 34 | with open(fg_names,'r') as f: 35 | self.fg_files = f.read().splitlines() 36 | with open(bg_names,'r') as f: 37 | self.bg_files = f.read().splitlines() 38 | 39 | elif self.usage=='test': 40 | self.a_path=" " 41 | pass 42 | 43 | def __len__(self): 44 | return len(self.names) 45 | 46 | 47 | def process(self,im_name,bg_name): 48 | im = cv2.imread(os.path.join(self.fg_path,im_name)) 49 | a = cv2.imread(os.path.join(self.a_path,im_name)) 50 | h,w = im.shape[:2] 51 | bg = cv2.imread(os.path.join(self.bg_path,bg_name)) 52 | bh,bw = bg.shape[:2] 53 | wratio = w/bw 54 | hratio = h/bh 55 | ratio = max(wratio,hratio) 56 | if ratio>1 57 | bg = cv.resize(src=bg, 58 | dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), 59 | interpolation=cv.INTER_CUBIC) 60 | return self.compose(im,bg,a,w,h) 61 | 62 | def compose(self,fg,bg,a,w,h): 63 | fg = np.array(fg,np.float32) 64 | bg_h,bg_w = bg.shape[:2] 65 | 66 | x =0 67 | if bg_w>w: 68 | x = np.random.randint(0,bg_w-w) 69 | y=0 70 | if bg_h>h 71 | y = np.random.randint(0,bg_h-h) 72 | bg = np.array(bg[y:y+h,x:x+w],np.float32) 73 | 74 | #gernerate alpah and merged image 75 | alpha = np.zeros((h,w,1),np.float32) 76 | alpha[:,:,0] = a/255.0 77 | im = alpha*fg + (1-alpha)*bg 78 | im = im.astype(np.uint8) 79 | 80 | #generate trimap 81 | fg_tr = np.array(np.equal(a,255).astype(np.float32)) 82 | un_tr = np.array(np.not_equal(a,0).astype(np.float32)) 83 | un_tr = cv2.dilate(un_tr,self.kernel, 84 | iterations = np.random.randint(1,20)) 85 | trimap = fg_tr*255+(un_tr-fg)*128 86 | 87 | 88 | 89 | return im,alpha,fg,bg,trimap 90 | 91 | 92 | def __getitem__(self, item): 93 | """get the x and y 94 | x is the [merged[0:3],trimap[3] ] , 95 | y is the [bg[0:3],fg[3:6],mask[6],alpha[7] ]""" 96 | name = self.names[item] 97 | fcount,bcount = [int(x) for x in name.split('.')[0].split('_')] 98 | im_name = self.fg_files[fcount] 99 | bg_name = self.bg_files[bcount] 100 | 101 | merged,alpha,fg,bg,trimap = self.process(im_name,bg_name) 102 | x = nd.empty((self.size,self.size,4),dtype=np.float32) 103 | y = nd.empty((self.size,self.size,7),dtype=np.float32) 104 | 105 | if self.transform: 106 | merged = self.transform(nd.array(merged)) 107 | 108 | 109 | x[:,:,0:3] = nd.array(merged) 110 | x[:,:,-1] = nd.array(trimap) 111 | 112 | y[:,:,0:3] = nd.array(bg) 113 | y[:,:,3:6] = nd.array(fg) 114 | y[:,:,-1] = nd.array(alpha) 115 | return x,y 116 | 117 | 118 | if __name__=='__main__': 119 | "this is the test for dataset" 120 | train_dataset = AdobeDataset(usage='train') 121 | for i,(x,y) in enumerate(train_dataset): 122 | print(x.shape) 123 | print(y.shape) 124 | if i ==3: 125 | break 126 | 127 | 128 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .py_adobe_data import get_train_val_dataloader 2 | -------------------------------------------------------------------------------- /data/py_adobe_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | from torch.utils.data.dataloader import DataLoader 3 | import cv2 4 | import PIL 5 | from PIL import Image 6 | import math 7 | import torchvision.transforms as T 8 | import numpy as np 9 | import os 10 | import torch 11 | import random 12 | import matplotlib.pyplot as plt 13 | mean=(0.485, 0.456, 0.406) 14 | std=[0.229, 0.224, 0.225] 15 | default_transform = T.Compose([T.ToTensor(), 16 | T.Normalize(mean=mean,std=std)]) 17 | 18 | 19 | 20 | 21 | class AdobeDataset(Dataset): 22 | 23 | def __init__(self,usage,size=320,transform=default_transform,stand = True): 24 | super(AdobeDataset,self).__init__() 25 | self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) 26 | self.transform = transform 27 | self.size = size 28 | self.usage = usage 29 | self.stand = stand 30 | filename = './data/adobe_data/{}_names.txt'.format(usage) # just store the image index for save 31 | with open(filename, 'r') as f: 32 | self.names = f.read().splitlines() 33 | np.random.shuffle(self.names) 34 | if self.usage in ['train', 'valid']: 35 | self.a_path = './data/adobe_data/trainval/alpha' 36 | self.fg_path = './data/adobe_data/trainval/fg/' 37 | self.bg_path = '/data/jh/notebooks/hehao/datasets/coco/train2014/' #the coco path 38 | fg_names = './data/adobe_data/training_fg_names.txt' # the file name all the foreground files 39 | bg_names = './data/adobe_data/training_bg_names.txt' # the file name all the background file names 40 | with open(fg_names, 'r') as f: 41 | self.fg_files = f.read().splitlines() 42 | with open(bg_names, 'r') as f: 43 | self.bg_files = f.read().splitlines() 44 | self.unknown_code = 128 45 | 46 | def __len__(self): 47 | return len(self.names) 48 | 49 | def process(self,im_name,bg_name): 50 | fg = cv2.imread(os.path.join(self.fg_path, im_name)) 51 | a = cv2.imread(os.path.join(self.a_path, im_name),0) 52 | bg = cv2.imread(os.path.join(self.bg_path, bg_name)) 53 | fg = cv2.cvtColor(fg,cv2.COLOR_BGR2RGB) 54 | bg = cv2.cvtColor(bg,cv2.COLOR_BGR2RGB) 55 | 56 | h, w = fg.shape[:2] 57 | bh, bw = bg.shape[:2] 58 | wratio = w / bw 59 | hratio = h / bh 60 | ratio = max(wratio, hratio) 61 | if ratio > 1: # need to enlarge the bg image 62 | bg = cv2.resize(src=bg, 63 | dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), 64 | interpolation=cv2.INTER_CUBIC) 65 | return self.compose(fg, bg, a, w, h) 66 | 67 | def compose(self, fg, bg, a, w, h): 68 | fg = np.array(fg, np.float32) 69 | bg_h, bg_w = bg.shape[:2] 70 | 71 | x = 0 72 | if bg_w > w: 73 | x = np.random.randint(0, bg_w - w) 74 | y = 0 75 | if bg_h > h: 76 | y = np.random.randint(0, bg_h - h) 77 | bg = np.array(bg[y:y + h, x:x + w], np.float32) 78 | 79 | # gernerate alpah and merged image 80 | alpha = np.zeros((h, w, 1), np.float32) 81 | alpha[:, :, 0] = a / 255.0 82 | im = alpha * fg + (1 - alpha) * bg 83 | 84 | 85 | # generate trimap 86 | fg_tr = np.array(np.equal(a, 255).astype(np.float32)) 87 | un_tr = np.array(np.not_equal(a, 0).astype(np.float32)) 88 | un_tr = cv2.dilate(un_tr, self.kernel, 89 | iterations=np.random.randint(1, 20)) 90 | trimap = fg_tr * 255 + (un_tr - fg_tr) * 128 91 | 92 | return im, alpha, fg, bg, trimap # in channel BGR 93 | 94 | 95 | def __getitem__(self, item): 96 | """get the x and y 97 | x is the [merged[0:3],trimap[3] ] , 98 | y is the [bg[0:3],fg[3:6],mask[6],alpha[7] ]""" 99 | name = self.names[item] 100 | fcount,bcount = [int(x) for x in name.split('.')[0].split('_')] 101 | im_name = self.fg_files[fcount] 102 | bg_name = self.bg_files[bcount] 103 | 104 | merged,alpha,fg,bg,trimap = self.process(im_name,bg_name) #all is float32 type and RGB channels last and 255 max 105 | 106 | 107 | data = torch.empty(size=(self.size,self.size,4),dtype=torch.float32) 108 | 109 | 110 | #safe crop and resize 111 | # Flip array left to right randomly (prob=1:1) 112 | if np.random.random_sample() > 0.5: 113 | merged = np.fliplr(merged) 114 | alpha = np.fliplr(alpha) 115 | fg = np.fliplr(fg) 116 | bg = np.fliplr(bg) 117 | trimap = np.fliplr(trimap) 118 | 119 | #to generate the clip contains the trimap unknow region 120 | different_sizes= [(320, 320), (480, 480), (640, 640)] 121 | scale_crop = random.choice(different_sizes) 122 | x,y = self.random_choice(trimap,scale_crop) 123 | 124 | merged = self.safe_crop(merged,x,y,crop_size=scale_crop,fixed=(self.size,self.size)) 125 | alpha = self.safe_crop(alpha,x,y,crop_size=scale_crop,fixed=(self.size,self.size)) 126 | fg = self.safe_crop(fg,x,y,crop_size=scale_crop,fixed=(self.size,self.size)) 127 | bg = self.safe_crop(bg,x,y,crop_size=scale_crop,fixed=(self.size,self.size)) 128 | trimap = self.safe_crop(trimap,x,y,crop_size=scale_crop,fixed=(self.size,self.size)) 129 | mask = np.equal(trimap, 128).astype(np.float32) 130 | 131 | image =torch.tensor(merged).div(255) 132 | image.sub_(torch.tensor(mean)).div_(torch.tensor(std)) 133 | 134 | data[:,:,0:3] = image # firset three rgb channel 135 | data[:,:,3] = torch.tensor(trimap).div(255).sub_(0.5).div_(0.5) # last channel is trimap 136 | 137 | if self.stand: 138 | label = torch.empty(size=(self.size, self.size, 11), dtype=torch.float32) 139 | label[:,:,0:3] = torch.tensor(bg) 140 | label[:,:,3:6] = torch.tensor(fg) 141 | label[:,:,6:9] = torch.tensor(merged) 142 | label[:,:,9:10] = torch.tensor(alpha.reshape(self.size,self.size,1)) 143 | label[:,:,10] = torch.tensor(mask) 144 | 145 | data = data.transpose(0, 1).transpose(0, 2).contiguous() 146 | label = label.transpose(0, 1).transpose(0, 2).contiguous() 147 | return data,label 148 | else: 149 | bg = torch.tensor(bg.transpose(2,0,1)) 150 | fg = torch.tensor(fg.transpose(2,0,1)) 151 | merged = torch.tensor(merged.transpose(2,0,1)) 152 | alpha = torch.tensor(alpha.reshape(1,self.size,self.size)) 153 | mask = torch.tensor(mask.reshape(1,self.size,self.size)) 154 | return data,bg,fg,merged,alpha,mask 155 | 156 | 157 | def random_choice(self,trimap, crop_size): 158 | crop_height, crop_width = crop_size 159 | y_indices, x_indices = np.where(trimap == self.unknown_code) 160 | num_unknowns = len(y_indices) 161 | tri_h,tri_w = trimap.shape[:2] 162 | x, y = 0, 0 163 | if num_unknowns > 0: 164 | ix = np.random.choice(range(num_unknowns)) 165 | center_x = x_indices[ix] 166 | center_y = y_indices[ix] #sampled center_x and center_y,so the crop region have unknown region 167 | x = max(0, center_x - int(crop_width / 2)) 168 | y = max(0, center_y - int(crop_height / 2)) 169 | x = min(x,tri_w-crop_width) 170 | y = min(y,tri_h-crop_height) #prerequest condition,the tri_w,tri_h must bigger than crop_size 171 | return x, y 172 | 173 | def safe_crop(self,mat, x, y, crop_size,fixed): 174 | crop_height, crop_width = crop_size 175 | if len(mat.shape) == 2 : 176 | ret = np.zeros((crop_height, crop_width), np.float32) 177 | else: 178 | channels = mat.shape[2] 179 | ret = np.zeros((crop_height, crop_width, channels), np.float32) 180 | crop = mat[y:y + crop_height, x:x + crop_width] 181 | h, w = crop.shape[:2] 182 | ret[0:h, 0:w] = crop 183 | if crop_size != fixed: 184 | ret = cv2.resize(ret, dsize=fixed, interpolation=cv2.INTER_NEAREST) 185 | return ret 186 | 187 | def get_train_val_dataloader(batch_size,num_workers,stand=True): 188 | train_dataset = AdobeDataset(usage='train',stand=stand) 189 | val_dataset = AdobeDataset(usage='valid',stand=stand) 190 | train_loader = DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers) 191 | valid_loader = DataLoader(val_dataset,batch_size=batch_size,num_workers=num_workers) 192 | return train_loader,valid_loader 193 | 194 | if __name__=='__main__': 195 | import time 196 | """ this is the dataset test 197 | run in DeepMatting_MXNet root directory path""" 198 | train_dataset = AdobeDataset(usage='train') 199 | for i,(x,y) in enumerate(train_dataset): 200 | print(x.shape) 201 | print(y.shape) 202 | if i==3: 203 | break 204 | valid_dataset = AdobeDataset(usage='valid') 205 | for i,(x,y) in enumerate(valid_dataset): 206 | print(x.shape) 207 | print(y.shape) 208 | if i==3: 209 | break 210 | print("test valid dataset finished") 211 | catagory = [True,False] 212 | for std in catagory: 213 | train_loader,valid_loader = get_train_val_dataloader(4,10,stand=std) 214 | start = time.time() 215 | for i,data in enumerate(train_loader): 216 | if i==100: 217 | break 218 | duration = time.time()-start 219 | print("time used duration {0} for {1} mode".format(duration,str(std))) 220 | -------------------------------------------------------------------------------- /docs/exampl1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hudengjunai/Deep-Image-Matting/9dea68d2620dcd7010c90e0577c262d19121e4da/docs/exampl1.jpg -------------------------------------------------------------------------------- /docs/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hudengjunai/Deep-Image-Matting/9dea68d2620dcd7010c90e0577c262d19121e4da/docs/example2.jpg -------------------------------------------------------------------------------- /docs/single_alpah_prediction_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hudengjunai/Deep-Image-Matting/9dea68d2620dcd7010c90e0577c262d19121e4da/docs/single_alpah_prediction_loss.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .py_encoder_decoder import Encoder_Decoder 2 | from .py_loss import ComposeLoss,AlphaPredLoss 3 | from .pretrain_keys_pair import pairs_keys -------------------------------------------------------------------------------- /models/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import nn 2 | from mxnet import autograd 3 | from mxnet import nd 4 | from mxnet.initializer import Xavier,Zero 5 | from gluoncv.model_zoo.model_store import get_model_file 6 | from mxnet import ndarray 7 | vgg_spec = [(2,2,3,3,3),(64,128,256,512,512)] 8 | 9 | 10 | class Matting_Encoder(nn.HybridBlock): 11 | def __init__(self,spec=vgg_spec): 12 | super(Matting_Encoder,self).__init__() 13 | layers,channels= spec 14 | features = nn.HybridSequential(prefix='') 15 | for i, num in enumerate(layers): 16 | for _ in range(num): 17 | features.add(nn.Conv2D(channels=channels[i], kernel_size=3, padding=1, strides=1, 18 | weight_initializer=Zero(), 19 | bias_initializer='zeros')) 20 | features.add(nn.BatchNorm()) 21 | features.add(nn.Activation('relu')) 22 | features.add(nn.MaxPool2D(strides=2)) 23 | self.features = features 24 | def hybrid_forward(self, F, x, *args, **kwargs): 25 | return self.features(x) 26 | 27 | def load_vgg_encoder_params(self): 28 | """load from vggbn16 params,initialize the vgg conv1 for alpha zero""" 29 | vgg_file = get_model_file('vgg%d%s' % (16, '_bn')) 30 | loaded = ndarray.load(vgg_file) 31 | params = self._collect_params_with_prefix() 32 | for name in loaded: 33 | if name in params: 34 | params[name]._load_init(loaded[name]) 35 | 36 | 37 | 38 | class Encoder_Decoder(nn.HybridBlock): 39 | """this is the Deep Image matting encoder decoder structure for alpha matting""" 40 | def __init__(self,spec,stage): 41 | super(Encoder_Decoder,self).__init__() 42 | self.stage = stage 43 | self.encoder = Matting_Encoder(spec=spec) 44 | self.decoder = Matting_Decoder() 45 | 46 | self.refine = nn.HybridSequential(prefix='refine') 47 | channels = [64,64,64,1] 48 | with self.refine.name_scope(): 49 | for i,c in enumerate(channels): 50 | self.refine.add(nn.Conv2D(channels=64,kernel_size=3, 51 | weight_initializer=Xavier(), 52 | bias_initializer='zeros')) 53 | self.refine.add(nn.BatchNorm()) 54 | self.refine.add(nn.Activation('relu')) 55 | 56 | 57 | def hybrid_forward(self, F,rgbt,fg,bg): 58 | if self.stage==1: 59 | feature = self.encoder(x) 60 | alpha = self.decoder(feature) 61 | return [alpha] 62 | elif self.stage==2: 63 | feature = self.encoder(x) 64 | alpha = self.decoder(feature) 65 | ref_input = nd.concat(rgbt[:,0:3,:,:],alpha,dims=1) 66 | alpha = self.refine(ref_input) 67 | return [alpha] 68 | else 69 | feature = self.encoder(x) 70 | alpha = self.decoder(feature) 71 | ref_input = nd.concat(rgbt[:, 0:3, :, :], alpha, dims=1) 72 | alpha2 = self.refine(ref_input) 73 | return [alpha,alpha2] 74 | 75 | 76 | 77 | 78 | class Matting_Decoder(nn.HybridBlock): 79 | def __init__(self): 80 | super(Matting_Decoder,self).__init__() 81 | self.trans = nn.HybridSequential(prefix='') 82 | self.trans.add(nn.Conv2D(channels=512,kernel_size=1)) 83 | self.trans.add(nn.BatchNorm()) 84 | 85 | channels = [512,256,128,64,64] 86 | self.dec_layers = [] 87 | for i,c in enumerate(channels): 88 | block = nn.HybridSequential(prefix='decove_{0}'.format(6-i)) 89 | block.add(nn.Conv2D(channels=c,kernel_size=5,padding=2, 90 | weight_initializer=Xavier(rnd_type='gaussian', factor_type='out', magnitude=2), 91 | bias_initializer='zeros')) 92 | block.add(nn.BatchNorm()) 93 | block.add(nn.Activation('relu')) 94 | self.dec_layers.append(block) 95 | 96 | self.alpha_block = nn.HybridSequential() 97 | self.alpha_block.add(nn.Conv2D(channels=1,kernel_size=5,padding=2, 98 | weight_initializer=Xavier(rnd_type='gaussian', factor_type='out', magnitude=2), 99 | bias_initializer='zeros')) 100 | self.alpha_block.add(nn.BatchNorm()) 101 | self.alpha_block.add(nn.Activation('relu')) 102 | 103 | 104 | def hybrid_forward(self, F, x, *args, **kwargs): 105 | out = self.trans(x) 106 | for layer in self.dec_layers: 107 | out = layer(out) 108 | _,_,h,w = out.shape 109 | out = F.contrib.BilinearResize2D(out, height= * 2, width= * 2) 110 | out = self.alpha_block(out) 111 | return out 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /models/matting_loss.py: -------------------------------------------------------------------------------- 1 | from mxnet import nd 2 | from mxnet.gluon import nn 3 | 4 | 5 | class Compose_Loss(nn.Block): 6 | def __init__(self,eps): 7 | self.eps = eps 8 | 9 | def forward(self,fg,bg,pred,mask,merg): 10 | c = fg*pred +(1-pred)*bg 11 | dis = mask*(c-merg) 12 | l = nd.sqrt(self.eps + nd.square(dis).sum(0)) 13 | return l 14 | 15 | class AlphaPre_Loss(nn.Block): 16 | def __init__(self,eps): 17 | self.eps = eps 18 | 19 | def forward(self,pre,aph,mask): 20 | pass 21 | 22 | class AlphaRef_Loss(nn.Block): 23 | """this is the alpha refinement loss""" 24 | def __init__(self,eps): 25 | self.eps = eps 26 | 27 | def forward(self, pred,aph,mask): 28 | pass 29 | 30 | 31 | -------------------------------------------------------------------------------- /models/pretrain_keys_pair.py: -------------------------------------------------------------------------------- 1 | import os 2 | pairs_keys = [('features.0.weight', 'down1.conv1.cbr_unit.0.weight'), 3 | ('features.0.bias', 'down1.conv1.cbr_unit.0.bias'), 4 | ('features.1.weight', 'down1.conv1.cbr_unit.1.weight'), 5 | ('features.1.bias', 'down1.conv1.cbr_unit.1.bias'), 6 | ('features.1.running_mean', 'down1.conv1.cbr_unit.1.running_mean'), 7 | ('features.1.running_var', 'down1.conv1.cbr_unit.1.running_var'), 8 | ('features.3.weight', 'down1.conv2.cbr_unit.0.weight'), 9 | ('features.3.bias', 'down1.conv2.cbr_unit.0.bias'), 10 | ('features.4.weight', 'down1.conv2.cbr_unit.1.weight'), 11 | ('features.4.bias', 'down1.conv2.cbr_unit.1.bias'), 12 | ('features.4.running_mean', 'down1.conv2.cbr_unit.1.running_mean'), 13 | ('features.4.running_var', 'down1.conv2.cbr_unit.1.running_var'), 14 | 15 | ('features.7.weight', 'down2.conv1.cbr_unit.0.weight'), 16 | ('features.7.bias', 'down2.conv1.cbr_unit.0.bias'), 17 | ('features.8.weight', 'down2.conv1.cbr_unit.1.weight'), 18 | ('features.8.bias', 'down2.conv1.cbr_unit.1.bias'), 19 | ('features.8.running_mean', 'down2.conv1.cbr_unit.1.running_mean'), 20 | ('features.8.running_var', 'down2.conv1.cbr_unit.1.running_var'), 21 | ('features.10.weight', 'down2.conv2.cbr_unit.0.weight'), 22 | ('features.10.bias', 'down2.conv2.cbr_unit.0.bias'), 23 | ('features.11.weight', 'down2.conv2.cbr_unit.1.weight'), 24 | ('features.11.bias', 'down2.conv2.cbr_unit.1.bias'), 25 | ('features.11.running_mean', 'down2.conv2.cbr_unit.1.running_mean'), 26 | ('features.11.running_var', 'down2.conv2.cbr_unit.1.running_var'), 27 | 28 | ('features.14.weight', 'down3.conv1.cbr_unit.0.weight'), 29 | ('features.14.bias', 'down3.conv1.cbr_unit.0.bias'), 30 | ('features.15.weight', 'down3.conv1.cbr_unit.1.weight'), 31 | ('features.15.bias', 'down3.conv1.cbr_unit.1.bias'), 32 | ('features.15.running_mean', 'down3.conv1.cbr_unit.1.running_mean'), 33 | ('features.15.running_var', 'down3.conv1.cbr_unit.1.running_var'), 34 | ('features.17.weight', 'down3.conv2.cbr_unit.0.weight'), 35 | ('features.17.bias', 'down3.conv2.cbr_unit.0.bias'), 36 | ('features.18.weight', 'down3.conv2.cbr_unit.1.weight'), 37 | ('features.18.bias', 'down3.conv2.cbr_unit.1.bias'), 38 | ('features.18.running_mean', 'down3.conv2.cbr_unit.1.running_mean'), 39 | ('features.18.running_var', 'down3.conv2.cbr_unit.1.running_var'), 40 | ('features.20.weight', 'down3.conv3.cbr_unit.0.weight'), 41 | ('features.20.bias', 'down3.conv3.cbr_unit.0.bias'), 42 | ('features.21.weight', 'down3.conv3.cbr_unit.1.weight'), 43 | ('features.21.bias', 'down3.conv3.cbr_unit.1.bias'), 44 | ('features.21.running_mean', 'down3.conv3.cbr_unit.1.running_mean'), 45 | ('features.21.running_var', 'down3.conv3.cbr_unit.1.running_var'), 46 | 47 | ('features.24.weight', 'down4.conv1.cbr_unit.0.weight'), 48 | ('features.24.bias', 'down4.conv1.cbr_unit.0.bias'), 49 | ('features.25.weight', 'down4.conv1.cbr_unit.1.weight'), 50 | ('features.25.bias', 'down4.conv1.cbr_unit.1.bias'), 51 | ('features.25.running_mean', 'down4.conv1.cbr_unit.1.running_mean'), 52 | ('features.25.running_var', 'down4.conv1.cbr_unit.1.running_var'), 53 | ('features.27.weight', 'down4.conv2.cbr_unit.0.weight'), 54 | ('features.27.bias', 'down4.conv2.cbr_unit.0.bias'), 55 | ('features.28.weight', 'down4.conv2.cbr_unit.1.weight'), 56 | ('features.28.bias', 'down4.conv2.cbr_unit.1.bias'), 57 | ('features.28.running_mean', 'down4.conv2.cbr_unit.1.running_mean'), 58 | ('features.28.running_var', 'down4.conv2.cbr_unit.1.running_var'), 59 | ('features.30.weight', 'down4.conv3.cbr_unit.0.weight'), 60 | ('features.30.bias', 'down4.conv3.cbr_unit.0.bias'), 61 | ('features.31.weight', 'down4.conv3.cbr_unit.1.weight'), 62 | ('features.31.bias', 'down4.conv3.cbr_unit.1.bias'), 63 | ('features.31.running_mean', 'down4.conv3.cbr_unit.1.running_mean'), 64 | ('features.31.running_var', 'down4.conv3.cbr_unit.1.running_var'), 65 | 66 | ('features.34.weight', 'down5.conv1.cbr_unit.0.weight'), 67 | ('features.34.bias', 'down5.conv1.cbr_unit.0.bias'), 68 | ('features.35.weight', 'down5.conv1.cbr_unit.1.weight'), 69 | ('features.35.bias', 'down5.conv1.cbr_unit.1.bias'), 70 | ('features.35.running_mean', 'down5.conv1.cbr_unit.1.running_mean'), 71 | ('features.35.running_var', 'down5.conv1.cbr_unit.1.running_var'), 72 | ('features.37.weight', 'down5.conv2.cbr_unit.0.weight'), 73 | ('features.37.bias', 'down5.conv2.cbr_unit.0.bias'), 74 | ('features.38.weight', 'down5.conv2.cbr_unit.1.weight'), 75 | ('features.38.bias', 'down5.conv2.cbr_unit.1.bias'), 76 | ('features.38.running_mean', 'down5.conv2.cbr_unit.1.running_mean'), 77 | ('features.38.running_var', 'down5.conv2.cbr_unit.1.running_var'), 78 | ('features.40.weight', 'down5.conv3.cbr_unit.0.weight'), 79 | ('features.40.bias', 'down5.conv3.cbr_unit.0.bias'), 80 | ('features.41.weight', 'down5.conv3.cbr_unit.1.weight'), 81 | ('features.41.bias', 'down5.conv3.cbr_unit.1.bias'), 82 | ('features.41.running_mean', 'down5.conv3.cbr_unit.1.running_mean'), 83 | ('features.41.running_var', 'down5.conv3.cbr_unit.1.running_var')] 84 | 85 | 86 | 87 | if __name__=='__main__': 88 | print(len(pairs_keys)) 89 | for p in pairs_keys: 90 | print(p) -------------------------------------------------------------------------------- /models/py_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | if __name__=='__main__': 4 | from pretrain_keys_pair import pairs_keys 5 | from .pretrain_keys_pair import pairs_keys 6 | class conv2DBatchNormRelu(nn.Module): 7 | def __init__( 8 | self, 9 | in_channels, 10 | n_filters, 11 | k_size, 12 | stride, 13 | padding, 14 | bias=True, 15 | dilation=1, 16 | is_batchnorm=True, 17 | ): 18 | super(conv2DBatchNormRelu, self).__init__() 19 | 20 | conv_mod = nn.Conv2d( 21 | int(in_channels), 22 | int(n_filters), 23 | kernel_size=k_size, 24 | padding=padding, 25 | stride=stride, 26 | bias=bias, 27 | dilation=dilation 28 | ) 29 | 30 | if is_batchnorm: 31 | self.cbr_unit = nn.Sequential( 32 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True) 33 | ) 34 | else: 35 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) 36 | 37 | def forward(self, inputs): 38 | outputs = self.cbr_unit(inputs) 39 | return outputs 40 | 41 | class segnetUp(nn.Module): 42 | def __init__(self,in_size,out_size): 43 | super(segnetUp,self).__init__() 44 | self.unpool = nn.MaxUnpool2d(2, 2) 45 | self.conv1 = conv2DBatchNormRelu(in_size,out_size,5,1,2) #kernel_size,stride,padding 46 | 47 | def forward(self, inputs,indices,output_shape): 48 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 49 | outputs = self.conv1(outputs) 50 | return outputs 51 | 52 | class segnetDown2(nn.Module): 53 | def __init__(self, in_size, out_size): 54 | super(segnetDown2, self).__init__() 55 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 56 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 57 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 58 | 59 | def forward(self, inputs): 60 | outputs = self.conv1(inputs) 61 | outputs = self.conv2(outputs) 62 | unpooled_shape = outputs.size() 63 | outputs, indices = self.maxpool_with_argmax(outputs) 64 | return outputs, indices, unpooled_shape 65 | 66 | 67 | class segnetDown3(nn.Module): 68 | def __init__(self, in_size, out_size): 69 | super(segnetDown3, self).__init__() 70 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 71 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 72 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 73 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 74 | 75 | def forward(self, inputs): 76 | outputs = self.conv1(inputs) 77 | outputs = self.conv2(outputs) 78 | outputs = self.conv3(outputs) 79 | unpooled_shape = outputs.size() 80 | outputs, indices = self.maxpool_with_argmax(outputs) 81 | return outputs, indices, unpooled_shape 82 | 83 | class transMap(nn.Module): 84 | def __init__(self): 85 | super(transMap,self).__init__() 86 | self.conv1 = conv2DBatchNormRelu(512,512,3,1,1) 87 | self.conv2 = conv2DBatchNormRelu(512,512,1,1,0) 88 | 89 | def forward(self, x): 90 | return self.conv2(self.conv1(x)) 91 | 92 | class refineHead(nn.Module): 93 | def __init__(self): 94 | super(refineHead,self).__init__() 95 | self.conv1 = conv2DBatchNormRelu(4,64,3,1,1) 96 | self.conv2 = conv2DBatchNormRelu(64,64,3,1,1) 97 | self.conv3 = conv2DBatchNormRelu(64,64,3,1,1) 98 | self.alpha = nn.Sequential( 99 | nn.Conv2d(in_channels=64,out_channels=1,kernel_size=3,padding=1), 100 | nn.Sigmoid()) 101 | 102 | def forward(self, x): 103 | x = self.conv3(self.conv2(self.conv1(x))) 104 | return self.alpha(x) 105 | 106 | 107 | class Encoder_Decoder(nn.Module): 108 | """the deep image matting encoder decoder structure 109 | stage 0:just train encoder decoder 110 | stage 1:just train refinehead 111 | stage 2:overall train the encoder decoder and refine_head""" 112 | 113 | def __init__(self,stage=1): 114 | super(Encoder_Decoder,self).__init__() 115 | self.down1 = segnetDown2(4,64) 116 | self.down2 = segnetDown2(64,128) 117 | self.down3 = segnetDown3(128,256) 118 | self.down4 = segnetDown3(256,512) 119 | self.down5 = segnetDown3(512,512) 120 | 121 | self.trans = transMap() 122 | 123 | self.deconv5 = segnetUp(512,512) 124 | self.deconv4 = segnetUp(512,256) 125 | self.deconv3 = segnetUp(256,128) 126 | self.deconv2 = segnetUp(128,64) 127 | self.deconv1 = segnetUp(64,3) 128 | 129 | self.rawalpha = nn.Sequential( 130 | nn.Conv2d(in_channels=3,out_channels=1,kernel_size=5,padding=2,stride=1), 131 | nn.Sigmoid()) 132 | 133 | self.refine_head = refineHead() 134 | self.stage = stage 135 | 136 | def forward(self, x): 137 | down1,indices_1,unpool_shape1 = self.down1(x) 138 | down2,indices_2,unpool_shape2 = self.down2(down1) 139 | down3,indices_3,unpool_shape3 = self.down3(down2) 140 | down4,indices_4,unpool_shape4 = self.down4(down3) 141 | down5,indices_5,unpool_shape5 = self.down5(down4) 142 | 143 | trans = self.trans(down5) 144 | 145 | up5 = self.deconv5(trans,indices_5,unpool_shape5) 146 | up4 = self.deconv4(up5,indices_4,unpool_shape4) 147 | up3 = self.deconv3(up4,indices_3,unpool_shape3) 148 | up2 = self.deconv2(up3,indices_2,unpool_shape2) 149 | up1 = self.deconv1(up2,indices_1,unpool_shape1) 150 | raw_alpha = self.rawalpha(up1) 151 | if self.stage==0: 152 | return raw_alpha,0 153 | else: 154 | refine_in = torch.cat((x[:,:3,:,:],alpha),1) 155 | refine_alpha = self.refine_head(refine_in) 156 | return raw_alpha,refine_alpha 157 | 158 | def load_vggbn(self,file): 159 | state_dict = torch.load(file) 160 | origin_keys = state_dict.keys() 161 | 162 | struct_dict = self.state_dict() #torch.load('./checkpoints/struct.pth') 163 | p0 = pairs_keys[0] 164 | #conv filters 165 | origin_conv0 = state_dict[p0[0]] 166 | addChn_conv0 = struct_dict[p0[1]] 167 | addChn_conv0.data.zero_() #set all the alpha channel filters zero 168 | addChn_conv0.data[:,0:3,:,:]=origin_conv0.data 169 | 170 | #all other data insert 171 | for p in pairs_keys[1:]: 172 | k1,k2 = p 173 | struct_dict[k2].data = state_dict[k1].data 174 | self.load_state_dict(struct_dict) 175 | 176 | 177 | 178 | if __name__=='__main__': 179 | model = Encoder_Decoder(stage=0) 180 | #this is a pre stored random params 181 | torch.save(model.state_dict(),'./checkpoints/struct.pth') 182 | model.load_vggbn('./checkpoints/vgg16_bn-6c64b313.pth') 183 | x = torch.rand(2,4,320,320) 184 | y = model(x) 185 | -------------------------------------------------------------------------------- /models/py_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd 3 | from torch import nn 4 | import numpy 5 | 6 | 7 | class ComposeLoss(nn.Module): 8 | """compute the ComposeLoss of gt_merged and pre_merged""" 9 | def __init__(self,eps=1e-6,cuda=False): 10 | super(ComposeLoss,self).__init__() 11 | self.eps = torch.tensor(eps) 12 | 13 | def cuda(self): 14 | self.eps = self.eps.cuda() 15 | 16 | 17 | 18 | def forward(self,a_pred,label): 19 | #a_pred,fg,bg,gt_merge,mask): 20 | """ 21 | compute the compose loss 22 | :param a_pred: the encoder_decoder predict alpha 23 | :param fg: the foreground image [0,255.0] float32 24 | :param bg: the background image[0,255.0] float32 25 | :param gt_merge:the gt_alpha merged iamge [0,255.0] float32 26 | :param mask: the unknown region mask (0,1) binary mask mat 27 | :return: the matting region compose loss 28 | """ 29 | fg = label[:,:3,:,:] 30 | bg = label[:,3:6,:,:] 31 | gt_merged = label[:,6:9,:,:] 32 | mask = label[:,-1:,:,:] 33 | prd_comp = a_pred*fg+(1-a_pred)*bg 34 | dis = mask*(gt_merged - prd_comp)/255 35 | dis = dis.sum() 36 | loss = torch.sqrt(torch.pow(dis,2)+torch.pow(self.eps,2)) 37 | return loss 38 | 39 | 40 | class AlphaPredLoss(nn.Module): 41 | """compute the pred alpha and the gt_alpha loss""" 42 | def __init__(self,eps=1e-6): 43 | super(AlphaPredLoss,self).__init__() 44 | self.eps = torch.tensor(eps) 45 | 46 | def cuda(self): 47 | self.eps = self.eps.cuda() 48 | 49 | 50 | 51 | def forward(self, a_pred,label): 52 | #a_pred,a_gt,mask): 53 | """ 54 | compute the encoder-decoder or refine head alpha loss 55 | :param a_pred: the encoder decoder or refine_head output alpha,value in (0,1) float32 56 | :param a_gt: the groundtruth alpha value in (0,1) float32 57 | :param mask: the unknown region mask 58 | :return: 59 | """ 60 | a_gt = label[:,-2:-1,:,:] 61 | mask = label[:,-1:,:,:] 62 | dis = mask*(a_pred-a_gt) #dis is every 63 | 64 | loss = torch.sqrt(torch.pow(dis,2).sum() + torch.pow(self.eps,2)) 65 | return loss 66 | 67 | 68 | 69 | 70 | if __name__=='__main__': 71 | print("test the training loss") 72 | 73 | a_pred = torch.rand((4,1,30,30),dtype=torch.float32) 74 | label = torch.rand((4,11,30,30),dtype=torch.float32) 75 | 76 | alpha_loss = AlphaPredLoss() 77 | loss = alpha_loss(a_pred,label) 78 | print(loss.item()) 79 | 80 | comp_loss = ComposeLoss() 81 | loss2 = comp_loss(a_pred,label) 82 | print(loss2.item()) 83 | 84 | -------------------------------------------------------------------------------- /train_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from data import get_train_val_dataloader 4 | from models import Encoder_Decoder 5 | import argparse 6 | import torch.optim as optim 7 | import os 8 | from models import ComposeLoss,AlphaPredLoss 9 | from utils.visulization import Visulizer 10 | import time 11 | from torch import autograd 12 | 13 | 14 | def parse_args(): 15 | """Training Options for Segmentation Experiments""" 16 | parser = argparse.ArgumentParser(description='Pytorch learning args') 17 | parser.add_argument('--stage',type=int,default=0) 18 | parser.add_argument('--batch_size',type=int,default=16) 19 | parser.add_argument('--num_workers',type=int,default=10) 20 | parser.add_argument('--crop_size',type=int,default=320) 21 | parser.add_argument('--epochs',type=int,default=50) 22 | parser.add_argument('--lr',type=float,default=0.0001) 23 | parser.add_argument('--wd',type=float,default=1e-5) 24 | parser.add_argument('--momentum',type=float,default=0.9) 25 | parser.add_argument('--gpu',type=str,default='2') 26 | parser.add_argument('--pretrain_model',type=str,default=None) 27 | parser.add_argument('--eps',type=float,default=1e-6) 28 | parser.add_argument('--lmd',type=float,default=0.5) 29 | parser.add_argument('--last_epoch',type=int,default=-1) 30 | parser.add_argument('--freq',type=int,default=20) 31 | parser.add_argument('--debug',action='store_true', default= False,help='if debug mode') 32 | parser.add_argument('--env',type=str,default='super_mali') 33 | args = parser.parse_args() 34 | return args 35 | 36 | class Trainer(object): 37 | """the unify trainer for encoder-decoder refinehead and ovaerall""" 38 | model_app={0:"encoder_decoder", 39 | 1:"refine_head", 40 | 2:"over_all"} 41 | #training stage for encoder_decoder or over_all 42 | def __init__(self,args): 43 | self.args = args 44 | os.environ['CUDA_VISIBLE_DEVICES']=str(self.args.gpu) 45 | self.stage = args.stage 46 | self.model_name = self.model_app[args.stage] 47 | self.freq = self.args.freq 48 | self.train_loader,self.valid_loader = get_train_val_dataloader(batch_size=args.batch_size, 49 | num_workers=args.num_workers) 50 | self.model = Encoder_Decoder(stage=args.stage) 51 | if torch.cuda.is_available(): 52 | self.model = self.model.cuda() 53 | base_lr = self.args.lr 54 | if self.stage==0: 55 | if not self.args.pretrain_model: 56 | self.model.load_vggbn('./checkpoints/vgg16_bn-6c64b313.pth') 57 | else: 58 | self.model.load_state_dict(torch.load(self.args.pretrain_model)) 59 | self.loss =[ComposeLoss(eps=self.args.eps),AlphaPredLoss(eps=self.args.eps)] 60 | self.loss_lambda=[torch.tensor(self.args.lmd),torch.tensor(1-self.args.lmd)] 61 | 62 | self.trainer = optim.SGD([ 63 | {'params': self.model.down1.parameters(),'lr':1*base_lr}, 64 | {'params': self.model.down2.parameters(),'lr':1*base_lr}, 65 | {'params': self.model.down3.parameters(), 'lr': 1*base_lr}, 66 | {'params': self.model.down4.parameters(), 'lr': 1*base_lr}, 67 | {'params': self.model.down5.parameters(), 'lr': 1*base_lr}, 68 | {'params': self.model.trans.parameters(), 'lr': 1*base_lr}, 69 | {'params': self.model.deconv5.parameters(), 'lr': 1*base_lr}, 70 | {'params': self.model.deconv4.parameters(), 'lr': 1*base_lr}, 71 | {'params': self.model.deconv3.parameters(), 'lr': 1*base_lr}, 72 | {'params': self.model.deconv2.parameters(), 'lr': 1*base_lr}, 73 | {'params': self.model.deconv1.parameters(), 'lr': 1*base_lr}, 74 | {'params': self.model.rawalpha.parameters(),'lr':1*base_lr} 75 | ], 76 | lr=self.args.lr,weight_decay=self.args.wd,momentum=self.args.momentum) 77 | self.lr_schedular = optim.lr_scheduler.MultiStepLR(self.trainer, 78 | milestones=[5,10,30], 79 | gamma=0.5, 80 | last_epoch=self.args.last_epoch) 81 | self.metrics = [] 82 | 83 | elif self.stage==1: 84 | self.model.load_state_dict(self.args.pretrain_model) 85 | self.loss=[AlphaPredLoss(eps=self.args.eps)] 86 | self.loss_lambda =[torch.tensor(1)] 87 | self.trainer = optim.SGD([ 88 | {'params':self.model.refine_head.parameters(),'lr':1} 89 | ], 90 | lr=self.args.lr,weight_decay=self.args.wd,momentum=self.args.momentum) 91 | self.lr_schedular = optim.lr_scheduler.MultiStepLR(self.trainer, 92 | milestones=[3,10,30], 93 | gamma=0.2, 94 | last_epoch=self.args.last_epoch) 95 | else: 96 | self.model.load_state_dict(self.args.pretrain_model) 97 | self.loss = [AlphaPredLoss(eps=self.args.eps)] 98 | self.loss_lambda=[torch.tensor(1)] 99 | self.trainer = optim.Adam(self.model.parameters(),lr=self.args.lr) 100 | self.lr_schedular = optim.lr_scheduler.CosineAnnealingLR(self.trainer,T_max=2) 101 | if torch.cuda.is_available(): 102 | self.loss_lambda = [x.cuda() for x in self.loss_lambda] 103 | for x in self.loss: 104 | x.cuda() 105 | self.vis = Visulizer(env='{0}_{1}_{2}_{3}'.format('matting',self.model_name,time.strftime('%m_%d'),self.args.env)) 106 | self.vis.log(str(self.args)) 107 | 108 | 109 | def training(self,epoch): 110 | self.model.train(mode=True) 111 | train_loss = 0.0 112 | total_loss,prev_loss = 0,0 113 | self.lr_schedular.step() 114 | for i,(data,label) in enumerate(self.train_loader): 115 | if torch.cuda.is_available(): 116 | data,label = data.cuda(),label.cuda() 117 | self.trainer.zero_grad() 118 | al_pred = self.model(data) 119 | if self.stage==0: 120 | #loss1 = self.loss_lambda[0]*self.loss[0](al_pred[0],label) #compose loss 121 | loss2 = self.loss_lambda[1]*self.loss[1](al_pred[0],label) # alpha mse loss 122 | l_loss = loss2#loss1+ 123 | elif self.stage==1: 124 | l_loss = self.loss_lambda[0]*self.loss[0](al_pred[1],label) 125 | else: 126 | l_loss = self.loss_lambda[0]*self.loss[1](al_pred[1],label) 127 | 128 | l_loss.backward() 129 | if self.args.debug: 130 | params = [p for p in self.model.parameters()] 131 | grad = torch.tensor(0.0).cuda() 132 | for param in params: 133 | if not param.grad is None: 134 | grad += torch.sum(torch.abs(param.grad)) 135 | else: 136 | print("none grad") 137 | print("the grad of this iter",grad,"loss",l_loss.item()) 138 | 139 | self.trainer.step() 140 | total_loss += l_loss.item() 141 | if i%self.args.freq==(self.freq-1): 142 | step_loss = total_loss - prev_loss 143 | self.vis.plot('fre_loss',step_loss//self.freq) 144 | prev_loss = total_loss 145 | #the trainning procedure visulization result 146 | if self.stage==0 and i%(self.freq*2)==(self.freq*2-1): 147 | bg = label[:, :3, :, :] 148 | fg = label[:, 3:6, :, :] 149 | compose = al_pred[0]*fg+(1-al_pred[0])*bg 150 | for j,(alpha,y,pre_compose) in enumerate(zip(al_pred[0],label,compose)): 151 | self.vis.img('bg_{0}'.format(j),y[0:3].detach().cpu().numpy()) 152 | self.vis.img('fg_{0}'.format(j),y[3:6].detach().cpu().numpy()) 153 | self.vis.img('merged_{0}'.format(j),y[6:9].detach().cpu().numpy()) 154 | self.vis.img('gt_alpha_{0}'.format(j),y[9:10].detach().cpu().numpy()) 155 | self.vis.img('compose_{0}'.format(j),pre_compose.detach().cpu().numpy()) 156 | self.vis.img('alpha_{0}'.format(j),alpha.detach().cpu().numpy()) 157 | break 158 | 159 | if self.args.debug and i//self.freq==1: 160 | break 161 | self.vis.plot("total_loss",total_loss) 162 | self.vis.log("training epoch {0} finished ".format(epoch)) 163 | 164 | 165 | 166 | 167 | def validation(self,epoch): 168 | mse = 0.0 169 | sad = 0.0 170 | self.model.train(mode=False) 171 | mse_total,mse_pre = 0,0 172 | with torch.no_grad(): 173 | for i,(data,label) in enumerate(self.valid_loader): 174 | if torch.cuda.is_available(): 175 | data,label = data.cuda(),label.cuda() 176 | a_pred = self.model(data) 177 | mse = self.metric_mse(a_pred,label) 178 | sad = self.metric_sad(a_pred,label) 179 | mse_total += mse 180 | if i%self.args.freq ==(self.args.freq-1): 181 | self.vis.log('mse_alpha {0}'.format(mse_total/i)) 182 | mse_pre = mse_total 183 | if self.args.debug and i//self.freq==1: 184 | break 185 | self.vis.log('the validation of epoch {0}'.format(epoch)) 186 | 187 | 188 | def save_model(self,epoch): 189 | file_name = './checkpoints/{0}_{1}_{2}_{3}.params'.format(self.model_name,time.strftime('%m_%d'),str(epoch),self.args.env) 190 | torch.save(self.model.state_dict(), file_name) 191 | 192 | def metric_mse(self,alpha_pred,label): 193 | """ 194 | compute the mean square error of the aplha predict 195 | :param alpha_pred: the predicted alpha value (0,1) N,1,H,W 196 | :param label: the label fg,bg,mask,alpha_gt 197 | :return: mse_error 198 | """ 199 | return 0 200 | 201 | def metric_sad(self,alpha_pred,label): 202 | """ 203 | the sad of two images 204 | :param alpha_pred: 205 | :param label: 206 | :return: 207 | """ 208 | return 0 209 | 210 | 211 | 212 | if __name__=='__main__': 213 | """the is the main train logic""" 214 | args = parse_args() 215 | print("Starting Epoch",args.last_epoch) 216 | trainer = Trainer(args) 217 | for epoch in range(args.last_epoch,args.epochs): 218 | trainer.training(epoch) 219 | trainer.validation(epoch) 220 | trainer.save_model(epoch) 221 | trainer.vis.log('training finished') 222 | exit(0) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .visulization import Visulizer 2 | -------------------------------------------------------------------------------- /utils/learning_vis_test.py: -------------------------------------------------------------------------------- 1 | from data import get_train_val_dataloader 2 | from .visulization import Visulizer 3 | 4 | 5 | vis = Visulizer('env'=main) 6 | vis.log('the is a learning precodure visulizaiton') 7 | 8 | train_loader,valid_loader = get_train_val_dataloader(batch_size=4,num_workers=3) 9 | for i,(data,label) in enumerate(train_loader): 10 | vis. -------------------------------------------------------------------------------- /utils/visulization.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | import time 3 | import numpy as np 4 | class Visulizer(object): 5 | """the object interface to store train trace to website""" 6 | def __init__(self,host="http://hpc3.yud.io",port=8088,env='street'): 7 | self.vis = visdom.Visdom(server=host,port=port,env=env) 8 | self.host = host 9 | self.port = port 10 | self.env = env 11 | self.index ={} 12 | self.log_text="" 13 | 14 | def reinit(self,env='default'): 15 | self.vis = visdom.Visdom(server=self.host,port=self.port,env=self.env) 16 | return self 17 | 18 | def plot(self,name,y): 19 | """plot loss:1.0""" 20 | x = self.index.get(name,0) 21 | self.vis.line(Y=np.array([y]),X=np.array([x]), 22 | win=name, 23 | opts=dict(title=name), 24 | update=None if x==0 else 'append') 25 | self.index[name] = x+1 26 | 27 | def img(self,name,img_,**kwargs): 28 | """ 29 | :param name: the window name 30 | :param img_: img shape and data type,t.Tensor(64,64),Tensor(3,64,64),Tensor(100,1,64,64) 31 | :param kwargs: 32 | :return: 33 | """ 34 | # the img_ data type is numpy.ndarray ,instead of torch Tensor 35 | self.vis.images(img_, 36 | win=name, 37 | opts=dict(title=name), 38 | **kwargs) 39 | 40 | def log(self,info,win='log_text'): 41 | """self.log({loss:1,'lr':0.0001}""" 42 | self.log_text += ('[{time}] {info}
'.format( 43 | time=time.strftime('%m-%d %H:%M:%S'), \ 44 | info=info)) 45 | self.vis.text(self.log_text, win) 46 | 47 | def delete_env(self,env): 48 | self.vis.delete_env(env) 49 | 50 | if __name__=='__main__': 51 | """nohup python -m visdom.server --port-8088 & 52 | this to start visdom server""" 53 | viz = Visulizer(host='http://hpc3.yud.io',port=8088,env='street') 54 | viz.log("this is a start") 55 | viz.plot('loss',2.3) 56 | viz.plot('loss',2.2) 57 | viz.plot('loss',2.1) 58 | 59 | viz.img('origin',np.random.random((10,3,224,224))) --------------------------------------------------------------------------------