├── model ├── __init__.py ├── ResNet.py └── ResNet_models.py ├── utils ├── __init__.py ├── func.py └── data.py ├── figure ├── framework.png ├── results1.png ├── results2.png └── results3.png ├── test_SCRN.py ├── train_SCRN.py └── README.md /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuzhe71/SCRN/HEAD/figure/framework.png -------------------------------------------------------------------------------- /figure/results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuzhe71/SCRN/HEAD/figure/results1.png -------------------------------------------------------------------------------- /figure/results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuzhe71/SCRN/HEAD/figure/results2.png -------------------------------------------------------------------------------- /figure/results3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuzhe71/SCRN/HEAD/figure/results3.png -------------------------------------------------------------------------------- /test_SCRN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | import numpy as np 6 | import os 7 | from scipy import misc 8 | from datetime import datetime 9 | 10 | from utils.data import test_dataset 11 | from model.ResNet_models import SCRN 12 | 13 | model = SCRN() 14 | model.load_state_dict(torch.load('./model/model.pth')) 15 | model.cuda() 16 | model.eval() 17 | 18 | data_path = '/backup/materials/Dataset/SalientObject/dataset/' 19 | # valset = ['ECSSD', 'HKUIS', 'PASCAL', 'DUT-OMRON', 'THUR15K', 'DUTS-TEST'] 20 | valset = ['ECSSD'] 21 | for dataset in valset: 22 | save_path = './saliency_maps/' + dataset + '/' 23 | if not os.path.exists(save_path): 24 | os.makedirs(save_path) 25 | image_root = data_path + dataset + '/images/' 26 | gt_root = data_path + dataset + '/gts/' 27 | test_loader = test_dataset(image_root, gt_root, testsize=352) 28 | 29 | with torch.no_grad(): 30 | for i in range(test_loader.size): 31 | image, gt, name = test_loader.load_data() 32 | gt = np.array(gt).astype('float') 33 | gt = gt / (gt.max() + 1e-8) 34 | image = Variable(image).cuda() 35 | 36 | res, edge = model(image) 37 | 38 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=True) 39 | res = res.sigmoid().data.cpu().numpy().squeeze() 40 | misc.imsave(save_path + name + '.png', res) 41 | 42 | -------------------------------------------------------------------------------- /utils/func.py: -------------------------------------------------------------------------------- 1 | # The edge code refers to 'Non-Local Deep Features for Salient Object Detection', CVPR 2017. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | fx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).astype(np.float32) 9 | fy = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).astype(np.float32) 10 | fx = np.reshape(fx, (1, 1, 3, 3)) 11 | fy = np.reshape(fy, (1, 1, 3, 3)) 12 | fx = Variable(torch.from_numpy(fx)).cuda() 13 | fy = Variable(torch.from_numpy(fy)).cuda() 14 | contour_th = 1.5 15 | 16 | 17 | def label_edge_prediction(label): 18 | # convert label to edge 19 | label = label.gt(0.5).float() 20 | label = F.pad(label, (1, 1, 1, 1), mode='replicate') 21 | label_fx = F.conv2d(label, fx) 22 | label_fy = F.conv2d(label, fy) 23 | label_grad = torch.sqrt(torch.mul(label_fx, label_fx) + torch.mul(label_fy, label_fy)) 24 | label_grad = torch.gt(label_grad, contour_th).float() 25 | 26 | return label_grad 27 | 28 | 29 | def pred_edge_prediction(pred): 30 | # infer edge from prediction 31 | pred = F.pad(pred, (1, 1, 1, 1), mode='replicate') 32 | pred_fx = F.conv2d(pred, fx) 33 | pred_fy = F.conv2d(pred, fy) 34 | pred_grad = (pred_fx*pred_fx + pred_fy*pred_fy).sqrt().tanh() 35 | 36 | return pred_fx, pred_fy, pred_grad 37 | 38 | 39 | class AvgMeter(object): 40 | def __init__(self, num=40): 41 | self.num = num 42 | self.reset() 43 | 44 | def reset(self): 45 | self.val = 0 46 | self.avg = 0 47 | self.sum = 0 48 | self.count = 0 49 | self.losses = [] 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | self.losses.append(val) 57 | 58 | def show(self): 59 | return np.mean(self.losses[np.maximum(len(self.losses)-self.num, 0):]) -------------------------------------------------------------------------------- /train_SCRN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.optim import lr_scheduler 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | import os, argparse 8 | from datetime import datetime 9 | 10 | from utils.data import get_loader 11 | from utils.func import label_edge_prediction, AvgMeter 12 | 13 | from model.ResNet_models import SCRN 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--epoch', type=int, default=30, help='epoch number') 17 | parser.add_argument('--lr', type=float, default=2e-3, help='learning rate') 18 | parser.add_argument('--batchsize', type=int, default=8, help='batch size') 19 | parser.add_argument('--trainsize', type=int, default=352, help='input size') 20 | parser.add_argument('--trainset', type=str, default='DUTS-TRAIN', help='training dataset') 21 | opt = parser.parse_args() 22 | 23 | # data preparing, set your own data path here 24 | data_path = '/SalientObject/dataset/' 25 | image_root = data_path + opt.trainset + '/images/' 26 | gt_root = data_path + opt.trainset + '/gts/' 27 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 28 | total_step = len(train_loader) 29 | 30 | # build models 31 | model = SCRN() 32 | model.cuda() 33 | params = model.parameters() 34 | optimizer = torch.optim.SGD(params, opt.lr, momentum=0.9, weight_decay=5e-4) 35 | scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 36 | CE = torch.nn.BCEWithLogitsLoss() 37 | size_rates = [0.75, 1, 1.25] # multi-scale training 38 | 39 | # training 40 | for epoch in range(0, opt.epoch): 41 | scheduler.step() 42 | model.train() 43 | loss_record1, loss_record2 = AvgMeter(), AvgMeter() 44 | for i, pack in enumerate(train_loader, start=1): 45 | for rate in size_rates: 46 | optimizer.zero_grad() 47 | 48 | images, gts = pack 49 | images = Variable(images).cuda() 50 | gts = Variable(gts).cuda() 51 | # edge prediction 52 | gt_edges = label_edge_prediction(gts) 53 | 54 | # multi-scale training samples 55 | trainsize = int(round(opt.trainsize*rate/32)*32) 56 | if rate != 1: 57 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 58 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 59 | gt_edges = F.upsample(gt_edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 60 | # forward 61 | pred_sal, pred_edge = model(images) 62 | 63 | loss1 = CE(pred_sal, gts) 64 | loss2 = CE(pred_edge, gt_edges) 65 | loss = loss1 + loss2 66 | loss.backward() 67 | 68 | optimizer.step() 69 | if rate == 1: 70 | loss_record1.update(loss1.data, opt.batchsize) 71 | loss_record2.update(loss2.data, opt.batchsize) 72 | 73 | if i % 1000 == 0 or i == total_step: 74 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f}, Loss2: {:.4f}'. 75 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(), loss_record2.show())) 76 | 77 | save_path = './models/' 78 | if not os.path.exists(save_path): 79 | os.makedirs(save_path) 80 | torch.save(model.state_dict(), save_path + opt.trainset + '_w.pth') 81 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os, glob, random 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class SalObjDataset(data.Dataset): 8 | def __init__(self, image_root, gt_root, trainsize): 9 | self.trainsize = trainsize 10 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 11 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] 12 | self.images = sorted(self.images) 13 | self.gts = sorted(self.gts) 14 | self.size = len(self.images) 15 | self.img_transform = transforms.Compose([ 16 | transforms.Resize((self.trainsize, self.trainsize)), 17 | transforms.ToTensor(), 18 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 19 | self.gt_transform = transforms.Compose([ 20 | transforms.Resize((self.trainsize, self.trainsize)), 21 | transforms.ToTensor()]) 22 | 23 | def __getitem__(self, index): 24 | image = self.rgb_loader(self.images[index]) 25 | gt = self.binary_loader(self.gts[index]) 26 | image = self.img_transform(image) 27 | gt = self.gt_transform(gt) 28 | return image, gt 29 | 30 | def rgb_loader(self, path): 31 | with open(path, 'rb') as f: 32 | img = Image.open(f) 33 | return img.convert('RGB') 34 | 35 | def binary_loader(self, path): 36 | with open(path, 'rb') as f: 37 | img = Image.open(f) 38 | return img.convert('L') 39 | 40 | def __len__(self): 41 | return self.size 42 | 43 | 44 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 45 | 46 | dataset = SalObjDataset(image_root, gt_root, trainsize) 47 | data_loader = data.DataLoader(dataset=dataset, 48 | batch_size=batchsize, 49 | shuffle=shuffle, 50 | num_workers=num_workers, 51 | pin_memory=pin_memory) 52 | return data_loader 53 | 54 | 55 | class test_dataset: 56 | def __init__(self, image_root, gt_root, testsize): 57 | self.testsize = testsize 58 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 59 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] 60 | self.images = sorted(self.images) 61 | self.gts = sorted(self.gts) 62 | self.img_transform = transforms.Compose([ 63 | transforms.Resize((self.testsize, self.testsize)), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 66 | self.transform = transforms.Compose([ 67 | transforms.ToTensor()]) 68 | self.gt_transform = transforms.ToTensor() 69 | self.size = len(self.images) 70 | self.index = 0 71 | 72 | def load_data(self): 73 | image = self.rgb_loader(self.images[self.index]) 74 | t_image = self.img_transform(image).unsqueeze(0) 75 | gt = self.binary_loader(self.gts[self.index]) 76 | name = self.images[self.index].split('/')[-1][0:-4] 77 | self.index += 1 78 | return t_image, gt, name 79 | 80 | def rgb_loader(self, path): 81 | with open(path, 'rb') as f: 82 | img = Image.open(f) 83 | return img.convert('RGB') 84 | 85 | def binary_loader(self, path): 86 | with open(path, 'rb') as f: 87 | img = Image.open(f) 88 | return img.convert('L') 89 | 90 | 91 | -------------------------------------------------------------------------------- /model/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet50(nn.Module): 83 | # ResNet with two branches 84 | def __init__(self): 85 | self.inplanes = 64 86 | super(ResNet50, self).__init__() 87 | 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 93 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 94 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 95 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) 96 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | def _make_layer(self, block, planes, blocks, stride=1): 107 | downsample = None 108 | if stride != 1 or self.inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | nn.Conv2d(self.inplanes, planes * block.expansion, 111 | kernel_size=1, stride=stride, bias=False), 112 | nn.BatchNorm2d(planes * block.expansion), 113 | ) 114 | 115 | layers = [] 116 | layers.append(block(self.inplanes, planes, stride, downsample)) 117 | self.inplanes = planes * block.expansion 118 | for i in range(1, blocks): 119 | layers.append(block(self.inplanes, planes)) 120 | 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) # 1/2 125 | x = self.bn1(x) 126 | x = self.relu(x) 127 | x = self.maxpool(x) # 1/4 128 | 129 | x = self.layer1(x) 130 | x = self.layer2(x) # 1/8 131 | x = self.layer3(x) # 1/ 16 132 | x = self.layer4(x) # 1/32 133 | 134 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCRN 2 | Code repository for our paper "Stacked Cross Refinement Network for Edge-Aware Salient Object Detection", ICCV 2019 poster. [Paper](http://openaccess.thecvf.com/content_ICCV_2019/papers/Wu_Stacked_Cross_Refinement_Network_for_Edge-Aware_Salient_Object_Detection_ICCV_2019_paper.pdf) and [supplementary material](http://openaccess.thecvf.com/content_ICCV_2019/supplemental/Wu_Stacked_Cross_Refinement_ICCV_2019_supplemental.pdf) are available. 3 | 4 | # Change Log 5 | 2020.11.4:We update the predicted [saliency maps](https://github.com/wuzhe71/SCRN#soc-saliency-maps) of 20 algorithms on SOC test and validation sets! 6 | # Framework 7 | ![image](https://github.com/wuzhe71/SCAN/blob/master/figure/framework.png) 8 | 9 | # Experiments 10 | 1. Results on traditional datasets 11 | ![results1](https://github.com/wuzhe71/SCAN/blob/master/figure/results1.png) 12 | 13 | 2. Results on SOC (attribute-based performance, structure simalarity scores), more comparison can be found in [SOC Leaderboard](http://dpfan.net/SOCBenchmark/) 14 | ![results3](https://github.com/wuzhe71/SCAN/blob/master/figure/results3.png) 15 | 16 | # Usage 17 | 1. Requirements 18 | * pytorch 0.40+ 19 | * scipy 20 | 2. Clone the repo 21 | ``` 22 | git clone https://github.com/wuzhe71/SCRN.git 23 | cd SCRN 24 | ``` 25 | 26 | 3. Train/Test 27 | * Train 28 | * Download datasets: [DUTS](http://saliencydetection.net/duts/), [DUT-OMRON](http://saliencydetection.net/dut-omron/), [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html), [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html), [PASCAL-S](http://www.cbi.gatech.edu/salobj/), [THUR15K](https://mmcheng.net/gsal/), [SOC](http://dpfan.net/SOCBenchmark/) 29 | * Set your dataset path, then 30 | ``` 31 | python train_SCRN.py 32 | ``` 33 | * We only use multi-scale traing for data agumentation, and the lr is set as 0.002. If you change to single-scale training, the lr should better change to 0.005. 34 | * Test 35 | * Download the pre-trained model from [google drive](https://drive.google.com/open?id=1PkGX9R-uTYpWBKX0lZRkE2qvvpz1-IiG) or [baidu yun](https://pan.baidu.com/s/1Gm-YptzsVnHU0a6YkdjQaQ) (code: ilhx), and put it in './model/'. This model is only trained on the training set of DUTS and tested on other datasets, including SOC and test set of DUTS. Set your dataset path, then 36 | ``` 37 | python test_SCRN.py 38 | ``` 39 | * You can also download the pre-computed saliency maps from [google drive](https://drive.google.com/open?id=1gRis5weSxuv9w6EZ23MPAnyDe-hUx07L) or [baidu yun](https://pan.baidu.com/s/1VHl_pWvbZGeAKgMwqFEHsw) (code: 8mty). 40 | 41 | # SOC saliency maps 42 | In the paper, we compare SCRN with nine methods on SOC validation set. Here we provide saliency maps of 20 SOD methods on both test and validation sets ([google drive](https://drive.google.com/file/d/10Jw1E4S6zQfeoa1SM3Aj93K0RnmAqRCg/view?usp=sharing) or [baidu yun](https://pan.baidu.com/s/1mWpE3jEVvGlb5VuSkSZ7jw) (code: wnjp)): [DSS](https://openaccess.thecvf.com/content_cvpr_2017/papers/Hou_Deeply_Supervised_Salient_CVPR_2017_paper.pdf)、[NLDF](https://openaccess.thecvf.com/content_cvpr_2017/papers/Luo_Non-Local_Deep_Features_CVPR_2017_paper.pdf)、[SRM](https://openaccess.thecvf.com/content_ICCV_2017/papers/Wang_A_Stagewise_Refinement_ICCV_2017_paper.pdf)、[Amulet](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhang_Amulet_Aggregating_Multi-Level_ICCV_2017_paper.pdf)、[DGRL](https://openaccess.thecvf.com/content_cvpr_2018/papers/Wang_Detect_Globally_Refine_CVPR_2018_paper.pdf)、[BMPM](https://openaccess.thecvf.com/content_cvpr_2018/papers_backup/Zhang_A_Bi-Directional_Message_CVPR_2018_paper.pdf)、[PiCANet-R](https://openaccess.thecvf.com/content_cvpr_2018/papers/Liu_PiCANet_Learning_Pixel-Wise_CVPR_2018_paper.pdf)、[R3Net](https://www.ijcai.org/Proceedings/2018/0095.pdf)、[C2S-Net](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xin_Li_Contour_Knowledge_Transfer_ECCV_2018_paper.pdf)、[RANet](https://openaccess.thecvf.com/content_ECCV_2018/papers/Shuhan_Chen_Reverse_Attention_for_ECCV_2018_paper.pdf)、[CPD](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_Cascaded_Partial_Decoder_for_Fast_and_Accurate_Salient_Object_Detection_CVPR_2019_paper.pdf)、[AFN](https://openaccess.thecvf.com/content_CVPR_2019/papers/Feng_Attentive_Feedback_Network_for_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.pdf)、[BASNet](https://openaccess.thecvf.com/content_CVPR_2019/papers/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.pdf)、[PoolNet](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_A_Simple_Pooling-Based_Design_for_Real-Time_Salient_Object_Detection_CVPR_2019_paper.pdf)、[SCRN](https://openaccess.thecvf.com/content_ICCV_2019/papers/Wu_Stacked_Cross_Refinement_Network_for_Edge-Aware_Salient_Object_Detection_ICCV_2019_paper.pdf)、[SIBA](https://openaccess.thecvf.com/content_ICCV_2019/papers/Su_Selectivity_or_Invariance_Boundary-Aware_Salient_Object_Detection_ICCV_2019_paper.pdf)、[EGNet](https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhao_EGNet_Edge_Guidance_Network_for_Salient_Object_Detection_ICCV_2019_paper.pdf)、[F3Net](https://aaai.org/ojs/index.php/AAAI/article/view/6916)、[GCPANet](https://aaai.org/ojs/index.php/AAAI/article/view/6633)、[MINet](https://openaccess.thecvf.com/content_CVPR_2020/papers/Pang_Multi-Scale_Interactive_Network_for_Salient_Object_Detection_CVPR_2020_paper.pdf). 43 | 44 | # Citation 45 | ``` 46 | @InProceedings{Wu_2019_ICCV, 47 | author = {Wu, Zhe and Su, Li and Huang, Qingming}, 48 | title = {Stacked Cross Refinement Network for Edge-Aware Salient Object Detection}, 49 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 50 | month = {October}, 51 | year = {2019} 52 | } 53 | ``` 54 | 55 | # Contact Us 56 | If you have any question, please contact us (wuzh02@pcl.ac.cn). 57 | -------------------------------------------------------------------------------- /model/ResNet_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision.models as models 6 | from .ResNet import ResNet50 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 11 | super(BasicConv2d, self).__init__() 12 | self.conv_bn = nn.Sequential( 13 | nn.Conv2d(in_planes, out_planes, 14 | kernel_size=kernel_size, stride=stride, 15 | padding=padding, dilation=dilation, bias=False), 16 | nn.BatchNorm2d(out_planes) 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.conv_bn(x) 21 | return x 22 | 23 | 24 | class Reduction(nn.Module): 25 | def __init__(self, in_channel, out_channel): 26 | super(Reduction, self).__init__() 27 | self.reduce = nn.Sequential( 28 | BasicConv2d(in_channel, out_channel, 1), 29 | BasicConv2d(out_channel, out_channel, 3, padding=1), 30 | BasicConv2d(out_channel, out_channel, 3, padding=1) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.reduce(x) 35 | 36 | 37 | class conv_upsample(nn.Module): 38 | def __init__(self, channel): 39 | super(conv_upsample, self).__init__() 40 | self.conv = BasicConv2d(channel, channel, 1) 41 | 42 | def forward(self, x, target): 43 | if x.size()[2:] != target.size()[2:]: 44 | x = self.conv(F.upsample(x, size=target.size()[2:], mode='bilinear', align_corners=True)) 45 | return x 46 | 47 | 48 | class DenseFusion(nn.Module): 49 | # Cross Refinement Unit 50 | def __init__(self, channel): 51 | super(DenseFusion, self).__init__() 52 | self.conv1 = conv_upsample(channel) 53 | self.conv2 = conv_upsample(channel) 54 | self.conv3 = conv_upsample(channel) 55 | self.conv4 = conv_upsample(channel) 56 | self.conv5 = conv_upsample(channel) 57 | self.conv6 = conv_upsample(channel) 58 | self.conv7 = conv_upsample(channel) 59 | self.conv8 = conv_upsample(channel) 60 | self.conv9 = conv_upsample(channel) 61 | self.conv10 = conv_upsample(channel) 62 | self.conv11 = conv_upsample(channel) 63 | self.conv12 = conv_upsample(channel) 64 | 65 | self.conv_f1 = nn.Sequential( 66 | BasicConv2d(5*channel, channel, 3, padding=1), 67 | BasicConv2d(channel, channel, 3, padding=1) 68 | ) 69 | self.conv_f2 = nn.Sequential( 70 | BasicConv2d(4*channel, channel, 3, padding=1), 71 | BasicConv2d(channel, channel, 3, padding=1) 72 | ) 73 | self.conv_f3 = nn.Sequential( 74 | BasicConv2d(3*channel, channel, 3, padding=1), 75 | BasicConv2d(channel, channel, 3, padding=1) 76 | ) 77 | self.conv_f4 = nn.Sequential( 78 | BasicConv2d(2*channel, channel, 3, padding=1), 79 | BasicConv2d(channel, channel, 3, padding=1) 80 | ) 81 | 82 | self.conv_f5 = nn.Sequential( 83 | BasicConv2d(channel, channel, 3, padding=1), 84 | BasicConv2d(channel, channel, 3, padding=1) 85 | ) 86 | self.conv_f6 = nn.Sequential( 87 | BasicConv2d(channel, channel, 3, padding=1), 88 | BasicConv2d(channel, channel, 3, padding=1) 89 | ) 90 | self.conv_f7 = nn.Sequential( 91 | BasicConv2d(channel, channel, 3, padding=1), 92 | BasicConv2d(channel, channel, 3, padding=1) 93 | ) 94 | self.conv_f8 = nn.Sequential( 95 | BasicConv2d(channel, channel, 3, padding=1), 96 | BasicConv2d(channel, channel, 3, padding=1) 97 | ) 98 | 99 | def forward(self, x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4): 100 | x_sf1 = x_s1 + self.conv_f1(torch.cat((x_s1, x_e1, 101 | self.conv1(x_e2, x_s1), 102 | self.conv2(x_e3, x_s1), 103 | self.conv3(x_e4, x_s1)), 1)) 104 | x_sf2 = x_s2 + self.conv_f2(torch.cat((x_s2, x_e2, 105 | self.conv4(x_e3, x_s2), 106 | self.conv5(x_e4, x_s2)), 1)) 107 | x_sf3 = x_s3 + self.conv_f3(torch.cat((x_s3, x_e3, 108 | self.conv6(x_e4, x_s3)), 1)) 109 | x_sf4 = x_s4 + self.conv_f4(torch.cat((x_s4, x_e4), 1)) 110 | 111 | x_ef1 = x_e1 + self.conv_f5(x_e1 * x_s1 * 112 | self.conv7(x_s2, x_e1) * 113 | self.conv8(x_s3, x_e1) * 114 | self.conv9(x_s4, x_e1)) 115 | x_ef2 = x_e2 + self.conv_f6(x_e2 * x_s2 * 116 | self.conv10(x_s3, x_e2) * 117 | self.conv11(x_s4, x_e2)) 118 | x_ef3 = x_e3 + self.conv_f7(x_e3 * x_s3 * 119 | self.conv12(x_s4, x_e3)) 120 | x_ef4 = x_e4 + self.conv_f8(x_e4 * x_s4) 121 | 122 | return x_sf1, x_sf2, x_sf3, x_sf4, x_ef1, x_ef2, x_ef3, x_ef4 123 | 124 | 125 | class ConcatOutput(nn.Module): 126 | def __init__(self, channel): 127 | super(ConcatOutput, self).__init__() 128 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 129 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 130 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 131 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 132 | 133 | self.conv_cat1 = nn.Sequential( 134 | BasicConv2d(2*channel, 2*channel, 3, padding=1), 135 | BasicConv2d(2*channel, channel, 1) 136 | ) 137 | self.conv_cat2 = nn.Sequential( 138 | BasicConv2d(2*channel, 2*channel, 3, padding=1), 139 | BasicConv2d(2*channel, channel, 1) 140 | ) 141 | self.conv_cat3 = nn.Sequential( 142 | BasicConv2d(2*channel, 2*channel, 3, padding=1), 143 | BasicConv2d(2*channel, channel, 1) 144 | ) 145 | self.output = nn.Sequential( 146 | BasicConv2d(channel, channel, 3, padding=1), 147 | nn.Conv2d(channel, 1, 1) 148 | ) 149 | 150 | def forward(self, x1, x2, x3, x4): 151 | x3 = torch.cat((x3, self.conv_upsample1(self.upsample(x4))), 1) 152 | x3 = self.conv_cat1(x3) 153 | 154 | x2 = torch.cat((x2, self.conv_upsample2(self.upsample(x3))), 1) 155 | x2 = self.conv_cat2(x2) 156 | 157 | x1 = torch.cat((x1, self.conv_upsample3(self.upsample(x2))), 1) 158 | x1 = self.conv_cat3(x1) 159 | 160 | x = self.output(x1) 161 | return x 162 | 163 | 164 | class SCRN(nn.Module): 165 | # Stacked Cross Refinement Network 166 | def __init__(self, channel=32): 167 | super(SCRN, self).__init__() 168 | self.resnet = ResNet50() 169 | self.reduce_s1 = Reduction(256, channel) 170 | self.reduce_s2 = Reduction(512, channel) 171 | self.reduce_s3 = Reduction(1024, channel) 172 | self.reduce_s4 = Reduction(2048, channel) 173 | 174 | self.reduce_e1 = Reduction(256, channel) 175 | self.reduce_e2 = Reduction(512, channel) 176 | self.reduce_e3 = Reduction(1024, channel) 177 | self.reduce_e4 = Reduction(2048, channel) 178 | 179 | self.df1 = DenseFusion(channel) 180 | self.df2 = DenseFusion(channel) 181 | self.df3 = DenseFusion(channel) 182 | self.df4 = DenseFusion(channel) 183 | 184 | self.output_s = ConcatOutput(channel) 185 | self.output_e = ConcatOutput(channel) 186 | 187 | for m in self.modules(): 188 | if isinstance(m, nn.Conv2d): 189 | m.weight.data.normal_(std=0.01) 190 | elif isinstance(m, nn.BatchNorm2d): 191 | m.weight.data.fill_(1) 192 | m.bias.data.zero_() 193 | 194 | self.initialize_weights() 195 | 196 | def forward(self, x): 197 | size = x.size()[2:] 198 | x = self.resnet.conv1(x) 199 | x = self.resnet.bn1(x) 200 | x = self.resnet.relu(x) 201 | x = self.resnet.maxpool(x) 202 | x1 = self.resnet.layer1(x) 203 | x2 = self.resnet.layer2(x1) 204 | x3 = self.resnet.layer3(x2) 205 | x4 = self.resnet.layer4(x3) 206 | 207 | # feature abstraction 208 | x_s1 = self.reduce_s1(x1) 209 | x_s2 = self.reduce_s2(x2) 210 | x_s3 = self.reduce_s3(x3) 211 | x_s4 = self.reduce_s4(x4) 212 | 213 | x_e1 = self.reduce_e1(x1) 214 | x_e2 = self.reduce_e2(x2) 215 | x_e3 = self.reduce_e3(x3) 216 | x_e4 = self.reduce_e4(x4) 217 | 218 | # four cross refinement units 219 | x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4 = self.df1(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4) 220 | x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4 = self.df2(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4) 221 | x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4 = self.df3(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4) 222 | x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4 = self.df4(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4) 223 | 224 | # feature aggregation using u-net 225 | pred_s = self.output_s(x_s1, x_s2, x_s3, x_s4) 226 | pred_e = self.output_e(x_e1, x_e2, x_e3, x_e4) 227 | 228 | pred_s = F.upsample(pred_s, size=size, mode='bilinear', align_corners=True) 229 | pred_e = F.upsample(pred_e, size=size, mode='bilinear', align_corners=True) 230 | 231 | return pred_s, pred_e 232 | 233 | def initialize_weights(self): 234 | res50 = models.resnet50(pretrained=True) 235 | self.resnet.load_state_dict(res50.state_dict(), False) 236 | --------------------------------------------------------------------------------