├── README.md
├── figs
└── result.png
├── v1
├── RAS.py
├── ResNet50.py
├── data.py
├── test.py
└── train.py
└── v2
├── RAS.py
├── data.py
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # RAS-pytorch
2 | The pytorch code for our TIP2020 paper: [Reverse Attention Based Residual Network for Salient Object Detection](https://ieeexplore.ieee.org/document/8966594)
3 |
4 | ---
5 |
6 | Notice
7 | ---
8 | We use ResNet50 as backbone and add the [IoU loss](https://github.com/NathanUA/BASNet)[1] for better performance in this pytorch implementation. The original caffe version is [here](https://github.com/ShuhanChen/RAS_ECCV18). We provide two versions with different training strategies.
9 | - v1: The same training strategy with [CPD](https://github.com/wuzhe71/CPD)[2].
10 | - v2: The same training strategy with [F3Net](https://github.com/weijun88/F3Net)[3].
11 |
12 | Usage
13 | ---
14 | Modify the pathes of datasets, then run:
15 | ```
16 | Training: python3 train.py
17 | Testing: python3 test.py
18 | ```
19 |
20 | Performace
21 | ---
22 | The codes are tested on Ubuntu 18.04 environment (Python3.6.9, PyTorch1.5.0, torchvision0.6.0, CUDA10.2, cuDNN7.6.5) with RTX 2080Ti GPU. We select several recent SOTA methods for comparisons. The evaluation code can be found [here](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox).
23 | 
24 |
25 |
26 | Pre-trained models & Pre-computed saliency maps
27 | ---
28 | - v1: model [Baidu](https://pan.baidu.com/s/1O5QsWWOjhPMGOWIiIwSX3A)(bc3k) [Google](https://drive.google.com/file/d/1KHmKrAG1M_C0mYgSD8pz9fDmBn2LtoMJ/view?usp=sharing); smaps [Baidu](https://pan.baidu.com/s/13I2F0dPU5mPmklcxbex0Lw)(kp6t) [Google](https://drive.google.com/file/d/1lT_BkFMuD8kPVkjQRR7HVBDzQnY3VkfB/view?usp=sharing)
29 | - v2: model [Baidu](https://pan.baidu.com/s/1XB3VE175bhT_4urBULJ-IQ)(wbz1) [Google](https://drive.google.com/open?id=14WUbyPiKnEafiMu9CWn5EdTAJpF_VqLj); smaps [Baidu](https://pan.baidu.com/s/1HZWx6eqYq7bAUcBtkw75Sw)(j57z) [Google](https://drive.google.com/open?id=1RwyR6GRAiDxeywRT1VjLe3qa_z7uhUWa)
30 |
31 | Citation
32 | ---
33 | ```
34 | @article{chen2020tip,
35 | author={Chen, Shuhan and Tan, Xiuli and Wang, Ben and Lu, Huchuan and Hu, Xuelong and Fu, Yun},
36 | journal={IEEE Transactions on Image Processing},
37 | title={Reverse Attention-Based Residual Network for Salient Object Detection},
38 | volume={29},
39 | pages={3763-3776},
40 | year={2020}
41 | }
42 | ```
43 | ```
44 | @inproceedings{chen2018eccv,
45 | author={Chen, Shuhan and Tan, Xiuli and Wang, Ben and Hu, Xuelong},
46 | booktitle={European Conference on Computer Vision},
47 | title={Reverse Attention for Salient Object Detection},
48 | year={2018}
49 | }
50 | ```
51 |
52 | Acknowledgements
53 | ---
54 | This code is built on [CPD](https://github.com/wuzhe71/CPD)[2] and [F3Net](https://github.com/weijun88/F3Net)[3]. We thank the authors for sharing their codes.
55 |
56 | Reference
57 | ---
58 | > [1] Xuebin Qin, Zichen Zhang, Chenyang Huang, Chao Gao, Masood Dehghan, Martin Jagersand. BASNet: Boundary-Aware Salient Object Detection. In CVPR, 2019.
59 | > [2] Zhe Wu, Li Su, Qingming Huang. Cascaded Partial Decoder for Fast and Accurate Salient Object Detection. In CVPR, 2019.
60 | > [3] Jun Wei, Shuhui Wang, Qingming Huang. F3Net: Fusion, Feedback and Focus for Salient Object Detection. In AAAI, 2020.
61 |
--------------------------------------------------------------------------------
/figs/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuhanChen/RAS-pytorch/67ea6c330e56b218e66e864b7c47d9a55f8c0217/figs/result.png
--------------------------------------------------------------------------------
/v1/RAS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ResNet50 import ResNet50
5 | import torchvision.models as models
6 |
7 | class MSCM(nn.Module):
8 | def __init__(self, in_channel, out_channel):
9 | super(MSCM, self).__init__()
10 | self.convert = nn.Conv2d(in_channel, out_channel, 1)
11 | self.branch1 = nn.Sequential(
12 | nn.Conv2d(out_channel, out_channel, 1), nn.ReLU(True),
13 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.ReLU(True),
14 | )
15 | self.branch2 = nn.Sequential(
16 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.ReLU(True),
17 | nn.Conv2d(out_channel, out_channel, 3, padding=2, dilation=2), nn.ReLU(True),
18 | )
19 | self.branch3 = nn.Sequential(
20 | nn.Conv2d(out_channel, out_channel, 5, padding=2, dilation=1), nn.ReLU(True),
21 | nn.Conv2d(out_channel, out_channel, 3, padding=4, dilation=4), nn.ReLU(True),
22 | )
23 | self.branch4 = nn.Sequential(
24 | nn.Conv2d(out_channel, out_channel, 7, padding=3, dilation=1), nn.ReLU(True),
25 | nn.Conv2d(out_channel, out_channel, 3, padding=6, dilation=6), nn.ReLU(True),
26 | )
27 | self.score = nn.Conv2d(out_channel*4, 1, 3, padding=1)
28 |
29 | def forward(self, x):
30 | x = self.convert(x)
31 | x1 = self.branch1(x)
32 | x2 = self.branch2(x)
33 | x3 = self.branch3(x)
34 | x4 = self.branch4(x)
35 |
36 | x = torch.cat((x1, x2, x3, x4), 1)
37 | x = self.score(x)
38 |
39 | return x
40 |
41 | class RA(nn.Module):
42 | def __init__(self, in_channel, out_channel):
43 | super(RA, self).__init__()
44 | self.convert = nn.Conv2d(in_channel, out_channel, 1)
45 | self.convs = nn.Sequential(
46 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
47 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
48 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
49 | nn.Conv2d(out_channel, 1, 3, padding=1),
50 | )
51 | self.channel = out_channel
52 |
53 | def forward(self, x, y):
54 | a = torch.sigmoid(-y)
55 | x = self.convert(x)
56 | x = a.expand(-1, self.channel, -1, -1).mul(x)
57 | y = y + self.convs(x)
58 |
59 | return y
60 |
61 | class RAS(nn.Module):
62 | def __init__(self, channel=64):
63 | super(RAS, self).__init__()
64 | self.resnet = ResNet50()
65 | self.mscm = MSCM(2048, channel)
66 | self.ra1 = RA(64, channel)
67 | self.ra2 = RA(256, channel)
68 | self.ra3 = RA(512, channel)
69 | self.ra4 = RA(1024, channel)
70 |
71 | for m in self.modules():
72 | if isinstance(m, nn.Conv2d):
73 | m.weight.data.normal_(std=0.01)
74 | elif isinstance(m, nn.BatchNorm2d):
75 | m.weight.data.fill_(1)
76 | m.bias.data.zero_()
77 |
78 | self.initialize_weights()
79 |
80 | def forward(self, x):
81 | x1, x2, x3, x4, x5 = self.resnet(x)
82 | x_size = x.size()[2:]
83 | x1_size = x1.size()[2:]
84 | x2_size = x2.size()[2:]
85 | x3_size = x3.size()[2:]
86 | x4_size = x4.size()[2:]
87 |
88 | y5 = self.mscm(x5)
89 | score5 = F.interpolate(y5, x_size, mode='bilinear', align_corners=True)
90 |
91 | y5_4 = F.interpolate(y5, x4_size, mode='bilinear', align_corners=True)
92 | y4 = self.ra4(x4, y5_4)
93 | score4 = F.interpolate(y4, x_size, mode='bilinear', align_corners=True)
94 |
95 | y4_3 = F.interpolate(y4, x3_size, mode='bilinear', align_corners=True)
96 | y3 = self.ra3(x3, y4_3)
97 | score3 = F.interpolate(y3, x_size, mode='bilinear', align_corners=True)
98 |
99 | y3_2 = F.interpolate(y3, x2_size, mode='bilinear', align_corners=True)
100 | y2 = self.ra2(x2, y3_2)
101 | score2 = F.interpolate(y2, x_size, mode='bilinear', align_corners=True)
102 |
103 | y2_1 = F.interpolate(y2, x1_size, mode='bilinear', align_corners=True)
104 | y1 = self.ra1(x1, y2_1)
105 | score1 = F.interpolate(y1, x_size, mode='bilinear', align_corners=True)
106 |
107 | return score1, score2, score3, score4, score5
108 |
109 | def initialize_weights(self):
110 | res50 = models.resnet50(pretrained=True)
111 | self.resnet.load_state_dict(res50.state_dict(), False)
112 |
--------------------------------------------------------------------------------
/v1/ResNet50.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=1, bias=False)
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, inplanes, planes, stride=1, downsample=None):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = conv3x3(inplanes, planes, stride)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.downsample = downsample
22 | self.stride = stride
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 |
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 |
37 | out += residual
38 | out = self.relu(out)
39 |
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None):
47 | super(Bottleneck, self).__init__()
48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
51 | padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(planes * 4)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv3(out)
71 | out = self.bn3(out)
72 |
73 | if self.downsample is not None:
74 | residual = self.downsample(x)
75 |
76 | out += residual
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class ResNet50(nn.Module):
83 | def __init__(self):
84 | self.inplanes = 64
85 | super(ResNet50, self).__init__()
86 |
87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
88 | bias=False)
89 | self.bn1 = nn.BatchNorm2d(64)
90 | self.relu = nn.ReLU(inplace=True)
91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
92 | self.layer1 = self._make_layer(Bottleneck, 64, 3)
93 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
94 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
95 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
96 |
97 | for m in self.modules():
98 | if isinstance(m, nn.Conv2d):
99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
100 | m.weight.data.normal_(0, math.sqrt(2. / n))
101 | elif isinstance(m, nn.BatchNorm2d):
102 | m.weight.data.fill_(1)
103 | m.bias.data.zero_()
104 |
105 | def _make_layer(self, block, planes, blocks, stride=1):
106 | downsample = None
107 | if stride != 1 or self.inplanes != planes * block.expansion:
108 | downsample = nn.Sequential(
109 | nn.Conv2d(self.inplanes, planes * block.expansion,
110 | kernel_size=1, stride=stride, bias=False),
111 | nn.BatchNorm2d(planes * block.expansion),
112 | )
113 |
114 | layers = []
115 | layers.append(block(self.inplanes, planes, stride, downsample))
116 | self.inplanes = planes * block.expansion
117 | for i in range(1, blocks):
118 | layers.append(block(self.inplanes, planes))
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | x1 = self.conv1(x) # 1/2
124 | x1 = self.bn1(x1)
125 | x1 = self.relu(x1)
126 | x2 = self.maxpool(x1) # 1/4
127 |
128 | x2 = self.layer1(x2) # 1/4
129 | x3 = self.layer2(x2) # 1/8
130 | x4 = self.layer3(x3) # 1/16
131 | x5 = self.layer4(x4) # 1/32
132 |
133 | return x1, x2, x3, x4, x5
134 |
--------------------------------------------------------------------------------
/v1/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import torch.utils.data as data
5 | import torchvision.transforms as transforms
6 |
7 | class SalObjDataset(data.Dataset):
8 | def __init__(self, image_path, gt_path, trainsize):
9 | self.images = [image_path + f for f in os.listdir(image_path) if f.endswith('.jpg')]
10 | self.gts = [gt_path + f for f in os.listdir(gt_path) if f.endswith('.png')]
11 | self.images = sorted(self.images)
12 | self.gts = sorted(self.gts)
13 | self.size = len(self.images)
14 | self.img_transform = transforms.Compose([
15 | transforms.Resize((trainsize, trainsize)),
16 | transforms.ToTensor(),
17 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
18 | self.gt_transform = transforms.Compose([
19 | transforms.Resize((trainsize, trainsize)),
20 | transforms.ToTensor()])
21 |
22 | def __getitem__(self, index):
23 | image = self.rgb_loader(self.images[index])
24 | gt = self.binary_loader(self.gts[index])
25 | image, gt = self.cv_random_flip(image, gt)
26 | image, gt = self.cv_random_rotate(image, gt)
27 | image = self.img_transform(image)
28 | gt = self.gt_transform(gt)
29 | return image, gt
30 |
31 |
32 | def rgb_loader(self, path):
33 | with open(path, 'rb') as f:
34 | img = Image.open(f)
35 | return img.convert('RGB')
36 |
37 | def binary_loader(self, path):
38 | with open(path, 'rb') as f:
39 | img = Image.open(f)
40 | return img.convert('L')
41 |
42 | def gray_loader(self, path):
43 | with open(path, 'rb') as f:
44 | img = Image.open(f)
45 | return img.convert('L')
46 |
47 | def cv_random_flip(self, img, gt):
48 | if np.random.randint(2)==0:
49 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
50 | gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
51 | return img, gt
52 |
53 | def cv_random_rotate(self, img, gt):
54 | rotate_degree = np.random.random() * 2 * 10 - 10
55 | img = img.rotate(rotate_degree, Image.BILINEAR)
56 | gt = gt.rotate(rotate_degree, Image.NEAREST)
57 | return img, gt
58 |
59 | def __len__(self):
60 | return self.size
61 |
62 |
63 | def get_loader(image_path, gt_path, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True):
64 |
65 | dataset = SalObjDataset(image_path, gt_path, trainsize)
66 | data_loader = data.DataLoader(dataset=dataset,
67 | batch_size=batchsize,
68 | shuffle=shuffle,
69 | num_workers=num_workers,
70 | pin_memory=pin_memory)
71 | return data_loader
72 |
73 |
74 | class test_dataset:
75 | def __init__(self, image_path, testsize):
76 | self.images = [image_path + f for f in os.listdir(image_path) if f.endswith('.jpg')]
77 | self.images = sorted(self.images)
78 | self.img_transform = transforms.Compose([
79 | transforms.Resize((testsize, testsize)),
80 | transforms.ToTensor(),
81 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
82 | self.size = len(self.images)
83 | self.index = 0
84 |
85 | def load_data(self):
86 | image = self.rgb_loader(self.images[self.index])
87 | img_size = (image.size[1], image.size[0])
88 | image = self.img_transform(image).unsqueeze(0)
89 | name = self.images[self.index].split('/')[-1]
90 | if name.endswith('.jpg'):
91 | name = name.split('.jpg')[0] + '.png'
92 | self.index += 1
93 | return image, img_size, name
94 |
95 | def rgb_loader(self, path):
96 | with open(path, 'rb') as f:
97 | img = Image.open(f)
98 | return img.convert('RGB')
99 |
100 |
101 |
--------------------------------------------------------------------------------
/v1/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import time
4 | import numpy as np
5 | import pdb, os, argparse
6 | import cv2
7 |
8 | from RAS import RAS
9 | from data import test_dataset
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--testsize', type=int, default=352, help='testing size')
13 | opt = parser.parse_args()
14 |
15 | dataset_path = '/home/ipal/datasets/'
16 |
17 | model = RAS()
18 | model.load_state_dict(torch.load('./models/RAS.v1.pth'))
19 |
20 | model.cuda()
21 | model.eval()
22 |
23 | test_datasets = ['ECSSD', 'DUTS', 'DUT-OMRON', 'HKU-IS']
24 |
25 | for dataset in test_datasets:
26 | save_path = '/home/ipal/evaluation/SaliencyMaps/' + dataset + '/RAS-v1/'
27 | if not os.path.exists(save_path):
28 | os.makedirs(save_path)
29 | image_root = dataset_path + dataset + '/imgs/'
30 | test_loader = test_dataset(image_root, opt.testsize)
31 | time_t = 0.0
32 | for i in range(test_loader.size):
33 | image, img_size, name = test_loader.load_data()
34 | image = image.cuda()
35 | time_start = time.time()
36 | res, _, _, _, _ = model(image)
37 | torch.cuda.synchronize()
38 | time_end = time.time()
39 | time_t = time_t + time_end - time_start
40 | res = F.interpolate(res, img_size, mode='bilinear', align_corners=True)
41 | res = res.sigmoid().data.cpu().numpy().squeeze()
42 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
43 | res = 255 * res
44 | cv2.imwrite(os.path.join(save_path + name[:-4] + '.png'), res)
45 | fps = test_loader.size / time_t
46 | print('FPS is %f' %(fps))
47 |
--------------------------------------------------------------------------------
/v1/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | import numpy as np
7 | import pdb, os, argparse
8 | from RAS import RAS
9 | from data import get_loader
10 |
11 | # set parameters
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--epoch', type=int, default=30, help='epoch number')
14 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
15 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
16 | parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
17 | parser.add_argument('--decay_epoch', type=int, default=[26,], help='every n epochs decay learning rate')
18 | opt = parser.parse_args()
19 |
20 | # build models
21 | model = RAS()
22 | model.cuda()
23 | params = model.parameters()
24 | optimizer = torch.optim.Adam(params, opt.lr)
25 |
26 | # print the network information
27 | num_params = 0
28 | for p in model.parameters():
29 | num_params += p.numel()
30 | print('RAS Structure')
31 | print(model)
32 | print("The number of parameters: {}".format(num_params))
33 |
34 | # dataset path
35 | image_path = '/home/ipal/datasets/DUTS_train/imgs/'
36 | gt_path = '/home/ipal/datasets/DUTS_train/gts/'
37 |
38 | train_loader = get_loader(image_path, gt_path, batchsize=opt.batchsize, trainsize=opt.trainsize)
39 | total_step = len(train_loader)
40 |
41 | def bce_iou_loss(pred, gt):
42 | bce = F.binary_cross_entropy_with_logits(pred, gt, reduction='mean')
43 |
44 | pred = torch.sigmoid(pred)
45 | inter = (pred*gt).sum(dim=(2,3))
46 | union = (pred+gt).sum(dim=(2,3))
47 | iou = 1-(inter+1)/(union-inter+1)
48 |
49 | return (bce+iou).mean()
50 |
51 | def train(train_loader, model, optimizer, epoch):
52 | model.train()
53 | for i, pack in enumerate(train_loader, start=1):
54 | optimizer.zero_grad()
55 | image, gt = pack
56 | image = Variable(image).cuda()
57 | gt = Variable(gt).cuda()
58 |
59 | pred = model(image)
60 | loss1 = bce_iou_loss(pred[0], gt)
61 | loss2 = bce_iou_loss(pred[1], gt)
62 | loss3 = bce_iou_loss(pred[2], gt)
63 | loss4 = bce_iou_loss(pred[3], gt)
64 | loss5 = bce_iou_loss(pred[4], gt)
65 | loss_fuse = loss1 + loss2 + loss3 + loss4 + loss5
66 | loss = loss_fuse / opt.batchsize
67 |
68 | loss.backward()
69 | optimizer.step()
70 |
71 | if i % 20 == 0 or i == total_step:
72 | print('Learning rate: %g, epoch: [%2d/%2d], iter: [%5d/%5d] || Loss: %10.4f' % (
73 | opt.lr, epoch, opt.epoch, i, total_step, loss.data))
74 | print('Loss1: %10.4f' % (loss1.data / opt.batchsize))
75 | print('Loss2: %10.4f' % (loss2.data / opt.batchsize))
76 | print('Loss3: %10.4f' % (loss3.data / opt.batchsize))
77 | print('Loss4: %10.4f' % (loss4.data / opt.batchsize))
78 | print('Loss5: %10.4f' % (loss5.data / opt.batchsize))
79 |
80 | save_path = 'models/'
81 |
82 | if not os.path.exists(save_path):
83 | os.makedirs(save_path)
84 | if epoch % 5 == 0:
85 | torch.save(model.state_dict(), save_path + 'RAS.v1' + '.%d' % epoch + '.pth')
86 |
87 | for epoch in range(1, opt.epoch+1):
88 | if epoch in opt.decay_epoch:
89 | opt.lr = opt.lr * 0.1
90 | params = model.parameters()
91 | optimizer = torch.optim.Adam(params, opt.lr)
92 | train(train_loader, model, optimizer, epoch)
93 |
--------------------------------------------------------------------------------
/v2/RAS.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | def weight_init(module):
9 | for n, m in module.named_children():
10 | print('initialize: '+n)
11 | if isinstance(m, nn.Conv2d):
12 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
13 | if m.bias is not None:
14 | nn.init.zeros_(m.bias)
15 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
16 | nn.init.ones_(m.weight)
17 | if m.bias is not None:
18 | nn.init.zeros_(m.bias)
19 | elif isinstance(m, nn.Linear):
20 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
21 | if m.bias is not None:
22 | nn.init.zeros_(m.bias)
23 | elif isinstance(m, nn.Sequential):
24 | weight_init(m)
25 | elif isinstance(m, nn.ReLU):
26 | pass
27 | else:
28 | m.initialize()
29 |
30 |
31 | class Bottleneck(nn.Module):
32 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
33 | super(Bottleneck, self).__init__()
34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
39 | self.bn3 = nn.BatchNorm2d(planes*4)
40 | self.downsample = downsample
41 |
42 | def forward(self, x):
43 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
44 | out = F.relu(self.bn2(self.conv2(out)), inplace=True)
45 | out = self.bn3(self.conv3(out))
46 | if self.downsample is not None:
47 | x = self.downsample(x)
48 | return F.relu(out+x, inplace=True)
49 |
50 |
51 | class ResNet(nn.Module):
52 | def __init__(self):
53 | super(ResNet, self).__init__()
54 | self.inplanes = 64
55 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
56 | self.bn1 = nn.BatchNorm2d(64)
57 | self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1)
58 | self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)
59 | self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)
60 | self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)
61 |
62 | def make_layer(self, planes, blocks, stride, dilation):
63 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4))
64 | layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)]
65 | self.inplanes = planes*4
66 | for _ in range(1, blocks):
67 | layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))
68 | return nn.Sequential(*layers)
69 |
70 | def forward(self, x):
71 | out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
72 | out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
73 | out2 = self.layer1(out1)
74 | out3 = self.layer2(out2)
75 | out4 = self.layer3(out3)
76 | out5 = self.layer4(out4)
77 | return out2, out3, out4, out5
78 |
79 | def initialize(self):
80 | res50 = models.resnet50(pretrained=True)
81 | self.load_state_dict(res50.state_dict(), False)
82 |
83 |
84 | class MSCM(nn.Module):
85 | def __init__(self, in_channel, out_channel):
86 | super(MSCM, self).__init__()
87 | self.convert = nn.Conv2d(in_channel, out_channel, 1)
88 | self.bn = nn.BatchNorm2d(out_channel)
89 | self.relu = nn.ReLU(True)
90 | self.branch1 = nn.Sequential(
91 | nn.Conv2d(out_channel, out_channel, 1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
92 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
93 | )
94 | self.branch2 = nn.Sequential(
95 | nn.Conv2d(out_channel, out_channel, 3, padding=1, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
96 | nn.Conv2d(out_channel, out_channel, 3, padding=2, dilation=2), nn.BatchNorm2d(out_channel), nn.ReLU(True),
97 | )
98 | self.branch3 = nn.Sequential(
99 | nn.Conv2d(out_channel, out_channel, 5, padding=2, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
100 | nn.Conv2d(out_channel, out_channel, 3, padding=4, dilation=4), nn.BatchNorm2d(out_channel), nn.ReLU(True),
101 | )
102 | self.branch4 = nn.Sequential(
103 | nn.Conv2d(out_channel, out_channel, 7, padding=3, dilation=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
104 | nn.Conv2d(out_channel, out_channel, 3, padding=6, dilation=6), nn.BatchNorm2d(out_channel), nn.ReLU(True),
105 | )
106 | self.score = nn.Conv2d(out_channel*4, 1, 3, padding=1)
107 |
108 | def forward(self, x):
109 | x = self.relu(self.bn(self.convert(x)))
110 | x1 = self.branch1(x)
111 | x2 = self.branch2(x)
112 | x3 = self.branch3(x)
113 | x4 = self.branch4(x)
114 | x = torch.cat((x1, x2, x3, x4), 1)
115 | x = self.score(x)
116 |
117 | return x
118 |
119 | def initialize(self):
120 | weight_init(self)
121 |
122 |
123 | class RA(nn.Module):
124 | def __init__(self, in_channel, out_channel):
125 | super(RA, self).__init__()
126 | self.convert = nn.Conv2d(in_channel, out_channel, 1)
127 | self.bn = nn.BatchNorm2d(out_channel)
128 | self.relu = nn.ReLU(True)
129 | self.convs = nn.Sequential(
130 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
131 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
132 | nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True),
133 | nn.Conv2d(out_channel, 1, 3, padding=1),
134 | )
135 | self.channel = out_channel
136 |
137 | def forward(self, x, y):
138 | a = torch.sigmoid(-y)
139 | x = self.relu(self.bn(self.convert(x)))
140 | x = a.expand(-1, self.channel, -1, -1).mul(x)
141 | y = y + self.convs(x)
142 |
143 | return y
144 |
145 | def initialize(self):
146 | weight_init(self)
147 |
148 | class RAS(nn.Module):
149 | def __init__(self, cfg, channel=64):
150 | self.cfg = cfg
151 | super(RAS, self).__init__()
152 | self.bkbone = ResNet()
153 | self.mscm = MSCM(2048, channel)
154 | self.ra2 = RA(256, channel)
155 | self.ra3 = RA(512, channel)
156 | self.ra4 = RA(1024, channel)
157 |
158 | self.initialize()
159 |
160 | def forward(self, x):
161 | x2, x3, x4, x5 = self.bkbone(x)
162 | x_size = x.size()[2:]
163 | x2_size = x2.size()[2:]
164 | x3_size = x3.size()[2:]
165 | x4_size = x4.size()[2:]
166 |
167 | y5 = self.mscm(x5)
168 | score5 = F.interpolate(y5, x_size, mode='bilinear', align_corners=True)
169 |
170 | y5_4 = F.interpolate(y5, x4_size, mode='bilinear', align_corners=True)
171 | y4 = self.ra4(x4, y5_4)
172 | score4 = F.interpolate(y4, x_size, mode='bilinear', align_corners=True)
173 |
174 | y4_3 = F.interpolate(y4, x3_size, mode='bilinear', align_corners=True)
175 | y3 = self.ra3(x3, y4_3)
176 | score3 = F.interpolate(y3, x_size, mode='bilinear', align_corners=True)
177 |
178 | y3_2 = F.interpolate(y3, x2_size, mode='bilinear', align_corners=True)
179 | y2 = self.ra2(x2, y3_2)
180 | score2 = F.interpolate(y2, x_size, mode='bilinear', align_corners=True)
181 |
182 | return score2, score3, score4, score5
183 |
184 | def initialize(self):
185 | if self.cfg.snapshot:
186 | self.load_state_dict(torch.load(self.cfg.snapshot))
187 | else:
188 | weight_init(self)
189 |
--------------------------------------------------------------------------------
/v2/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 |
7 | ########################### Data Augmentation ###########################
8 | class Normalize(object):
9 | def __init__(self, mean, std):
10 | self.mean = mean
11 | self.std = std
12 |
13 | def __call__(self, image, mask=None):
14 | image = (image - self.mean)/self.std
15 | if mask is None:
16 | return image
17 | else:
18 | mask /= 255
19 | return image, mask
20 |
21 | class RandomCrop(object):
22 | def __call__(self, image, mask):
23 | H,W,_ = image.shape
24 | randw = np.random.randint(W/8)
25 | randh = np.random.randint(H/8)
26 | offseth = 0 if randh == 0 else np.random.randint(randh)
27 | offsetw = 0 if randw == 0 else np.random.randint(randw)
28 | p0, p1, p2, p3 = offseth, H+offseth-randh, offsetw, W+offsetw-randw
29 | return image[p0:p1,p2:p3, :], mask[p0:p1,p2:p3]
30 |
31 | class RandomFlip(object):
32 | def __call__(self, image, mask):
33 | if np.random.randint(2)==0:
34 | return image[:,::-1,:], mask[:, ::-1]
35 | else:
36 | return image, mask
37 |
38 | class Resize(object):
39 | def __init__(self, H, W):
40 | self.H = H
41 | self.W = W
42 |
43 | def __call__(self, image, mask=None):
44 | image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
45 | if mask is None:
46 | return image
47 | else:
48 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
49 | return image, mask
50 |
51 | class ToTensor(object):
52 | def __call__(self, image, mask=None):
53 | image = torch.from_numpy(image)
54 | image = image.permute(2, 0, 1)
55 | if mask is None:
56 | return image
57 | else:
58 | mask = torch.from_numpy(mask)
59 | return image, mask
60 |
61 |
62 | ########################### Config File ###########################
63 | class Config(object):
64 | def __init__(self, **kwargs):
65 | self.kwargs = kwargs
66 | self.mean = np.array([[[124.55, 118.90, 102.94]]])
67 | self.std = np.array([[[ 56.77, 55.97, 57.50]]])
68 | print('\nParameters...')
69 | for k, v in self.kwargs.items():
70 | print('%-10s: %s'%(k, v))
71 |
72 | def __getattr__(self, name):
73 | if name in self.kwargs:
74 | return self.kwargs[name]
75 | else:
76 | return None
77 |
78 |
79 | ########################### Dataset Class ###########################
80 | class Data(Dataset):
81 | def __init__(self, cfg):
82 | self.cfg = cfg
83 | self.normalize = Normalize(mean=cfg.mean, std=cfg.std)
84 | self.randomcrop = RandomCrop()
85 | self.randomflip = RandomFlip()
86 | self.resize = Resize(352, 352)
87 | self.totensor = ToTensor()
88 | image_path = self.cfg.datapath+'/imgs/'
89 | self.images = [image_path + f for f in os.listdir(image_path) if f.endswith('.jpg')]
90 | if self.cfg.mode=='train':
91 | mask_path = self.cfg.datapath+'/gts/'
92 | self.masks = [mask_path + f for f in os.listdir(mask_path) if f.endswith('.png')]
93 |
94 | def __getitem__(self, idx):
95 | image_name = self.images[idx]
96 | image = cv2.imread(image_name)[:,:,::-1].astype(np.float32)
97 | shape = image.shape[:2]
98 |
99 | if self.cfg.mode=='train':
100 | mask_name = self.masks[idx]
101 | mask = cv2.imread(mask_name, 0).astype(np.float32)
102 | image, mask = self.normalize(image, mask)
103 | image, mask = self.randomcrop(image, mask)
104 | image, mask = self.randomflip(image, mask)
105 | return image, mask
106 | else:
107 | image = self.normalize(image)
108 | image = self.resize(image)
109 | image = self.totensor(image)
110 | return image, shape, image_name
111 |
112 | def collate(self, batch):
113 | size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
114 | image, mask = [list(item) for item in zip(*batch)]
115 | for i in range(len(batch)):
116 | image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
117 | mask[i] = cv2.resize(mask[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
118 | image = torch.from_numpy(np.stack(image, axis=0)).permute(0,3,1,2)
119 | mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1)
120 | return image, mask
121 |
122 | def __len__(self):
123 | return len(self.images)
124 |
--------------------------------------------------------------------------------
/v2/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, '../')
4 | sys.dont_write_bytecode = True
5 |
6 | import cv2
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.utils.data import DataLoader
13 | import data
14 | from RAS import RAS
15 | import time
16 |
17 | class Test(object):
18 | def __init__(self, Dataset, Network, path):
19 | ## dataset
20 | self.cfg = Dataset.Config(datapath=path, snapshot='./models/RAS.v2.pth', mode='test')
21 | self.data = Dataset.Data(self.cfg)
22 | self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=8)
23 | ## network
24 | self.net = Network(self.cfg)
25 | self.net.train(False)
26 | self.net.cuda()
27 |
28 | def save(self):
29 | with torch.no_grad():
30 | time_t = 0.0
31 |
32 | for image, shape, name in self.loader:
33 | image = image.cuda().float()
34 | time_start = time.time()
35 | res, _, _, _ = self.net(image)
36 | torch.cuda.synchronize()
37 | time_end = time.time()
38 | time_t = time_t + time_end - time_start
39 | res = F.interpolate(res, shape, mode='bilinear', align_corners=True)
40 | res = res.sigmoid().data.cpu().numpy().squeeze()
41 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
42 | res = 255 * res
43 | save_path = '/home/ipal/evaluation/SaliencyMaps/'+ self.cfg.datapath.split('/')[-1]+'/RAS-v2/'
44 | if not os.path.exists(save_path):
45 | os.makedirs(save_path)
46 | cv2.imwrite(save_path+'/'+name[0]+'.png', res)
47 | fps = len(self.loader) / time_t
48 | print('FPS is %f' %(fps))
49 |
50 |
51 | if __name__=='__main__':
52 | for path in ['/home/ipal/datasets/ECSSD', '/home/ipal/datasets/DUTS', '/home/ipal/datasets/DUT-OMRON', '/home/ipal/datasets/HKU-IS']:
53 | test = Test(data, RAS, path)
54 | test.save()
55 |
--------------------------------------------------------------------------------
/v2/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import datetime
3 | sys.path.insert(0, '../')
4 | sys.dont_write_bytecode = True
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 | import data
11 | from RAS import RAS
12 | from apex import amp
13 |
14 | def bce_iou_loss(pred, mask):
15 | bce = F.binary_cross_entropy_with_logits(pred, mask, reduction='mean')
16 | pred = torch.sigmoid(pred)
17 | inter = (pred*mask).sum(dim=(2,3))
18 | union = (pred+mask).sum(dim=(2,3))
19 | iou = 1-(inter+1)/(union-inter+1)
20 |
21 | return (bce+iou).mean()
22 |
23 | def train(Dataset, Network):
24 | ## dataset
25 | cfg = Dataset.Config(datapath='/home/ipal/datasets/DUTS_train', savepath='./models', mode='train', batch=32, lr=0.05, momen=0.9, decay=5e-4, epoch=32)
26 | data = Dataset.Data(cfg)
27 | loader = DataLoader(data, collate_fn=data.collate, batch_size=cfg.batch, shuffle=True, num_workers=8)
28 | ## network
29 | net = Network(cfg)
30 | net.train(True)
31 | net.cuda()
32 | ## parameter
33 | base, head = [], []
34 | for name, param in net.named_parameters():
35 | if 'bkbone.conv1' in name or 'bkbone.bn1' in name:
36 | print(name)
37 | elif 'bkbone' in name:
38 | base.append(param)
39 | else:
40 | head.append(param)
41 | optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True)
42 | net, optimizer = amp.initialize(net, optimizer, opt_level='O2')
43 | global_step = 0
44 |
45 | for epoch in range(cfg.epoch):
46 | optimizer.param_groups[0]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr*0.1
47 | optimizer.param_groups[1]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr
48 |
49 | for step, (image, mask) in enumerate(loader):
50 | image, mask = image.cuda().float(), mask.cuda().float()
51 | out2, out3, out4, out5 = net(image)
52 |
53 | loss2 = bce_iou_loss(out2, mask)
54 | loss3 = bce_iou_loss(out3, mask)
55 | loss4 = bce_iou_loss(out4, mask)
56 | loss5 = bce_iou_loss(out5, mask)
57 | loss = loss2 + loss3 + loss4 + loss5
58 |
59 | optimizer.zero_grad()
60 | with amp.scale_loss(loss, optimizer) as scale_loss:
61 | scale_loss.backward()
62 | optimizer.step()
63 |
64 | global_step += 1
65 | if step%10 == 0:
66 | print('%s | step:%d/%d/%d | lr=%.6f | loss=%.6f'%(datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss.item()))
67 | print('loss2=%.6f'%(loss2.item()))
68 | print('loss3=%.6f'%(loss3.item()))
69 | print('loss4=%.6f'%(loss4.item()))
70 | print('loss5=%.6f'%(loss5.item()))
71 |
72 | if (epoch + 1) % 8 == 0:
73 | torch.save(net.state_dict(), cfg.savepath+'/RAS.v2' + str(epoch+1) + '.pth')
74 |
75 |
76 | if __name__=='__main__':
77 | train(data, RAS)
78 |
--------------------------------------------------------------------------------