├── CCAFNet.py
├── README.md
├── config.py
├── rgbd_dataset.py
├── test.py
└── utils.py
/CCAFNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | import torchvision.models as models
7 |
8 | class Separable_conv(nn.Module):
9 | def __init__(self, inp, oup):
10 | super(Separable_conv, self).__init__()
11 |
12 | self.conv = nn.Sequential(
13 | # dw
14 | nn.Conv2d(inp, inp, kernel_size=3, stride=1, padding=1, groups=inp, bias=False),
15 | nn.BatchNorm2d(inp),
16 | nn.ReLU(inplace=True),
17 | # pw
18 | nn.Conv2d(inp, oup, kernel_size=1),
19 | )
20 |
21 | def forward(self, x):
22 | return self.conv(x)
23 |
24 |
25 | model = models.vgg16_bn(pretrained=True)
26 | model_urls = {
27 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
28 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
29 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
30 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
31 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
32 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
33 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
34 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
35 | }
36 |
37 | class vgg_rgb(nn.Module):
38 | def __init__(self, pretrained=True):
39 | super(vgg_rgb, self).__init__()
40 | self.features = nn.Sequential(
41 | nn.Conv2d(3, 64, 3, 1, 1), # first model 224*24*64
42 | nn.BatchNorm2d(64),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(64, 64, 3, 1, 1),
45 | nn.BatchNorm2d(64),
46 | nn.ReLU(inplace=True), # [:6]
47 | nn.MaxPool2d(kernel_size=2, stride=2),
48 | nn.Conv2d(64, 128, 3, 1, 1), # second model 112*112*128
49 | nn.BatchNorm2d(128),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(128, 128, 3, 1, 1),
52 | nn.BatchNorm2d(128),
53 | nn.ReLU(inplace=True), # [6:13]
54 | nn.MaxPool2d(kernel_size=2, stride=2),
55 | nn.Conv2d(128, 256, 3, 1, 1), # third model 56*56*256
56 | nn.BatchNorm2d(256),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(256, 256, 3, 1, 1),
59 | nn.BatchNorm2d(256),
60 | nn.ReLU(inplace=True),
61 | nn.Conv2d(256, 256, 3, 1, 1),
62 | nn.BatchNorm2d(256),
63 | nn.ReLU(inplace=True), # [13:23]
64 | nn.MaxPool2d(kernel_size=2, stride=2),
65 | nn.Conv2d(256, 512, 3, 1, 1), # forth model 28*28*512
66 | nn.BatchNorm2d(512),
67 | nn.ReLU(inplace=True),
68 | nn.Conv2d(512, 512, 3, 1, 1),
69 | nn.BatchNorm2d(512),
70 | nn.ReLU(inplace=True),
71 | nn.Conv2d(512, 512, 3, 1, 1),
72 | nn.BatchNorm2d(512),
73 | nn.ReLU(inplace=True), # [13:33]
74 | nn.MaxPool2d(kernel_size=2, stride=2),
75 | nn.Conv2d(512, 512, 3, 1, 1), # fifth model 14*14*512
76 | nn.BatchNorm2d(512),
77 | nn.ReLU(inplace=True),
78 | nn.Conv2d(512, 512, 3, 1, 1),
79 | nn.BatchNorm2d(512),
80 | nn.ReLU(inplace=True),
81 | nn.Conv2d(512, 512, 3, 1, 1),
82 | nn.BatchNorm2d(512),
83 | nn.ReLU(inplace=True), # [33:43]
84 | )
85 |
86 | if pretrained:
87 | pretrained_vgg = model_zoo.load_url(model_urls['vgg16_bn'])
88 | model_dict = {}
89 | state_dict = self.state_dict()
90 | for k, v in pretrained_vgg.items():
91 | if k in state_dict:
92 | model_dict[k] = v
93 | # print(k, v)
94 |
95 | state_dict.update(model_dict)
96 | self.load_state_dict(state_dict)
97 |
98 | def forward(self, rgb):
99 | A1 = self.features[:6](rgb)
100 | A2 = self.features[6:13](A1)
101 | A3 = self.features[13:23](A2)
102 | A4 = self.features[23:33](A3)
103 | A5 = self.features[33:43](A4)
104 | return A1, A2, A3, A4, A5
105 |
106 |
107 | class vgg_depth(nn.Module):
108 | def __init__(self, pretrained=True):
109 | super(vgg_depth, self).__init__()
110 | self.features = nn.Sequential(
111 | nn.Conv2d(3, 64, 3, 1, 1), # first model 224*224*64
112 | nn.BatchNorm2d(64),
113 | nn.ReLU(inplace=True),
114 | nn.Conv2d(64, 64, 3, 1, 1),
115 | nn.BatchNorm2d(64),
116 | nn.ReLU(inplace=True), # [:6]
117 | nn.MaxPool2d(kernel_size=2, stride=2),
118 | nn.Conv2d(64, 128, 3, 1, 1), # second model 112*112*128
119 | nn.BatchNorm2d(128),
120 | nn.ReLU(inplace=True),
121 | nn.Conv2d(128, 128, 3, 1, 1),
122 | nn.BatchNorm2d(128),
123 | nn.ReLU(inplace=True), # [6:13]
124 | nn.MaxPool2d(kernel_size=2, stride=2),
125 | nn.Conv2d(128, 256, 3, 1, 1), # third model 56*56*256
126 | nn.BatchNorm2d(256),
127 | nn.ReLU(inplace=True),
128 | nn.Conv2d(256, 256, 3, 1, 1),
129 | nn.BatchNorm2d(256),
130 | nn.ReLU(inplace=True),
131 | nn.Conv2d(256, 256, 3, 1, 1),
132 | nn.BatchNorm2d(256),
133 | nn.ReLU(inplace=True), # [13:23]
134 | nn.MaxPool2d(kernel_size=2, stride=2),
135 | nn.Conv2d(256, 512, 3, 1, 1), # forth model 28*28*512
136 | nn.BatchNorm2d(512),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(512, 512, 3, 1, 1),
139 | nn.BatchNorm2d(512),
140 | nn.ReLU(inplace=True),
141 | nn.Conv2d(512, 512, 3, 1, 1),
142 | nn.BatchNorm2d(512),
143 | nn.ReLU(inplace=True), # [13:33]
144 | nn.MaxPool2d(kernel_size=2, stride=2),
145 | nn.Conv2d(512, 512, 3, 1, 1), # fifth model 14*14*512
146 | nn.BatchNorm2d(512),
147 | nn.ReLU(inplace=True),
148 | nn.Conv2d(512, 512, 3, 1, 1),
149 | nn.BatchNorm2d(512),
150 | nn.ReLU(inplace=True),
151 | nn.Conv2d(512, 512, 3, 1, 1),
152 | nn.BatchNorm2d(512),
153 | nn.ReLU(inplace=True), # [33:43]
154 | )
155 |
156 | if pretrained:
157 | pretrained_vgg = model_zoo.load_url(model_urls['vgg16_bn'])
158 | model_dict = {}
159 | state_dict = self.state_dict()
160 | for k, v in pretrained_vgg.items():
161 | if k in state_dict:
162 | model_dict[k] = v
163 | # print(k, v)
164 |
165 | state_dict.update(model_dict)
166 | self.load_state_dict(state_dict)
167 |
168 | def forward(self, thermal):
169 | A1_d = self.features[:6](thermal)
170 | A2_d = self.features[6:13](A1_d)
171 | A3_d = self.features[13:23](A2_d)
172 | A4_d = self.features[23:33](A3_d)
173 | A5_d = self.features[33:43](A4_d)
174 | return A1_d, A2_d, A3_d, A4_d, A5_d
175 |
176 |
177 | class Hsigmoid(nn.Module):
178 | def __init__(self, inplace=True):
179 | super(Hsigmoid, self).__init__()
180 | self.inplace = inplace
181 |
182 | def forward(self, x):
183 | return F.relu6(x + 3., inplace=self.inplace) / 6.
184 |
185 |
186 | class Spatical_Fuse_attention3_GHOST(nn.Module): # 最终为rgb rgb, y为depth 加入恒等变化
187 | def __init__(self, in_channels,):
188 | super(Spatical_Fuse_attention3_GHOST, self).__init__()
189 | self.conv = nn.Conv2d(in_channels, 1, 3, 1, 1)
190 | self.active = Hsigmoid()
191 |
192 | def forward(self, x, y):
193 | input_y = self.conv(y)
194 | input_y = self.active(input_y)
195 | # return input_y
196 | return x + x * input_y
197 |
198 | class Channel_Fuse_attention2(nn.Module): # 最终为depth x为depth, y为rgb 加入恒等变化
199 | def __init__(self, channel, reduction=4):
200 | super(Channel_Fuse_attention2, self).__init__()
201 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
202 | self.fc = nn.Sequential(
203 | nn.Linear(channel, channel // reduction, bias=False),
204 | nn.Linear(channel // reduction, channel, bias=False),
205 | Hsigmoid()
206 | )
207 |
208 | def forward(self, x, y):
209 | b, c, _, _ = x.size()
210 | y = self.avg_pool(y).view(b, c)
211 | y = self.fc(y).view(b, c, 1, 1)
212 | return x + x * y.expand_as(x)
213 |
214 |
215 | class Gatefusion3(nn.Module):
216 | def __init__(self, channel):
217 | super(Gatefusion3, self).__init__()
218 | self.channel = channel
219 | self.gate = nn.Sigmoid()
220 |
221 | def forward(self, x, y, fusion_up):
222 | first_fusion = torch.cat((x, y), dim=1)
223 | gate_fusion = self.gate(first_fusion)
224 | gate_fusion = torch.split(gate_fusion, self.channel, dim=1)
225 | fusion_x = gate_fusion[0] * x + x
226 | fusion_y = gate_fusion[1] * y + y
227 | fusion = fusion_x + fusion_y
228 | fusion = torch.abs((fusion - fusion_up)) * fusion + fusion
229 | return fusion
230 |
231 | class Gatefusion3_fusionup(nn.Module):
232 | def __init__(self, channel):
233 | super(Gatefusion3_fusionup, self).__init__()
234 | self.channel = channel
235 | self.gate = nn.Sigmoid()
236 |
237 | def forward(self, x, y):
238 | first_fusion = torch.cat((x, y), dim=1)
239 | gate_fusion = self.gate(first_fusion)
240 | gate_fusion = torch.split(gate_fusion, self.channel, dim=1)
241 | fusion_x = gate_fusion[0] * x + x
242 | fusion_y = gate_fusion[1] * y + y
243 | fusion = fusion_x + fusion_y
244 | return fusion
245 |
246 | class CCAFNet(nn.Module):
247 | def __init__(self, ):
248 | super(CCAFNet, self).__init__()
249 | # rgb,depth encode
250 | self.rgb_pretrained = vgg_rgb()
251 | self.depth_pretrained = vgg_depth()
252 |
253 | # rgb Fuse_model
254 | self.SAG1 = Spatical_Fuse_attention3_GHOST(64)
255 | self.SAG2 = Spatical_Fuse_attention3_GHOST(128)
256 | self.SAG3 = Spatical_Fuse_attention3_GHOST(256)
257 |
258 | # depth Fuse_model
259 | self.CAG4 = Channel_Fuse_attention2(512)
260 | self.CAG5 = Channel_Fuse_attention2(512)
261 |
262 | self.gatefusion5 = Gatefusion3_fusionup(512)
263 | self.gatefusion4 = Gatefusion3(512)
264 | self.gatefusion3 = Gatefusion3(256)
265 | self.gatefusion2 = Gatefusion3(128)
266 | self.gatefusion1 = Gatefusion3(64)
267 |
268 |
269 | # Upsample_model
270 | self.upsample1 = nn.Sequential(nn.Conv2d(288, 144, 3, 1, 1),nn.BatchNorm2d(144),nn.ReLU())
271 | self.upsample2 = nn.Sequential(nn.Conv2d(448, 224,3,1,1),nn.BatchNorm2d(224),nn.ReLU(),
272 | nn.UpsamplingBilinear2d(scale_factor=2, ))
273 | self.upsample3 = nn.Sequential(nn.Conv2d(640, 320,3,1,1),nn.BatchNorm2d(320),nn.ReLU(),
274 | nn.UpsamplingBilinear2d(scale_factor=2, ))
275 | self.upsample4 = nn.Sequential(nn.Conv2d(768, 384,3,1,1),nn.BatchNorm2d(384),nn.ReLU(),
276 | nn.UpsamplingBilinear2d(scale_factor=2, ))
277 | self.upsample5 = nn.Sequential(nn.Conv2d(512, 256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(),
278 | nn.UpsamplingBilinear2d(scale_factor=2, ))
279 |
280 | # duibi
281 | self.upsample5_4 = nn.Sequential(nn.Conv2d(512, 512,3,1,1),nn.BatchNorm2d(512),nn.ReLU(),
282 | nn.UpsamplingBilinear2d(scale_factor=2, ))
283 | self.upsample4_3 = nn.Sequential(nn.Conv2d(768, 256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(),
284 | nn.UpsamplingBilinear2d(scale_factor=2, ))
285 | self.upsample3_2 = nn.Sequential(nn.Conv2d(640, 128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(),
286 | nn.UpsamplingBilinear2d(scale_factor=2, ))
287 | self.upsample2_1 = nn.Sequential(nn.Conv2d(448, 64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(),
288 | nn.UpsamplingBilinear2d(scale_factor=2, ))
289 |
290 | self.conv = nn.Conv2d(144, 1, 1)
291 | self.conv2 = nn.Conv2d(224, 1, 1)
292 | self.conv3 = nn.Conv2d(320, 1, 1)
293 | self.conv4 = nn.Conv2d(384, 1, 1)
294 | self.conv5 = nn.Conv2d(256, 1, 1)
295 |
296 | def forward(self, rgb, depth):
297 | # rgb
298 | A1, A2, A3, A4, A5 = self.rgb_pretrained(rgb)
299 | # depth
300 | A1_d, A2_d, A3_d, A4_d, A5_d = self.depth_pretrained(depth)
301 |
302 | SAG1_R = self.SAG1(A1, A1_d)
303 | SAG2_R = self.SAG2(A2, A2_d)
304 | SAG3_R = self.SAG3(A3, A3_d)
305 |
306 | CAG5_D = self.CAG5(A5_d, A5)
307 | CAG4_D = self.CAG4(A4_d, A4)
308 |
309 | F5 = self.gatefusion5(A5, CAG5_D)
310 | F5_UP = self.upsample5_4(F5)
311 | F5 = self.upsample5(F5) # 14*14
312 | F4 = self.gatefusion4(A4, CAG4_D, F5_UP)
313 | F4 = torch.cat((F4, F5), dim=1)
314 | F4_UP = self.upsample4_3(F4)
315 | F4 = self.upsample4(F4) # 28*28
316 | F3 = self.gatefusion3(SAG3_R, A3_d, F4_UP)
317 | F3 = torch.cat((F3, F4), dim=1)
318 | F3_UP = self.upsample3_2(F3)
319 | F3 = self.upsample3(F3) # 56*56
320 | F2 = self.gatefusion2(SAG2_R, A2_d, F3_UP)
321 | F2 = torch.cat((F2, F3), dim=1)
322 | F2_UP = self.upsample2_1(F2)
323 | F2 = self.upsample2(F2) # 112*112
324 | F1 = self.gatefusion1(SAG1_R, A1_d, F2_UP)
325 | F1 = torch.cat((F1, F2), dim=1)
326 | F1 = self.upsample1(F1) # 224*224
327 | out = self.conv(F1)
328 |
329 | out5 = self.conv5(F5)
330 | out4 = self.conv4(F4)
331 | out3 = self.conv3(F3)
332 | out2 = self.conv2(F2)
333 |
334 | if self.training:
335 | return out, out2, out3, out4, out5
336 | return out
337 |
338 |
339 |
340 |
341 | if __name__=='__main__':
342 |
343 | # model = ghost_net()
344 | # model.eval()
345 | model = CCAFNet()
346 | rgb = torch.randn(1, 3, 224, 224)
347 | depth = torch.randn(1, 3, 224, 224)
348 | out = model(rgb,depth)
349 | for i in out:
350 | print(i.shape)
351 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Code and result about CCAFNet(IEEE TMM)
2 | 'CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images' [IEEE TMM](https://ieeexplore.ieee.org/document/9424966)
3 | 
4 |
5 | # Requirements
6 | Python 3.7, Pytorch 1.5.0+, Cuda 10.2, TensorboardX 2.1, opencv-python
7 |
8 | # Dataset and Evaluate tools
9 | RGB-D SOD Datasets can be found in: http://dpfan.net/d3netbenchmark/ or https://github.com/jiwei0921/RGBD-SOD-datasets
10 |
11 | we use the matlab verison provide by Dengping Fan, and we provide our test datesets [百度网盘](https://pan.baidu.com/s/1tVJCWRwqIoZQ3KAplMSHsA) 提取码:zust
12 |
13 | # Result
14 | 
15 | 
16 |
17 | Test maps: [百度网盘](https://pan.baidu.com/s/1QcEAHlS8llyX-i3kX4npAA) 提取码:zust
18 | Pretrained model download:[百度网盘](https://pan.baidu.com/s/1reGFvIYX7rZjzKuaDcs-3A) 提取码:zust
19 | PS: we resize the testing data to the size of 224 * 224 for quicky evaluate, [百度网盘](https://pan.baidu.com/s/1t5cES-RAnMCLJ76s9bwzmA) 提取码:zust
20 |
21 | # Citation
22 | @ARTICLE{9424966,
23 | author={Zhou, Wujie and Zhu, Yun and Lei, Jingsheng and Wan, Jian and Yu, Lu},
24 | journal={IEEE Transactions on Multimedia},
25 | title={CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images},
26 | year={2021},
27 | doi={10.1109/TMM.2021.3077767}}
28 |
29 | # Acknowledgement
30 | The implement of this project is based on the code of ‘Cascaded Partial Decoder for Fast and Accurate Salient Object Detection, CVPR2019’and 'BBS-Net: RGB-D Salient Object Detection with a Bifurcated Backbone Strategy Network' proposed by Wu et al and Deng et al.
31 |
32 | # Contact
33 | Please drop me an email for further problems or discussion: zzzyylink@gmail.com or wujiezhou@163.com
34 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | # train/val
4 | parser.add_argument('--epoch', type=int, default=200, help='epoch number')
5 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
6 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
7 | parser.add_argument('--trainsize', type=int, default=224, help='training dataset size')
8 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
9 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
10 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate')
11 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints')
12 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
13 | parser.add_argument('--train_root', type=str, default='/home/zy/PycharmProjects/zy/newdata/train/NJUNLPR', help='the train images root')
14 | parser.add_argument('--val_root', type=str, default='/home/zy/PycharmProjects/zy/newdata/val', help='the val images root')
15 | parser.add_argument('--save_path', type=str, default='/media/zy/shuju/RGBDweight/PVTbackbone_SC2/', help='the path to save models and logs')
16 | # test(predict)
17 | parser.add_argument('--testsize', type=int, default=224, help='testing size')
18 | parser.add_argument('--test_path',type=str,default='/home/zy/PycharmProjects/zy/newdata/test/',help='test dataset path')
19 | # parser.add_argument('--test_path',type=str,default='/home/zy/PycharmProjects/zy/DUT-RGBD/test_data/',help='test dataset path')
20 | opt = parser.parse_args()
21 |
--------------------------------------------------------------------------------
/rgbd_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 | import torch
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label, depth):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
19 | # top bottom flip
20 | # if flip_flag2==1:
21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
24 | return img, label, depth
25 |
26 |
27 | def randomCrop(image, label, depth):
28 | border = 30
29 | image_width = image.size[0]
30 | image_height = image.size[1]
31 | crop_win_width = np.random.randint(image_width - border, image_width)
32 | crop_win_height = np.random.randint(image_height - border, image_height)
33 | random_region = (
34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
35 | (image_height + crop_win_height) >> 1)
36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region)
37 |
38 |
39 | def randomRotation(image, label, depth):
40 | mode = Image.BICUBIC
41 | if random.random() > 0.8:
42 | random_angle = np.random.randint(-15, 15)
43 | image = image.rotate(random_angle, mode)
44 | label = label.rotate(random_angle, mode)
45 | depth = depth.rotate(random_angle, mode)
46 | return image, label, depth
47 |
48 |
49 | def colorEnhance(image):
50 | bright_intensity = random.randint(5, 15) / 10.0
51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
52 | contrast_intensity = random.randint(5, 15) / 10.0
53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
54 | color_intensity = random.randint(0, 20) / 10.0
55 | image = ImageEnhance.Color(image).enhance(color_intensity)
56 | sharp_intensity = random.randint(0, 30) / 10.0
57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
58 | return image
59 |
60 |
61 | def randomGaussian(image, mean=0.1, sigma=0.35):
62 | def gaussianNoisy(im, mean=mean, sigma=sigma):
63 | for _i in range(len(im)):
64 | im[_i] += random.gauss(mean, sigma)
65 | return im
66 |
67 | img = np.asarray(image)
68 | width, height = img.shape
69 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
70 | img = img.reshape([width, height])
71 | return Image.fromarray(np.uint8(img))
72 |
73 |
74 | def randomPeper(img):
75 | img = np.array(img)
76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
77 | for i in range(noiseNum):
78 |
79 | randX = random.randint(0, img.shape[0] - 1)
80 |
81 | randY = random.randint(0, img.shape[1] - 1)
82 |
83 | if random.randint(0, 1) == 0:
84 |
85 | img[randX, randY] = 0
86 |
87 | else:
88 |
89 | img[randX, randY] = 255
90 | return Image.fromarray(img)
91 |
92 |
93 | # dataset for training
94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
96 | class SalObjDataset(data.Dataset):
97 | def __init__(self, image_root, gt_root, depth_root, trainsize):
98 | self.trainsize = trainsize
99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')
100 | or f.endswith('.png')]
101 | # print(self.images)
102 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
103 | or f.endswith('.png')]
104 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
105 | or f.endswith('.png')]
106 | self.images = sorted(self.images)
107 | self.gts = sorted(self.gts)
108 | self.depths = sorted(self.depths)
109 | self.filter_files()
110 | self.size = len(self.images)
111 | self.img_transform = transforms.Compose([
112 | transforms.Resize((self.trainsize, self.trainsize)),
113 | transforms.ToTensor(),
114 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
115 | self.gt_transform = transforms.Compose([
116 | transforms.Resize((self.trainsize, self.trainsize)),
117 | transforms.ToTensor()])
118 | self.depths_transform = transforms.Compose([
119 | transforms.Resize((self.trainsize, self.trainsize)),
120 | transforms.ToTensor(),
121 | # transforms.Normalize([0.485], [0.229])
122 | ])
123 |
124 | def __getitem__(self, index):
125 | image = self.rgb_loader(self.images[index])
126 | gt = self.binary_loader(self.gts[index])
127 | depth = self.binary_loader(self.depths[index])
128 | image, gt, depth = cv_random_flip(image, gt, depth)
129 | image, gt, depth = randomCrop(image, gt, depth)
130 | image, gt, depth = randomRotation(image, gt, depth)
131 | image = colorEnhance(image)
132 | # gt=randomGaussian(gt)
133 | gt = randomPeper(gt)
134 | # image, gt, depth = self.resize(image,gt, depth)
135 | image = self.img_transform(image)
136 | gt = self.gt_transform(gt)
137 | depth = self.depths_transform(depth)
138 | # depth = torch.div(depth.float(),255.0) # DUT
139 |
140 | return image, gt, depth
141 |
142 | def filter_files(self):
143 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
144 | # print(len(self.images),len(self.gts),len(self.depths))
145 | images = []
146 | gts = []
147 | depths = []
148 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths):
149 | img = Image.open(img_path)
150 | gt = Image.open(gt_path)
151 | depth = Image.open(depth_path)
152 | if img.size == gt.size and gt.size == depth.size:
153 | # if img.size == gt.size:
154 | images.append(img_path)
155 | gts.append(gt_path)
156 | depths.append(depth_path)
157 | self.images = images
158 | self.gts = gts
159 | self.depths = depths
160 |
161 | def rgb_loader(self, path):
162 | # print(path)
163 | with open(path, 'rb') as f:
164 | img = Image.open(f)
165 | # print(img)
166 | return img.convert('RGB')
167 |
168 | def binary_loader(self, path):
169 | with open(path, 'rb') as f:
170 | img = Image.open(f)
171 | return img.convert('L')
172 |
173 | def resize(self, img, gt, depth):
174 | assert img.size == gt.size and gt.size == depth.size
175 | h = self.trainsize
176 | w = self.trainsize
177 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
178 | Image.NEAREST)
179 |
180 |
181 | def __len__(self):
182 | return self.size
183 |
184 |
185 | # dataloader for training
186 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True):
187 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize)
188 | data_loader = data.DataLoader(dataset=dataset,
189 | batch_size=batchsize,
190 | shuffle=shuffle,
191 | num_workers=num_workers,
192 | pin_memory=pin_memory)
193 | return data_loader
194 |
195 |
196 | # test dataset and loader
197 | class test_dataset:
198 | def __init__(self, image_root, gt_root, depth_root, testsize):
199 | self.testsize = testsize
200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
201 |
202 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
203 | or f.endswith('.png')]
204 |
205 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
206 | or f.endswith('.png')]
207 |
208 | self.images = sorted(self.images)
209 | # print(self.images)
210 | self.gts = sorted(self.gts)
211 | # print(self.gts)
212 | self.depths = sorted(self.depths)
213 | # print(self.depths)
214 | self.transform = transforms.Compose([
215 | transforms.Resize((self.testsize, self.testsize)),
216 | transforms.ToTensor(),
217 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
218 | # self.gt_transform = transforms.ToTensor()
219 | self.gt_transform = transforms.Compose([
220 | transforms.Resize((self.testsize, self.testsize)),
221 | transforms.ToTensor()])
222 | self.depths_transform = transforms.Compose([
223 | transforms.Resize((self.testsize, self.testsize)),
224 | transforms.ToTensor(),
225 | transforms.Normalize([0.485], [0.229])
226 | ])
227 | self.size = len(self.images)
228 | self.index = 0
229 |
230 | def load_data(self):
231 | image = self.rgb_loader(self.images[self.index])
232 | gt = self.binary_loader(self.gts[self.index])
233 | depth = self.binary_loader(self.depths[self.index])
234 | # image, gt, depth = self.resize(image, gt, depth)
235 | image = self.transform(image).unsqueeze(0)
236 | gt = self.gt_transform(gt).unsqueeze(0)
237 | depth = self.depths_transform(depth)
238 | # depth = torch.div(depth.float(), 255.0) # DUT
239 | depth = depth.unsqueeze(0)
240 | name = self.images[self.index].split('/')[-1]
241 | if name.endswith('.jpg'):
242 | name = name.split('.jpg')[0] + '.png'
243 | self.index += 1
244 | self.index = self.index % self.size
245 | return image, gt, depth, name
246 |
247 | def rgb_loader(self, path):
248 | with open(path, 'rb') as f:
249 | img = Image.open(f)
250 | return img.convert('RGB')
251 |
252 | def binary_loader(self, path):
253 | with open(path, 'rb') as f:
254 | img = Image.open(f)
255 | return img.convert('L')
256 |
257 | def resize(self, img, gt, depth):
258 | # assert img.size == gt.size and gt.size == depth.size
259 | h = self.testsize
260 | w = self.testsize
261 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
262 | Image.NEAREST)
263 |
264 | def __len__(self):
265 | return self.size
266 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import sys
4 | sys.path.append('./models')
5 | import numpy as np
6 | import os
7 | import cv2
8 | import matplotlib.pyplot as plt
9 |
10 | from rgbd.rgbd_models.CCAFNet import CCAFNet
11 | from config import opt
12 | from rgbd.rgbd_dataset import test_dataset
13 | from torch.cuda import amp
14 |
15 |
16 | dataset_path = opt.test_path
17 |
18 | #set device for test
19 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
20 | print('USE GPU:', opt.gpu_id)
21 |
22 | #load the model
23 | model = CCAFNet()
24 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training.
25 | # model.load_state_dict(torch.load('/media/zy/shuju/TMMweight/TMMALLCFM/TMM_epoch_100.pth'))
26 | model.load_state_dict(torch.load('/media/zy/shuju/RGBDweight/PVTbackbone_SC/II_epoch_best.pth'))
27 |
28 | # model.load_state_dict(torch.load('/media/zy/shuju/TMMweight/vgg16plus/TMM_epoch_60.pth'))
29 | model.cuda()
30 | model.eval()
31 |
32 | #test
33 | test_mae = []
34 | test_datasets = ['NJU2K','STERE','DES','LFSD','NLPR','SIP']
35 |
36 | for dataset in test_datasets:
37 | mae_sum = 0
38 | save_path = '/home/zy/PycharmProjects/SOD/rgbd/rgbd_test_maps/CCAFNet/' + dataset + '/'
39 |
40 | if not os.path.exists(save_path):
41 | os.makedirs(save_path)
42 | image_root = dataset_path + dataset + '/RGB/'
43 | gt_root = dataset_path + dataset + '/GT/'
44 | depth_root=dataset_path +dataset +'/depth/'
45 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize)
46 | for i in range(test_loader.size):
47 | image, gt, depth, name = test_loader.load_data()
48 | gt = gt.cuda()
49 | image = image.cuda()
50 | # print(image.shape)
51 | n, c, h, w = image.size()
52 | depth = depth.cuda()
53 | depth = depth.view(n, h, w, 1).repeat(1, 1, 1, c)
54 | depth = depth.transpose(3, 1)
55 | depth = depth.transpose(3, 2)
56 | res = model(image, depth)
57 | predict = torch.sigmoid(res)
58 | predict = (predict - predict.min()) / (predict.max() - predict.min() + 1e-8)
59 | mae = torch.sum(torch.abs(predict - gt)) / torch.numel(gt)
60 | mae_sum = mae.item() + mae_sum
61 | predict = predict.data.cpu().numpy().squeeze()
62 | print('save img to: ', save_path + name)
63 |
64 | plt.imsave(save_path + name, arr=predict, cmap='gray')
65 |
66 | test_mae.append(mae_sum / test_loader.size)
67 | print('Test_mae:', test_mae)
68 | print('Test Done!')
69 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | def clip_gradient(optimizer, grad_clip):
2 | for group in optimizer.param_groups:
3 | for param in group['params']:
4 | if param.grad is not None:
5 | param.grad.data.clamp_(-grad_clip, grad_clip)
6 |
7 |
8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9 | decay = decay_rate ** (epoch // decay_epoch)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = decay*init_lr
12 | lr=param_group['lr']
13 | return lr
14 |
--------------------------------------------------------------------------------