├── README.md ├── figs └── result.png ├── v1 ├── RAS.py ├── ResNet50.py ├── data.py ├── test.py └── train.py └── v2 ├── RAS.py ├── data.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # RAS-pytorch 2 | The pytorch code for our TIP2020 paper: [Reverse Attention Based Residual Network for Salient Object Detection](https://ieeexplore.ieee.org/document/8966594) 3 | 4 | --- 5 | 6 | Notice 7 | --- 8 | We use ResNet50 as backbone and add the [IoU loss](https://github.com/NathanUA/BASNet)[1] for better performance in this pytorch implementation. The original caffe version is [here](https://github.com/ShuhanChen/RAS_ECCV18). We provide two versions with different training strategies.
9 | - v1: The same training strategy with [CPD](https://github.com/wuzhe71/CPD)[2].
10 | - v2: The same training strategy with [F3Net](https://github.com/weijun88/F3Net)[3].
11 | 12 | Usage 13 | --- 14 | Modify the pathes of datasets, then run:
15 | ``` 16 | Training: python3 train.py 17 | Testing: python3 test.py 18 | ``` 19 | 20 | Performace 21 | --- 22 | The codes are tested on Ubuntu 18.04 environment (Python3.6.9, PyTorch1.5.0, torchvision0.6.0, CUDA10.2, cuDNN7.6.5) with RTX 2080Ti GPU. We select several recent SOTA methods for comparisons. The evaluation code can be found [here](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox). 23 | ![Quantitative comparisons](https://github.com/ShuhanChen/RAS-pytorch/blob/master/figs/result.png) 24 | 25 | 26 | Pre-trained models & Pre-computed saliency maps 27 | --- 28 | - v1: model [Baidu](https://pan.baidu.com/s/1O5QsWWOjhPMGOWIiIwSX3A)(bc3k) [Google](https://drive.google.com/file/d/1KHmKrAG1M_C0mYgSD8pz9fDmBn2LtoMJ/view?usp=sharing); smaps [Baidu](https://pan.baidu.com/s/13I2F0dPU5mPmklcxbex0Lw)(kp6t) [Google](https://drive.google.com/file/d/1lT_BkFMuD8kPVkjQRR7HVBDzQnY3VkfB/view?usp=sharing)
29 | - v2: model [Baidu](https://pan.baidu.com/s/1XB3VE175bhT_4urBULJ-IQ)(wbz1) [Google](https://drive.google.com/open?id=14WUbyPiKnEafiMu9CWn5EdTAJpF_VqLj); smaps [Baidu](https://pan.baidu.com/s/1HZWx6eqYq7bAUcBtkw75Sw)(j57z) [Google](https://drive.google.com/open?id=1RwyR6GRAiDxeywRT1VjLe3qa_z7uhUWa)
30 | 31 | Citation 32 | --- 33 | ``` 34 | @article{chen2020tip, 35 | author={Chen, Shuhan and Tan, Xiuli and Wang, Ben and Lu, Huchuan and Hu, Xuelong and Fu, Yun}, 36 | journal={IEEE Transactions on Image Processing}, 37 | title={Reverse Attention-Based Residual Network for Salient Object Detection}, 38 | volume={29}, 39 | pages={3763-3776}, 40 | year={2020} 41 | } 42 | ``` 43 | ``` 44 | @inproceedings{chen2018eccv, 45 | author={Chen, Shuhan and Tan, Xiuli and Wang, Ben and Hu, Xuelong}, 46 | booktitle={European Conference on Computer Vision}, 47 | title={Reverse Attention for Salient Object Detection}, 48 | year={2018} 49 | } 50 | ``` 51 | 52 | Acknowledgements 53 | --- 54 | This code is built on [CPD](https://github.com/wuzhe71/CPD)[2] and [F3Net](https://github.com/weijun88/F3Net)[3]. We thank the authors for sharing their codes. 55 | 56 | Reference 57 | --- 58 | > [1] Xuebin Qin, Zichen Zhang, Chenyang Huang, Chao Gao, Masood Dehghan, Martin Jagersand. BASNet: Boundary-Aware Salient Object Detection. In CVPR, 2019.
59 | > [2] Zhe Wu, Li Su, Qingming Huang. Cascaded Partial Decoder for Fast and Accurate Salient Object Detection. In CVPR, 2019.
60 | > [3] Jun Wei, Shuhui Wang, Qingming Huang. F3Net: Fusion, Feedback and Focus for Salient Object Detection. In AAAI, 2020.
61 | -------------------------------------------------------------------------------- /figs/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuhanChen/RAS-pytorch/67ea6c330e56b218e66e864b7c47d9a55f8c0217/figs/result.png -------------------------------------------------------------------------------- /v1/RAS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ResNet50 import ResNet50 5 | import torchvision.models as models 6 | 7 | class MSCM(nn.Module): 8 | def __init__(self, in_channel, out_channel): 9 | super(MSCM, self).__init__() 10 | self.convert = nn.Conv2d(in_channel, out_channel, 1) 11 | self.branch1 = nn.Sequential( 12 | nn.Conv2d(out_channel, out_channel, 1), nn.ReLU(True), 13 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.ReLU(True), 14 | ) 15 | self.branch2 = nn.Sequential( 16 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.ReLU(True), 17 | nn.Conv2d(out_channel, out_channel, 3, padding=2, dilation=2), nn.ReLU(True), 18 | ) 19 | self.branch3 = nn.Sequential( 20 | nn.Conv2d(out_channel, out_channel, 5, padding=2, dilation=1), nn.ReLU(True), 21 | nn.Conv2d(out_channel, out_channel, 3, padding=4, dilation=4), nn.ReLU(True), 22 | ) 23 | self.branch4 = nn.Sequential( 24 | nn.Conv2d(out_channel, out_channel, 7, padding=3, dilation=1), nn.ReLU(True), 25 | nn.Conv2d(out_channel, out_channel, 3, padding=6, dilation=6), nn.ReLU(True), 26 | ) 27 | self.score = nn.Conv2d(out_channel*4, 1, 3, padding=1) 28 | 29 | def forward(self, x): 30 | x = self.convert(x) 31 | x1 = self.branch1(x) 32 | x2 = self.branch2(x) 33 | x3 = self.branch3(x) 34 | x4 = self.branch4(x) 35 | 36 | x = torch.cat((x1, x2, x3, x4), 1) 37 | x = self.score(x) 38 | 39 | return x 40 | 41 | class RA(nn.Module): 42 | def __init__(self, in_channel, out_channel): 43 | super(RA, self).__init__() 44 | self.convert = nn.Conv2d(in_channel, out_channel, 1) 45 | self.convs = nn.Sequential( 46 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True), 47 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True), 48 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True), 49 | nn.Conv2d(out_channel, 1, 3, padding=1), 50 | ) 51 | self.channel = out_channel 52 | 53 | def forward(self, x, y): 54 | a = torch.sigmoid(-y) 55 | x = self.convert(x) 56 | x = a.expand(-1, self.channel, -1, -1).mul(x) 57 | y = y + self.convs(x) 58 | 59 | return y 60 | 61 | class RAS(nn.Module): 62 | def __init__(self, channel=64): 63 | super(RAS, self).__init__() 64 | self.resnet = ResNet50() 65 | self.mscm = MSCM(2048, channel) 66 | self.ra1 = RA(64, channel) 67 | self.ra2 = RA(256, channel) 68 | self.ra3 = RA(512, channel) 69 | self.ra4 = RA(1024, channel) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | m.weight.data.normal_(std=0.01) 74 | elif isinstance(m, nn.BatchNorm2d): 75 | m.weight.data.fill_(1) 76 | m.bias.data.zero_() 77 | 78 | self.initialize_weights() 79 | 80 | def forward(self, x): 81 | x1, x2, x3, x4, x5 = self.resnet(x) 82 | x_size = x.size()[2:] 83 | x1_size = x1.size()[2:] 84 | x2_size = x2.size()[2:] 85 | x3_size = x3.size()[2:] 86 | x4_size = x4.size()[2:] 87 | 88 | y5 = self.mscm(x5) 89 | score5 = F.interpolate(y5, x_size, mode='bilinear', align_corners=True) 90 | 91 | y5_4 = F.interpolate(y5, x4_size, mode='bilinear', align_corners=True) 92 | y4 = self.ra4(x4, y5_4) 93 | score4 = F.interpolate(y4, x_size, mode='bilinear', align_corners=True) 94 | 95 | y4_3 = F.interpolate(y4, x3_size, mode='bilinear', align_corners=True) 96 | y3 = self.ra3(x3, y4_3) 97 | score3 = F.interpolate(y3, x_size, mode='bilinear', align_corners=True) 98 | 99 | y3_2 = F.interpolate(y3, x2_size, mode='bilinear', align_corners=True) 100 | y2 = self.ra2(x2, y3_2) 101 | score2 = F.interpolate(y2, x_size, mode='bilinear', align_corners=True) 102 | 103 | y2_1 = F.interpolate(y2, x1_size, mode='bilinear', align_corners=True) 104 | y1 = self.ra1(x1, y2_1) 105 | score1 = F.interpolate(y1, x_size, mode='bilinear', align_corners=True) 106 | 107 | return score1, score2, score3, score4, score5 108 | 109 | def initialize_weights(self): 110 | res50 = models.resnet50(pretrained=True) 111 | self.resnet.load_state_dict(res50.state_dict(), False) 112 | -------------------------------------------------------------------------------- /v1/ResNet50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet50(nn.Module): 83 | def __init__(self): 84 | self.inplanes = 64 85 | super(ResNet50, self).__init__() 86 | 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 88 | bias=False) 89 | self.bn1 = nn.BatchNorm2d(64) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 92 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 93 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 94 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) 95 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2. / n)) 101 | elif isinstance(m, nn.BatchNorm2d): 102 | m.weight.data.fill_(1) 103 | m.bias.data.zero_() 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x1 = self.conv1(x) # 1/2 124 | x1 = self.bn1(x1) 125 | x1 = self.relu(x1) 126 | x2 = self.maxpool(x1) # 1/4 127 | 128 | x2 = self.layer1(x2) # 1/4 129 | x3 = self.layer2(x2) # 1/8 130 | x4 = self.layer3(x3) # 1/16 131 | x5 = self.layer4(x4) # 1/32 132 | 133 | return x1, x2, x3, x4, x5 134 | -------------------------------------------------------------------------------- /v1/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | 7 | class SalObjDataset(data.Dataset): 8 | def __init__(self, image_path, gt_path, trainsize): 9 | self.images = [image_path + f for f in os.listdir(image_path) if f.endswith('.jpg')] 10 | self.gts = [gt_path + f for f in os.listdir(gt_path) if f.endswith('.png')] 11 | self.images = sorted(self.images) 12 | self.gts = sorted(self.gts) 13 | self.size = len(self.images) 14 | self.img_transform = transforms.Compose([ 15 | transforms.Resize((trainsize, trainsize)), 16 | transforms.ToTensor(), 17 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 18 | self.gt_transform = transforms.Compose([ 19 | transforms.Resize((trainsize, trainsize)), 20 | transforms.ToTensor()]) 21 | 22 | def __getitem__(self, index): 23 | image = self.rgb_loader(self.images[index]) 24 | gt = self.binary_loader(self.gts[index]) 25 | image, gt = self.cv_random_flip(image, gt) 26 | image, gt = self.cv_random_rotate(image, gt) 27 | image = self.img_transform(image) 28 | gt = self.gt_transform(gt) 29 | return image, gt 30 | 31 | 32 | def rgb_loader(self, path): 33 | with open(path, 'rb') as f: 34 | img = Image.open(f) 35 | return img.convert('RGB') 36 | 37 | def binary_loader(self, path): 38 | with open(path, 'rb') as f: 39 | img = Image.open(f) 40 | return img.convert('L') 41 | 42 | def gray_loader(self, path): 43 | with open(path, 'rb') as f: 44 | img = Image.open(f) 45 | return img.convert('L') 46 | 47 | def cv_random_flip(self, img, gt): 48 | if np.random.randint(2)==0: 49 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 50 | gt = gt.transpose(Image.FLIP_LEFT_RIGHT) 51 | return img, gt 52 | 53 | def cv_random_rotate(self, img, gt): 54 | rotate_degree = np.random.random() * 2 * 10 - 10 55 | img = img.rotate(rotate_degree, Image.BILINEAR) 56 | gt = gt.rotate(rotate_degree, Image.NEAREST) 57 | return img, gt 58 | 59 | def __len__(self): 60 | return self.size 61 | 62 | 63 | def get_loader(image_path, gt_path, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 64 | 65 | dataset = SalObjDataset(image_path, gt_path, trainsize) 66 | data_loader = data.DataLoader(dataset=dataset, 67 | batch_size=batchsize, 68 | shuffle=shuffle, 69 | num_workers=num_workers, 70 | pin_memory=pin_memory) 71 | return data_loader 72 | 73 | 74 | class test_dataset: 75 | def __init__(self, image_path, testsize): 76 | self.images = [image_path + f for f in os.listdir(image_path) if f.endswith('.jpg')] 77 | self.images = sorted(self.images) 78 | self.img_transform = transforms.Compose([ 79 | transforms.Resize((testsize, testsize)), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 82 | self.size = len(self.images) 83 | self.index = 0 84 | 85 | def load_data(self): 86 | image = self.rgb_loader(self.images[self.index]) 87 | img_size = (image.size[1], image.size[0]) 88 | image = self.img_transform(image).unsqueeze(0) 89 | name = self.images[self.index].split('/')[-1] 90 | if name.endswith('.jpg'): 91 | name = name.split('.jpg')[0] + '.png' 92 | self.index += 1 93 | return image, img_size, name 94 | 95 | def rgb_loader(self, path): 96 | with open(path, 'rb') as f: 97 | img = Image.open(f) 98 | return img.convert('RGB') 99 | 100 | 101 | -------------------------------------------------------------------------------- /v1/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time 4 | import numpy as np 5 | import pdb, os, argparse 6 | import cv2 7 | 8 | from RAS import RAS 9 | from data import test_dataset 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 13 | opt = parser.parse_args() 14 | 15 | dataset_path = '/home/ipal/datasets/' 16 | 17 | model = RAS() 18 | model.load_state_dict(torch.load('./models/RAS.v1.pth')) 19 | 20 | model.cuda() 21 | model.eval() 22 | 23 | test_datasets = ['ECSSD', 'DUTS', 'DUT-OMRON', 'HKU-IS'] 24 | 25 | for dataset in test_datasets: 26 | save_path = '/home/ipal/evaluation/SaliencyMaps/' + dataset + '/RAS-v1/' 27 | if not os.path.exists(save_path): 28 | os.makedirs(save_path) 29 | image_root = dataset_path + dataset + '/imgs/' 30 | test_loader = test_dataset(image_root, opt.testsize) 31 | time_t = 0.0 32 | for i in range(test_loader.size): 33 | image, img_size, name = test_loader.load_data() 34 | image = image.cuda() 35 | time_start = time.time() 36 | res, _, _, _, _ = model(image) 37 | torch.cuda.synchronize() 38 | time_end = time.time() 39 | time_t = time_t + time_end - time_start 40 | res = F.interpolate(res, img_size, mode='bilinear', align_corners=True) 41 | res = res.sigmoid().data.cpu().numpy().squeeze() 42 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 43 | res = 255 * res 44 | cv2.imwrite(os.path.join(save_path + name[:-4] + '.png'), res) 45 | fps = test_loader.size / time_t 46 | print('FPS is %f' %(fps)) 47 | -------------------------------------------------------------------------------- /v1/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | import pdb, os, argparse 8 | from RAS import RAS 9 | from data import get_loader 10 | 11 | # set parameters 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--epoch', type=int, default=30, help='epoch number') 14 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') 15 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size') 16 | parser.add_argument('--trainsize', type=int, default=352, help='training dataset size') 17 | parser.add_argument('--decay_epoch', type=int, default=[26,], help='every n epochs decay learning rate') 18 | opt = parser.parse_args() 19 | 20 | # build models 21 | model = RAS() 22 | model.cuda() 23 | params = model.parameters() 24 | optimizer = torch.optim.Adam(params, opt.lr) 25 | 26 | # print the network information 27 | num_params = 0 28 | for p in model.parameters(): 29 | num_params += p.numel() 30 | print('RAS Structure') 31 | print(model) 32 | print("The number of parameters: {}".format(num_params)) 33 | 34 | # dataset path 35 | image_path = '/home/ipal/datasets/DUTS_train/imgs/' 36 | gt_path = '/home/ipal/datasets/DUTS_train/gts/' 37 | 38 | train_loader = get_loader(image_path, gt_path, batchsize=opt.batchsize, trainsize=opt.trainsize) 39 | total_step = len(train_loader) 40 | 41 | def bce_iou_loss(pred, gt): 42 | bce = F.binary_cross_entropy_with_logits(pred, gt, reduction='mean') 43 | 44 | pred = torch.sigmoid(pred) 45 | inter = (pred*gt).sum(dim=(2,3)) 46 | union = (pred+gt).sum(dim=(2,3)) 47 | iou = 1-(inter+1)/(union-inter+1) 48 | 49 | return (bce+iou).mean() 50 | 51 | def train(train_loader, model, optimizer, epoch): 52 | model.train() 53 | for i, pack in enumerate(train_loader, start=1): 54 | optimizer.zero_grad() 55 | image, gt = pack 56 | image = Variable(image).cuda() 57 | gt = Variable(gt).cuda() 58 | 59 | pred = model(image) 60 | loss1 = bce_iou_loss(pred[0], gt) 61 | loss2 = bce_iou_loss(pred[1], gt) 62 | loss3 = bce_iou_loss(pred[2], gt) 63 | loss4 = bce_iou_loss(pred[3], gt) 64 | loss5 = bce_iou_loss(pred[4], gt) 65 | loss_fuse = loss1 + loss2 + loss3 + loss4 + loss5 66 | loss = loss_fuse / opt.batchsize 67 | 68 | loss.backward() 69 | optimizer.step() 70 | 71 | if i % 20 == 0 or i == total_step: 72 | print('Learning rate: %g, epoch: [%2d/%2d], iter: [%5d/%5d] || Loss: %10.4f' % ( 73 | opt.lr, epoch, opt.epoch, i, total_step, loss.data)) 74 | print('Loss1: %10.4f' % (loss1.data / opt.batchsize)) 75 | print('Loss2: %10.4f' % (loss2.data / opt.batchsize)) 76 | print('Loss3: %10.4f' % (loss3.data / opt.batchsize)) 77 | print('Loss4: %10.4f' % (loss4.data / opt.batchsize)) 78 | print('Loss5: %10.4f' % (loss5.data / opt.batchsize)) 79 | 80 | save_path = 'models/' 81 | 82 | if not os.path.exists(save_path): 83 | os.makedirs(save_path) 84 | if epoch % 5 == 0: 85 | torch.save(model.state_dict(), save_path + 'RAS.v1' + '.%d' % epoch + '.pth') 86 | 87 | for epoch in range(1, opt.epoch+1): 88 | if epoch in opt.decay_epoch: 89 | opt.lr = opt.lr * 0.1 90 | params = model.parameters() 91 | optimizer = torch.optim.Adam(params, opt.lr) 92 | train(train_loader, model, optimizer, epoch) 93 | -------------------------------------------------------------------------------- /v2/RAS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | def weight_init(module): 9 | for n, m in module.named_children(): 10 | print('initialize: '+n) 11 | if isinstance(m, nn.Conv2d): 12 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 13 | if m.bias is not None: 14 | nn.init.zeros_(m.bias) 15 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): 16 | nn.init.ones_(m.weight) 17 | if m.bias is not None: 18 | nn.init.zeros_(m.bias) 19 | elif isinstance(m, nn.Linear): 20 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 21 | if m.bias is not None: 22 | nn.init.zeros_(m.bias) 23 | elif isinstance(m, nn.Sequential): 24 | weight_init(m) 25 | elif isinstance(m, nn.ReLU): 26 | pass 27 | else: 28 | m.initialize() 29 | 30 | 31 | class Bottleneck(nn.Module): 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 33 | super(Bottleneck, self).__init__() 34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 39 | self.bn3 = nn.BatchNorm2d(planes*4) 40 | self.downsample = downsample 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 44 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 45 | out = self.bn3(self.conv3(out)) 46 | if self.downsample is not None: 47 | x = self.downsample(x) 48 | return F.relu(out+x, inplace=True) 49 | 50 | 51 | class ResNet(nn.Module): 52 | def __init__(self): 53 | super(ResNet, self).__init__() 54 | self.inplanes = 64 55 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 56 | self.bn1 = nn.BatchNorm2d(64) 57 | self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1) 58 | self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) 59 | self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) 60 | self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) 61 | 62 | def make_layer(self, planes, blocks, stride, dilation): 63 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4)) 64 | layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)] 65 | self.inplanes = planes*4 66 | for _ in range(1, blocks): 67 | layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) 72 | out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) 73 | out2 = self.layer1(out1) 74 | out3 = self.layer2(out2) 75 | out4 = self.layer3(out3) 76 | out5 = self.layer4(out4) 77 | return out2, out3, out4, out5 78 | 79 | def initialize(self): 80 | res50 = models.resnet50(pretrained=True) 81 | self.load_state_dict(res50.state_dict(), False) 82 | 83 | 84 | class MSCM(nn.Module): 85 | def __init__(self, in_channel, out_channel): 86 | super(MSCM, self).__init__() 87 | self.convert = nn.Conv2d(in_channel, out_channel, 1) 88 | self.bn = nn.BatchNorm2d(out_channel) 89 | self.relu = nn.ReLU(True) 90 | self.branch1 = nn.Sequential( 91 | nn.Conv2d(out_channel, out_channel, 1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 92 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 93 | ) 94 | self.branch2 = nn.Sequential( 95 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 96 | nn.Conv2d(out_channel, out_channel, 3, padding=2, dilation=2), nn.BatchNorm2d(out_channel), nn.ReLU(True), 97 | ) 98 | self.branch3 = nn.Sequential( 99 | nn.Conv2d(out_channel, out_channel, 5, padding=2, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 100 | nn.Conv2d(out_channel, out_channel, 3, padding=4, dilation=4), nn.BatchNorm2d(out_channel), nn.ReLU(True), 101 | ) 102 | self.branch4 = nn.Sequential( 103 | nn.Conv2d(out_channel, out_channel, 7, padding=3, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 104 | nn.Conv2d(out_channel, out_channel, 3, padding=6, dilation=6), nn.BatchNorm2d(out_channel), nn.ReLU(True), 105 | ) 106 | self.score = nn.Conv2d(out_channel*4, 1, 3, padding=1) 107 | 108 | def forward(self, x): 109 | x = self.relu(self.bn(self.convert(x))) 110 | x1 = self.branch1(x) 111 | x2 = self.branch2(x) 112 | x3 = self.branch3(x) 113 | x4 = self.branch4(x) 114 | x = torch.cat((x1, x2, x3, x4), 1) 115 | x = self.score(x) 116 | 117 | return x 118 | 119 | def initialize(self): 120 | weight_init(self) 121 | 122 | 123 | class RA(nn.Module): 124 | def __init__(self, in_channel, out_channel): 125 | super(RA, self).__init__() 126 | self.convert = nn.Conv2d(in_channel, out_channel, 1) 127 | self.bn = nn.BatchNorm2d(out_channel) 128 | self.relu = nn.ReLU(True) 129 | self.convs = nn.Sequential( 130 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 131 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 132 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), 133 | nn.Conv2d(out_channel, 1, 3, padding=1), 134 | ) 135 | self.channel = out_channel 136 | 137 | def forward(self, x, y): 138 | a = torch.sigmoid(-y) 139 | x = self.relu(self.bn(self.convert(x))) 140 | x = a.expand(-1, self.channel, -1, -1).mul(x) 141 | y = y + self.convs(x) 142 | 143 | return y 144 | 145 | def initialize(self): 146 | weight_init(self) 147 | 148 | class RAS(nn.Module): 149 | def __init__(self, cfg, channel=64): 150 | self.cfg = cfg 151 | super(RAS, self).__init__() 152 | self.bkbone = ResNet() 153 | self.mscm = MSCM(2048, channel) 154 | self.ra2 = RA(256, channel) 155 | self.ra3 = RA(512, channel) 156 | self.ra4 = RA(1024, channel) 157 | 158 | self.initialize() 159 | 160 | def forward(self, x): 161 | x2, x3, x4, x5 = self.bkbone(x) 162 | x_size = x.size()[2:] 163 | x2_size = x2.size()[2:] 164 | x3_size = x3.size()[2:] 165 | x4_size = x4.size()[2:] 166 | 167 | y5 = self.mscm(x5) 168 | score5 = F.interpolate(y5, x_size, mode='bilinear', align_corners=True) 169 | 170 | y5_4 = F.interpolate(y5, x4_size, mode='bilinear', align_corners=True) 171 | y4 = self.ra4(x4, y5_4) 172 | score4 = F.interpolate(y4, x_size, mode='bilinear', align_corners=True) 173 | 174 | y4_3 = F.interpolate(y4, x3_size, mode='bilinear', align_corners=True) 175 | y3 = self.ra3(x3, y4_3) 176 | score3 = F.interpolate(y3, x_size, mode='bilinear', align_corners=True) 177 | 178 | y3_2 = F.interpolate(y3, x2_size, mode='bilinear', align_corners=True) 179 | y2 = self.ra2(x2, y3_2) 180 | score2 = F.interpolate(y2, x_size, mode='bilinear', align_corners=True) 181 | 182 | return score2, score3, score4, score5 183 | 184 | def initialize(self): 185 | if self.cfg.snapshot: 186 | self.load_state_dict(torch.load(self.cfg.snapshot)) 187 | else: 188 | weight_init(self) 189 | -------------------------------------------------------------------------------- /v2/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | ########################### Data Augmentation ########################### 8 | class Normalize(object): 9 | def __init__(self, mean, std): 10 | self.mean = mean 11 | self.std = std 12 | 13 | def __call__(self, image, mask=None): 14 | image = (image - self.mean)/self.std 15 | if mask is None: 16 | return image 17 | else: 18 | mask /= 255 19 | return image, mask 20 | 21 | class RandomCrop(object): 22 | def __call__(self, image, mask): 23 | H,W,_ = image.shape 24 | randw = np.random.randint(W/8) 25 | randh = np.random.randint(H/8) 26 | offseth = 0 if randh == 0 else np.random.randint(randh) 27 | offsetw = 0 if randw == 0 else np.random.randint(randw) 28 | p0, p1, p2, p3 = offseth, H+offseth-randh, offsetw, W+offsetw-randw 29 | return image[p0:p1,p2:p3, :], mask[p0:p1,p2:p3] 30 | 31 | class RandomFlip(object): 32 | def __call__(self, image, mask): 33 | if np.random.randint(2)==0: 34 | return image[:,::-1,:], mask[:, ::-1] 35 | else: 36 | return image, mask 37 | 38 | class Resize(object): 39 | def __init__(self, H, W): 40 | self.H = H 41 | self.W = W 42 | 43 | def __call__(self, image, mask=None): 44 | image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 45 | if mask is None: 46 | return image 47 | else: 48 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 49 | return image, mask 50 | 51 | class ToTensor(object): 52 | def __call__(self, image, mask=None): 53 | image = torch.from_numpy(image) 54 | image = image.permute(2, 0, 1) 55 | if mask is None: 56 | return image 57 | else: 58 | mask = torch.from_numpy(mask) 59 | return image, mask 60 | 61 | 62 | ########################### Config File ########################### 63 | class Config(object): 64 | def __init__(self, **kwargs): 65 | self.kwargs = kwargs 66 | self.mean = np.array([[[124.55, 118.90, 102.94]]]) 67 | self.std = np.array([[[ 56.77, 55.97, 57.50]]]) 68 | print('\nParameters...') 69 | for k, v in self.kwargs.items(): 70 | print('%-10s: %s'%(k, v)) 71 | 72 | def __getattr__(self, name): 73 | if name in self.kwargs: 74 | return self.kwargs[name] 75 | else: 76 | return None 77 | 78 | 79 | ########################### Dataset Class ########################### 80 | class Data(Dataset): 81 | def __init__(self, cfg): 82 | self.cfg = cfg 83 | self.normalize = Normalize(mean=cfg.mean, std=cfg.std) 84 | self.randomcrop = RandomCrop() 85 | self.randomflip = RandomFlip() 86 | self.resize = Resize(352, 352) 87 | self.totensor = ToTensor() 88 | image_path = self.cfg.datapath+'/imgs/' 89 | self.images = [image_path + f for f in os.listdir(image_path) if f.endswith('.jpg')] 90 | if self.cfg.mode=='train': 91 | mask_path = self.cfg.datapath+'/gts/' 92 | self.masks = [mask_path + f for f in os.listdir(mask_path) if f.endswith('.png')] 93 | 94 | def __getitem__(self, idx): 95 | image_name = self.images[idx] 96 | image = cv2.imread(image_name)[:,:,::-1].astype(np.float32) 97 | shape = image.shape[:2] 98 | 99 | if self.cfg.mode=='train': 100 | mask_name = self.masks[idx] 101 | mask = cv2.imread(mask_name, 0).astype(np.float32) 102 | image, mask = self.normalize(image, mask) 103 | image, mask = self.randomcrop(image, mask) 104 | image, mask = self.randomflip(image, mask) 105 | return image, mask 106 | else: 107 | image = self.normalize(image) 108 | image = self.resize(image) 109 | image = self.totensor(image) 110 | return image, shape, image_name 111 | 112 | def collate(self, batch): 113 | size = [224, 256, 288, 320, 352][np.random.randint(0, 5)] 114 | image, mask = [list(item) for item in zip(*batch)] 115 | for i in range(len(batch)): 116 | image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 117 | mask[i] = cv2.resize(mask[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR) 118 | image = torch.from_numpy(np.stack(image, axis=0)).permute(0,3,1,2) 119 | mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1) 120 | return image, mask 121 | 122 | def __len__(self): 123 | return len(self.images) 124 | -------------------------------------------------------------------------------- /v2/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, '../') 4 | sys.dont_write_bytecode = True 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | import data 14 | from RAS import RAS 15 | import time 16 | 17 | class Test(object): 18 | def __init__(self, Dataset, Network, path): 19 | ## dataset 20 | self.cfg = Dataset.Config(datapath=path, snapshot='./models/RAS.v2.pth', mode='test') 21 | self.data = Dataset.Data(self.cfg) 22 | self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=8) 23 | ## network 24 | self.net = Network(self.cfg) 25 | self.net.train(False) 26 | self.net.cuda() 27 | 28 | def save(self): 29 | with torch.no_grad(): 30 | time_t = 0.0 31 | 32 | for image, shape, name in self.loader: 33 | image = image.cuda().float() 34 | time_start = time.time() 35 | res, _, _, _ = self.net(image) 36 | torch.cuda.synchronize() 37 | time_end = time.time() 38 | time_t = time_t + time_end - time_start 39 | res = F.interpolate(res, shape, mode='bilinear', align_corners=True) 40 | res = res.sigmoid().data.cpu().numpy().squeeze() 41 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 42 | res = 255 * res 43 | save_path = '/home/ipal/evaluation/SaliencyMaps/'+ self.cfg.datapath.split('/')[-1]+'/RAS-v2/' 44 | if not os.path.exists(save_path): 45 | os.makedirs(save_path) 46 | cv2.imwrite(save_path+'/'+name[0]+'.png', res) 47 | fps = len(self.loader) / time_t 48 | print('FPS is %f' %(fps)) 49 | 50 | 51 | if __name__=='__main__': 52 | for path in ['/home/ipal/datasets/ECSSD', '/home/ipal/datasets/DUTS', '/home/ipal/datasets/DUT-OMRON', '/home/ipal/datasets/HKU-IS']: 53 | test = Test(data, RAS, path) 54 | test.save() 55 | -------------------------------------------------------------------------------- /v2/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | sys.path.insert(0, '../') 4 | sys.dont_write_bytecode = True 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | import data 11 | from RAS import RAS 12 | from apex import amp 13 | 14 | def bce_iou_loss(pred, mask): 15 | bce = F.binary_cross_entropy_with_logits(pred, mask, reduction='mean') 16 | pred = torch.sigmoid(pred) 17 | inter = (pred*mask).sum(dim=(2,3)) 18 | union = (pred+mask).sum(dim=(2,3)) 19 | iou = 1-(inter+1)/(union-inter+1) 20 | 21 | return (bce+iou).mean() 22 | 23 | def train(Dataset, Network): 24 | ## dataset 25 | cfg = Dataset.Config(datapath='/home/ipal/datasets/DUTS_train', savepath='./models', mode='train', batch=32, lr=0.05, momen=0.9, decay=5e-4, epoch=32) 26 | data = Dataset.Data(cfg) 27 | loader = DataLoader(data, collate_fn=data.collate, batch_size=cfg.batch, shuffle=True, num_workers=8) 28 | ## network 29 | net = Network(cfg) 30 | net.train(True) 31 | net.cuda() 32 | ## parameter 33 | base, head = [], [] 34 | for name, param in net.named_parameters(): 35 | if 'bkbone.conv1' in name or 'bkbone.bn1' in name: 36 | print(name) 37 | elif 'bkbone' in name: 38 | base.append(param) 39 | else: 40 | head.append(param) 41 | optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True) 42 | net, optimizer = amp.initialize(net, optimizer, opt_level='O2') 43 | global_step = 0 44 | 45 | for epoch in range(cfg.epoch): 46 | optimizer.param_groups[0]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr*0.1 47 | optimizer.param_groups[1]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr 48 | 49 | for step, (image, mask) in enumerate(loader): 50 | image, mask = image.cuda().float(), mask.cuda().float() 51 | out2, out3, out4, out5 = net(image) 52 | 53 | loss2 = bce_iou_loss(out2, mask) 54 | loss3 = bce_iou_loss(out3, mask) 55 | loss4 = bce_iou_loss(out4, mask) 56 | loss5 = bce_iou_loss(out5, mask) 57 | loss = loss2 + loss3 + loss4 + loss5 58 | 59 | optimizer.zero_grad() 60 | with amp.scale_loss(loss, optimizer) as scale_loss: 61 | scale_loss.backward() 62 | optimizer.step() 63 | 64 | global_step += 1 65 | if step%10 == 0: 66 | print('%s | step:%d/%d/%d | lr=%.6f | loss=%.6f'%(datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss.item())) 67 | print('loss2=%.6f'%(loss2.item())) 68 | print('loss3=%.6f'%(loss3.item())) 69 | print('loss4=%.6f'%(loss4.item())) 70 | print('loss5=%.6f'%(loss5.item())) 71 | 72 | if (epoch + 1) % 8 == 0: 73 | torch.save(net.state_dict(), cfg.savepath+'/RAS.v2' + str(epoch+1) + '.pth') 74 | 75 | 76 | if __name__=='__main__': 77 | train(data, RAS) 78 | --------------------------------------------------------------------------------