├── bibtex ├── utils.py ├── README.md ├── test.py ├── model ├── vgg16.py └── model.py ├── train.py └── data_loader.py /bibtex: -------------------------------------------------------------------------------- 1 | @inproceedings{zhang2020Select, 2 | title={Select, Supplement and Focus for RGB-D Saliency Detection}, 3 | author={Zhang, Miao and Ren, Weisong and Piao, Yongri and Rong, Zhengkun and Lu, Huchuan}, 4 | booktitle={CVPR}, 5 | year={2020} 6 | } 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 8 | decay = decay_rate ** (epoch // decay_epoch) 9 | for param_group in optimizer.param_groups: 10 | param_group['lr'] *= decay 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVPR_SSF-RGBD 2 | codes of SSF_RGBD 3 | === 4 | Accepted paper in CVPR2020, '', Miao Zhang, Weisong Ren, Yongri Piao, Zhengkun Rong and [Huchuan Lu](http://ice.dlut.edu.cn/lu/publications.html). 5 | 6 | ## Usage Instructions 7 | Requirements 8 | * Windows 10 9 | * PyTorch 0.4.1 10 | * CUDA 9.0 11 | * Cudnn 7.6.0 12 | * Python 3.6.5 13 | * Numpy 1.16.4 14 | 15 | ## Training and Testing Datasets 16 | Training dataset: * [download_link](https://pan.baidu.com/s/1dv6cw3TfW4ZBaUsMC-tN1g) codes:5fms. || Test dataset: * [download_link](https://pan.baidu.com/s/13RHAF7VMvMvP5YtSj1ovHA) codes:a934. 17 | 18 | ## Testing 19 | * Download pretrained model from * [download_link](https://pan.baidu.com/s/1sZH4Wh_-nne-nMvDQvSyZw) codes:8zw8. 20 | * Modify your path of testing dataset in test.py 21 | * Run test.py to inference saliency maps 22 | ## Results 23 | * Saliency maps generated from the model can be downnloaded from: 24 | 25 | * [ **DUT-RGBD** ](https://pan.baidu.com/s/1Fk35_f4HKkkDVuTGo3qVrQ) code:w92o ||* [ **NLPR** ](https://pan.baidu.com/s/1Tuv-2cfhq8BvmWky1yhL7w) 26 | codes:c9rw ||* [ **NJUD** ](https://pan.baidu.com/s/1eMKC6DSsnevG8jkectjd7A 27 | ) code:ufun ||* [ **STEREO** ](https://pan.baidu.com/s/15YrSUVV1kE5r6YzjUvH6Xw) code: hhkb ||* [ **LFSD** ](https://pan.baidu.com/s/1mbnu3H_j8pGmzGBr1Ea6TA) code:d86o ||* [ **RGBD-135** ](https://pan.baidu.com/s/1tOaDjvDFZUQTScrLatccBQ) 28 | code: v975 29 | ## Contact and Questions 30 | Contact: Weisong Ren. Email: [beatlescoco@mail.dlut.edu.cn](). 31 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time 4 | import numpy as np 5 | import pdb, os, argparse 6 | from scipy import misc 7 | from model.model import model_VGG 8 | from data import test_dataset 9 | from torch.autograd import Variable 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--testsize', type=int, default=256, help='testing size') 13 | opt = parser.parse_args() 14 | 15 | dataset_path = '' 16 | model = model_VGG() 17 | model.load_state_dict(torch.load('')) 18 | model.cuda() 19 | model.eval() 20 | test_datasets = [ '\\DUT-RGBD\\test_data'] 21 | #test_datasets = [ '\\LFSD'] 22 | #test_datasets = [ '\\NJUD\\test_data'] 23 | #test_datasets = [ '\\NLPR\\test_data'] 24 | #test_datasets = [ '\\RGBD135'] 25 | #test_datasets = [ '\\SSD'] 26 | #test_datasets = [ '\\STEREO'] 27 | #test_datasets = [ '\\DUTS-TEST'] 28 | #time_start=time.time() 29 | for dataset in test_datasets: 30 | save_path = '' + dataset + '\\results\\' 31 | if not os.path.exists(save_path): 32 | os.makedirs(save_path) 33 | image_root = dataset_path + dataset + '\\images\\' 34 | gt_root = dataset_path + dataset + '\\gts\\' 35 | depth_root = dataset_path + dataset + '\\depths\\' 36 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize) 37 | for i in range(test_loader.size): 38 | image, gt,depth, name = test_loader.load_data() 39 | gt = np.asarray(gt, np.float32) 40 | gt /= (gt.max() + 1e-8) 41 | depth /= (depth.max() + 1e-8) 42 | image = Variable(image).cuda() 43 | depth = Variable(depth).cuda() 44 | n,c, h, w = image.size() 45 | depth1 = depth.view(n,h, w, 1).repeat(1,1, 1, c) 46 | depth1 = depth1.transpose(3, 2) 47 | depth1 = depth1.transpose(2, 1) 48 | time_start = time.time() 49 | _, res, _, _ = model(image, depth1, depth) 50 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 51 | time_end = time.time() 52 | res = res.data.sigmoid().cpu().numpy().squeeze() 53 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 54 | print(name) 55 | misc.imsave(save_path + name, res) 56 | print('totally cost:',time_end-time_start,'s') 57 | -------------------------------------------------------------------------------- /model/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision as tv 4 | 5 | 6 | class VGG_Pr(nn.Module): 7 | # VGG16 with two branches 8 | # pooling layer at the front of block 9 | def __init__(self): 10 | super(VGG_Pr, self).__init__() 11 | conv1 = nn.Sequential() 12 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1)) 13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True)) 14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1)) 15 | conv1.add_module('relu1_2', nn.ReLU(inplace=True)) 16 | 17 | self.conv1 = conv1 18 | conv2 = nn.Sequential() 19 | conv2.add_module('pool1', nn.AvgPool2d(2, stride=2)) 20 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1)) 21 | conv2.add_module('relu2_1', nn.ReLU()) 22 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1)) 23 | conv2.add_module('relu2_2', nn.ReLU()) 24 | self.conv2 = conv2 25 | 26 | conv3 = nn.Sequential() 27 | conv3.add_module('pool2', nn.AvgPool2d(2, stride=2)) 28 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1)) 29 | conv3.add_module('relu3_1', nn.ReLU()) 30 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1)) 31 | conv3.add_module('relu3_2', nn.ReLU()) 32 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1)) 33 | conv3.add_module('relu3_3', nn.ReLU()) 34 | self.conv3 = conv3 35 | 36 | conv4_1 = nn.Sequential() 37 | conv4_1.add_module('pool3_1', nn.AvgPool2d(2, stride=2)) 38 | conv4_1.add_module('conv4_1_1', nn.Conv2d(256, 512, 3, 1, 1)) 39 | conv4_1.add_module('relu4_1_1', nn.ReLU()) 40 | conv4_1.add_module('conv4_2_1', nn.Conv2d(512, 512, 3, 1, 1)) 41 | conv4_1.add_module('relu4_2_1', nn.ReLU()) 42 | conv4_1.add_module('conv4_3_1', nn.Conv2d(512, 512, 3, 1, 1)) 43 | conv4_1.add_module('relu4_3_1', nn.ReLU()) 44 | self.conv4_1 = conv4_1 45 | 46 | conv5_1 = nn.Sequential() 47 | conv5_1.add_module('pool4_1', nn.AvgPool2d(2, stride=2)) 48 | conv5_1.add_module('conv5_1_1', nn.Conv2d(512, 512, 3, 1, 1)) 49 | conv5_1.add_module('relu5_1_1', nn.ReLU()) 50 | conv5_1.add_module('conv5_2_1', nn.Conv2d(512, 512, 3, 1, 1)) 51 | conv5_1.add_module('relu5_2_1', nn.ReLU()) 52 | conv5_1.add_module('conv5_3_1', nn.Conv2d(512, 512, 3, 1, 1)) 53 | conv5_1.add_module('relu5_3_1', nn.ReLU()) 54 | self.conv5_1 = conv5_1 55 | 56 | # pre_train = torch.load('./torch/models/CPD_VGG16.pth') 57 | vgg_16 = tv.models.vgg16(pretrained=True) 58 | self._initialize_weights(vgg_16) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.conv2(x) 63 | x = self.conv3(x) 64 | x1 = self.conv4_1(x) 65 | x1 = self.conv5_1(x1) 66 | return x1 67 | 68 | def _initialize_weights(self, vgg_16): 69 | features = [ 70 | self.conv1.conv1_1, self.conv1.relu1_1, 71 | self.conv1.conv1_2, self.conv1.relu1_2, 72 | self.conv2.pool1, 73 | self.conv2.conv2_1, self.conv2.relu2_1, 74 | self.conv2.conv2_2, self.conv2.relu2_2, 75 | self.conv3.pool2, 76 | self.conv3.conv3_1, self.conv3.relu3_1, 77 | self.conv3.conv3_2, self.conv3.relu3_2, 78 | self.conv3.conv3_3, self.conv3.relu3_3, 79 | self.conv4_1.pool3_1, 80 | self.conv4_1.conv4_1_1, self.conv4_1.relu4_1_1, 81 | self.conv4_1.conv4_2_1, self.conv4_1.relu4_2_1, 82 | self.conv4_1.conv4_3_1, self.conv4_1.relu4_3_1, 83 | self.conv5_1.pool4_1, 84 | self.conv5_1.conv5_1_1, self.conv5_1.relu5_1_1, 85 | self.conv5_1.conv5_2_1, self.conv5_1.relu5_2_1, 86 | self.conv5_1.conv5_3_1, self.conv5_1.relu5_3_1, 87 | ] 88 | for l1, l2 in zip(vgg_16.features, features): 89 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 90 | assert l1.weight.size() == l2.weight.size() 91 | assert l1.bias.size() == l2.bias.size() 92 | l2.weight.data = l1.weight.data 93 | l2.bias.data = l1.bias.data 94 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 95 | assert l1.weight.size() == l2.weight.size() 96 | assert l1.bias.size() == l2.bias.size() 97 | l2.weight.data = l1.weight.data 98 | l2.bias.data = l1.bias.data 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os, argparse 4 | from torch.autograd import Variable 5 | from datetime import datetime 6 | from model.model import model_VGG 7 | from data_loader import get_loader 8 | from utils import clip_gradient, adjust_lr 9 | 10 | bce_loss = torch.nn.BCELoss(size_average=True) 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--epoch', type=int, default=50, help='epoch number') 13 | parser.add_argument('--lr', type=float, default=3e-5, help='learning rate') 14 | parser.add_argument('--batchsize', type=int, default=20, help='training batch size') 15 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size') 16 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 17 | parser.add_argument('--decay_rate', type=float, default=0.3, help='decay rate of learning rate') 18 | parser.add_argument('--decay_epoch', type=int, default=45, help='every n epochs decay learning rate') 19 | parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters') 20 | parser.add_argument('--start_epoch', default=37, type=int) 21 | parser.add_argument('--total_depth', type=int, default=10, help='total depth') 22 | parser.add_argument('--total_length', type=int,default=4, help='length of regions') 23 | parser.add_argument('--total_width', type=int,default=4, help='width of regions') 24 | opt = parser.parse_args() 25 | 26 | print('Learning Rate: {}'.format(opt.lr)) 27 | model = model_VGG() 28 | model.cuda() 29 | params = model.parameters() 30 | optimizer = torch.optim.Adam(params, opt.lr) 31 | # iamge roots # 32 | image_root = '' 33 | gt_root = '' 34 | depth_root = '' 35 | boundary_root = '' 36 | pre_cheak_root = '' 37 | 38 | train_loader = get_loader(image_root, gt_root,depth_root,boundary_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 39 | total_step = len(train_loader) 40 | CE = torch.nn.BCEWithLogitsLoss(reduce = False) 41 | BCE = torch.nn.BCEWithLogitsLoss() 42 | 43 | def train(train_loader, model, optimizer, epoch): 44 | model.train() 45 | for i, pack in enumerate(train_loader, start=1): 46 | optimizer.zero_grad() 47 | images, gts,depth_1,bdrs = pack 48 | images = Variable(images) 49 | gts = Variable(gts) 50 | depth = Variable(depth_1) 51 | bdrs = Variable(bdrs) 52 | images = images.cuda() 53 | gts = gts.cuda() 54 | depth = depth.cuda() 55 | n, c, h, w = images.size() 56 | depth1 = depth.view(n, h, w, 1).repeat(1, 1, 1, c) 57 | depth1 = depth1.transpose(3, 1) 58 | depth1 = depth1.transpose(3, 2) 59 | bdrs = bdrs.cuda() 60 | det_dps, dets,bdr_p,atts5 = model(images,depth1,depth) 61 | loss_bdr = BCE(bdr_p, bdrs) 62 | max_pool1 = nn.MaxPool2d(4, stride=None) 63 | max_pool2 = nn.MaxPool2d(4, stride=None) 64 | max_pool3 = nn.MaxPool2d(2, stride=None) 65 | upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 66 | upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 67 | detts = torch.nn.functional.sigmoid(bdr_p) 68 | gtts = bdrs * 1 69 | result_pool = max_pool1(detts) 70 | result_pool = upsample(result_pool) 71 | result_pool2 = max_pool2(gtts) 72 | result_pool2 = upsample(result_pool2) 73 | result_ = torch.max(result_pool, result_pool2) 74 | result1 = result_pool * result_pool2 75 | result = result_ - result1 76 | result_p = max_pool3(result) 77 | resultp = upsample2(result_p) 78 | loss_sal_depth = BCE(det_dps, gts) 79 | loss_sal = CE(det_dps, gts) 80 | loss_sals = CE(dets, gts) 81 | loss_sal_RGB = BCE(dets, gts) 82 | loss3 = torch.mul(loss_sals,resultp).mean() 83 | n_, _, _, _ = gts.size() 84 | res_f = torch.zeros((n_,1,256, 256)) 85 | loss_att = bce_loss(atts5, gts) 86 | for jj in range(opt.total_depth): 87 | res = depth * 255 88 | target = gts 89 | target = target * 255 90 | res1 = (res >= (255.0 / opt.total_depth) * jj) 91 | res1 = res1.type(torch.FloatTensor) 92 | res1 = res1.cuda() 93 | res3 = (res <= (255.0 / opt.total_depth) * (jj + 1)) 94 | res3 = res3.type(torch.FloatTensor) 95 | res3 = res3.cuda() 96 | res2 = res * res1 * res3 97 | res2[res2 > 0] = 255 98 | res_sim = res2 * (target / 255) 99 | res_res = res2 100 | total = target.mean(dim=3) 101 | total = total.mean(dim=2) 102 | res_sim = res_sim.mean(dim = 3) 103 | weight = torch.div(res_sim.mean(dim=2),total) 104 | weight = torch.unsqueeze(weight, -1) 105 | weight = torch.unsqueeze(weight, -1) 106 | res_f = res_f.cuda() 107 | res__ = torch.mul(res_res , weight) 108 | res_f = res_f + res__ 109 | res_f = res_f / 255 110 | pre_hard_region = torch.mul(loss_sal, res_f).mean() 111 | loss_hard_region = pre_hard_region + loss3 112 | loss = loss_sal_depth + loss_sal_RGB + loss_bdr + loss_att + loss_hard_region*0.3 113 | loss.backward() 114 | clip_gradient(optimizer, opt.clip) 115 | optimizer.step() 116 | if i % 5 == 0 or i == total_step: 117 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], loss_sal_depth: {:.4f} loss_sal_RGB: {:0.4f} loss_bdr: {:0.4f} loss_att: {:0.4f} Loss6: {:0.4f} Loss: {:0.4f} Step: {:0.4f}'. 118 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_sal_depth.data, loss_sal_RGB.data, loss_bdr.data, loss_att.data,loss_hard_region.data,loss.data, i+(epoch-1)*total_step )) 119 | save_path = '' 120 | if not os.path.exists(save_path): 121 | os.makedirs(save_path) 122 | if (epoch+1) % 1 == 0: 123 | torch.save(model.state_dict(), save_path + '%d' % epoch + '_w.pth' ) 124 | 125 | progress = range(opt.start_epoch+1 , opt.epoch) 126 | for epoch in progress: 127 | if opt.param == True: 128 | if epoch!= 1: 129 | print("\nloading parameters") 130 | model.load_state_dict(torch.load(pre_cheak_root + '%d'% (epoch-1) + '_w.pth')) 131 | print(epoch) 132 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 133 | train(train_loader, model, optimizer, epoch) 134 | print("train.") -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torch 5 | import torchvision.transforms as transforms 6 | 7 | class SalObjDataset(data.Dataset): 8 | def __init__(self, image_root, gt_root, depth_root , boundary_root, trainsize): 9 | self.trainsize = trainsize 10 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 11 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 12 | or f.endswith('.png')] 13 | self.bdrs = [boundary_root + f for f in os.listdir(boundary_root) if f.endswith('.jpg') 14 | or f.endswith('.png')] 15 | self.depth = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg') 16 | or f.endswith('.png')] 17 | self.depth = sorted(self.depth) 18 | self.images = sorted(self.images) 19 | self.gts = sorted(self.gts) 20 | self.bdrs = sorted(self.bdrs) 21 | self.filter_files() 22 | self.size = len(self.images) 23 | self.img_transform = transforms.Compose([ 24 | transforms.Resize((self.trainsize, self.trainsize)), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 27 | self.gt_transform = transforms.Compose([ 28 | transforms.Resize((self.trainsize, self.trainsize)), 29 | transforms.ToTensor()]) 30 | self.depth_transform = transforms.Compose([ 31 | transforms.Resize((self.trainsize, self.trainsize)), 32 | transforms.ToTensor(), 33 | ]) 34 | self.bdrs_transform = transforms.Compose([ 35 | transforms.Resize((self.trainsize, self.trainsize)), 36 | transforms.ToTensor(), 37 | ]) 38 | def __getitem__(self, index): 39 | image = self.rgb_loader(self.images[index]) 40 | gt = self.binary_loader(self.gts[index]) 41 | depth = self.binary_loader(self.depth[index]) 42 | bdrs = self.binary_loader(self.bdrs[index]) 43 | image = self.img_transform(image) 44 | gt = self.gt_transform(gt) 45 | bdrs = self.bdrs_transform(bdrs) 46 | depth = self.depth_transform(depth) 47 | #depth = torch.div(depth.float(),255.0) 48 | return image, gt, depth, bdrs 49 | def filter_files(self): 50 | assert len(self.images) == len(self.gts) 51 | images = [] 52 | gts = [] 53 | depth = [] 54 | bdrs = [] 55 | for img_path, gt_path, depth_path, bdr_path in zip(self.images, self.gts , self.depth, self.bdrs): 56 | img = Image.open(img_path) 57 | gt = Image.open(gt_path) 58 | depth_ = Image.open(depth_path) 59 | bdr_ = Image.open(bdr_path) 60 | if img.size == gt.size == depth_.size == bdr_.size: 61 | images.append(img_path) 62 | gts.append(gt_path) 63 | depth.append(depth_path) 64 | bdrs.append(bdr_path) 65 | self.images = images 66 | self.gts = gts 67 | self.depth = depth 68 | self.bdrs = bdrs 69 | def rgb_loader(self, path): 70 | with open(path, 'rb') as f: 71 | img = Image.open(f) 72 | return img.convert('RGB') 73 | def binary_loader(self, path): 74 | with open(path, 'rb') as f: 75 | img = Image.open(f) 76 | # return img.convert('1') 77 | return img.convert('L') 78 | def resize(self, img, gt): 79 | assert img.size == gt.size 80 | w, h = img.size 81 | if h < self.trainsize or w < self.trainsize: 82 | h = max(h, self.trainsize) 83 | w = max(w, self.trainsize) 84 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 85 | else: 86 | return img, gt 87 | def __len__(self): 88 | return self.size 89 | def get_loader(image_root, gt_root, depth_root ,boundary_root, batchsize, trainsize, shuffle=True, pin_memory=True): 90 | 91 | dataset = SalObjDataset(image_root, gt_root, depth_root,boundary_root,trainsize) 92 | data_loader = data.DataLoader(dataset=dataset, 93 | batch_size=batchsize, 94 | shuffle=shuffle, 95 | 96 | pin_memory=pin_memory) 97 | return data_loader 98 | 99 | class test_dataset: 100 | def __init__(self, image_root, gt_root, depth_root, testsize): 101 | self.testsize = testsize 102 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 103 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 104 | or f.endswith('.png')] 105 | self.depth = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg') 106 | or f.endswith('.png')] 107 | self.images = sorted(self.images) 108 | self.gts = sorted(self.gts) 109 | self.depth = sorted(self.depth) 110 | self.transform = transforms.Compose([ 111 | transforms.Resize((self.testsize, self.testsize)), 112 | transforms.ToTensor(), 113 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 114 | self.gt_transform = transforms.ToTensor() 115 | self.depth_transform = transforms.Compose([ 116 | transforms.Resize((self.testsize, self.testsize)), 117 | transforms.ToTensor(), 118 | transforms.Normalize([0.485,], [0.229,])]) 119 | self.size = len(self.images) 120 | self.index = 0 121 | 122 | def load_data(self): 123 | image = self.rgb_loader(self.images[self.index]) 124 | image = self.transform(image).unsqueeze(0) 125 | gt = self.binary_loader(self.gts[self.index]) 126 | depth = self.binary_loader(self.depth[self.index]) 127 | 128 | depth = self.depth_transform(depth).unsqueeze(0) 129 | #depth = torch.div(depth.float(), 255.0) 130 | name = self.images[self.index].split('\\')[-1] 131 | if name.endswith('.jpg'): 132 | name = name.split('.jpg')[0] + '.png' 133 | self.index += 1 134 | return image, gt,depth, name 135 | 136 | def rgb_loader(self, path): 137 | with open(path, 'rb') as f: 138 | img = Image.open(f) 139 | return img.convert('RGB') 140 | 141 | def binary_loader(self, path): 142 | with open(path, 'rb') as f: 143 | img = Image.open(f) 144 | return img.convert('L') 145 | 146 | 147 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import numpy as np 5 | import scipy.stats as st 6 | from model.vgg16 import VGG_Pr 7 | ############################feature extraction block#################################### 8 | class RFB(nn.Module): 9 | def __init__(self, in_channel, out_channel): 10 | super(RFB, self).__init__() 11 | self.relu = nn.ReLU(True) 12 | self.branch0 = nn.Sequential( 13 | nn.Conv2d(in_channel, out_channel, 1), 14 | ) 15 | self.branch1 = nn.Sequential( 16 | nn.Conv2d(in_channel, out_channel, 1), 17 | nn.Conv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), 18 | nn.Conv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), 19 | nn.Conv2d(out_channel, out_channel, 3, padding=3, dilation=3) 20 | ) 21 | self.branch2 = nn.Sequential( 22 | nn.Conv2d(in_channel, out_channel, 1), 23 | nn.Conv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), 24 | nn.Conv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), 25 | nn.Conv2d(out_channel, out_channel, 3, padding=5, dilation=5) 26 | ) 27 | self.branch3 = nn.Sequential( 28 | nn.Conv2d(in_channel, out_channel, 1), 29 | nn.Conv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), 30 | nn.Conv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), 31 | nn.Conv2d(out_channel, out_channel, 3, padding=7, dilation=7) 32 | ) 33 | self.conv_cat = nn.Conv2d(4*out_channel, out_channel, 3, padding=1) 34 | self.conv_res = nn.Conv2d(in_channel, out_channel, 1) 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | m.weight.data.normal_(std=0.01) 38 | m.bias.data.fill_(0) 39 | def forward(self, x): 40 | x0 = self.branch0(x) 41 | x1 = self.branch1(x) 42 | x2 = self.branch2(x) 43 | x3 = self.branch3(x) 44 | x_cat = torch.cat((x0, x1, x2,x3), 1) 45 | x_cat = self.conv_cat(x_cat) 46 | x = self.relu(x_cat + self.conv_res(x)) 47 | return x 48 | ######################################################################################## 49 | def gkern(kernlen=16, nsig=3): 50 | interval = (2 * nsig + 1.) / kernlen 51 | x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1) 52 | kern1d = np.diff(st.norm.cdf(x)) 53 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 54 | kernel = kernel_raw / kernel_raw.sum() 55 | return kernel 56 | def min_max_norm(in_): 57 | max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) 58 | min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) 59 | in_ = in_ - min_ 60 | return in_.div(max_ - min_ + 1e-8) 61 | class HA(nn.Module): 62 | # holistic attention module 63 | def __init__(self): 64 | super(HA, self).__init__() 65 | gaussian_kernel = np.float32(gkern(31, 4)) 66 | gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...] 67 | self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel)) 68 | def forward(self, attention, x): 69 | soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15) 70 | soft_attention = min_max_norm(soft_attention) 71 | x = torch.mul(x, soft_attention.max(attention)) 72 | return x, soft_attention 73 | ######################## aggragation three level depth features######################### 74 | class decoder_d(nn.Module): 75 | def __init__(self, channel): 76 | super(decoder_d, self).__init__() 77 | self.relu = nn.ReLU(True) 78 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 79 | self.conv_upsample4 = nn.Conv2d(channel, channel, 3, padding=1) 80 | self.conv_upsample5 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 81 | self.conv_concat4 = nn.Conv2d(2*channel, 2*channel, 3, padding=1) 82 | self.conv_concat5 = nn.Conv2d(3*channel, 3*channel, 3, padding=1) 83 | self.conv5_1 = nn.Conv2d(3*channel, 3*channel, 3, padding=1) 84 | self.conv5_2 = nn.Conv2d(3*channel, 1, 1) 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | m.weight.data.normal_(std=0.01) 88 | m.bias.data.fill_(0) 89 | def forward(self, x3, x4, x5): 90 | # x3: 1/16 x4: 1/8 x5: 1/4 91 | x3_1 = x3 92 | x4_1 = x4 93 | x5_1 = x5 94 | x4_2 = self.conv_concat4(torch.cat((x4_1, self.conv_upsample4(self.upsample(x5_1))), 1)) 95 | x3_2 = self.conv_concat5(torch.cat((x3_1, self.conv_upsample5(self.upsample(x4_2))), 1)) 96 | x = self.conv5_2(self.conv5_1(x3_2)) 97 | return x 98 | ######################################################################################## 99 | ###################################boundary decoder #################################### 100 | class decoder_b(nn.Module): 101 | def __init__(self, channel): 102 | super(decoder_b, self).__init__() 103 | self.relu = nn.ReLU(True) 104 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 105 | self.conv_1 = nn.Conv2d(channel, channel, 3, padding=1) 106 | self.conv_2 = nn.Conv2d(channel, channel, 3, padding=1) 107 | self.conv_3 = nn.Conv2d(channel, channel, 3, padding=1) 108 | self.conv_4 = nn.Conv2d(channel, channel, 3, padding=1) 109 | self.conv_5 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 110 | self.conv_cat1 = nn.Conv2d(2*channel, 2*channel, 3, padding=1) 111 | self.conv_cat2 = nn.Conv2d(3*channel, 3*channel, 3, padding=1) 112 | self.conv3_1 = nn.Conv2d(3*channel, 3*channel, 3, padding=1) 113 | self.conv3_2 = nn.Conv2d(3*channel, 1, 1) 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | m.weight.data.normal_(std=0.01) 117 | m.bias.data.fill_(0) 118 | def forward(self, x3, x4, x5): 119 | # x1: 1 x2: 1/2 x3: 1/4 120 | x5_1 = x5 121 | x4_1 = x4 122 | x3_1 = x3 123 | x4_1 = self.conv_1(self.upsample(x5_1)) + x4_1 124 | x3_1 = self.conv_2(self.upsample(self.upsample(x5_1)))+ self.conv_3(self.upsample(x4_1)) + x3_1 125 | x4_2 = self.conv_cat1(torch.cat((x4_1, self.conv_4(self.upsample(x5_1))), 1)) 126 | x3_2 = self.conv_cat2(torch.cat((x3_1, self.conv_5(self.upsample(x4_2))), 1)) 127 | x = self.conv3_1(x3_2) 128 | x = self.conv3_2(x) 129 | return x 130 | ######################################################################################## 131 | class AttentionLayer(nn.Module): 132 | def __init__(self, channel, reduction=2, multiply=True): 133 | super(AttentionLayer, self).__init__() 134 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 135 | self.fc = nn.Sequential( 136 | nn.Linear(channel, channel // reduction), 137 | nn.ReLU(inplace=True), 138 | nn.Linear(channel // reduction, channel), 139 | nn.Sigmoid() 140 | ) 141 | self.multiply = multiply 142 | def forward(self, x): 143 | b, c, _, _ = x.size() 144 | y = self.avg_pool(x).view(b, c) 145 | y = self.fc(y).view(b, c, 1, 1) 146 | if self.multiply == True: 147 | return x * y 148 | else: 149 | return y 150 | ################################## #saliency decoder block############################## 151 | class decoder_s(nn.Module): 152 | def __init__(self, channel): 153 | super(decoder_s, self).__init__() 154 | self.relu = nn.ReLU(True) 155 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 156 | self.conv_upsample1 = nn.Conv2d(channel, channel, 3, padding=1) 157 | self.conv_upsample2 = nn.Conv2d(channel, channel, 3, padding=1) 158 | self.conv_upsample3 = nn.Conv2d(channel, channel, 3, padding=1) 159 | self.conv_upsample4 = nn.Conv2d(channel, channel, 3, padding=1) 160 | self.conv_upsample5 = nn.Conv2d(channel, channel, 3, padding=1) 161 | self.conv_upsample6 = nn.Conv2d(channel, channel, 3, padding=1) 162 | self.conv_upsample7 = nn.Conv2d(channel, channel, 3, padding=1) 163 | self.conv_upsample8 = nn.Conv2d(channel, channel, 3, padding=1) 164 | self.conv_upsample9 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 165 | self.conv_upsample10 = nn.Conv2d(channel, channel, 3, padding=1) 166 | self.conv_upsample11 = nn.Conv2d(2 *channel, 2 *channel, 3, padding=1) 167 | self.conv_upsample12 = nn.Conv2d(channel, channel, 3, padding=1) 168 | self.channel_att5 = AttentionLayer(channel,reduction=2) 169 | self.channel_att4 = AttentionLayer(channel,reduction=2) 170 | self.channel_att3 = AttentionLayer(channel,reduction=2) 171 | self.channel_ratt5 = AttentionLayer(channel, reduction=2) 172 | self.channel_ratt4 = AttentionLayer(channel, reduction=2) 173 | self.channel_ratt3 = AttentionLayer(channel, reduction=2) 174 | self.channel_reatt5 = AttentionLayer(channel, reduction=2) 175 | self.channel_reatt4 = AttentionLayer(channel, reduction=2) 176 | self.channel_reatt3 = AttentionLayer(channel, reduction=2) 177 | self.channel_rdatt5 = AttentionLayer(channel, reduction=2) 178 | self.channel_rdatt4 = AttentionLayer(channel, reduction=2) 179 | self.channel_rdatt3 = AttentionLayer(channel, reduction=2) 180 | self.conv_concat3 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 181 | self.conv_concat4 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 182 | self.conv_concat5 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 183 | self.conv_concat6 = nn.Conv2d(3 * channel, 3 * channel, 3, padding=1) 184 | self.conv_concat7 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 185 | self.conv_concat8 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 186 | self.conv_concat9 = nn.Conv2d(2 * channel, channel, 1) 187 | self.conv_concat10 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 188 | self.conv_concat11 = nn.Conv2d(2 * channel, channel, 1) 189 | self.conv_concat12 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 190 | self.conv_concat13 = nn.Conv2d(2 * channel, channel, 1) 191 | self.conv_concat14 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 192 | self.conv_concat15 = nn.Conv2d(2 * channel, channel, 1) 193 | self.conv_concat16 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 194 | self.conv_concat17 = nn.Conv2d(2 * channel, channel, 1) 195 | self.conv_concat18 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 196 | self.conv_concat19 = nn.Conv2d(2 * channel, channel, 1) 197 | self.conv_concat20 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 198 | self.conv_concat21 = nn.Conv2d(2 * channel, channel, 1) 199 | self.conv5_1 = nn.Conv2d(1* channel, 1*channel, 3, padding=1) 200 | self.conv5_2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 201 | self.conv5_r = nn.Conv2d(1* channel, 1*channel, 3, padding=1) 202 | self.conv5_3 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 203 | self.conv5_4 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 204 | self.conv5_5 = nn.Conv2d(1 * channel, 1, 1) 205 | self.conv5_6 = nn.Conv2d(1 * channel, 1 * channel, 3 , padding=1) 206 | self.conv5_7 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 207 | self.conv5_8 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 208 | self.conv5_9 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 209 | self.conv5_be = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 210 | self.conv5_bd = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 211 | self.conv4_be = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 212 | self.conv4_bd = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 213 | self.conv3_be = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 214 | self.conv3_bd = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 215 | self.convcat5 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 216 | self.convadd5 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 217 | self.convcat52 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 218 | self.convcat53 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 219 | self.convcat54 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 220 | self.convadd52 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 221 | self.convcat4 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 222 | self.convadd4 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 223 | self.convcat42 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 224 | self.convcat43 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 225 | self.convcat44 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 226 | self.convadd42 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 227 | self.convcat32 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 228 | self.convcat33 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 229 | self.convcat34 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 230 | self.convadd32 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 231 | self.convcat3 = nn.Conv2d(2 * channel, 1 * channel, 3, padding=1) 232 | self.convadd3 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 233 | self.conv5_res1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 234 | self.conv5_res2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 235 | self.conv5_red1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 236 | self.conv5_red2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 237 | 238 | self.conv4_1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 239 | self.conv4_2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 240 | self.conv4_3 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 241 | self.conv4_4 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 242 | self.conv4_8 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 243 | self.conv4_9 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 244 | self.conv4_r = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 245 | self.conv4_res1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 246 | self.conv4_res2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 247 | self.conv4_red1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 248 | self.conv4_red2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 249 | self.conv3_1 = nn.Conv2d(channel, channel, 3, padding=1) 250 | self.conv3_2 = nn.Conv2d(channel, channel, 3, padding=1) 251 | self.conv3_3 = nn.Conv2d(3 * channel, 3 * channel, 3, padding=1) 252 | self.conv3_4 = nn.Conv2d(3 * channel, 1 * channel, 3, padding=1) 253 | self.conv3_5 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 254 | self.conv3_6 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 255 | self.conv3_7 = nn.Conv2d(1 * channel, 1, 1) 256 | self.conv3_8 = nn.Conv2d(1 * channel, 1, 3, padding=1) 257 | self.conv3_9 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 258 | self.conv3_10 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 259 | self.conv3_11 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 260 | self.conv3_12 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 261 | self.conv3_13 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 262 | self.conv3_14 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 263 | self.conv3_res1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 264 | self.conv3_res2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 265 | self.conv3_red1 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 266 | self.conv3_red2 = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 267 | self.conv3_r = nn.Conv2d(1 * channel, 1 * channel, 3, padding=1) 268 | self.conv2_1 = nn.Conv2d(2 * channel, 2 * channel, 3, padding=1) 269 | self.conv2_2 = nn.Conv2d(2 * channel, 1, 3, padding=1) 270 | self.relu5_1 = nn.ReLU(True) 271 | self.relu5_2 = nn.ReLU(True) 272 | self.relu5_6 = nn.ReLU(True) 273 | self.relu5_7 = nn.ReLU(True) 274 | self.relu5_8 = nn.ReLU(True) 275 | self.relu5_9 = nn.ReLU(True) 276 | self.relu5_r = nn.ReLU(True) 277 | self.relu4_1 = nn.ReLU(True) 278 | self.relu4_2 = nn.ReLU(True) 279 | self.relu4_3 = nn.ReLU(True) 280 | self.relu4_4 = nn.ReLU(True) 281 | self.relu4_8 = nn.ReLU(True) 282 | self.relu4_9 = nn.ReLU(True) 283 | self.relu4_r = nn.ReLU(True) 284 | self.relu3_1 = nn.ReLU(True) 285 | self.relu3_2 = nn.ReLU(True) 286 | self.relu3_3 = nn.ReLU(True) 287 | self.relu3_4 = nn.ReLU(True) 288 | self.relu3_5 = nn.ReLU(True) 289 | self.relu3_6 = nn.ReLU(True) 290 | self.relu3_7 = nn.ReLU(True) 291 | self.relu3_8 = nn.ReLU(True) 292 | self.relu3_9 = nn.ReLU(True) 293 | self.relu3_10 = nn.ReLU(True) 294 | self.relu3_13 = nn.ReLU(True) 295 | self.relu3_14 = nn.ReLU(True) 296 | self.relu3_r = nn.ReLU(True) 297 | self.pool2_1 = nn.AvgPool2d(2,stride=2) 298 | self.pool4_1 = nn.AvgPool2d(2, stride=2) 299 | self.pool4_2 = nn.AvgPool2d(2, stride=2) 300 | self.pool5_1 = nn.AvgPool2d(2, stride=2) 301 | self.pool5_2 = nn.AvgPool2d(2, stride=2) 302 | self.pool5_4 = nn.AvgPool2d(2, stride=2) 303 | self.pool5_3 = nn.AvgPool2d(2, stride=2) 304 | self.pool5_5 = nn.AvgPool2d(2, stride=2) 305 | self.pool3_1 = nn.AvgPool2d(2, stride=2) 306 | self.pool3_2 = nn.AvgPool2d(2, stride=2) 307 | self.maxpool5 = nn.MaxPool2d(2,stride = 2) 308 | self.maxpool4 = nn.MaxPool2d(4, stride=4) 309 | self.maxpool3 = nn.MaxPool2d(4, stride=4) 310 | self.pool_depth = nn.AvgPool2d(2, stride=2) 311 | self.pool_depth2 = nn.AvgPool2d(2, stride=2) 312 | self.sigmoid3_1 =nn.Sigmoid() 313 | self.sigmoid5_1 = nn.Sigmoid() 314 | self.HA = HA() 315 | for m in self.modules(): 316 | if isinstance(m, nn.Conv2d): 317 | m.weight.data.normal_(std=0.01) 318 | m.bias.data.fill_(0) 319 | 320 | def forward(self, x3, x4 ,x5, x3_2, x4_2,x5_2,x3_3, x4_3, x5_3,depth_): 321 | depth_ = self.pool_depth(self.pool_depth2(depth_)) 322 | x5_s_1 = x5 323 | x5_s_1 = self.relu5_8(self.conv5_8(x5_s_1)) 324 | x5_s_1 = self.relu5_9(self.conv5_9(x5_s_1)) 325 | x5_d = x5_3 326 | x5_d = self.conv5_1(x5_d) 327 | x5_d = self.relu5_1(x5_d) 328 | x5_c = self.conv5_2(x5_d) 329 | x5_c = self.relu5_2(x5_c) 330 | x5_b = x5_2 331 | x5_b = self.conv5_6(x5_b) 332 | x5_b = self.relu5_6(x5_b) 333 | x5_b = self.conv5_7(x5_b) 334 | x5_b = self.relu5_7(x5_b) 335 | x5_sal = self.conv5_3(x5_s_1) 336 | x5_sal = self.conv5_4(x5_sal) 337 | x_att5= self.conv5_5(x5_sal) 338 | x_att5 = self.sigmoid5_1(x_att5) 339 | n_, _, _, _ = x_att5.size() 340 | res_f = torch.zeros((n_, 1, 64, 64)) 341 | res_dsf = torch.zeros((n_, 1, 64, 64)) 342 | for jj in range(10): 343 | res = depth_ * 255 344 | target = self.upsample(self.upsample(x_att5)) 345 | target = target * 255 346 | res1 = (res >= (255.0 / 10) * jj) 347 | res1 = res1.type(torch.FloatTensor) 348 | res1 = res1.cuda() 349 | res3 = (res <= (255.0 / 10) * (jj + 1)) 350 | res3 = res3.type(torch.FloatTensor) 351 | res3 = res3.cuda() 352 | res2 = res * res1 * res3 353 | res2[res2 > 0] = 255 354 | res_sim = res2 * (target / 255) 355 | res_bi = res2 356 | res_res = res2 357 | total = target.mean(dim=3) 358 | total = total.mean(dim=2) 359 | total_d = res_bi.mean(dim=3) 360 | total_d = total_d.mean(dim=2) 361 | res_sim = res_sim.mean(dim = 3) 362 | res_sim = res_sim.mean(dim=2 ) 363 | weight = torch.div(res_sim,total) 364 | weight = torch.unsqueeze(weight, -1) 365 | weight = torch.unsqueeze(weight, -1) 366 | weight_d1 = torch.div(res_sim, total_d + 1e-4) 367 | weight_d = torch.unsqueeze(weight_d1, -1) 368 | weight_d = torch.unsqueeze(weight_d, -1) 369 | weight_d = ( weight_d) * (weight_d/(weight_d + 1e-4)) 370 | res_f = res_f.cuda() 371 | res_dsf = res_dsf.cuda() 372 | res__ = torch.mul(res_res , weight) 373 | res_dsf2 = torch.mul(res_res, weight_d) 374 | res_f = res_f + res__ 375 | res_dsf = res_dsf + res_dsf2 376 | res_f = res_f / 255 377 | res_dsf = res_dsf / 255 378 | res_df5 = self.pool5_3(self.pool5_4(res_f)) 379 | res_f5 = self.pool5_2(self.pool5_1(res_dsf)) 380 | x5_reatt = self.upsample(self.maxpool5( x_att5)) 381 | x5_res2 = x5_s_1 * x5_reatt 382 | x5_res2 = self.channel_reatt5(x5_res2) 383 | x5_res = x5_s_1 * res_f5 384 | x5_ratt = self.channel_ratt5(x5_res) 385 | x5_ratt = self.conv5_res1(x5_ratt) 386 | x5_ratt = self.conv5_res2(x5_ratt) 387 | x5_res = x5_ratt 388 | x5_stru = x5_s_1 * x5_b 389 | x5_stru = self.conv5_be(x5_stru) 390 | x5_stru = self.relu(x5_stru + x5_s_1) 391 | x5_rf = torch.cat((x5_res2 , x5_res),1) 392 | x5_rf1 = self.convcat5(x5_rf) 393 | x5_rf = x5_rf1 + x5_stru 394 | x5_rf = self.convadd5(x5_rf) 395 | x5_red = x5_c * x_att5 396 | x5_att = self.channel_att5(x5_red) 397 | x5_att = self.conv5_red1(x5_att) 398 | x5_att = self.conv5_red2(x5_att) 399 | x5_red2 = x5_c * res_df5 400 | x5_red2 = self.channel_rdatt5(x5_red2) 401 | x5_strud = x5_c * x5_b 402 | x5_strud = self.conv5_bd(x5_strud) 403 | x5_strud = self.relu(x5_strud + x5_c) 404 | x5_df = torch.cat((x5_att , x5_red2),1) 405 | x5_df = self.convcat52(x5_df) 406 | x5_df = x5_df + x5_strud 407 | x5_df = self.convadd52(x5_df) 408 | x5_f = torch.cat((x5_rf , x5_df),1) 409 | x5_f = self.convcat53(x5_f) 410 | x5_f = self.convcat54(x5_f) 411 | x5_s_r = self.conv5_r(x5_f) 412 | x5_s_r = self.relu5_r(x5_s_r) 413 | x5_s = x5_f + x5_s_r 414 | x4_s_1 = x4 415 | x4_s_1 = self.conv4_8(x4_s_1) 416 | x4_s_1 = self.relu4_8(x4_s_1) 417 | x4_s_1 = self.conv4_9(x4_s_1) 418 | x4_s_1 = self.relu4_9(x4_s_1) 419 | x4_d = x4_3 420 | x4_d = self.conv4_1(x4_d) 421 | x4_d = self.relu4_1(x4_d) 422 | x4_c = self.conv4_2(x4_d) 423 | x4_c = self.relu4_2(x4_c) 424 | x4_b = x4_2 425 | x4_b = self.conv4_3(x4_b) 426 | x4_b = self.relu4_3(x4_b) 427 | x4_b = self.conv4_4(x4_b) 428 | x4_b = self.relu4_4(x4_b) 429 | x4_reatt = self.upsample(self.upsample((self.maxpool4( self.upsample(1-x_att5))))) 430 | x4_res2 = x4_s_1 * x4_reatt 431 | x4_res2 = self.channel_reatt4(x4_res2) 432 | x4_res = x4_s_1 * self.pool4_1(res_dsf) 433 | x4_res = self.channel_ratt4(x4_res) 434 | x4_res = self.conv4_res1(x4_res) 435 | x4_res = self.conv4_res2(x4_res) 436 | x4_stru = x4_s_1 * x4_b 437 | x4_stru = self.conv4_be(x4_stru) 438 | x4_stru = self.relu(x4_stru + x4_s_1) 439 | x4_rf = torch.cat((x4_res2, x4_res), 1) 440 | x4_rf = self.convcat4(x4_rf) 441 | x4_rf = x4_rf + x4_stru 442 | x4_rf = self.convadd4(x4_rf) 443 | x4_red = x4_c * self.upsample(x_att5) 444 | x4_att = self.channel_att4(x4_red) 445 | x4_att = self.conv4_red1(x4_att) 446 | x4_att = self.conv4_red2(x4_att) 447 | x4_red2 = x4_c * self.pool4_2(res_f) 448 | x4_red2 = self.channel_rdatt4(x4_red2) 449 | x4_strud = x4_c * x4_b 450 | x4_strud = self.conv4_bd(x4_strud) 451 | x4_strud = self.relu(x4_strud + x4_c) 452 | x4_df = torch.cat((x4_att, x4_red2), 1) 453 | x4_df = self.convcat42(x4_df) 454 | x4_df = x4_df + x4_strud 455 | x4_df = self.convadd42(x4_df) 456 | x4_f = torch.cat((x4_rf, x4_df), 1) 457 | x4_f = self.convcat43(x4_f) 458 | x4_f = self.convcat44(x4_f) 459 | x4_s_r = self.conv4_r(x4_f) 460 | x4_s_r = self.relu4_r(x4_s_r) 461 | x4_s = x4_f + x4_s_r 462 | x3_s_1 = x3 463 | x3_s_1 = self.conv3_13(x3_s_1) 464 | x3_s_1 = self.relu3_13(x3_s_1) 465 | x3_s_1 = self.conv3_14(x3_s_1) 466 | x3_s_1 = self.relu3_14(x3_s_1) 467 | x3_d = x3_3 468 | x3_d = self.conv3_9(x3_d) 469 | x3_d = self.relu3_7(x3_d) 470 | x3_d = self.conv3_10(x3_d) 471 | x3_c = self.relu3_8(x3_d) 472 | x3_b = x3_2 473 | x3_b = self.conv3_11(x3_b) 474 | x3_b = self.relu3_9(x3_b) 475 | x3_b = self.conv3_12(x3_b) 476 | x3_b = self.relu3_10(x3_b) 477 | x3_red = x3_c * self.upsample(self.upsample(x_att5)) 478 | x3_att = self.channel_att3(x3_red) 479 | x3_att = self.conv3_red1(x3_att) 480 | x3_att = self.conv3_red2(x3_att) 481 | x3_res = x3_s_1 * res_dsf 482 | x3_res = self.channel_ratt3(x3_res) 483 | x3_res = self.conv3_res1(x3_res) 484 | x3_res = self.conv3_res2(x3_res) 485 | x3_reatt = self.upsample(self.upsample(self.maxpool3((self.upsample(self.upsample(1-x_att5)))))) 486 | x3_res2 = x3_s_1 * x3_reatt 487 | x3_res2 = self.channel_reatt3(x3_res2) 488 | x3_stru = x3_s_1 * x3_b 489 | x3_stru = self.conv3_be(x3_stru) 490 | x3_stru = self.relu(x3_stru + x3_s_1) 491 | x3_rf = torch.cat((x3_res2, x3_res), 1) 492 | x3_rf1 = self.convcat3(x3_rf) 493 | x3_rf = x3_rf1 + x3_stru 494 | x3_rf = self.convadd3(x3_rf) 495 | # self modal attention 496 | x3_red2 = x3_c * res_f 497 | x3_red2 = self.channel_rdatt3(x3_red2) 498 | x3_strud = x3_c * x3_b 499 | x3_strud = self.conv3_bd(x3_strud) 500 | x3_strud = self.relu(x3_strud + x3_c) 501 | x3_df = torch.cat((x3_att, x3_red2), 1) 502 | x3_df1 = self.convcat32(x3_df) 503 | x3_df = x3_df1 + x3_strud 504 | x3_df = self.convadd32(x3_df) 505 | x3_f = torch.cat((x3_rf, x3_df), 1) 506 | x3_f = self.convcat33(x3_f) 507 | x3_f = self.convcat34(x3_f) 508 | x3_s_r = self.conv3_r(x3_f) 509 | x3_s_r = self.relu3_r(x3_s_r) 510 | x3_s = x3_f + x3_s_r 511 | x3_s = x3_s + self.conv_upsample6(self.upsample(x4_s)) + self.conv_upsample7(self.upsample(self.upsample(x5_s))) 512 | x4_s = x4_s + self.conv_upsample12(self.upsample(x5_s)) 513 | x4_s_2 = torch.cat((x4_s, self.conv_upsample8(self.upsample(x5_s))), 1) 514 | x4_s_2 = self.conv_concat5(x4_s_2) 515 | x3_s_2 = torch.cat((x3_s, self.conv_upsample9(self.upsample(x4_s_2))), 1) 516 | x3_s_2 = self.conv_concat6(x3_s_2) 517 | x3_s_2 = self.conv3_3(x3_s_2) 518 | x3_s_2 = self.conv3_4(x3_s_2) 519 | x_attention = self.conv3_8(x3_s_2) 520 | 521 | return x_attention,x_att5 522 | 523 | 524 | class model_VGG(nn.Module): 525 | def __init__(self, channel=32): 526 | super(model_VGG, self).__init__() 527 | self.vgg = VGG_Pr() 528 | self.vgg_d = VGG_Pr() 529 | self.macb3_1 = RFB(256, channel) 530 | self.macb4_1 = RFB(512, channel) 531 | self.macb5_1 = RFB(512, channel) 532 | self.agg1 = decoder_d(channel) 533 | self.macb3_2 = RFB(256, channel) 534 | self.macb4_2 = RFB(512, channel) 535 | self.macb5_2 = RFB(512, channel) 536 | self.agg2 = decoder_s(channel) 537 | self.macb3_3 = RFB(256, channel) 538 | self.macb4_3 = RFB(512, channel) 539 | self.macb5_3 = RFB(512, channel) 540 | self.agg3 = decoder_b(channel) 541 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 542 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 543 | def forward(self, x,x_d,depth_): 544 | #depth VGG_Stream 545 | x1 = self.vgg_d.conv1(x_d) 546 | x2 = self.vgg_d.conv2(x1) 547 | x3 = self.vgg_d.conv3(x2) 548 | x4 = self.vgg_d.conv4_1(x3) 549 | x5 = self.vgg_d.conv5_1(x4) 550 | x3_d = self.macb3_1(x3) 551 | x4_d = self.macb4_1(x4) 552 | x5_d = self.macb5_1(x5) 553 | saliency_d = self.agg1(x3_d,x4_d,x5_d) 554 | # RGB VGG_Stream 555 | x1_s = self.vgg.conv1(x) 556 | x2_s = self.vgg.conv2(x1_s) 557 | x3_s = self.vgg.conv3(x2_s) 558 | x4_s = self.vgg.conv4_1(x3_s) 559 | x5_s = self.vgg.conv5_1(x4_s) 560 | x3_s1 = self.macb3_2(x3_s) 561 | x4_s1 = self.macb4_2(x4_s) 562 | x5_s1 = self.macb5_2(x5_s) 563 | x3_b = self.macb3_3(x3_s) 564 | x4_b = self.macb4_3(x4_s) 565 | x5_b = self.macb5_3(x5_s) 566 | boundary = self.agg3(x3_b, x4_b, x5_b) 567 | detection,x_att5= self.agg2(x3_s1, x4_s1, x5_s1,x3_b,x4_b,x5_b,x3_d.detach(),x4_d.detach(),x5_d.detach(),depth_) 568 | return self.upsample1(saliency_d), self.upsample1(detection), self.upsample1(boundary),self.upsample1(self.upsample1(x_att5)) --------------------------------------------------------------------------------