├── README.md ├── datasets ├── __init__.py ├── saliency.py └── unlabeled.py ├── evaluate.py ├── models ├── __init__.py ├── fcn.py ├── net.py └── vgg.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # lps 2 | code for the paper ``learning to promote saliency detectors" 3 | 4 | Environment: python 2.7; pytorch '0.5.0a0+54db14e'; two GTX 1080Ti GPU; 5 | 6 | ## usage 7 | modify the path to prior maps, images, groud truth and then run test.py 8 | 9 | ## pre-trained model 10 | [download pre-trained model](https://pan.baidu.com/s/1mOMz6pXYsoJPgqE6hQxI1A) 11 | 12 | ## performance and results 13 | 14 | This version gives slightly different results from the paper. 15 | 16 | Prior maps are the results of other methods which are provided by the authors or obtained by runing their code. 17 | For example, results of SRM can be downloaded [here](https://github.com/TiantianWang/ICCV17_SRM). Results of applying our method on SRM can be downloaded from [百度网盘](https://pan.baidu.com/s/1T51KDP0NlLW971kDardZ6g) or [google drive](https://drive.google.com/open?id=1lufzjX2478U0W3-tbXaEMQGqnadljYQe) 18 | 19 | ### F-measure 20 | 21 | v |ECSSD | HKU-IS|PASCALS|DUTS-test|THUR|OMRON 22 | --- | --- | --- | --- | --- | ---| --- 23 | SRM |.8924 | .8739 | .7961 | .7591 |.7079|.7223 24 | +Ours|.9102 | .9032 | .8054 | .7999 |.7299|.7338 25 | 26 | 27 | ### MAE 28 | 29 | v |ECSSD | HKU-IS|PASCALS|DUTS-test|THUR|OMRON 30 | --- | --- | --- | --- | --- | ---| --- 31 | SRM |.0542 | .0459 | .0852 |.0633|.0769|.0767 32 | +Ours|.0416 | .0330 | .0729 |.0536|.0735|.0696 33 | 34 | I'll add the results on other methods later. 35 | 36 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .saliency import Folder, PriorFolder 2 | from .unlabeled import ImageFiles 3 | -------------------------------------------------------------------------------- /datasets/saliency.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image as Image 4 | import torch 5 | from torch.utils import data 6 | import pdb 7 | import random 8 | 9 | 10 | def rotated_rect_with_max_area(w, h, angle): 11 | """ 12 | Given a rectangle of size wxh that has been rotated by 'angle' (in 13 | radians), computes the width and height of the largest possible 14 | axis-aligned rectangle (maximal area) within the rotated rectangle. 15 | """ 16 | if w <= 0 or h <= 0: 17 | return 0, 0 18 | 19 | width_is_longer = w >= h 20 | side_long, side_short = (w, h) if width_is_longer else (h, w) 21 | 22 | # since the solutions for angle, -angle and 180-angle are all the same, 23 | # if suffices to look at the first quadrant and the absolute values of sin,cos: 24 | sin_a, cos_a = abs(np.sin(angle)), abs(np.cos(angle)) 25 | if side_short <= 2. * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: 26 | # half constrained case: two crop corners touch the longer side, 27 | # the other two corners are on the mid-line parallel to the longer line 28 | x = 0.5 * side_short 29 | wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a) 30 | else: 31 | # fully constrained case: crop touches all 4 sides 32 | cos_2a = cos_a * cos_a - sin_a * sin_a 33 | wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a 34 | 35 | return wr, hr 36 | 37 | 38 | class BaseFolder(data.Dataset): 39 | def __init__(self, root, crop=None, rotate=None, flip=False, 40 | mean=None, std=None): 41 | super(BaseFolder, self).__init__() 42 | self.mean, self.std = mean, std 43 | self.flip = flip 44 | self.rotate = rotate 45 | self.crop = crop 46 | img_dir = os.path.join(root, 'images') 47 | gt_dir = os.path.join(root, 'masks') 48 | names = ['.'.join(name.split('.')[:-1]) for name in os.listdir(gt_dir)] 49 | self.img_filenames = [os.path.join(img_dir, name+'.jpg') for name in names] 50 | self.gt_filenames = [os.path.join(gt_dir, name+'.png') for name in names] 51 | self.names = names 52 | 53 | def random_crop(self, *images): 54 | images = list(images) 55 | sz = [img.size for img in images] 56 | sz = set(sz) 57 | assert(len(sz)==1) 58 | w, h = sz.pop() 59 | th, tw = int(self.crop*h), int(self.crop*w) 60 | if w == tw and h == th: 61 | return 0, 0, h, w 62 | i = random.randint(0, h - th) 63 | j = random.randint(0, w - tw) 64 | results = [img.crop((j, i, j + tw, i + th)) for img in images] 65 | return tuple(results) 66 | 67 | def random_flip(self, *images): 68 | if self.flip and random.randint(0, 1): 69 | images = list(images) 70 | results = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in images] 71 | return tuple(results) 72 | else: 73 | return images 74 | 75 | def random_rotate(self, *images): 76 | images = list(images) 77 | sz = [img.size for img in images] 78 | sz = set(sz) 79 | assert(len(sz)==1) 80 | w, h = sz.pop() 81 | degree = random.randint(-1*self.rotate, self.rotate) 82 | images_r = [img.rotate(degree, expand=1) for img in images] 83 | w_b, h_b = images_r[0].size 84 | w_r, h_r = rotated_rect_with_max_area(w, h, np.radians(degree)) 85 | ws = (w_b - w_r) / 2 86 | ws = max(ws, 0) 87 | hs = (h_b - h_r) / 2 88 | hs = max(hs, 0) 89 | we = ws + w_r 90 | he = hs + h_r 91 | we = min(we, w_b) 92 | he = min(he, h_b) 93 | results = [img.crop((ws, hs, we, he)) for img in images_r] 94 | return tuple(results) 95 | 96 | def __len__(self): 97 | return len(self.names) 98 | 99 | 100 | class PriorFolder(BaseFolder): 101 | def __init__(self, root, prior_dir, size=256, crop=None, rotate=None, flip=False, 102 | mean=None, std=None): 103 | super(PriorFolder, self).__init__(root, crop=crop, rotate=rotate, flip=flip, mean=mean, std=std) 104 | self.size = size 105 | self.pr_filenames = [os.path.join(prior_dir, name+'.png') for name in self.names] 106 | 107 | def __getitem__(self, index): 108 | # load image 109 | img_file = self.img_filenames[index] 110 | img = Image.open(img_file) 111 | gt_file = self.gt_filenames[index] 112 | name = self.names[index] 113 | gt = Image.open(gt_file) 114 | WW, HH = gt.size 115 | img = img.resize((WW, HH)) 116 | pr_file = self.pr_filenames[index] 117 | pr = Image.open(pr_file) 118 | pr = pr.resize((WW, HH)) 119 | if self.crop is not None: 120 | img, gt, pr = self.random_crop(img, gt, pr) 121 | if self.rotate is not None: 122 | img, gt, pr = self.random_rotate(img, gt, pr) 123 | if self.flip: 124 | img, gt, pr = self.random_flip(img, gt, pr) 125 | img, gt, pr = [_img.resize((self.size, self.size)) for _img in [img, gt, pr]] 126 | img = np.array(img, dtype=np.float64) / 255.0 127 | gt = np.array(gt, dtype=np.uint8) 128 | gt[gt != 0] = 1 129 | pr = np.array(pr, dtype=np.float64) / 255.0 130 | if len(img.shape) < 3: 131 | img = np.stack((img, img, img), 2) 132 | if img.shape[2] > 3: 133 | img = img[:, :, :3] 134 | if len(gt.shape) > 2: 135 | gt = gt[:, :, 0] 136 | if self.mean is not None: 137 | img -= self.mean 138 | if self.std is not None: 139 | img /= self.std 140 | img = img.transpose(2, 0, 1) 141 | img = torch.from_numpy(img).float() 142 | gt = torch.from_numpy(gt).float() 143 | pr = torch.from_numpy(pr).float() 144 | return img, gt, pr, name, WW, HH 145 | 146 | 147 | class Folder(BaseFolder): 148 | def __init__(self, root, scales=[256], crop=None, rotate=None, flip=False, 149 | mean=None, std=None): 150 | super(Folder, self).__init__(root, crop=crop, rotate=rotate, flip=flip, mean=mean, std=std) 151 | self.scales = scales 152 | 153 | def __getitem__(self, index): 154 | # load image 155 | img_file = self.img_filenames[index] 156 | img = Image.open(img_file) 157 | gt_file = self.gt_filenames[index] 158 | name = self.names[index] 159 | gt = Image.open(gt_file) 160 | WW, HH = gt.size 161 | img = img.resize((WW, HH)) 162 | if self.crop is not None: 163 | img, gt = self.random_crop(img, gt) 164 | if self.rotate is not None: 165 | img, gt = self.random_rotate(img, gt) 166 | if self.flip: 167 | img, gt = self.random_flip(img, gt) 168 | max_size = max(self.scales) 169 | img = img.resize((max_size, max_size)) 170 | gts = [gt.resize((s, s)) for s in self.scales] 171 | 172 | img = np.array(img, dtype=np.float64) / 255.0 173 | gts = [np.array(gt, dtype=np.uint8) for gt in gts] 174 | for gt in gts: gt[gt != 0] = 1 175 | if len(img.shape) < 3: 176 | img = np.stack((img, img, img), 2) 177 | if img.shape[2] > 3: 178 | img = img[:, :, :3] 179 | for i, gt in enumerate(gts): 180 | if len(gt.shape) > 2: 181 | gts[i] = gt[:, :, 0] 182 | if self.mean is not None: 183 | img -= self.mean 184 | if self.std is not None: 185 | img /= self.std 186 | img = img.transpose(2, 0, 1) 187 | img = torch.from_numpy(img).float() 188 | gts = [torch.from_numpy(gt).float() for gt in gts] 189 | return img, gts, name 190 | 191 | 192 | def collate_more(data): 193 | images, gts, name = zip(*data) 194 | gts = list(map(list, zip(*gts))) 195 | 196 | images = torch.stack(images, 0) 197 | gts = [torch.stack(gt, 0) for gt in gts] 198 | 199 | return images, gts, name 200 | 201 | 202 | 203 | if __name__ == "__main__": 204 | import matplotlib.pyplot as plt 205 | from datetime import datetime 206 | random.seed(datetime.now()) 207 | sb = Folder('/home/zeng/data/datasets/saliency_Dataset/ECSSD', 208 | crop=None, rotate=10, flip=True) 209 | img, gt, _ = sb.__getitem__(random.randint(0, 1000)) 210 | img = img.numpy().transpose((1, 2, 0)) 211 | plt.imshow(img) 212 | plt.show() 213 | plt.imshow(gt[0]) 214 | plt.show() 215 | 216 | # sb = PriorFolder('/home/zeng/data/datasets/saliency_Dataset/ECSSD', 217 | # '/home/zeng/data/datasets/saliency_Dataset/results/ECSSD-Sal/SRM', 218 | # crop=None, rotate=None, flip=True, size=256) 219 | # img, gt, pr, _, _, _ = sb.__getitem__(0) 220 | # img = img.numpy().transpose((1, 2, 0)) 221 | # plt.imshow(img) 222 | # plt.show() 223 | # plt.imshow(gt) 224 | # plt.show() 225 | # plt.imshow(pr) 226 | # plt.show() 227 | 228 | -------------------------------------------------------------------------------- /datasets/unlabeled.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image as Image 4 | import torch 5 | from torch.utils import data 6 | import pdb 7 | import random 8 | 9 | 10 | class ImageFiles(data.Dataset): 11 | def __init__(self, img_dir, prior_dir, 12 | size = 256, 13 | mean=None, std=None): 14 | super(ImageFiles, self).__init__() 15 | self.mean, self.std = mean, std 16 | self.size = size 17 | names = os.listdir(img_dir) 18 | names = ['.'.join(name.split('.')[:-1]) for name in names] 19 | self.img_filenames = list(map(lambda x: os.path.join(img_dir, x+'.jpg'), names)) 20 | self.pr_filenames = list(map(lambda x: os.path.join(prior_dir, x+'.png'), names)) 21 | self.names = names 22 | 23 | def __len__(self): 24 | return len(self.names) 25 | 26 | def __getitem__(self, index): 27 | # load image 28 | img_file = self.img_filenames[index] 29 | img = Image.open(img_file) 30 | pr_file = self.pr_filenames[index] 31 | pr = Image.open(pr_file) 32 | name = self.names[index] 33 | WW, HH = img.size 34 | img = img.resize((self.size, self.size)) 35 | img = np.array(img, dtype=np.float64)/255 36 | pr = pr.resize((self.size, self.size)) 37 | pr = np.array(pr, dtype=np.float64)/255 38 | if len(img.shape) < 3: 39 | img = np.stack((img, img, img), 2) 40 | if img.shape[2] > 3: 41 | img = img[:, :, :3] 42 | if self.mean is not None: 43 | img -= self.mean 44 | if self.std is not None: 45 | img /= self.std 46 | img = img.transpose(2, 0, 1) 47 | img = torch.from_numpy(img).float() 48 | pr = torch.from_numpy(pr).float() 49 | return img, pr, name, WW, HH 50 | 51 | 52 | if __name__ == "__main__": 53 | sb = ImageFiles('../../data/datasets/ILSVRC14VOC/images') 54 | pdb.set_trace() 55 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | import PIL.Image as Image 5 | import pdb 6 | from multiprocessing import Pool 7 | from functools import partial 8 | import matplotlib.pyplot as plt 9 | eps = np.finfo(float).eps 10 | 11 | 12 | def print_table(): 13 | base_dir = '/home/zeng/data/datasets/saliency_Dataset' 14 | # algs = ['LEGS', 'RFCN', 'DCL', 'DHS', 'MCDL', 'MDF'] 15 | algs = ['Ours-Seg', 'Ours-Seg-woSeg', 'Ours-N1-Seg', 'Ours-N1-Seg-woSeg'] 16 | # datasets = ['ECSSD', 'PASCALS', 'SOD', 'MSRA5K', 'OMRON'] 17 | datasets = ['ECSSD'] 18 | for alg in algs: 19 | print(alg+'& ', end='') 20 | for i, dset in enumerate(datasets): 21 | input_dir = '{}/results/{}-Sal/{}'.format(base_dir, dset, alg) 22 | gt_dir = '{}/{}/masks'.format(base_dir, dset) 23 | output_dir = '{}/results/{}-npy'.format(base_dir, dset) 24 | maxfm, mae, _, _ = fm_and_mae(input_dir, gt_dir, output_dir, alg) 25 | if i != len(datasets)-1: 26 | print('%.3f&%.3f& '%(round(maxfm, 3), round(mae, 3)), end='') 27 | else: 28 | print('%.3f&%.3f\\\\'%(round(maxfm, 3), round(mae, 3)), end='\n') 29 | print('\hline', end='\n') 30 | 31 | 32 | def draw_curves(): 33 | base_dir = '/home/zeng/data/datasets/saliency_Dataset' 34 | algs = ['BSCA', 'MR', 'HS', 'Ours', 'WSS', 'DRFI', 'LEGS', 'MCDL', 'MDF'] 35 | datasets = ['ECSSD', 'PASCALS', 'SOD', 'OMRON'] 36 | # color = iter(plt.cm.rainbow(np.linspace(0, 1, len(algs)))) 37 | for dset in datasets: 38 | fig = plt.figure() 39 | ax = fig.add_subplot(111) 40 | for i, alg in enumerate(algs): 41 | sb = np.load('{}/results/{}-npy/{}.npz'.format(base_dir, dset, alg)) 42 | ax.plot(sb['recs'], sb['pres'], linewidth=2, label=alg) 43 | ax.grid(True) 44 | ax.set_xlabel('Recall', fontsize=14) 45 | ax.set_ylabel('Precision', fontsize=14) 46 | handles, labels = ax.get_legend_handles_labels() 47 | lgd = ax.legend(handles, labels, loc='center left', bbox_to_anchor=(0.5, -0.5), ncol=8, fontsize=14) 48 | fig.savefig('%s.pdf'%dset, bbox_extra_artists=(lgd,), bbox_inches='tight') 49 | 50 | 51 | def eva_one(param): 52 | input_name, gt_name = param 53 | mask = Image.open(input_name) 54 | gt = Image.open(gt_name) 55 | mask = mask.resize(gt.size) 56 | mask = np.array(mask, dtype=np.float) 57 | if len(mask.shape) != 2: 58 | mask = mask[:, :, 0] 59 | mask = (mask - mask.min()) / (mask.max()-mask.min()+eps) 60 | gt = np.array(gt, dtype=np.uint8) 61 | if len(gt.shape)>2: 62 | gt = gt[:, :, 0] 63 | gt[gt != 0] = 1 64 | pres = [] 65 | recs = [] 66 | mea = np.abs(gt-mask).mean() 67 | # threshold fm 68 | binary = np.zeros(mask.shape) 69 | th = 2*mask.mean() 70 | if th > 1: 71 | th = 1 72 | binary[mask >= th] = 1 73 | sb = (binary * gt).sum() 74 | pre = sb / (binary.sum()+eps) 75 | rec = sb / (gt.sum()+eps) 76 | thfm = 1.3 * pre * rec / (0.3 * pre + rec + eps) 77 | for th in np.linspace(0, 1, 21): 78 | binary = np.zeros(mask.shape) 79 | binary[ mask >= th] = 1 80 | pre = (binary * gt).sum() / (binary.sum()+eps) 81 | rec = (binary * gt).sum() / (gt.sum()+ eps) 82 | pres.append(pre) 83 | recs.append(rec) 84 | pres = np.array(pres) 85 | recs = np.array(recs) 86 | return thfm, mea, recs, pres 87 | 88 | 89 | def fm_and_mae(input_dir, gt_dir, output_dir=None, name=None): 90 | if output_dir is not None and not os.path.exists(output_dir): 91 | os.mkdir(output_dir) 92 | 93 | filelist_gt = os.listdir(gt_dir) 94 | gt_format = filelist_gt[0].split('.')[-1] 95 | filelist_gt = ['.'.join(f.split('.')[:-1]) for f in filelist_gt] 96 | 97 | filelist_pred = os.listdir(input_dir) 98 | pred_format = filelist_pred[0].split('.')[-1] 99 | filelist_pred = ['.'.join(f.split('.')[:-1]) for f in filelist_pred] 100 | 101 | filelist = list(set(filelist_gt)&set(filelist_pred)) 102 | 103 | inputlist = [os.path.join(input_dir, '.'.join([_name, pred_format])) for _name in filelist] 104 | gtlist = [os.path.join(gt_dir, '.'.join([_name, gt_format])) for _name in filelist] 105 | 106 | pool = Pool(4) 107 | results = pool.map(eva_one, zip(inputlist, gtlist)) 108 | thfm, m_mea, m_recs, m_pres = list(map(list, zip(*results))) 109 | m_mea = np.array(m_mea).mean() 110 | m_pres = np.array(m_pres).mean(0) 111 | m_recs = np.array(m_recs).mean(0) 112 | thfm = np.array(thfm).mean() 113 | fms = 1.3 * m_pres * m_recs / (0.3*m_pres + m_recs + eps) 114 | maxfm = fms.max() 115 | if not (output_dir is None or name is None): 116 | np.savez('%s/%s.npz'%(output_dir, name), mea=m_mea, thfm=thfm, maxfm = maxfm, recs=m_recs, pres=m_pres, fms=fms) 117 | return thfm, m_mea, m_recs, m_pres 118 | 119 | 120 | if __name__ == '__main__': 121 | # fm, mae, _, _ = fm_and_mae('/home/zeng/data/datasets/saliency_Dataset/results/ECSSD-Sal/WSS', 122 | # '/home/zeng/data/datasets/saliency_Dataset/ECSSD/masks') 123 | # print(fm) 124 | # print(mae) 125 | print_table() 126 | # draw_curves() 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .net import Net 2 | from .fcn import FCN 3 | -------------------------------------------------------------------------------- /models/fcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd.variable import Variable 5 | 6 | # from .densenet import * 7 | # from .resnet import * 8 | from .vgg import * 9 | 10 | # from densenet import * 11 | # from resnet import * 12 | # from vgg import * 13 | 14 | import numpy as np 15 | import sys 16 | thismodule = sys.modules[__name__] 17 | import pdb 18 | 19 | img_size = 256 20 | 21 | dim_dict = { 22 | 'vgg': [64, 128, 256, 512, 512] 23 | } 24 | 25 | 26 | def proc_vgg(model): 27 | def hook(module, input, output): 28 | model.feats[output.device.index] += [output] 29 | for m in model.features[:-1]: 30 | m[-2].register_forward_hook(hook) 31 | # dilation 32 | def remove_sequential(all_layers, network): 33 | for layer in network.children(): 34 | if isinstance(layer, nn.Sequential): # if sequential layer, apply recursively to layers in sequential layer 35 | remove_sequential(all_layers, layer) 36 | if list(layer.children()) == []: # if leaf node, add it to list 37 | all_layers.append(layer) 38 | model.features[2][-1].stride = 1 39 | model.features[2][-1].kernel_size = 1 40 | all_layers = [] 41 | remove_sequential(all_layers, model.features[3]) 42 | for m in all_layers: 43 | if isinstance(m, nn.Conv2d): 44 | m.dilation = (2, 2) 45 | m.padding = (2, 2) 46 | 47 | model.features[3][-1].stride = 1 48 | model.features[3][-1].kernel_size = 1 49 | all_layers = [] 50 | remove_sequential(all_layers, model.features[4]) 51 | for m in model.features[4]: 52 | if isinstance(m, nn.Conv2d): 53 | m.dilation = (4, 4) 54 | m.padding = (4, 4) 55 | model.features[4][-1].stride = 1 56 | model.features[4][-1].kernel_size = 1 57 | return model 58 | 59 | 60 | procs = { 61 | 'vgg16': proc_vgg, 62 | } 63 | 64 | 65 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 66 | """Make a 2D bilinear kernel suitable for upsampling""" 67 | factor = (kernel_size + 1) // 2 68 | if kernel_size % 2 == 1: 69 | center = factor - 1 70 | else: 71 | center = factor - 0.5 72 | og = np.ogrid[:kernel_size, :kernel_size] 73 | filt = (1 - abs(og[0] - center) / factor) * \ 74 | (1 - abs(og[1] - center) / factor) 75 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 76 | dtype=np.float64) 77 | weight[range(in_channels), range(out_channels), :, :] = filt 78 | return torch.from_numpy(weight).float() 79 | 80 | 81 | class FCN(nn.Module): 82 | def __init__(self, net): 83 | super(FCN, self).__init__() 84 | if 'vgg' in net.base: 85 | dims = dim_dict['vgg'][::-1] 86 | else: 87 | dims = dim_dict[net.base][::-1] 88 | self.preds = nn.ModuleList([nn.Conv2d(d, 1, kernel_size=1) for d in dims]) 89 | self.upscales = nn.ModuleList([nn.ConvTranspose2d(1, 1, 1, 1, 0)]*2+[nn.ConvTranspose2d(1, 1, 4, 2, 1)]*2) 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 92 | m.weight.data.normal_(0.0, 0.02) 93 | if m.bias is not None: 94 | m.bias.data.fill_(0) 95 | if isinstance(m, nn.ConvTranspose2d): 96 | assert m.kernel_size[0] == m.kernel_size[1] 97 | initial_weight = get_upsampling_weight( 98 | m.in_channels, m.out_channels, m.kernel_size[0]) 99 | m.weight.data.copy_(initial_weight) 100 | self.feature = net.feature 101 | for m in self.modules(): 102 | if isinstance(m, nn.BatchNorm2d): 103 | m.requires_grad=False 104 | 105 | def forward(self, *data): 106 | x = data[0] 107 | self.feature.feats[x.device.index] = [] 108 | x = self.feature(x) 109 | feats = self.feature.feats[x.device.index] 110 | feats += [x] 111 | feats = feats[::-1] 112 | pred = self.preds[0](feats[0]) 113 | preds = [pred] 114 | for i in range(4): 115 | pred = self.preds[i+1](feats[i+1]+self.upscales[i](pred)) 116 | preds += [pred] 117 | if self.training: 118 | return preds 119 | else: 120 | return pred 121 | 122 | 123 | 124 | if __name__ == "__main__": 125 | net = Net() 126 | net.cuda() 127 | x = torch.Tensor(2, 3, 256, 256).cuda() 128 | z = torch.Tensor(2, 1, 256, 256).cuda() 129 | sb = net(x, z) 130 | pdb.set_trace() 131 | -------------------------------------------------------------------------------- /models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd.variable import Variable 5 | 6 | # from .densenet import * 7 | # from .resnet import * 8 | from .vgg import * 9 | 10 | # from densenet import * 11 | # from resnet import * 12 | # from vgg import * 13 | 14 | import numpy as np 15 | import sys 16 | thismodule = sys.modules[__name__] 17 | import pdb 18 | 19 | img_size = 256 20 | 21 | dim_dict = { 22 | 'vgg': [64, 128, 256, 512, 512] 23 | } 24 | 25 | 26 | def proc_vgg(model): 27 | def hook(module, input, output): 28 | model.feats[output.device.index] += [output] 29 | for m in model.features[:-1]: 30 | m[-2].register_forward_hook(hook) 31 | # dilation 32 | def remove_sequential(all_layers, network): 33 | for layer in network.children(): 34 | if isinstance(layer, nn.Sequential): # if sequential layer, apply recursively to layers in sequential layer 35 | remove_sequential(all_layers, layer) 36 | if list(layer.children()) == []: # if leaf node, add it to list 37 | all_layers.append(layer) 38 | model.features[2][-1].stride = 1 39 | model.features[2][-1].kernel_size = 1 40 | all_layers = [] 41 | remove_sequential(all_layers, model.features[3]) 42 | for m in all_layers: 43 | if isinstance(m, nn.Conv2d): 44 | m.dilation = (2, 2) 45 | m.padding = (2, 2) 46 | 47 | model.features[3][-1].stride = 1 48 | model.features[3][-1].kernel_size = 1 49 | all_layers = [] 50 | remove_sequential(all_layers, model.features[4]) 51 | for m in model.features[4]: 52 | if isinstance(m, nn.Conv2d): 53 | m.dilation = (4, 4) 54 | m.padding = (4, 4) 55 | model.features[4][-1].stride = 1 56 | model.features[4][-1].kernel_size = 1 57 | return model 58 | 59 | 60 | procs = { 61 | 'vgg16': proc_vgg, 62 | } 63 | 64 | 65 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 66 | """Make a 2D bilinear kernel suitable for upsampling""" 67 | factor = (kernel_size + 1) // 2 68 | if kernel_size % 2 == 1: 69 | center = factor - 1 70 | else: 71 | center = factor - 0.5 72 | og = np.ogrid[:kernel_size, :kernel_size] 73 | filt = (1 - abs(og[0] - center) / factor) * \ 74 | (1 - abs(og[1] - center) / factor) 75 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 76 | dtype=np.float64) 77 | weight[range(in_channels), range(out_channels), :, :] = filt 78 | return torch.from_numpy(weight).float() 79 | 80 | 81 | class Net(nn.Module): 82 | def __init__(self, pretrained=True, base='vgg16'): 83 | super(Net, self).__init__() 84 | if 'vgg' in base: 85 | dims = dim_dict['vgg'][::-1] 86 | else: 87 | dims = dim_dict[base][::-1] 88 | self.base = base 89 | odims = [64]*5 90 | hdim = 512 91 | self.classifier = nn.Linear(512, 1) 92 | self.proc_feats_list = nn.ModuleList([ 93 | nn.Sequential(nn.ConvTranspose2d(dims[0], dims[0], 8, 4, 2), nn.Conv2d(dims[0], odims[0], kernel_size=3, padding=1)), 94 | nn.Sequential(nn.ConvTranspose2d(dims[1], dims[1], 8, 4, 2), nn.Conv2d(dims[1], odims[1], kernel_size=3, padding=1)), 95 | nn.Sequential(nn.ConvTranspose2d(dims[2], dims[2], 8, 4, 2), nn.Conv2d(dims[2], odims[2], kernel_size=3, padding=1)), 96 | nn.Sequential(nn.ConvTranspose2d(dims[3], dims[3], 4, 2, 1), nn.Conv2d(dims[3], odims[3], kernel_size=3, padding=1)), 97 | # nn.Sequential(nn.Conv2d(dims[0], odims[0]*16, kernel_size=3, padding=1), nn.PixelShuffle(4)), 98 | # nn.Sequential(nn.Conv2d(dims[1], odims[1]*16, kernel_size=3, padding=1), nn.PixelShuffle(4)), 99 | # nn.Sequential(nn.Conv2d(dims[2], odims[2]*16, kernel_size=3, padding=1), nn.PixelShuffle(4)), 100 | # nn.Sequential(nn.Conv2d(dims[3], odims[3]*4, kernel_size=3, padding=1), nn.PixelShuffle(2)), 101 | nn.Conv2d(dims[4], dims[4], kernel_size=3, padding=1), 102 | ]) 103 | self.proc_feats = nn.Conv2d(sum(odims), hdim, kernel_size=3, padding=1) 104 | self.proc_mul = nn.Conv2d(sum(odims), hdim, kernel_size=3, padding=1) 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 107 | m.weight.data.normal_(0.0, 0.02) 108 | if m.bias is not None: 109 | m.bias.data.fill_(0) 110 | if isinstance(m, nn.ConvTranspose2d): 111 | assert m.kernel_size[0] == m.kernel_size[1] 112 | initial_weight = get_upsampling_weight( 113 | m.in_channels, m.out_channels, m.kernel_size[0]) 114 | m.weight.data.copy_(initial_weight) 115 | self.feature = getattr(thismodule, base)(pretrained=pretrained) 116 | self.feature.feats = {} 117 | self.feature = procs[base](self.feature) 118 | for m in self.modules(): 119 | if isinstance(m, nn.BatchNorm2d): 120 | m.requires_grad=False 121 | 122 | def forward(self, x, prior): 123 | self.feature.feats[x.device.index] = [] 124 | x = self.feature(x) 125 | feats = self.feature.feats[x.device.index] 126 | feats += [x] 127 | feats = feats[::-1] 128 | for i, p in enumerate(self.proc_feats_list): feats[i]=p(feats[i]) 129 | feats = torch.cat(feats, 1) 130 | c1 = self.proc_mul(feats*prior).sum(3, keepdim=True).sum(2, keepdim=True)/(prior.sum()) 131 | c2 = self.proc_mul(feats*(1-prior)).sum(3, keepdim=True).sum(2, keepdim=True)/((1-prior).sum()) 132 | feats = self.proc_feats(feats) 133 | dist1 = (feats - c1) ** 2 134 | dist1 = torch.sqrt(dist1.sum(dim=1, keepdim=True)) 135 | dist2 = (feats - c2) ** 2 136 | dist2 = torch.sqrt(dist2.sum(dim=1, keepdim=True)) 137 | return dist2 - dist1 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | net = Net() 143 | net.cuda() 144 | x = torch.Tensor(2, 3, 256, 256).cuda() 145 | z = torch.Tensor(2, 1, 256, 256).cuda() 146 | sb = net(x, z) 147 | pdb.set_trace() 148 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | import torch 5 | from torch.autograd.variable import Variable 6 | import pdb 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 21 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 22 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 23 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 24 | } 25 | 26 | 27 | class VGG(nn.Module): 28 | 29 | def __init__(self, features, num_classes=1000, init_weights=True): 30 | super(VGG, self).__init__() 31 | self.features = features 32 | self.classifier = nn.Sequential( 33 | nn.Linear(512 * 7 * 7, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, 4096), 37 | nn.ReLU(True), 38 | nn.Dropout(), 39 | nn.Linear(4096, num_classes), 40 | ) 41 | if init_weights: 42 | self._initialize_weights() 43 | 44 | def forward(self, x): 45 | for f in self.features: 46 | x = f(x) 47 | return x 48 | 49 | def _initialize_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | elif isinstance(m, nn.BatchNorm2d): 57 | m.weight.data.fill_(1) 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.Linear): 60 | m.weight.data.normal_(0, 0.01) 61 | m.bias.data.zero_() 62 | 63 | 64 | def make_layers(cfg, batch_norm=False): 65 | layers = [] 66 | in_channels = 3 67 | for v in cfg: 68 | if v == 'M': 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 72 | if batch_norm: 73 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 74 | else: 75 | layers += [conv2d, nn.ReLU(inplace=True)] 76 | in_channels = v 77 | return nn.Sequential(*layers) 78 | 79 | 80 | cfg = { 81 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 84 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 85 | } 86 | 87 | 88 | def vgg11(pretrained=False, **kwargs): 89 | """VGG 11-layer model (configuration "A") 90 | 91 | Args: 92 | pretrained (bool): If True, returns a model pre-trained on ImageNet 93 | """ 94 | if pretrained: 95 | kwargs['init_weights'] = False 96 | model = VGG(make_layers(cfg['A']), **kwargs) 97 | if pretrained: 98 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 99 | list_feature = list(model.features) 100 | _features = [nn.Sequential(*list_feature[:3]), 101 | nn.Sequential(*list_feature[3:6]), 102 | nn.Sequential(*list_feature[6:11]), 103 | nn.Sequential(*list_feature[11:16]), 104 | nn.Sequential(*list_feature[16:21])] 105 | model.features = nn.ModuleList(_features) 106 | model.classifier = None 107 | return model 108 | 109 | 110 | def vgg11_bn(pretrained=False, **kwargs): 111 | """VGG 11-layer model (configuration "A") with batch normalization 112 | 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | """ 116 | if pretrained: 117 | kwargs['init_weights'] = False 118 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 119 | if pretrained: 120 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 121 | list_feature = list(model.features) 122 | _features = [nn.Sequential(*list_feature[:4]), 123 | nn.Sequential(*list_feature[4:8]), 124 | nn.Sequential(*list_feature[8:15]), 125 | nn.Sequential(*list_feature[15:22]), 126 | nn.Sequential(*list_feature[22:29])] 127 | model.features = nn.ModuleList(_features) 128 | model.classifier = None 129 | return model 130 | 131 | 132 | def vgg13(pretrained=False, **kwargs): 133 | """VGG 13-layer model (configuration "B") 134 | 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | """ 138 | if pretrained: 139 | kwargs['init_weights'] = False 140 | model = VGG(make_layers(cfg['B']), **kwargs) 141 | if pretrained: 142 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 143 | list_feature = list(model.features) 144 | _features = [nn.Sequential(*list_feature[:5]), 145 | nn.Sequential(*list_feature[5:10]), 146 | nn.Sequential(*list_feature[10:15]), 147 | nn.Sequential(*list_feature[15:20]), 148 | nn.Sequential(*list_feature[20:25])] 149 | model.features = nn.ModuleList(_features) 150 | model.classifier = None 151 | return model 152 | 153 | 154 | def vgg13_bn(pretrained=False, **kwargs): 155 | """VGG 13-layer model (configuration "B") with batch normalization 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | if pretrained: 161 | kwargs['init_weights'] = False 162 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 165 | list_feature = list(model.features) 166 | _features = [nn.Sequential(*list_feature[:7]), 167 | nn.Sequential(*list_feature[7:14]), 168 | nn.Sequential(*list_feature[14:21]), 169 | nn.Sequential(*list_feature[21:28]), 170 | nn.Sequential(*list_feature[28:35])] 171 | model.features = nn.ModuleList(_features) 172 | model.classifier = None 173 | return model 174 | 175 | 176 | def vgg16(pretrained=False, **kwargs): 177 | """VGG 16-layer model (configuration "D") 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | if pretrained: 183 | kwargs['init_weights'] = False 184 | model = VGG(make_layers(cfg['D']), **kwargs) 185 | if pretrained: 186 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 187 | # model.load_state_dict(torch.load('/home/crow/SPN.pytorch/demo/models/vgg16_from_caffe.pth')) 188 | list_feature = list(model.features) 189 | _features = [nn.Sequential(*list_feature[:5]), 190 | nn.Sequential(*list_feature[5:10]), 191 | nn.Sequential(*list_feature[10:17]), 192 | nn.Sequential(*list_feature[17:24]), 193 | nn.Sequential(*list_feature[24:31])] 194 | model.features = nn.ModuleList(_features) 195 | model.classifier = None 196 | return model 197 | 198 | 199 | def vgg16_bn(pretrained=False, **kwargs): 200 | """VGG 16-layer model (configuration "D") with batch normalization 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | if pretrained: 206 | kwargs['init_weights'] = False 207 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 208 | if pretrained: 209 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 210 | list_feature = list(model.features) 211 | _features = [nn.Sequential(*list_feature[:7]), 212 | nn.Sequential(*list_feature[7:14]), 213 | nn.Sequential(*list_feature[14:24]), 214 | nn.Sequential(*list_feature[24:34]), 215 | nn.Sequential(*list_feature[34:44])] 216 | model.features = nn.ModuleList(_features) 217 | model.classifier = None 218 | return model 219 | 220 | 221 | def vgg19(pretrained=False, **kwargs): 222 | """VGG 19-layer model (configuration "E") 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | """ 227 | if pretrained: 228 | kwargs['init_weights'] = False 229 | model = VGG(make_layers(cfg['E']), **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 232 | list_feature = list(model.features) 233 | _features = [nn.Sequential(*list_feature[:5]), 234 | nn.Sequential(*list_feature[5:10]), 235 | nn.Sequential(*list_feature[10:19]), 236 | nn.Sequential(*list_feature[19:28]), 237 | nn.Sequential(*list_feature[28:37])] 238 | model.features = nn.ModuleList(_features) 239 | model.classifier = None 240 | return model 241 | 242 | 243 | def vgg19_bn(pretrained=False, **kwargs): 244 | """VGG 19-layer model (configuration 'E') with batch normalization 245 | 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | """ 249 | if pretrained: 250 | kwargs['init_weights'] = False 251 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 252 | if pretrained: 253 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 254 | list_feature = list(model.features) 255 | _features = [nn.Sequential(*list_feature[:7]), 256 | nn.Sequential(*list_feature[7:14]), 257 | nn.Sequential(*list_feature[14:27]), 258 | nn.Sequential(*list_feature[27:40]), 259 | nn.Sequential(*list_feature[40:53])] 260 | model.features = nn.ModuleList(_features) 261 | model.classifier = None 262 | return model 263 | 264 | 265 | if __name__ == "__main__": 266 | vgg = vgg16(pretrained=True).cuda() 267 | x = torch.Tensor(2, 3, 256, 256).cuda() 268 | sb = vgg(Variable(x)) 269 | pdb.set_trace() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import os 6 | import pdb 7 | import numpy as np 8 | from PIL import Image 9 | import argparse 10 | 11 | from datasets import ImageFiles 12 | from models import Net 13 | from evaluate import fm_and_mae 14 | 15 | from tqdm import tqdm 16 | import random 17 | 18 | random.seed(1996) 19 | 20 | 21 | home = os.path.expanduser("~") 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--prior_dir', default='%s/data/datasets/saliency_Dataset/results/ECSSD-Sal/SRM' % home) # prior maps 25 | parser.add_argument('--img_dir', default='%s/data/datasets/saliency_Dataset/ECSSD/images' % home) # images 26 | parser.add_argument('--gt_dir', default='%s/data/datasets/saliency_Dataset/ECSSD/masks' % home) # ground truth 27 | parser.add_argument('--base', default='vgg16') # training dataset 28 | parser.add_argument('--img_size', type=int, default=256) # image size 29 | parser.add_argument('--b', type=int, default=12) # batch size 30 | opt = parser.parse_args() 31 | print(opt) 32 | 33 | mean = [0.485, 0.456, 0.406] 34 | std = [0.229, 0.224, 0.225] 35 | 36 | 37 | def validate(loader, net, output_dir, gt_dir=None): 38 | if not os.path.exists(output_dir): 39 | os.mkdir(output_dir) 40 | net.eval() 41 | loader = tqdm(loader, desc='validating') 42 | for ib, (data, prior, img_name, w, h) in enumerate(loader): 43 | with torch.no_grad(): 44 | outputs = net(data.cuda(), prior[:, None].cuda()) 45 | outputs = F.sigmoid(outputs) 46 | outputs = outputs.squeeze(1).cpu().numpy() 47 | outputs *= 255 48 | for ii, msk in enumerate(outputs): 49 | msk = Image.fromarray(msk.astype(np.uint8)) 50 | msk = msk.resize((w[ii], h[ii])) 51 | msk.save('{}/{}.png'.format(output_dir, img_name[ii]), 'PNG') 52 | if gt_dir is not None: 53 | fm, mae, _, _ = fm_and_mae(output_dir, gt_dir) 54 | pfm, pmae, _, _ = fm_and_mae(opt.prior_dir, gt_dir) 55 | print('%.4f, %.4f'%(pfm, pmae)) 56 | print('%.4f, %.4f'%(fm, mae)) 57 | 58 | 59 | def main(): 60 | # models 61 | net = Net(base=opt.base) 62 | net = nn.DataParallel(net).cuda() 63 | sdict =torch.load('./net.pth') 64 | net.load_state_dict(sdict) 65 | val_loader = torch.utils.data.DataLoader( 66 | ImageFiles(opt.img_dir, opt.prior_dir, size=256, 67 | mean=mean, std=std), 68 | batch_size=opt.b, shuffle=False, num_workers=4, pin_memory=True) 69 | validate(val_loader, net, 'results', opt.gt_dir) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torchvision 4 | from torchvision import transforms 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | from torchvision.utils import make_grid 8 | from torch.autograd import Variable 9 | from tensorboardX import SummaryWriter 10 | from datetime import datetime 11 | import os 12 | import pdb 13 | import numpy as np 14 | from PIL import Image 15 | import argparse 16 | import json 17 | 18 | from datasets import PriorFolder, Folder 19 | from datasets.saliency import collate_more 20 | from models import Net, FCN 21 | from evaluate import fm_and_mae 22 | 23 | from tqdm import tqdm 24 | import random 25 | 26 | # random.seed(1996) 27 | 28 | 29 | home = os.path.expanduser("~") 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--train_dir', default='%s/data/datasets/saliency_Dataset/DUT-train' % home) # training dataset 33 | parser.add_argument('--prior_dir', default='%s/data/datasets/saliency_Dataset/results/ECSSD-Sal' % home) # training dataset 34 | parser.add_argument('--val_dir', default='%s/data/datasets/saliency_Dataset/ECSSD' % home) # training dataset 35 | parser.add_argument('--base', default='vgg16') # training dataset 36 | parser.add_argument('--img_size', type=int, default=256) # batch size 37 | parser.add_argument('--b', type=int, default=12) # batch size 38 | parser.add_argument('--max', type=int, default=100000) # epoches 39 | opt = parser.parse_args() 40 | print(opt) 41 | 42 | name = 'Train_{}'.format(opt.base) 43 | 44 | mean = [0.485, 0.456, 0.406] 45 | std = [0.229, 0.224, 0.225] 46 | 47 | # tensorboard writer 48 | os.system('rm -rf ./runs_%s/*'%name) 49 | writer = SummaryWriter('./runs_%s/'%name + datetime.now().strftime('%B%d %H:%M:%S')) 50 | if not os.path.exists('./runs_%s'%name): 51 | os.mkdir('./runs_%s'%name) 52 | 53 | 54 | def make_image_grid(img, mean, std): 55 | img = make_grid(img) 56 | for i in range(3): 57 | img[i] *= std[i] 58 | img[i] += mean[i] 59 | return img 60 | 61 | 62 | def validate(loader, net, output_dir, gt_dir): 63 | if not os.path.exists(output_dir): 64 | os.mkdir(output_dir) 65 | net.eval() 66 | loader = tqdm(loader, desc='validating') 67 | for ib, (data, lbl, prior, img_name, w, h) in enumerate(loader): 68 | with torch.no_grad(): 69 | outputs = net(data.cuda(), prior[:, None].cuda()) 70 | outputs = F.sigmoid(outputs) 71 | outputs = outputs.squeeze(1).cpu().numpy() 72 | outputs *= 255 73 | for ii, msk in enumerate(outputs): 74 | msk = Image.fromarray(msk.astype(np.uint8)) 75 | msk = msk.resize((w[ii], h[ii])) 76 | msk.save('{}/{}.png'.format(output_dir, img_name[ii]), 'PNG') 77 | fm, mae, _, _ = fm_and_mae(output_dir, gt_dir) 78 | net.train() 79 | return fm, mae 80 | 81 | 82 | def main(): 83 | 84 | check_dir = '../LPSfiles/' + name 85 | 86 | if not os.path.exists(check_dir): 87 | os.mkdir(check_dir) 88 | 89 | # data 90 | val_loader = torch.utils.data.DataLoader( 91 | PriorFolder(opt.val_dir, opt.prior_dir, size=256, 92 | mean=mean, std=std), 93 | batch_size=opt.b*3, shuffle=False, num_workers=4, pin_memory=True) 94 | train_loader = torch.utils.data.DataLoader( 95 | Folder(opt.train_dir, scales=[64]*3+[128, 256], 96 | crop=0.9, flip=True, rotate=None, 97 | mean=mean, std=std), collate_fn=collate_more, 98 | batch_size=opt.b*6, shuffle=True, num_workers=4, pin_memory=True) 99 | # models 100 | p = 5 101 | net = Net(base=opt.base) 102 | fcn = FCN(net) 103 | net = nn.DataParallel(net).cuda() 104 | net.train() 105 | """ 106 | # fcn = nn.DataParallel(fcn).cuda() 107 | # sdict =torch.load('/home/crow/LPSfiles/Train2_vgg16/fcn-iter13800.pth') 108 | # fcn.load_state_dict(sdict) 109 | fcn = nn.DataParallel(fcn).cuda() 110 | fcn.train() 111 | optimizer = torch.optim.Adam([ 112 | {'params': fcn.parameters(), 'lr': 1e-4}, 113 | ]) 114 | logs = {'best_it':0, 'best': 0} 115 | sal_data_iter = iter(train_loader) 116 | i_sal_data = 0 117 | for it in tqdm(range(opt.max)): 118 | # for it in tqdm(range(1)): 119 | # if it > 1000 and it % 100 == 0: 120 | # optimizer.param_groups[0]['lr'] *= 0.5 121 | if i_sal_data >= len(train_loader): 122 | sal_data_iter = iter(train_loader) 123 | i_sal_data = 0 124 | data, lbls, _ = sal_data_iter.next() 125 | i_sal_data += 1 126 | data = data.cuda() 127 | lbls = [lbl.unsqueeze(1).cuda() for lbl in lbls] 128 | msks = fcn(data) 129 | loss = sum([F.binary_cross_entropy_with_logits(msk, lbl) for msk, lbl in zip(msks, lbls)]) 130 | optimizer.zero_grad() 131 | loss.backward() 132 | optimizer.step() 133 | if it % 10 == 0: 134 | writer.add_scalar('loss', loss.item(), it) 135 | image = make_image_grid(data[:6], mean, std) 136 | writer.add_image('Image', torchvision.utils.make_grid(image), it) 137 | big_msk = F.sigmoid(msks[-1]).expand(-1, 3, -1, -1) 138 | writer.add_image('msk', torchvision.utils.make_grid(big_msk.data[:6]), it) 139 | big_msk = lbls[-1].expand(-1, 3, -1, -1) 140 | writer.add_image('gt', torchvision.utils.make_grid(big_msk.data[:6]), it) 141 | # if it % 100 == 0: 142 | if it != 0 and it % 100 == 0: 143 | fm, mae = validate(val_loader, fcn, os.path.join(check_dir, 'results'), 144 | os.path.join(opt.val_dir, 'masks')) 145 | print(u'损失: %.4f'%(loss.item())) 146 | print(u'最大FM: iteration %d的%.4f, 这次FM: %.4f'%(logs['best_it'], logs['best'], fm)) 147 | logs[it] = {'FM': fm} 148 | if fm > logs['best']: 149 | logs['best'] = fm 150 | logs['best_it'] = it 151 | torch.save(fcn.state_dict(), '%s/fcn-best.pth' % (check_dir)) 152 | with open(os.path.join(check_dir, 'logs.json'), 'w') as outfile: 153 | json.dump(logs, outfile) 154 | torch.save(fcn.state_dict(), '%s/fcn-iter%d.pth' % (check_dir, it)) 155 | """ 156 | ################################################################################################### 157 | val_loader = torch.utils.data.DataLoader( 158 | PriorFolder(opt.val_dir, opt.prior_dir, size=256, 159 | mean=mean, std=std), 160 | batch_size=opt.b, shuffle=False, num_workers=4, pin_memory=True) 161 | train_loader = torch.utils.data.DataLoader( 162 | Folder(opt.train_dir, scales=[256], 163 | crop=0.9, flip=True, rotate=None, 164 | mean=mean, std=std), collate_fn=collate_more, 165 | batch_size=opt.b, shuffle=True, num_workers=4, pin_memory=True) 166 | optimizer = torch.optim.Adam([ 167 | {'params': net.parameters(), 'lr': 1e-4}, 168 | ]) 169 | logs = {'best_it':0, 'best': 0} 170 | sal_data_iter = iter(train_loader) 171 | i_sal_data = 0 172 | for it in tqdm(range(opt.max)): 173 | # if it > 1000 and it % 100 == 0: 174 | # optimizer.param_groups[0]['lr'] *= 0.5 175 | if i_sal_data >= len(train_loader): 176 | sal_data_iter = iter(train_loader) 177 | i_sal_data = 0 178 | data, lbl, _ = sal_data_iter.next() 179 | i_sal_data += 1 180 | data = data.cuda() 181 | lbl = lbl[0].unsqueeze(1) 182 | noisy_label = (lbl.numpy() + np.random.binomial(1, float(p) / 100.0, (256, 256))) % 2 183 | noisy_label = torch.Tensor(noisy_label).cuda() 184 | lbl = lbl.cuda() 185 | msk = net(data, noisy_label) 186 | loss = F.binary_cross_entropy_with_logits(msk, lbl) 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | if it % 10 == 0: 191 | writer.add_scalar('loss', loss.item(), it) 192 | image = make_image_grid(data[:6], mean, std) 193 | writer.add_image('Image', torchvision.utils.make_grid(image), it) 194 | big_msk = F.sigmoid(msk).expand(-1, 3, -1, -1) 195 | writer.add_image('msk', torchvision.utils.make_grid(big_msk.data[:6]), it) 196 | big_msk = lbl.expand(-1, 3, -1, -1) 197 | writer.add_image('gt', torchvision.utils.make_grid(big_msk.data[:6]), it) 198 | # if it % 200 == 0: 199 | if it != 0 and it % 100 == 0: 200 | fm, mae = validate(val_loader, net, os.path.join(check_dir, 'results'), 201 | os.path.join(opt.val_dir, 'masks')) 202 | print(u'损失: %.4f'%(loss.item())) 203 | print(u'最大FM: iteration %d的%.4f, 这次FM: %.4f'%(logs['best_it'], logs['best'], fm)) 204 | logs[it] = {'FM': fm} 205 | if fm > logs['best']: 206 | logs['best'] = fm 207 | logs['best_it'] = it 208 | torch.save(net.state_dict(), '%s/net-best.pth' % (check_dir)) 209 | with open(os.path.join(check_dir, 'logs.json'), 'w') as outfile: 210 | json.dump(logs, outfile) 211 | torch.save(net.state_dict(), '%s/net-iter%d.pth' % (check_dir, it)) 212 | 213 | 214 | if __name__ == "__main__": 215 | main() 216 | --------------------------------------------------------------------------------