├── code ├── utils │ ├── __init__.py │ └── evaluateFM.py ├── FCMNet-main.zip ├── README.md ├── functions.py ├── testdata.py ├── model │ ├── attention_module.py │ ├── model_fusion.py │ ├── model_depth.py │ └── model_baseline.py ├── dataset_loader.py ├── train.py ├── demo.py └── saliency_metric.py └── README.md /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('utils') 3 | 4 | 5 | -------------------------------------------------------------------------------- /code/FCMNet-main.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoJinNK/FCMNet/HEAD/code/FCMNet-main.zip -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # FCMNet 2 | 3 | Prerequisites 4 | 5 | Ubuntu 18 6 | 7 | pytorch 1.7 8 | 9 | CUDA 10.0 10 | 11 | python 3.7 12 | 13 | 14 | Train/Test 15 | 16 | train 17 | 18 | set the param '--phase' as "train" and '--param' as 'True'(loading checkpoint) or 'False'(do not load checkpoint) in demo.py. 19 | 20 | python demo.py 21 | 22 | test 23 | 24 | set the param '--phase' as "test" and '--param' as 'True' in demo.py. 25 | 26 | python demo.py 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FCMNet 2 | Jin, Xiao, et al. "FCMNet: Frequency-aware cross-modality attention networks for RGB-D salient object detection." Neurocomputing 491 (2022): 414-425. 3 | 4 | 5 | 6 | 7 | 8 | @article{jin2022fcmnet, 9 | title={FCMNet: Frequency-aware cross-modality attention networks for RGB-D salient object detection}, 10 | author={Jin, Xiao and Guo, Chunle and He, Zhen and Xu, Jing and Wang, Yongwei and Su, Yuting}, 11 | journal={Neurocomputing}, 12 | volume={491}, 13 | pages={414--425}, 14 | year={2022}, 15 | publisher={Elsevier} 16 | } 17 | -------------------------------------------------------------------------------- /code/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | # from scipy.misc import imresize 5 | from PIL import Image 6 | def imsave(file_name, img, img_size): 7 | """ 8 | save a torch tensor as an image 9 | :param file_name: 'image/folder/image_name' 10 | :param img: 3*h*w torch tensor 11 | :return: nothing 12 | """ 13 | assert(type(img) == torch.FloatTensor, 14 | 'img must be a torch.FloatTensor') 15 | ndim = len(img.shape) 16 | assert(ndim == 2 or ndim == 3, 17 | 'img must be a 2 or 3 dimensional tensor') 18 | 19 | # img = img.numpy() 20 | img = np.array(Image.fromarray(img).resize((img_size[1][0], img_size[0][0]))) 21 | # img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest') 22 | if ndim == 3: 23 | plt.imsave(file_name, np.transpose(img, (1, 2, 0))) 24 | else: 25 | plt.imsave(file_name, img, cmap='gray') 26 | -------------------------------------------------------------------------------- /code/testdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class test_dataset: 6 | def __init__(self, image_root, gt_root): 7 | self.img_list = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png')] 8 | self.image_root = image_root 9 | self.gt_root = gt_root 10 | self.transform = transforms.Compose([ 11 | transforms.ToTensor(), 12 | ]) 13 | self.gt_transform = transforms.ToTensor() 14 | self.size = len(self.img_list) 15 | self.index = 0 16 | 17 | def load_data(self): 18 | #image = self.rgb_loader(self.images[self.index]) 19 | image = self.binary_loader(os.path.join(self.image_root,self.img_list[self.index]+ '.png')) 20 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png')) 21 | self.index += 1 22 | return image, gt 23 | 24 | def rgb_loader(self, path): 25 | with open(path, 'rb') as f: 26 | img = Image.open(f) 27 | return img.convert('RGB') 28 | 29 | def binary_loader(self, path): 30 | with open(path, 'rb') as f: 31 | img = Image.open(f) 32 | return img.convert('L') 33 | 34 | -------------------------------------------------------------------------------- /code/utils/evaluateFM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | # import cv2 4 | import matplotlib.pyplot as plt 5 | import PIL.Image as Image 6 | def get_FM(salpath,gtpath): 7 | 8 | gtdir = gtpath 9 | saldir = salpath 10 | 11 | files = os.listdir(gtdir) 12 | eps = np.finfo(float).eps 13 | 14 | m_pres = np.zeros(21) 15 | m_recs = np.zeros(21) 16 | m_fms = np.zeros(21) 17 | m_thfm = 0 18 | m_mea = 0 19 | it = 1 20 | for i, name in enumerate(files): 21 | if not os.path.exists(gtdir + name): 22 | print(gtdir + name, 'does not exist') 23 | gt = Image.open(gtdir + name) 24 | gt = np.array(gt, dtype=np.uint8) 25 | 26 | 27 | mask=Image.open(saldir+name).convert('L') 28 | mask=mask.resize((np.shape(gt)[1],np.shape(gt)[0])) 29 | mask = np.array(mask, dtype=np.float) 30 | # salmap = cv2.resize(salmap,(W,H)) 31 | 32 | if len(mask.shape) != 2: 33 | mask = mask[:, :, 0] 34 | mask = (mask - mask.min()) / (mask.max() - mask.min() + eps) 35 | gt[gt != 0] = 1 36 | pres = [] 37 | recs = [] 38 | fms = [] 39 | mea = np.abs(gt-mask).mean() 40 | # threshold fm 41 | binary = np.zeros(mask.shape) 42 | th = 2*mask.mean() 43 | if th > 1: 44 | th = 1 45 | binary[mask >= th] = 1 46 | sb = (binary * gt).sum() 47 | pre = sb / (binary.sum()+eps) 48 | rec = sb / (gt.sum()+eps) 49 | thfm = 1.3 * pre * rec / (0.3 * pre + rec + eps) 50 | for th in np.linspace(0, 1, 21): 51 | binary = np.zeros(mask.shape) 52 | binary[ mask >= th] = 1 53 | pre = (binary * gt).sum() / (binary.sum()+eps) 54 | rec = (binary * gt).sum() / (gt.sum()+ eps) 55 | fm = 1.3 * pre * rec / (0.3*pre + rec + eps) 56 | pres.append(pre) 57 | recs.append(rec) 58 | fms.append(fm) 59 | fms = np.array(fms) 60 | pres = np.array(pres) 61 | recs = np.array(recs) 62 | m_mea = m_mea * (it-1) / it + mea / it 63 | m_fms = m_fms * (it - 1) / it + fms / it 64 | m_recs = m_recs * (it - 1) / it + recs / it 65 | m_pres = m_pres * (it - 1) / it + pres / it 66 | m_thfm = m_thfm * (it - 1) / it + thfm / it 67 | it += 1 68 | return m_thfm, m_mea 69 | 70 | if __name__ == '__main__': 71 | m_thfm, m_mea=get_FM() 72 | print(m_thfm) 73 | print(m_mea) -------------------------------------------------------------------------------- /code/model/attention_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | def get_1d_dct(i, freq, L): 5 | result = math.cos(math.pi * freq * (i+0.5)/L) / math.sqrt(L) 6 | if freq == 0: 7 | return result 8 | else: 9 | return result * math.sqrt(2) 10 | def get_dct_weights(width,height,channel,fidx_u,fidx_v): 11 | dct_weights = torch.zeros(1, channel, width, height) 12 | c_part = channel // len(fidx_u) 13 | for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)): 14 | for t_x in range(width): 15 | for t_y in range(height): 16 | dct_weights[:, i*c_part: (i+1)*c_part, t_x, t_y] = get_1d_dct(t_x, u_x, width) * get_1d_dct(t_y, v_y, height) 17 | return dct_weights 18 | class FCABlock(nn.Module): 19 | """ 20 | FcaNet: Frequency Channel Attention Networks 21 | https://arxiv.org/pdf/2012.11879.pdf 22 | """ 23 | def __init__(self, channel,width,height,fidx_u, fidx_v, reduction=16): 24 | super(FCABlock, self).__init__() 25 | mid_channel = channel // reduction 26 | self.register_buffer('pre_computed_dct_weights', get_dct_weights(width,height,channel,fidx_u,fidx_v)) 27 | self.excitation = nn.Sequential( 28 | nn.Linear(channel, mid_channel, bias=False), 29 | nn.ReLU(inplace=True), 30 | nn.Linear(mid_channel, channel, bias=False), 31 | nn.Sigmoid() 32 | ) 33 | def forward(self, x): 34 | b, c, _, _ = x.size() 35 | y = torch.sum(x * self.pre_computed_dct_weights, dim=[2,3]) 36 | z = self.excitation(y).view(b, c, 1, 1) 37 | return x * z.expand_as(x) 38 | class SFCA(nn.Module): 39 | def __init__(self, in_channel,width,height,fidx_u,fidx_v): 40 | super(SFCA, self).__init__() 41 | 42 | fidx_u = [temp_u * (width // 8) for temp_u in fidx_u] 43 | fidx_v = [temp_v * (width // 8) for temp_v in fidx_v] 44 | self.FCA = FCABlock(in_channel, width, height, fidx_u, fidx_v) 45 | self.conv1 = nn.Conv2d(in_channel, 1, kernel_size=1, bias=False) 46 | self.norm = nn.Sigmoid() 47 | def forward(self, x): 48 | # FCA 49 | F_fca = self.FCA(x) 50 | #context attention 51 | con = self.conv1(x) # c,h,w -> 1,h,w 52 | con = self.norm(con) 53 | F_con = x * con 54 | return F_fca + F_con 55 | class FACMA(nn.Module): 56 | def __init__(self,in_channel,width,height,fidx_u,fidx_v): 57 | super(FACMA, self).__init__() 58 | self.sfca_depth = SFCA(in_channel, width, height, fidx_u, fidx_v) 59 | self.sfca_rgb = SFCA(in_channel, width, height, fidx_u, fidx_v) 60 | def forward(self, rgb, depth): 61 | out_d = self.sfca_depth(depth) 62 | out_d = rgb * out_d 63 | 64 | out_rgb = self.sfca_rgb(rgb) 65 | out_rgb = depth * out_rgb 66 | return out_rgb, out_d 67 | -------------------------------------------------------------------------------- /code/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image 4 | import scipy.io as sio 5 | import torch 6 | from torch.utils import data 7 | 8 | class MyData(data.Dataset): 9 | """ 10 | load data in a folder 11 | """ 12 | mean_rgb = np.array([0.447, 0.407, 0.386]) 13 | std_rgb = np.array([0.244, 0.250, 0.253]) 14 | def __init__(self, root, transform=False): 15 | super(MyData, self).__init__() 16 | self.root = root 17 | self._transform = transform 18 | img_root = os.path.join(self.root, 'train_images') 19 | mask_root = os.path.join(self.root, 'train_masks') 20 | depth_root = os.path.join(self.root, 'train_depth') 21 | edge_root = os.path.join(self.root, 'train_edges') 22 | 23 | file_names = os.listdir(img_root) 24 | self.img_names = [] 25 | self.mask_names = [] 26 | self.edge_names = [] 27 | self.depth_names = [] 28 | for i, name in enumerate(file_names): 29 | if not name.endswith('.jpg'): 30 | continue 31 | self.mask_names.append( 32 | os.path.join(mask_root, name[:-4] + '.png') 33 | ) 34 | 35 | self.img_names.append( 36 | os.path.join(img_root, name) 37 | ) 38 | self.edge_names.append( 39 | os.path.join(edge_root, name[:-4] + '.png') 40 | ) 41 | self.depth_names.append( 42 | os.path.join(depth_root, name[:-4] + '.png') 43 | ) 44 | 45 | def __len__(self): 46 | return len(self.img_names) 47 | 48 | def __getitem__(self, index): 49 | # load image 50 | img_file = self.img_names[index] 51 | img = PIL.Image.open(img_file) 52 | img = np.array(img, dtype=np.uint8) 53 | # load label 54 | mask_file = self.mask_names[index] 55 | mask = PIL.Image.open(mask_file) 56 | mask = np.array(mask, dtype=np.int32) 57 | mask[mask != 0] = 1 58 | # load depth 59 | depth_file = self.depth_names[index] 60 | depth = PIL.Image.open(depth_file) 61 | depth = np.array(depth, dtype=np.uint8) 62 | # load edge 63 | edges_file = self.edge_names[index] 64 | edge = PIL.Image.open(edges_file) 65 | edge = np.array(edge, dtype=np.int32) 66 | edge[edge != 0] = 1 67 | 68 | if self._transform: 69 | return self.transform(img, mask, depth, edge) 70 | else: 71 | return img, mask, depth, edge 72 | 73 | def transform(self, img, mask, depth, edge): 74 | img = img.astype(np.float64)/255.0 75 | img -= self.mean_rgb 76 | img /= self.std_rgb 77 | img = img.transpose(2, 0, 1) # to verify 78 | img = torch.from_numpy(img).float() 79 | mask = torch.from_numpy(mask).long() 80 | depth = depth.astype(np.float64) / 255.0 81 | depth = torch.from_numpy(depth).float() 82 | edge = torch.from_numpy(edge).long() 83 | return img, mask, depth, edge 84 | 85 | 86 | class MyTestData(data.Dataset): 87 | """ 88 | load data in a folder 89 | """ 90 | mean_rgb = np.array([0.447, 0.407, 0.386]) 91 | std_rgb = np.array([0.244, 0.250, 0.253]) 92 | 93 | def __init__(self, root, transform=False): 94 | super(MyTestData, self).__init__() 95 | self.root = root 96 | self._transform = transform 97 | 98 | img_root = os.path.join(self.root, 'test_images') 99 | depth_root = os.path.join(self.root, 'test_depth') 100 | file_names = os.listdir(img_root) 101 | self.img_names = [] 102 | self.names = [] 103 | self.depth_names = [] 104 | 105 | for i, name in enumerate(file_names): 106 | if not name.endswith('.jpg'): 107 | continue 108 | self.img_names.append( 109 | os.path.join(img_root, name) 110 | ) 111 | self.names.append(name[:-4]) 112 | self.depth_names.append( 113 | # os.path.join(depth_root, name[:-4] + '_depth.png') 114 | os.path.join(depth_root, name[:-4] + '.png') 115 | ) 116 | 117 | def __len__(self): 118 | return len(self.img_names) 119 | 120 | def __getitem__(self, index): 121 | # load image 122 | img_file = self.img_names[index] 123 | img = PIL.Image.open(img_file) 124 | img_size = img.size 125 | img = np.array(img, dtype=np.uint8) 126 | 127 | # load focal 128 | depth_file = self.depth_names[index] 129 | depth = PIL.Image.open(depth_file) 130 | depth = np.array(depth, dtype=np.uint8) 131 | if self._transform: 132 | img, focal = self.transform(img, depth) 133 | return img, focal, self.names[index], img_size 134 | else: 135 | return img, depth, self.names[index], img_size 136 | 137 | def transform(self, img, depth): 138 | img = img.astype(np.float64)/255.0 139 | img -= self.mean_rgb 140 | img /= self.std_rgb 141 | img = img.transpose(2, 0, 1) 142 | img = torch.from_numpy(img).float() 143 | depth = depth.astype(np.float64)/255.0 144 | depth = torch.from_numpy(depth).float() 145 | 146 | return img, depth 147 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import torch 5 | import matplotlib.pylab as plt 6 | running_loss_final = 0 7 | 8 | def cross_entropy2d(input, target, weight=None, size_average=True): 9 | n, c, h, w = input.size() 10 | # print(n,c,h,w) 11 | input = input.transpose(1, 2).transpose(2, 3).contiguous() 12 | # print(input.shape,target.shape) 13 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] # 262144 #input = 2*256*256*2 14 | input = input.view(-1, c) 15 | mask = target >= 0 16 | target = target[mask] 17 | loss = F.cross_entropy(input, target, weight=weight, size_average=False) 18 | if size_average: 19 | loss /= mask.data.sum() 20 | return loss 21 | 22 | def cross_entropy2d_edge(input, target, reduction='sum'): 23 | assert (input.size() == target.size()) 24 | pos = torch.eq(target, 1).float() 25 | neg = torch.eq(target, 0).float() 26 | # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float() 27 | num_pos = torch.sum(pos) 28 | num_neg = torch.sum(neg) 29 | num_total = num_pos + num_neg 30 | 31 | alpha = num_neg / num_total 32 | beta = 1.1 * num_pos / num_total 33 | # target pixel = 1 -> weight beta 34 | # target pixel = 0 -> weight 1-beta 35 | weights = alpha * pos + beta * neg 36 | 37 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) 38 | 39 | 40 | class Trainer(object): 41 | def __init__(self, cuda, model_depth, model_baseline, model_fusion, optimizer_depth, optimizer_baseline, optimizer_ladder, 42 | train_loader, max_iter, snapshot, outpath, sshow, size_average=False): 43 | self.cuda = cuda 44 | self.model_depth = model_depth 45 | self.model_baseline = model_baseline 46 | self.model_fusion = model_fusion 47 | self.optim_depth = optimizer_depth 48 | self.optim_baseline = optimizer_baseline 49 | self.optim_ladder = optimizer_ladder 50 | self.train_loader = train_loader 51 | self.epoch = 0 52 | self.iteration = 0 53 | self.max_iter = max_iter 54 | self.snapshot = snapshot 55 | self.outpath = outpath 56 | self.sshow = sshow 57 | self.size_average = size_average 58 | 59 | def train_epoch(self): 60 | for batch_idx, (img, mask, depth, edge) in enumerate(self.train_loader): 61 | iteration = batch_idx + self.epoch * len(self.train_loader) 62 | if self.iteration != 0 and (iteration - 1) != self.iteration: 63 | continue # for resuming 64 | self.iteration = iteration 65 | if self.iteration >= self.max_iter: 66 | break 67 | if self.cuda: 68 | img, mask, depth, edge = img.cuda(), mask.cuda(), depth.cuda(), edge.cuda() 69 | img, mask, depth, edge = Variable(img), Variable(mask), Variable(depth), Variable(edge) 70 | n, c, h, w = img.size() # batch_size, channels, height, weight 71 | 72 | self.optim_depth.zero_grad() 73 | self.optim_baseline.zero_grad() 74 | self.optim_ladder.zero_grad() 75 | 76 | global running_loss_final 77 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1) 78 | # depth = depth.view(n, 1, h, w) 79 | mask = mask.view(n, 1, h, w) 80 | edge = edge.view(n, 1, h, w) 81 | d2, d3, d4, d5 = self.model_depth(depth) 82 | h2, h3, h4, h5 = self.model_baseline(img) 83 | p2, p3, p4, p5, p = self.model_fusion(img,h2, h3, h4, h5, d2, d3, d4, d5) 84 | 85 | mask = mask.to(torch.float32) 86 | edge = edge.to(torch.float32) 87 | loss_e = cross_entropy2d_edge(p2, edge) 88 | loss_0 = F.binary_cross_entropy_with_logits(p3, mask, reduction='sum') 89 | loss_1 = F.binary_cross_entropy_with_logits(p4, mask, reduction='sum') 90 | loss_2 = F.binary_cross_entropy_with_logits(p5, mask, reduction='sum') 91 | loss_3 = F.binary_cross_entropy_with_logits(p , mask, reduction='sum') 92 | 93 | 94 | loss_all = loss_e + loss_0 + loss_1 + loss_2 + loss_3 95 | running_loss_final = loss_all.item() 96 | 97 | if iteration % self.sshow == (self.sshow - 1): 98 | print('\n [%3d, %6d, RGB-D Net loss: %.3f]' % ( 99 | self.epoch + 1, iteration + 1, running_loss_final / (n * self.sshow) )) 100 | 101 | running_loss_final = 0.0 102 | 103 | if iteration <= 200000: 104 | if iteration % self.snapshot == (self.snapshot - 1): 105 | savename_depth = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 106 | torch.save(self.model_depth.state_dict(), savename_depth) 107 | print('save: (snapshot: %d)' % (iteration + 1)) 108 | 109 | savename_baseline = ('%s/baseline_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 110 | torch.save(self.model_baseline.state_dict(), savename_baseline) 111 | print('save: (snapshot: %d)' % (iteration + 1)) 112 | 113 | savename_ladder = ('%s/ladder_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 114 | torch.save(self.model_fusion.state_dict(), savename_ladder) 115 | print('save: (snapshot: %d)' % (iteration + 1)) 116 | else: 117 | 118 | if iteration % 10000 == (10000 - 1): 119 | savename_depth = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 120 | torch.save(self.model_depth.state_dict(), savename_depth) 121 | print('save: (snapshot: %d)' % (iteration + 1)) 122 | 123 | savename_baseline = ('%s/baseline_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 124 | torch.save(self.model_baseline.state_dict(), savename_baseline) 125 | print('save: (snapshot: %d)' % (iteration + 1)) 126 | 127 | savename_ladder = ('%s/ladder_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 128 | torch.save(self.model_fusion.state_dict(), savename_ladder) 129 | print('save: (snapshot: %d)' % (iteration + 1)) 130 | 131 | if (iteration + 1) == self.max_iter: 132 | savename_depth = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 133 | torch.save(self.model_depth.state_dict(), savename_depth) 134 | print('save: (snapshot: %d)' % (iteration + 1)) 135 | 136 | savename_baseline = ('%s/baseline_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 137 | torch.save(self.model_baseline.state_dict(), savename_baseline) 138 | print('save: (snapshot: %d)' % (iteration + 1)) 139 | 140 | savename_ladder = ('%s/ladder_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 141 | torch.save(self.model_fusion.state_dict(), savename_ladder) 142 | print('save: (snapshot: %d)' % (iteration + 1)) 143 | 144 | loss_all.backward() 145 | self.optim_depth.step() 146 | self.optim_baseline.step() 147 | self.optim_ladder.step() 148 | 149 | def train(self): 150 | max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader))) 151 | 152 | for epoch in range(max_epoch): 153 | self.epoch = epoch 154 | self.train_epoch() 155 | if self.iteration >= self.max_iter: 156 | break 157 | -------------------------------------------------------------------------------- /code/model/model_fusion.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import os 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from attention_module import FACMA 10 | 11 | logger = logging.getLogger(__name__) 12 | class ASPP_module(nn.Module): 13 | def __init__(self, inplanes, planes, rate): 14 | super(ASPP_module, self).__init__() 15 | if rate == 1: 16 | kernel_size = 1 17 | padding = 0 18 | else: 19 | kernel_size = 3 20 | padding = rate 21 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 22 | stride=1, padding=padding, dilation=rate, bias=False) 23 | self.bn = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU() 25 | def forward(self, x): 26 | x = self.atrous_convolution(x) 27 | x = self.bn(x) 28 | return self.relu(x) 29 | 30 | class ASPP(nn.Module): 31 | def __init__(self, inplanes, planes, rates): 32 | super(ASPP, self).__init__() 33 | 34 | self.aspp1 = ASPP_module(inplanes, planes, rate=rates[0]) 35 | self.aspp2 = ASPP_module(inplanes, planes, rate=rates[1]) 36 | self.aspp3 = ASPP_module(inplanes, planes, rate=rates[2]) 37 | self.aspp4 = ASPP_module(inplanes, planes, rate=rates[3]) 38 | 39 | self.relu = nn.ReLU() 40 | 41 | self.global_avg_pool = nn.Sequential( 42 | nn.AdaptiveAvgPool2d((1, 1)), 43 | nn.Conv2d(inplanes, planes, 1, stride=1, bias=False), 44 | nn.BatchNorm2d(planes), 45 | nn.ReLU() 46 | ) 47 | self.conv1 = nn.Conv2d(planes*5, planes, 1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | 50 | def forward(self, x): 51 | x1 = self.aspp1(x) 52 | x2 = self.aspp2(x) 53 | x3 = self.aspp3(x) 54 | x4 = self.aspp4(x) 55 | x5 = self.global_avg_pool(x) 56 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 57 | 58 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 59 | x = self.conv1(x) 60 | x = self.bn1(x) 61 | x = self.relu(x) 62 | return x 63 | 64 | 65 | class WCMF(nn.Module): 66 | def __init__(self,channel=256): 67 | super(WCMF, self).__init__() 68 | self.conv_r1 = nn.Sequential(nn.Conv2d(channel, channel, 1, 1, 0), nn.BatchNorm2d(channel), nn.ReLU()) 69 | self.conv_d1 = nn.Sequential(nn.Conv2d(channel, channel, 1, 1, 0), nn.BatchNorm2d(channel), nn.ReLU()) 70 | 71 | self.conv_c1 = nn.Sequential(nn.Conv2d(2*channel, channel, 3, 1, 1), nn.BatchNorm2d(channel), nn.ReLU()) 72 | self.conv_c2 = nn.Sequential(nn.Conv2d(channel, 2, 3, 1, 1), nn.BatchNorm2d(2), nn.ReLU()) 73 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 74 | def fusion(self,f1,f2,f_vec): 75 | 76 | w1 = f_vec[:, 0, :, :].unsqueeze(1) 77 | w2 = f_vec[:, 1, :, :].unsqueeze(1) 78 | out1 = (w1 * f1) + (w2 * f2) 79 | out2 = (w1 * f1) * (w2 * f2) 80 | return out1 + out2 81 | def forward(self,rgb,depth): 82 | Fr = self.conv_r1(rgb) 83 | Fd = self.conv_d1(depth) 84 | f = torch.cat([Fr, Fd],dim=1) 85 | f = self.conv_c1(f) 86 | f = self.conv_c2(f) 87 | # f = self.avgpool(f) 88 | Fo = self.fusion(Fr, Fd, f) 89 | return Fo 90 | 91 | class FusionNet(nn.Module): 92 | def __init__(self): 93 | super(FusionNet, self).__init__() 94 | fidx_u = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2] 95 | fidx_v = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2] 96 | 97 | self.FACMA1 = FACMA(128, 64, 64, fidx_u, fidx_v) 98 | self.FACMA2 = FACMA(256, 32, 32, fidx_u, fidx_v) 99 | self.FACMA3 = FACMA(512, 16, 16, fidx_u, fidx_v) 100 | self.FACMA4 = FACMA(512, 8, 8, fidx_u, fidx_v) 101 | 102 | 103 | self.WCMF2 = WCMF(128) 104 | self.WCMF3 = WCMF(256) 105 | self.WCMF4 = WCMF(512) 106 | self.WCMF5 = WCMF(512) 107 | 108 | self.relu = nn.ReLU(inplace=True) 109 | self.cp2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), self.relu, nn.Conv2d(128, 128, 3, 1, 1), self.relu, 110 | nn.Conv2d(128, 64, 3, 1, 1), self.relu) 111 | 112 | self.cp3 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), self.relu, nn.Conv2d(128, 128, 3, 1, 1), self.relu, 113 | nn.Conv2d(128, 64, 3, 1, 1), self.relu) 114 | 115 | self.cp4 = nn.Sequential(nn.Conv2d(512, 256, 5, 1, 2), self.relu, nn.Conv2d(256, 128, 5, 1, 2), self.relu, 116 | nn.Conv2d(128, 64, 3, 1, 1), self.relu) 117 | 118 | self.cp5 = nn.Sequential(nn.Conv2d(512, 256, 5, 1, 2), self.relu, nn.Conv2d(256, 128, 5, 1, 2), self.relu, 119 | nn.Conv2d(128, 64, 3, 1, 1), self.relu) 120 | 121 | rates = [1, 6, 12, 18] 122 | self.ASPP1 = ASPP(64, 64, rates) 123 | self.ASPP2 = ASPP(64, 64, rates) 124 | self.ASPP3 = ASPP(64, 64, rates) 125 | self.ASPP4 = ASPP(64, 64, rates) 126 | 127 | 128 | self.conv_2 = nn.Conv2d(64, 1, 3, 1, 1) 129 | 130 | self.conv_3 = nn.Conv2d(64, 1, 3, 1, 1) 131 | self.conv_4 = nn.Conv2d(64, 1, 3, 1, 1) 132 | self.conv_5 = nn.Conv2d(64, 1, 3, 1, 1) 133 | self.conv_o = nn.Conv2d(64, 1, 3, 1, 1) 134 | 135 | def forward(self, img, h2, h3, h4, h5, d2, d3, d4, d5): 136 | raw_size = img.size()[2:] 137 | 138 | rf2, rd2 = self.FACMA1(h2, d2) # 64*64*128 139 | rf3, rd3 = self.FACMA2(h3, d3) # 32*32*256 140 | rf4, rd4 = self.FACMA3(h4, d4) # 16*16*512 141 | rf5, rd5 = self.FACMA4(h5, d5) # 8 *8 *512 142 | 143 | F2 = self.WCMF2(rf2, rd2) 144 | F3 = self.WCMF3(rf3, rd3) 145 | F4 = self.WCMF4(rf4, rd4) 146 | F5 = self.WCMF5(rf5, rd5) 147 | 148 | F2 = self.cp2(F2) 149 | F3 = self.cp3(F3) 150 | F4 = self.cp4(F4) 151 | F5 = self.cp5(F5) 152 | # print("//////", F5.shape) 153 | F5_A = self.ASPP1(F5) 154 | F4_A = self.ASPP2(F4 + F.interpolate(F5_A, F4.shape[2:], mode='bilinear')) 155 | F3_A = self.ASPP3(F3 + F.interpolate(F5_A, F3.shape[2:], mode='bilinear') + F.interpolate(F4_A, F3.shape[2:], 156 | mode='bilinear')) 157 | F2_A = self.ASPP4(F2 + F.interpolate(F5_A, F2.shape[2:], mode='bilinear') + F.interpolate(F4_A, F2.shape[2:], 158 | mode='bilinear') + F.interpolate(F3_A, F2.shape[2:], mode='bilinear')) 159 | 160 | Fo_2 = F.interpolate(self.conv_2(F2), raw_size, mode='bilinear') 161 | Fo_3 = F.interpolate(self.conv_3(F3), raw_size, mode='bilinear') 162 | Fo_4 = F.interpolate(self.conv_4(F4), raw_size, mode='bilinear') 163 | Fo_5 = F.interpolate(self.conv_5(F5), raw_size, mode='bilinear') 164 | 165 | Fo = F.interpolate(self.conv_2(F2_A), raw_size, mode='bilinear') 166 | return Fo_2, Fo_3, Fo_4, Fo_5, Fo 167 | def init_weights(self): 168 | logger.info('=> init weights from normal distribution') 169 | for m in self.modules(): 170 | if isinstance(m, nn.Conv2d): 171 | nn.init.normal_(m.weight, std=0.001) 172 | for name, _ in m.named_parameters(): 173 | if name in ['bias']: 174 | nn.init.constant_(m.bias, 0) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | nn.init.constant_(m.weight, 1) 177 | nn.init.constant_(m.bias, 0) 178 | elif isinstance(m, nn.ConvTranspose2d): 179 | nn.init.normal_(m.weight, std=0.001) 180 | for name, _ in m.named_parameters(): 181 | if name in ['bias']: 182 | nn.init.constant_(m.bias, 0) 183 | 184 | 185 | -------------------------------------------------------------------------------- /code/demo.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import torch.optim as optim 7 | from dataset_loader import MyData, MyTestData 8 | from functions import imsave 9 | import argparse 10 | from train import Trainer 11 | from model.model_depth import DepthNet 12 | from model.model_baseline import BaselineNet 13 | from model.model_fusion import FusionNet 14 | import time 15 | import torchvision 16 | from tqdm import tqdm 17 | from torchsummary import summary 18 | import os 19 | from testdata import test_dataset 20 | from saliency_metric import cal_mae,cal_fm,cal_sm,cal_em,cal_wfm 21 | 22 | configurations = { 23 | 1: dict( 24 | max_iteration=600000, 25 | lr=1.0e-10, 26 | momentum=0.99, 27 | weight_decay=0.0005, 28 | spshot=20000, 29 | nclass=2, 30 | sshow=10, 31 | ) 32 | } 33 | 34 | parser=argparse.ArgumentParser() 35 | parser.add_argument('--phase', type=str, default='train', help='train or test') 36 | parser.add_argument('--param', type=str, default=False, help='path to pre-trained parameters') 37 | parser.add_argument('--train_dataroot', type=str, default='/media/ubuntu/新加卷/Zhen He/RGB-D/rgb-d code/dataset/SOD-RGBD/train_data-augment', help= 38 | 'path to train data') 39 | parser.add_argument('--test_dataroot', type=str, default='/media/ubuntu/新加卷/Zhen He/RGB-D/rgb-d code/dataset/SOD-RGBD/val/DUT-RGBD', help= 40 | 'path to test data') 41 | parser.add_argument('--snapshot_root', type=str, default='./checkpoint', help='path to snapshot') 42 | parser.add_argument('--salmap_root', type=str, default='./sal_map/DUT-RGBD/', help='path to saliency map') 43 | parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys()) 44 | args = parser.parse_args() 45 | cfg = configurations 46 | cuda = torch.cuda.is_available 47 | 48 | """""""""""dataset loader""""""""" 49 | train_dataRoot = args.train_dataroot 50 | test_dataRoot = args.test_dataroot 51 | 52 | if not os.path.exists(args.snapshot_root): 53 | os.mkdir(args.snapshot_root) 54 | if not os.path.exists(args.salmap_root): 55 | os.mkdir(args.salmap_root) 56 | if args.phase == 'train': 57 | SnapRoot = args.snapshot_root # checkpoint 58 | train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), 59 | batch_size=4, shuffle=True, num_workers=4, pin_memory=True,drop_last=True) 60 | else: 61 | MapRoot = args.salmap_root 62 | test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), 63 | batch_size=1, shuffle=True, num_workers=4, pin_memory=True,drop_last=True) 64 | print ('data already') 65 | """""""""""train_data/test_data through nets""""""""" 66 | start_epoch = 0 67 | start_iteration = 0 68 | model_depth = DepthNet() 69 | model_baseline = BaselineNet() 70 | model_fusion = FusionNet() 71 | # print(model_rgb) 72 | if args.param is True: 73 | # ckpt = str(ckpt) 74 | ckpt = '60' 75 | model_depth.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'depth_snapshot_iter_' + ckpt + '0000.pth'))) 76 | model_baseline.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'baseline_snapshot_iter_'+ckpt+'0000.pth'))) 77 | model_fusion.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'ladder_snapshot_iter_'+ckpt+'0000.pth'))) 78 | else: 79 | # model_depth.init_weights() 80 | vgg16_bn = torchvision.models.vgg16_bn(pretrained=True) 81 | model_depth.copy_params_from_vgg16_bn(vgg16_bn) 82 | model_baseline.copy_params_from_vgg16_bn(vgg16_bn) 83 | model_fusion.init_weights() 84 | if cuda: 85 | model_depth = model_depth.cuda() 86 | model_baseline = model_baseline.cuda() 87 | model_fusion = model_fusion.cuda() 88 | 89 | if args.phase == 'train': 90 | optimizer_depth = optim.SGD(model_depth.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay']) 91 | optimizer_baseline = optim.SGD(model_baseline.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay']) 92 | optimizer_ladder = optim.SGD(model_fusion.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay']) 93 | 94 | training = Trainer( 95 | cuda=cuda, 96 | model_depth=model_depth, 97 | model_baseline=model_baseline, 98 | model_fusion=model_fusion, 99 | optimizer_depth=optimizer_depth, 100 | optimizer_baseline=optimizer_baseline, 101 | optimizer_ladder=optimizer_ladder, 102 | train_loader=train_loader, 103 | max_iter=cfg[1]['max_iteration'], 104 | snapshot=cfg[1]['spshot'], 105 | outpath=args.snapshot_root, 106 | sshow=cfg[1]['sshow'] 107 | ) 108 | training.epoch = start_epoch 109 | training.iteration = start_iteration 110 | training.train() 111 | else: 112 | res = [] 113 | for id, (data, depth, img_name, img_size) in enumerate(test_loader): 114 | # print('testing bach %d' % id) 115 | inputs = Variable(data).cuda() 116 | depth = Variable(depth).cuda() 117 | n, c, h, w = inputs.size() 118 | # depth = torch.unsqueeze(depth, 1) 119 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1) 120 | # depth = depth.view(n, 1, h, w) 121 | torch.cuda.synchronize() 122 | start = time.time() 123 | model_fusion.eval() 124 | 125 | h2,h3, h4, h5 = model_baseline(inputs) 126 | d2,d3, d4, d5 = model_depth(depth) 127 | p2, p3, p4, p5, p = model_fusion(inputs, h2, h3, h4, h5, d2,d3, d4, d5) 128 | torch.cuda.synchronize() 129 | end = time.time() 130 | res.append(end - start) 131 | 132 | pred = torch.sigmoid(p) 133 | outputs = pred[0, 0].detach().cpu().numpy() 134 | imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) 135 | # imsave(os.path.join(MapRoot,img_name[0][-1] + '.png'), outputs, img_size) 136 | time_sum = 0 137 | for i in res: 138 | time_sum += i 139 | print("FPS: %f" % (1.0 / (time_sum / len(res)))) 140 | 141 | # -------------------------- validation --------------------------- # 142 | sal_root = MapRoot 143 | gt_root = test_dataRoot + '/test_masks/' 144 | dataset = test_dataRoot.split('/')[-1] 145 | test_loader = test_dataset(sal_root, gt_root) 146 | mae, fm, sm, em, wfm = cal_mae(), cal_fm(test_loader.size), cal_sm(), cal_em(), cal_wfm() 147 | for i in tqdm(range(test_loader.size)): 148 | sal, gt = test_loader.load_data() 149 | if sal.size != gt.size: 150 | x, y = gt.size 151 | sal = sal.resize((x, y)) 152 | gt = np.asarray(gt, np.float32) 153 | gt /= (gt.max() + 1e-8) 154 | gt[gt > 0.5] = 1 155 | gt[gt != 1] = 0 156 | res = sal 157 | res = np.array(res) 158 | if res.max() == res.min(): 159 | res = res / 255 160 | else: 161 | res = (res - res.min()) / (res.max() - res.min()) 162 | mae.update(res, gt) 163 | sm.update(res, gt) 164 | fm.update(res, gt) 165 | em.update(res, gt) 166 | wfm.update(res, gt) 167 | 168 | MAE = mae.show() 169 | maxf, meanf, _, _ = fm.show() 170 | sm = sm.show() 171 | em = em.show() 172 | wfm = wfm.show() 173 | print( 174 | 'dataset: {} MAE: {:.4f} maxF: {:.4f} avgF: {:.4f} wfm: {:.4f} Sm: {:.4f} Em: {:.4f}'.format(dataset, MAE, maxf, 175 | meanf, wfm, sm,em)) 176 | # summary(model_baseline,input_size=(3,256,256)) 177 | # print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 178 | # summary(model_depth,input_size=(3,256,256)) 179 | # print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") 180 | # summary(model_fusion,input_size=[(3,256,256),(128,64,64),(256,32,32),(512,16,16),(512,8,8),(128,64,64),(256,32,32),(512,16,16),(512,8,8)]) -------------------------------------------------------------------------------- /code/model/model_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 7 | """Make a 2D bilinear kernel suitable for upsampling""" 8 | factor = (kernel_size + 1) // 2 9 | if kernel_size % 2 == 1: 10 | center = factor - 1 11 | else: 12 | center = factor - 0.5 13 | og = np.ogrid[:kernel_size, :kernel_size] 14 | filt = (1 - abs(og[0] - center) / factor) * \ 15 | (1 - abs(og[1] - center) / factor) 16 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 17 | dtype=np.float64) 18 | weight[range(in_channels), range(out_channels), :, :] = filt 19 | return torch.from_numpy(weight).float() 20 | #################################### Baseline Network ##################################### 21 | class DepthNet(nn.Module): 22 | def __init__(self): 23 | super(DepthNet, self).__init__() 24 | # original image's size = 256*256*3 25 | # conv1 26 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 27 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 28 | self.relu1_1 = nn.ReLU(inplace=True) 29 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 30 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 31 | self.relu1_2 = nn.ReLU(inplace=True) 32 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 33 | 34 | # conv2 35 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 36 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 37 | self.relu2_1 = nn.ReLU(inplace=True) 38 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 39 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 40 | self.relu2_2 = nn.ReLU(inplace=True) 41 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 42 | 43 | # conv3 44 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 45 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 46 | self.relu3_1 = nn.ReLU(inplace=True) 47 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 48 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 49 | self.relu3_2 = nn.ReLU(inplace=True) 50 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 51 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 52 | self.relu3_3 = nn.ReLU(inplace=True) 53 | # self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 54 | # self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 55 | # self.relu3_4 = nn.ReLU(inplace=True) 56 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 57 | 58 | # conv4 59 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 60 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 61 | self.relu4_1 = nn.ReLU(inplace=True) 62 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 63 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 64 | self.relu4_2 = nn.ReLU(inplace=True) 65 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 66 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 67 | self.relu4_3 = nn.ReLU(inplace=True) 68 | # self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 69 | # self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 70 | # self.relu4_4 = nn.ReLU(inplace=True) 71 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 72 | 73 | # conv5 74 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 75 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 76 | self.relu5_1 = nn.ReLU(inplace=True) 77 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 78 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 79 | self.relu5_2 = nn.ReLU(inplace=True) 80 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 81 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 82 | self.relu5_3 = nn.ReLU(inplace=True) 83 | # self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 84 | # self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 85 | # self.relu5_4 = nn.ReLU(inplace=True) 86 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 4 layers 87 | self._initialize_weights() 88 | 89 | 90 | def _initialize_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | # m.weight.data.zero_() 94 | nn.init.normal(m.weight.data, std=0.01) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | if isinstance(m, nn.ConvTranspose2d): 98 | assert m.kernel_size[0] == m.kernel_size[1] 99 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 100 | m.weight.data.copy_(initial_weight) 101 | 102 | 103 | 104 | def forward(self, x): 105 | h = x 106 | 107 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 108 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 109 | h1 = self.pool1(h) # (128x128)*64 110 | 111 | h = self.relu2_1(self.bn2_1(self.conv2_1(h1))) 112 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 113 | h2 = self.pool2(h) # (64x64)*128 114 | 115 | h = self.relu3_1(self.bn3_1(self.conv3_1(h2))) 116 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 117 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 118 | # h = self.relu3_4(self.bn3_4(self.conv3_4(h))) 119 | h3 = self.pool3(h)# (32x32)*256 120 | 121 | h = self.relu4_1(self.bn4_1(self.conv4_1(h3))) 122 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 123 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 124 | # h = self.relu4_4(self.bn4_4(self.conv4_4(h))) 125 | h4 = self.pool4(h)# (16x16)*512 126 | 127 | h = self.relu5_1(self.bn5_1(self.conv5_1(h4))) 128 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 129 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 130 | # h = self.relu5_4(self.bn5_4(self.conv5_4(h))) 131 | h5 = self.pool5(h)#(8x8)*512 132 | 133 | return h2,h3,h4,h5 134 | 135 | 136 | def copy_params_from_vgg16_bn(self, vgg16_bn): 137 | features = [ 138 | self.conv1_1, self.bn1_1, self.relu1_1, 139 | self.conv1_2, self.bn1_2, self.relu1_2, 140 | self.pool1, 141 | self.conv2_1, self.bn2_1, self.relu2_1, 142 | self.conv2_2, self.bn2_2, self.relu2_2, 143 | self.pool2, 144 | self.conv3_1, self.bn3_1, self.relu3_1, 145 | self.conv3_2, self.bn3_2, self.relu3_2, 146 | self.conv3_3, self.bn3_3, self.relu3_3, 147 | # self.conv3_4, self.bn3_4, self.relu3_4, 148 | self.pool3, 149 | self.conv4_1, self.bn4_1, self.relu4_1, 150 | self.conv4_2, self.bn4_2, self.relu4_2, 151 | self.conv4_3, self.bn4_3, self.relu4_3, 152 | # self.conv4_4, self.bn4_4, self.relu4_4, 153 | self.pool4, 154 | self.conv5_1, self.bn5_1, self.relu5_1, 155 | self.conv5_2, self.bn5_2, self.relu5_2, 156 | self.conv5_3, self.bn5_3, self.relu5_3, 157 | # self.conv5_4, self.bn5_4, self.relu5_4, 158 | ] 159 | for l1, l2 in zip(vgg16_bn.features, features): 160 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 161 | assert l1.weight.size() == l2.weight.size() 162 | assert l1.bias.size() == l2.bias.size() 163 | l2.weight.data = l1.weight.data 164 | l2.bias.data = l1.bias.data 165 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 166 | assert l1.weight.size() == l2.weight.size() 167 | assert l1.bias.size() == l2.bias.size() 168 | l2.weight.data = l1.weight.data 169 | l2.bias.data = l1.bias.data -------------------------------------------------------------------------------- /code/model/model_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 7 | """Make a 2D bilinear kernel suitable for upsampling""" 8 | factor = (kernel_size + 1) // 2 9 | if kernel_size % 2 == 1: 10 | center = factor - 1 11 | else: 12 | center = factor - 0.5 13 | og = np.ogrid[:kernel_size, :kernel_size] 14 | filt = (1 - abs(og[0] - center) / factor) * \ 15 | (1 - abs(og[1] - center) / factor) 16 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 17 | dtype=np.float64) 18 | weight[range(in_channels), range(out_channels), :, :] = filt 19 | return torch.from_numpy(weight).float() 20 | #################################### Baseline Network ##################################### 21 | class BaselineNet(nn.Module): 22 | def __init__(self): 23 | super(BaselineNet, self).__init__() 24 | # original image's size = 256*256*3 25 | # conv1 26 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 27 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 28 | self.relu1_1 = nn.ReLU(inplace=True) 29 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 30 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 31 | self.relu1_2 = nn.ReLU(inplace=True) 32 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 33 | 34 | # conv2 35 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 36 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 37 | self.relu2_1 = nn.ReLU(inplace=True) 38 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 39 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 40 | self.relu2_2 = nn.ReLU(inplace=True) 41 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 42 | 43 | # conv3 44 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 45 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 46 | self.relu3_1 = nn.ReLU(inplace=True) 47 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 48 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 49 | self.relu3_2 = nn.ReLU(inplace=True) 50 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 51 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 52 | self.relu3_3 = nn.ReLU(inplace=True) 53 | # self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 54 | # self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 55 | # self.relu3_4 = nn.ReLU(inplace=True) 56 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 57 | 58 | # conv4 59 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 60 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 61 | self.relu4_1 = nn.ReLU(inplace=True) 62 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 63 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 64 | self.relu4_2 = nn.ReLU(inplace=True) 65 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 66 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 67 | self.relu4_3 = nn.ReLU(inplace=True) 68 | # self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 69 | # self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 70 | # self.relu4_4 = nn.ReLU(inplace=True) 71 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 72 | 73 | # conv5 74 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 75 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 76 | self.relu5_1 = nn.ReLU(inplace=True) 77 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 78 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 79 | self.relu5_2 = nn.ReLU(inplace=True) 80 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 81 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 82 | self.relu5_3 = nn.ReLU(inplace=True) 83 | # self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 84 | # self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 85 | # self.relu5_4 = nn.ReLU(inplace=True) 86 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 4 layers 87 | self._initialize_weights() 88 | 89 | 90 | def _initialize_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | # m.weight.data.zero_() 94 | nn.init.normal(m.weight.data, std=0.01) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | if isinstance(m, nn.ConvTranspose2d): 98 | assert m.kernel_size[0] == m.kernel_size[1] 99 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 100 | m.weight.data.copy_(initial_weight) 101 | 102 | 103 | 104 | def forward(self, x): 105 | h = x 106 | 107 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 108 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 109 | h1 = self.pool1(h) # (128x128)*64 110 | 111 | h = self.relu2_1(self.bn2_1(self.conv2_1(h1))) 112 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 113 | h2 = self.pool2(h) # (64x64)*128 114 | 115 | h = self.relu3_1(self.bn3_1(self.conv3_1(h2))) 116 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 117 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 118 | # h = self.relu3_4(self.bn3_4(self.conv3_4(h))) 119 | h3 = self.pool3(h)# (32x32)*256 120 | 121 | h = self.relu4_1(self.bn4_1(self.conv4_1(h3))) 122 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 123 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 124 | # h = self.relu4_4(self.bn4_4(self.conv4_4(h))) 125 | h4 = self.pool4(h)# (16x16)*512 126 | 127 | h = self.relu5_1(self.bn5_1(self.conv5_1(h4))) 128 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 129 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 130 | # h = self.relu5_4(self.bn5_4(self.conv5_4(h))) 131 | h5 = self.pool5(h)#(8x8)*512 132 | 133 | return h2,h3,h4,h5 134 | 135 | 136 | def copy_params_from_vgg16_bn(self, vgg16_bn): 137 | features = [ 138 | self.conv1_1, self.bn1_1, self.relu1_1, 139 | self.conv1_2, self.bn1_2, self.relu1_2, 140 | self.pool1, 141 | self.conv2_1, self.bn2_1, self.relu2_1, 142 | self.conv2_2, self.bn2_2, self.relu2_2, 143 | self.pool2, 144 | self.conv3_1, self.bn3_1, self.relu3_1, 145 | self.conv3_2, self.bn3_2, self.relu3_2, 146 | self.conv3_3, self.bn3_3, self.relu3_3, 147 | # self.conv3_4, self.bn3_4, self.relu3_4, 148 | self.pool3, 149 | self.conv4_1, self.bn4_1, self.relu4_1, 150 | self.conv4_2, self.bn4_2, self.relu4_2, 151 | self.conv4_3, self.bn4_3, self.relu4_3, 152 | # self.conv4_4, self.bn4_4, self.relu4_4, 153 | self.pool4, 154 | self.conv5_1, self.bn5_1, self.relu5_1, 155 | self.conv5_2, self.bn5_2, self.relu5_2, 156 | self.conv5_3, self.bn5_3, self.relu5_3, 157 | # self.conv5_4, self.bn5_4, self.relu5_4, 158 | ] 159 | for l1, l2 in zip(vgg16_bn.features, features): 160 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 161 | assert l1.weight.size() == l2.weight.size() 162 | assert l1.bias.size() == l2.bias.size() 163 | l2.weight.data = l1.weight.data 164 | l2.bias.data = l1.bias.data 165 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 166 | assert l1.weight.size() == l2.weight.size() 167 | assert l1.bias.size() == l2.bias.size() 168 | l2.weight.data = l1.weight.data 169 | l2.bias.data = l1.bias.data -------------------------------------------------------------------------------- /code/saliency_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | 5 | 6 | class cal_fm(object): 7 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009) 8 | def __init__(self, num, thds=255): 9 | self.num = num 10 | self.thds = thds 11 | self.precision = np.zeros((self.num, self.thds)) 12 | self.recall = np.zeros((self.num, self.thds)) 13 | self.meanF = np.zeros((self.num,1)) 14 | self.idx = 0 15 | 16 | def update(self, pred, gt): 17 | if gt.max() != 0: 18 | prediction, recall, Fmeasure_temp = self.cal(pred, gt) 19 | self.precision[self.idx, :] = prediction 20 | self.recall[self.idx, :] = recall 21 | self.meanF[self.idx, :] = Fmeasure_temp 22 | self.idx += 1 23 | 24 | def cal(self, pred, gt): 25 | ########################meanF############################## 26 | th = 2 * pred.mean() 27 | if th > 1: 28 | th = 1 29 | binary = np.zeros_like(pred) 30 | binary[pred >= th] = 1 31 | hard_gt = np.zeros_like(gt) 32 | hard_gt[gt > 0.5] = 1 33 | tp = (binary * hard_gt).sum() 34 | if tp == 0: 35 | meanF = 0 36 | else: 37 | pre = tp / binary.sum() 38 | rec = tp / hard_gt.sum() 39 | meanF = 1.3 * pre * rec / (0.3 * pre + rec) 40 | ########################maxF############################## 41 | pred = np.uint8(pred * 255) 42 | target = pred[gt > 0.5] 43 | nontarget = pred[gt <= 0.5] 44 | targetHist, _ = np.histogram(target, bins=range(256)) 45 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 46 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 47 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 48 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 49 | recall = targetHist / np.sum(gt) 50 | return precision, recall, meanF 51 | 52 | def show(self): 53 | assert self.num == self.idx 54 | precision = self.precision.mean(axis=0) 55 | recall = self.recall.mean(axis=0) 56 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 57 | fmeasure_avg = self.meanF.mean(axis=0) 58 | return fmeasure.max(),fmeasure_avg[0],precision,recall 59 | 60 | 61 | class cal_mae(object): 62 | # mean absolute error 63 | def __init__(self): 64 | self.prediction = [] 65 | 66 | def update(self, pred, gt): 67 | score = self.cal(pred, gt) 68 | self.prediction.append(score) 69 | 70 | def cal(self, pred, gt): 71 | return np.mean(np.abs(pred - gt)) 72 | 73 | def show(self): 74 | return np.mean(self.prediction) 75 | 76 | 77 | class cal_sm(object): 78 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 79 | def __init__(self, alpha=0.5): 80 | self.prediction = [] 81 | self.alpha = alpha 82 | 83 | def update(self, pred, gt): 84 | gt = gt > 0.5 85 | score = self.cal(pred, gt) 86 | self.prediction.append(score) 87 | 88 | def show(self): 89 | return np.mean(self.prediction) 90 | 91 | def cal(self, pred, gt): 92 | y = np.mean(gt) 93 | if y == 0: 94 | score = 1 - np.mean(pred) 95 | elif y == 1: 96 | score = np.mean(pred) 97 | else: 98 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 99 | return score 100 | 101 | def object(self, pred, gt): 102 | fg = pred * gt 103 | bg = (1 - pred) * (1 - gt) 104 | 105 | u = np.mean(gt) 106 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 107 | 108 | def s_object(self, in1, in2): 109 | x = np.mean(in1[in2]) 110 | sigma_x = np.std(in1[in2]) 111 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 112 | 113 | def region(self, pred, gt): 114 | [y, x] = ndimage.center_of_mass(gt) 115 | y = int(round(y)) + 1 116 | x = int(round(x)) + 1 117 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 118 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 119 | 120 | score1 = self.ssim(pred1, gt1) 121 | score2 = self.ssim(pred2, gt2) 122 | score3 = self.ssim(pred3, gt3) 123 | score4 = self.ssim(pred4, gt4) 124 | 125 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 126 | 127 | def divideGT(self, gt, x, y): 128 | h, w = gt.shape 129 | area = h * w 130 | LT = gt[0:y, 0:x] 131 | RT = gt[0:y, x:w] 132 | LB = gt[y:h, 0:x] 133 | RB = gt[y:h, x:w] 134 | 135 | w1 = x * y / area 136 | w2 = y * (w - x) / area 137 | w3 = (h - y) * x / area 138 | w4 = (h - y) * (w - x) / area 139 | 140 | return LT, RT, LB, RB, w1, w2, w3, w4 141 | 142 | def dividePred(self, pred, x, y): 143 | h, w = pred.shape 144 | LT = pred[0:y, 0:x] 145 | RT = pred[0:y, x:w] 146 | LB = pred[y:h, 0:x] 147 | RB = pred[y:h, x:w] 148 | 149 | return LT, RT, LB, RB 150 | 151 | def ssim(self, in1, in2): 152 | in2 = np.float32(in2) 153 | h, w = in1.shape 154 | N = h * w 155 | 156 | x = np.mean(in1) 157 | y = np.mean(in2) 158 | sigma_x = np.var(in1) 159 | sigma_y = np.var(in2) 160 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 161 | 162 | alpha = 4 * x * y * sigma_xy 163 | beta = (x * x + y * y) * (sigma_x + sigma_y) 164 | 165 | if alpha != 0: 166 | score = alpha / (beta + 1e-8) 167 | elif alpha == 0 and beta == 0: 168 | score = 1 169 | else: 170 | score = 0 171 | 172 | return score 173 | 174 | class cal_em(object): 175 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 176 | def __init__(self): 177 | self.prediction = [] 178 | 179 | def update(self, pred, gt): 180 | score = self.cal(pred, gt) 181 | self.prediction.append(score) 182 | 183 | def cal(self, pred, gt): 184 | th = 2 * pred.mean() 185 | if th > 1: 186 | th = 1 187 | FM = np.zeros(gt.shape) 188 | FM[pred >= th] = 1 189 | FM = np.array(FM,dtype=bool) 190 | GT = np.array(gt,dtype=bool) 191 | dFM = np.double(FM) 192 | if (sum(sum(np.double(GT)))==0): 193 | enhanced_matrix = 1.0-dFM 194 | elif (sum(sum(np.double(~GT)))==0): 195 | enhanced_matrix = dFM 196 | else: 197 | dGT = np.double(GT) 198 | align_matrix = self.AlignmentTerm(dFM, dGT) 199 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 200 | [w, h] = np.shape(GT) 201 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8) 202 | return score 203 | def AlignmentTerm(self,dFM,dGT): 204 | mu_FM = np.mean(dFM) 205 | mu_GT = np.mean(dGT) 206 | align_FM = dFM - mu_FM 207 | align_GT = dGT - mu_GT 208 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8) 209 | return align_Matrix 210 | def EnhancedAlignmentTerm(self,align_Matrix): 211 | enhanced = np.power(align_Matrix + 1,2) / 4 212 | return enhanced 213 | def show(self): 214 | return np.mean(self.prediction) 215 | class cal_wfm(object): 216 | def __init__(self, beta=1): 217 | self.beta = beta 218 | self.eps = 1e-6 219 | self.scores_list = [] 220 | 221 | def update(self, pred, gt): 222 | assert pred.ndim == gt.ndim and pred.shape == gt.shape 223 | assert pred.max() <= 1 and pred.min() >= 0 224 | assert gt.max() <= 1 and gt.min() >= 0 225 | 226 | gt = gt > 0.5 227 | if gt.max() == 0: 228 | score = 0 229 | else: 230 | score = self.cal(pred, gt) 231 | self.scores_list.append(score) 232 | 233 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 234 | """ 235 | 2D gaussian mask - should give the same result as MATLAB's 236 | fspecial('gaussian',[shape],[sigma]) 237 | """ 238 | m, n = [(ss - 1.) / 2. for ss in shape] 239 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 240 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 241 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 242 | sumh = h.sum() 243 | if sumh != 0: 244 | h /= sumh 245 | return h 246 | 247 | def cal(self, pred, gt): 248 | # [Dst,IDXT] = bwdist(dGT); 249 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 250 | 251 | # %Pixel dependency 252 | # E = abs(FG-dGT); 253 | E = np.abs(pred - gt) 254 | # Et = E; 255 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 256 | Et = np.copy(E) 257 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 258 | 259 | # K = fspecial('gaussian',7,5); 260 | # EA = imfilter(Et,K); 261 | # MIN_E_EA(GT & EA