├── CCAFNet.py ├── README.md ├── config.py ├── rgbd_dataset.py ├── test.py └── utils.py /CCAFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | import torchvision.models as models 7 | 8 | class Separable_conv(nn.Module): 9 | def __init__(self, inp, oup): 10 | super(Separable_conv, self).__init__() 11 | 12 | self.conv = nn.Sequential( 13 | # dw 14 | nn.Conv2d(inp, inp, kernel_size=3, stride=1, padding=1, groups=inp, bias=False), 15 | nn.BatchNorm2d(inp), 16 | nn.ReLU(inplace=True), 17 | # pw 18 | nn.Conv2d(inp, oup, kernel_size=1), 19 | ) 20 | 21 | def forward(self, x): 22 | return self.conv(x) 23 | 24 | 25 | model = models.vgg16_bn(pretrained=True) 26 | model_urls = { 27 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 28 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 29 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 30 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 31 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 32 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 33 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 34 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 35 | } 36 | 37 | class vgg_rgb(nn.Module): 38 | def __init__(self, pretrained=True): 39 | super(vgg_rgb, self).__init__() 40 | self.features = nn.Sequential( 41 | nn.Conv2d(3, 64, 3, 1, 1), # first model 224*24*64 42 | nn.BatchNorm2d(64), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(64, 64, 3, 1, 1), 45 | nn.BatchNorm2d(64), 46 | nn.ReLU(inplace=True), # [:6] 47 | nn.MaxPool2d(kernel_size=2, stride=2), 48 | nn.Conv2d(64, 128, 3, 1, 1), # second model 112*112*128 49 | nn.BatchNorm2d(128), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(128, 128, 3, 1, 1), 52 | nn.BatchNorm2d(128), 53 | nn.ReLU(inplace=True), # [6:13] 54 | nn.MaxPool2d(kernel_size=2, stride=2), 55 | nn.Conv2d(128, 256, 3, 1, 1), # third model 56*56*256 56 | nn.BatchNorm2d(256), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(256, 256, 3, 1, 1), 59 | nn.BatchNorm2d(256), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(256, 256, 3, 1, 1), 62 | nn.BatchNorm2d(256), 63 | nn.ReLU(inplace=True), # [13:23] 64 | nn.MaxPool2d(kernel_size=2, stride=2), 65 | nn.Conv2d(256, 512, 3, 1, 1), # forth model 28*28*512 66 | nn.BatchNorm2d(512), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(512, 512, 3, 1, 1), 69 | nn.BatchNorm2d(512), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(512, 512, 3, 1, 1), 72 | nn.BatchNorm2d(512), 73 | nn.ReLU(inplace=True), # [13:33] 74 | nn.MaxPool2d(kernel_size=2, stride=2), 75 | nn.Conv2d(512, 512, 3, 1, 1), # fifth model 14*14*512 76 | nn.BatchNorm2d(512), 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(512, 512, 3, 1, 1), 79 | nn.BatchNorm2d(512), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(512, 512, 3, 1, 1), 82 | nn.BatchNorm2d(512), 83 | nn.ReLU(inplace=True), # [33:43] 84 | ) 85 | 86 | if pretrained: 87 | pretrained_vgg = model_zoo.load_url(model_urls['vgg16_bn']) 88 | model_dict = {} 89 | state_dict = self.state_dict() 90 | for k, v in pretrained_vgg.items(): 91 | if k in state_dict: 92 | model_dict[k] = v 93 | # print(k, v) 94 | 95 | state_dict.update(model_dict) 96 | self.load_state_dict(state_dict) 97 | 98 | def forward(self, rgb): 99 | A1 = self.features[:6](rgb) 100 | A2 = self.features[6:13](A1) 101 | A3 = self.features[13:23](A2) 102 | A4 = self.features[23:33](A3) 103 | A5 = self.features[33:43](A4) 104 | return A1, A2, A3, A4, A5 105 | 106 | 107 | class vgg_depth(nn.Module): 108 | def __init__(self, pretrained=True): 109 | super(vgg_depth, self).__init__() 110 | self.features = nn.Sequential( 111 | nn.Conv2d(3, 64, 3, 1, 1), # first model 224*224*64 112 | nn.BatchNorm2d(64), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(64, 64, 3, 1, 1), 115 | nn.BatchNorm2d(64), 116 | nn.ReLU(inplace=True), # [:6] 117 | nn.MaxPool2d(kernel_size=2, stride=2), 118 | nn.Conv2d(64, 128, 3, 1, 1), # second model 112*112*128 119 | nn.BatchNorm2d(128), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(128, 128, 3, 1, 1), 122 | nn.BatchNorm2d(128), 123 | nn.ReLU(inplace=True), # [6:13] 124 | nn.MaxPool2d(kernel_size=2, stride=2), 125 | nn.Conv2d(128, 256, 3, 1, 1), # third model 56*56*256 126 | nn.BatchNorm2d(256), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(256, 256, 3, 1, 1), 129 | nn.BatchNorm2d(256), 130 | nn.ReLU(inplace=True), 131 | nn.Conv2d(256, 256, 3, 1, 1), 132 | nn.BatchNorm2d(256), 133 | nn.ReLU(inplace=True), # [13:23] 134 | nn.MaxPool2d(kernel_size=2, stride=2), 135 | nn.Conv2d(256, 512, 3, 1, 1), # forth model 28*28*512 136 | nn.BatchNorm2d(512), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(512, 512, 3, 1, 1), 139 | nn.BatchNorm2d(512), 140 | nn.ReLU(inplace=True), 141 | nn.Conv2d(512, 512, 3, 1, 1), 142 | nn.BatchNorm2d(512), 143 | nn.ReLU(inplace=True), # [13:33] 144 | nn.MaxPool2d(kernel_size=2, stride=2), 145 | nn.Conv2d(512, 512, 3, 1, 1), # fifth model 14*14*512 146 | nn.BatchNorm2d(512), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(512, 512, 3, 1, 1), 149 | nn.BatchNorm2d(512), 150 | nn.ReLU(inplace=True), 151 | nn.Conv2d(512, 512, 3, 1, 1), 152 | nn.BatchNorm2d(512), 153 | nn.ReLU(inplace=True), # [33:43] 154 | ) 155 | 156 | if pretrained: 157 | pretrained_vgg = model_zoo.load_url(model_urls['vgg16_bn']) 158 | model_dict = {} 159 | state_dict = self.state_dict() 160 | for k, v in pretrained_vgg.items(): 161 | if k in state_dict: 162 | model_dict[k] = v 163 | # print(k, v) 164 | 165 | state_dict.update(model_dict) 166 | self.load_state_dict(state_dict) 167 | 168 | def forward(self, thermal): 169 | A1_d = self.features[:6](thermal) 170 | A2_d = self.features[6:13](A1_d) 171 | A3_d = self.features[13:23](A2_d) 172 | A4_d = self.features[23:33](A3_d) 173 | A5_d = self.features[33:43](A4_d) 174 | return A1_d, A2_d, A3_d, A4_d, A5_d 175 | 176 | 177 | class Hsigmoid(nn.Module): 178 | def __init__(self, inplace=True): 179 | super(Hsigmoid, self).__init__() 180 | self.inplace = inplace 181 | 182 | def forward(self, x): 183 | return F.relu6(x + 3., inplace=self.inplace) / 6. 184 | 185 | 186 | class Spatical_Fuse_attention3_GHOST(nn.Module): # 最终为rgb rgb, y为depth 加入恒等变化 187 | def __init__(self, in_channels,): 188 | super(Spatical_Fuse_attention3_GHOST, self).__init__() 189 | self.conv = nn.Conv2d(in_channels, 1, 3, 1, 1) 190 | self.active = Hsigmoid() 191 | 192 | def forward(self, x, y): 193 | input_y = self.conv(y) 194 | input_y = self.active(input_y) 195 | # return input_y 196 | return x + x * input_y 197 | 198 | class Channel_Fuse_attention2(nn.Module): # 最终为depth x为depth, y为rgb 加入恒等变化 199 | def __init__(self, channel, reduction=4): 200 | super(Channel_Fuse_attention2, self).__init__() 201 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 202 | self.fc = nn.Sequential( 203 | nn.Linear(channel, channel // reduction, bias=False), 204 | nn.Linear(channel // reduction, channel, bias=False), 205 | Hsigmoid() 206 | ) 207 | 208 | def forward(self, x, y): 209 | b, c, _, _ = x.size() 210 | y = self.avg_pool(y).view(b, c) 211 | y = self.fc(y).view(b, c, 1, 1) 212 | return x + x * y.expand_as(x) 213 | 214 | 215 | class Gatefusion3(nn.Module): 216 | def __init__(self, channel): 217 | super(Gatefusion3, self).__init__() 218 | self.channel = channel 219 | self.gate = nn.Sigmoid() 220 | 221 | def forward(self, x, y, fusion_up): 222 | first_fusion = torch.cat((x, y), dim=1) 223 | gate_fusion = self.gate(first_fusion) 224 | gate_fusion = torch.split(gate_fusion, self.channel, dim=1) 225 | fusion_x = gate_fusion[0] * x + x 226 | fusion_y = gate_fusion[1] * y + y 227 | fusion = fusion_x + fusion_y 228 | fusion = torch.abs((fusion - fusion_up)) * fusion + fusion 229 | return fusion 230 | 231 | class Gatefusion3_fusionup(nn.Module): 232 | def __init__(self, channel): 233 | super(Gatefusion3_fusionup, self).__init__() 234 | self.channel = channel 235 | self.gate = nn.Sigmoid() 236 | 237 | def forward(self, x, y): 238 | first_fusion = torch.cat((x, y), dim=1) 239 | gate_fusion = self.gate(first_fusion) 240 | gate_fusion = torch.split(gate_fusion, self.channel, dim=1) 241 | fusion_x = gate_fusion[0] * x + x 242 | fusion_y = gate_fusion[1] * y + y 243 | fusion = fusion_x + fusion_y 244 | return fusion 245 | 246 | class CCAFNet(nn.Module): 247 | def __init__(self, ): 248 | super(CCAFNet, self).__init__() 249 | # rgb,depth encode 250 | self.rgb_pretrained = vgg_rgb() 251 | self.depth_pretrained = vgg_depth() 252 | 253 | # rgb Fuse_model 254 | self.SAG1 = Spatical_Fuse_attention3_GHOST(64) 255 | self.SAG2 = Spatical_Fuse_attention3_GHOST(128) 256 | self.SAG3 = Spatical_Fuse_attention3_GHOST(256) 257 | 258 | # depth Fuse_model 259 | self.CAG4 = Channel_Fuse_attention2(512) 260 | self.CAG5 = Channel_Fuse_attention2(512) 261 | 262 | self.gatefusion5 = Gatefusion3_fusionup(512) 263 | self.gatefusion4 = Gatefusion3(512) 264 | self.gatefusion3 = Gatefusion3(256) 265 | self.gatefusion2 = Gatefusion3(128) 266 | self.gatefusion1 = Gatefusion3(64) 267 | 268 | 269 | # Upsample_model 270 | self.upsample1 = nn.Sequential(nn.Conv2d(288, 144, 3, 1, 1),nn.BatchNorm2d(144),nn.ReLU()) 271 | self.upsample2 = nn.Sequential(nn.Conv2d(448, 224,3,1,1),nn.BatchNorm2d(224),nn.ReLU(), 272 | nn.UpsamplingBilinear2d(scale_factor=2, )) 273 | self.upsample3 = nn.Sequential(nn.Conv2d(640, 320,3,1,1),nn.BatchNorm2d(320),nn.ReLU(), 274 | nn.UpsamplingBilinear2d(scale_factor=2, )) 275 | self.upsample4 = nn.Sequential(nn.Conv2d(768, 384,3,1,1),nn.BatchNorm2d(384),nn.ReLU(), 276 | nn.UpsamplingBilinear2d(scale_factor=2, )) 277 | self.upsample5 = nn.Sequential(nn.Conv2d(512, 256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(), 278 | nn.UpsamplingBilinear2d(scale_factor=2, )) 279 | 280 | # duibi 281 | self.upsample5_4 = nn.Sequential(nn.Conv2d(512, 512,3,1,1),nn.BatchNorm2d(512),nn.ReLU(), 282 | nn.UpsamplingBilinear2d(scale_factor=2, )) 283 | self.upsample4_3 = nn.Sequential(nn.Conv2d(768, 256,3,1,1),nn.BatchNorm2d(256),nn.ReLU(), 284 | nn.UpsamplingBilinear2d(scale_factor=2, )) 285 | self.upsample3_2 = nn.Sequential(nn.Conv2d(640, 128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(), 286 | nn.UpsamplingBilinear2d(scale_factor=2, )) 287 | self.upsample2_1 = nn.Sequential(nn.Conv2d(448, 64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(), 288 | nn.UpsamplingBilinear2d(scale_factor=2, )) 289 | 290 | self.conv = nn.Conv2d(144, 1, 1) 291 | self.conv2 = nn.Conv2d(224, 1, 1) 292 | self.conv3 = nn.Conv2d(320, 1, 1) 293 | self.conv4 = nn.Conv2d(384, 1, 1) 294 | self.conv5 = nn.Conv2d(256, 1, 1) 295 | 296 | def forward(self, rgb, depth): 297 | # rgb 298 | A1, A2, A3, A4, A5 = self.rgb_pretrained(rgb) 299 | # depth 300 | A1_d, A2_d, A3_d, A4_d, A5_d = self.depth_pretrained(depth) 301 | 302 | SAG1_R = self.SAG1(A1, A1_d) 303 | SAG2_R = self.SAG2(A2, A2_d) 304 | SAG3_R = self.SAG3(A3, A3_d) 305 | 306 | CAG5_D = self.CAG5(A5_d, A5) 307 | CAG4_D = self.CAG4(A4_d, A4) 308 | 309 | F5 = self.gatefusion5(A5, CAG5_D) 310 | F5_UP = self.upsample5_4(F5) 311 | F5 = self.upsample5(F5) # 14*14 312 | F4 = self.gatefusion4(A4, CAG4_D, F5_UP) 313 | F4 = torch.cat((F4, F5), dim=1) 314 | F4_UP = self.upsample4_3(F4) 315 | F4 = self.upsample4(F4) # 28*28 316 | F3 = self.gatefusion3(SAG3_R, A3_d, F4_UP) 317 | F3 = torch.cat((F3, F4), dim=1) 318 | F3_UP = self.upsample3_2(F3) 319 | F3 = self.upsample3(F3) # 56*56 320 | F2 = self.gatefusion2(SAG2_R, A2_d, F3_UP) 321 | F2 = torch.cat((F2, F3), dim=1) 322 | F2_UP = self.upsample2_1(F2) 323 | F2 = self.upsample2(F2) # 112*112 324 | F1 = self.gatefusion1(SAG1_R, A1_d, F2_UP) 325 | F1 = torch.cat((F1, F2), dim=1) 326 | F1 = self.upsample1(F1) # 224*224 327 | out = self.conv(F1) 328 | 329 | out5 = self.conv5(F5) 330 | out4 = self.conv4(F4) 331 | out3 = self.conv3(F3) 332 | out2 = self.conv2(F2) 333 | 334 | if self.training: 335 | return out, out2, out3, out4, out5 336 | return out 337 | 338 | 339 | 340 | 341 | if __name__=='__main__': 342 | 343 | # model = ghost_net() 344 | # model.eval() 345 | model = CCAFNet() 346 | rgb = torch.randn(1, 3, 224, 224) 347 | depth = torch.randn(1, 3, 224, 224) 348 | out = model(rgb,depth) 349 | for i in out: 350 | print(i.shape) 351 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code and result about CCAFNet(IEEE TMM)
2 | 'CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images' [IEEE TMM](https://ieeexplore.ieee.org/document/9424966) 3 | ![image](https://user-images.githubusercontent.com/38373305/134313486-f347b60a-3301-45f0-a22f-b9bdebf2b064.png) 4 | 5 | # Requirements 6 | Python 3.7, Pytorch 1.5.0+, Cuda 10.2, TensorboardX 2.1, opencv-python 7 | 8 | # Dataset and Evaluate tools 9 | RGB-D SOD Datasets can be found in: http://dpfan.net/d3netbenchmark/ or https://github.com/jiwei0921/RGBD-SOD-datasets
10 | 11 | we use the matlab verison provide by Dengping Fan, and we provide our test datesets [百度网盘](https://pan.baidu.com/s/1tVJCWRwqIoZQ3KAplMSHsA) 提取码:zust 12 | 13 | # Result 14 | ![image](https://user-images.githubusercontent.com/38373305/134769121-0360bdc1-3504-432a-9869-b08d74ca562f.png) 15 | ![image](https://user-images.githubusercontent.com/38373305/134769150-068d21a5-f44f-47b6-a8cd-fd4540c5ae21.png) 16 | 17 | Test maps: [百度网盘](https://pan.baidu.com/s/1QcEAHlS8llyX-i3kX4npAA) 提取码:zust
18 | Pretrained model download:[百度网盘](https://pan.baidu.com/s/1reGFvIYX7rZjzKuaDcs-3A) 提取码:zust
19 | PS: we resize the testing data to the size of 224 * 224 for quicky evaluate, [百度网盘](https://pan.baidu.com/s/1t5cES-RAnMCLJ76s9bwzmA) 提取码:zust
20 | 21 | # Citation 22 | @ARTICLE{9424966,
23 | author={Zhou, Wujie and Zhu, Yun and Lei, Jingsheng and Wan, Jian and Yu, Lu},
24 | journal={IEEE Transactions on Multimedia},
25 | title={CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images},
26 | year={2021},
27 | doi={10.1109/TMM.2021.3077767}}
28 | 29 | # Acknowledgement 30 | The implement of this project is based on the code of ‘Cascaded Partial Decoder for Fast and Accurate Salient Object Detection, CVPR2019’and 'BBS-Net: RGB-D Salient Object Detection with a Bifurcated Backbone Strategy Network' proposed by Wu et al and Deng et al. 31 | 32 | # Contact 33 | Please drop me an email for further problems or discussion: zzzyylink@gmail.com or wujiezhou@163.com 34 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | # train/val 4 | parser.add_argument('--epoch', type=int, default=200, help='epoch number') 5 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 6 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size') 7 | parser.add_argument('--trainsize', type=int, default=224, help='training dataset size') 8 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 9 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 10 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate') 11 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 12 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 13 | parser.add_argument('--train_root', type=str, default='/home/zy/PycharmProjects/zy/newdata/train/NJUNLPR', help='the train images root') 14 | parser.add_argument('--val_root', type=str, default='/home/zy/PycharmProjects/zy/newdata/val', help='the val images root') 15 | parser.add_argument('--save_path', type=str, default='/media/zy/shuju/RGBDweight/PVTbackbone_SC2/', help='the path to save models and logs') 16 | # test(predict) 17 | parser.add_argument('--testsize', type=int, default=224, help='testing size') 18 | parser.add_argument('--test_path',type=str,default='/home/zy/PycharmProjects/zy/newdata/test/',help='test dataset path') 19 | # parser.add_argument('--test_path',type=str,default='/home/zy/PycharmProjects/zy/DUT-RGBD/test_data/',help='test dataset path') 20 | opt = parser.parse_args() 21 | -------------------------------------------------------------------------------- /rgbd_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | import torch 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label, depth): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 19 | # top bottom flip 20 | # if flip_flag2==1: 21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 24 | return img, label, depth 25 | 26 | 27 | def randomCrop(image, label, depth): 28 | border = 30 29 | image_width = image.size[0] 30 | image_height = image.size[1] 31 | crop_win_width = np.random.randint(image_width - border, image_width) 32 | crop_win_height = np.random.randint(image_height - border, image_height) 33 | random_region = ( 34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 35 | (image_height + crop_win_height) >> 1) 36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region) 37 | 38 | 39 | def randomRotation(image, label, depth): 40 | mode = Image.BICUBIC 41 | if random.random() > 0.8: 42 | random_angle = np.random.randint(-15, 15) 43 | image = image.rotate(random_angle, mode) 44 | label = label.rotate(random_angle, mode) 45 | depth = depth.rotate(random_angle, mode) 46 | return image, label, depth 47 | 48 | 49 | def colorEnhance(image): 50 | bright_intensity = random.randint(5, 15) / 10.0 51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 52 | contrast_intensity = random.randint(5, 15) / 10.0 53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 54 | color_intensity = random.randint(0, 20) / 10.0 55 | image = ImageEnhance.Color(image).enhance(color_intensity) 56 | sharp_intensity = random.randint(0, 30) / 10.0 57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 58 | return image 59 | 60 | 61 | def randomGaussian(image, mean=0.1, sigma=0.35): 62 | def gaussianNoisy(im, mean=mean, sigma=sigma): 63 | for _i in range(len(im)): 64 | im[_i] += random.gauss(mean, sigma) 65 | return im 66 | 67 | img = np.asarray(image) 68 | width, height = img.shape 69 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 70 | img = img.reshape([width, height]) 71 | return Image.fromarray(np.uint8(img)) 72 | 73 | 74 | def randomPeper(img): 75 | img = np.array(img) 76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 77 | for i in range(noiseNum): 78 | 79 | randX = random.randint(0, img.shape[0] - 1) 80 | 81 | randY = random.randint(0, img.shape[1] - 1) 82 | 83 | if random.randint(0, 1) == 0: 84 | 85 | img[randX, randY] = 0 86 | 87 | else: 88 | 89 | img[randX, randY] = 255 90 | return Image.fromarray(img) 91 | 92 | 93 | # dataset for training 94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 96 | class SalObjDataset(data.Dataset): 97 | def __init__(self, image_root, gt_root, depth_root, trainsize): 98 | self.trainsize = trainsize 99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') 100 | or f.endswith('.png')] 101 | # print(self.images) 102 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 103 | or f.endswith('.png')] 104 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 105 | or f.endswith('.png')] 106 | self.images = sorted(self.images) 107 | self.gts = sorted(self.gts) 108 | self.depths = sorted(self.depths) 109 | self.filter_files() 110 | self.size = len(self.images) 111 | self.img_transform = transforms.Compose([ 112 | transforms.Resize((self.trainsize, self.trainsize)), 113 | transforms.ToTensor(), 114 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 115 | self.gt_transform = transforms.Compose([ 116 | transforms.Resize((self.trainsize, self.trainsize)), 117 | transforms.ToTensor()]) 118 | self.depths_transform = transforms.Compose([ 119 | transforms.Resize((self.trainsize, self.trainsize)), 120 | transforms.ToTensor(), 121 | # transforms.Normalize([0.485], [0.229]) 122 | ]) 123 | 124 | def __getitem__(self, index): 125 | image = self.rgb_loader(self.images[index]) 126 | gt = self.binary_loader(self.gts[index]) 127 | depth = self.binary_loader(self.depths[index]) 128 | image, gt, depth = cv_random_flip(image, gt, depth) 129 | image, gt, depth = randomCrop(image, gt, depth) 130 | image, gt, depth = randomRotation(image, gt, depth) 131 | image = colorEnhance(image) 132 | # gt=randomGaussian(gt) 133 | gt = randomPeper(gt) 134 | # image, gt, depth = self.resize(image,gt, depth) 135 | image = self.img_transform(image) 136 | gt = self.gt_transform(gt) 137 | depth = self.depths_transform(depth) 138 | # depth = torch.div(depth.float(),255.0) # DUT 139 | 140 | return image, gt, depth 141 | 142 | def filter_files(self): 143 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 144 | # print(len(self.images),len(self.gts),len(self.depths)) 145 | images = [] 146 | gts = [] 147 | depths = [] 148 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths): 149 | img = Image.open(img_path) 150 | gt = Image.open(gt_path) 151 | depth = Image.open(depth_path) 152 | if img.size == gt.size and gt.size == depth.size: 153 | # if img.size == gt.size: 154 | images.append(img_path) 155 | gts.append(gt_path) 156 | depths.append(depth_path) 157 | self.images = images 158 | self.gts = gts 159 | self.depths = depths 160 | 161 | def rgb_loader(self, path): 162 | # print(path) 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | # print(img) 166 | return img.convert('RGB') 167 | 168 | def binary_loader(self, path): 169 | with open(path, 'rb') as f: 170 | img = Image.open(f) 171 | return img.convert('L') 172 | 173 | def resize(self, img, gt, depth): 174 | assert img.size == gt.size and gt.size == depth.size 175 | h = self.trainsize 176 | w = self.trainsize 177 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 178 | Image.NEAREST) 179 | 180 | 181 | def __len__(self): 182 | return self.size 183 | 184 | 185 | # dataloader for training 186 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True): 187 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize) 188 | data_loader = data.DataLoader(dataset=dataset, 189 | batch_size=batchsize, 190 | shuffle=shuffle, 191 | num_workers=num_workers, 192 | pin_memory=pin_memory) 193 | return data_loader 194 | 195 | 196 | # test dataset and loader 197 | class test_dataset: 198 | def __init__(self, image_root, gt_root, depth_root, testsize): 199 | self.testsize = testsize 200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 201 | 202 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 203 | or f.endswith('.png')] 204 | 205 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 206 | or f.endswith('.png')] 207 | 208 | self.images = sorted(self.images) 209 | # print(self.images) 210 | self.gts = sorted(self.gts) 211 | # print(self.gts) 212 | self.depths = sorted(self.depths) 213 | # print(self.depths) 214 | self.transform = transforms.Compose([ 215 | transforms.Resize((self.testsize, self.testsize)), 216 | transforms.ToTensor(), 217 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 218 | # self.gt_transform = transforms.ToTensor() 219 | self.gt_transform = transforms.Compose([ 220 | transforms.Resize((self.testsize, self.testsize)), 221 | transforms.ToTensor()]) 222 | self.depths_transform = transforms.Compose([ 223 | transforms.Resize((self.testsize, self.testsize)), 224 | transforms.ToTensor(), 225 | transforms.Normalize([0.485], [0.229]) 226 | ]) 227 | self.size = len(self.images) 228 | self.index = 0 229 | 230 | def load_data(self): 231 | image = self.rgb_loader(self.images[self.index]) 232 | gt = self.binary_loader(self.gts[self.index]) 233 | depth = self.binary_loader(self.depths[self.index]) 234 | # image, gt, depth = self.resize(image, gt, depth) 235 | image = self.transform(image).unsqueeze(0) 236 | gt = self.gt_transform(gt).unsqueeze(0) 237 | depth = self.depths_transform(depth) 238 | # depth = torch.div(depth.float(), 255.0) # DUT 239 | depth = depth.unsqueeze(0) 240 | name = self.images[self.index].split('/')[-1] 241 | if name.endswith('.jpg'): 242 | name = name.split('.jpg')[0] + '.png' 243 | self.index += 1 244 | self.index = self.index % self.size 245 | return image, gt, depth, name 246 | 247 | def rgb_loader(self, path): 248 | with open(path, 'rb') as f: 249 | img = Image.open(f) 250 | return img.convert('RGB') 251 | 252 | def binary_loader(self, path): 253 | with open(path, 'rb') as f: 254 | img = Image.open(f) 255 | return img.convert('L') 256 | 257 | def resize(self, img, gt, depth): 258 | # assert img.size == gt.size and gt.size == depth.size 259 | h = self.testsize 260 | w = self.testsize 261 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 262 | Image.NEAREST) 263 | 264 | def __len__(self): 265 | return self.size 266 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('./models') 5 | import numpy as np 6 | import os 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | from rgbd.rgbd_models.CCAFNet import CCAFNet 11 | from config import opt 12 | from rgbd.rgbd_dataset import test_dataset 13 | from torch.cuda import amp 14 | 15 | 16 | dataset_path = opt.test_path 17 | 18 | #set device for test 19 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 20 | print('USE GPU:', opt.gpu_id) 21 | 22 | #load the model 23 | model = CCAFNet() 24 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training. 25 | # model.load_state_dict(torch.load('/media/zy/shuju/TMMweight/TMMALLCFM/TMM_epoch_100.pth')) 26 | model.load_state_dict(torch.load('/media/zy/shuju/RGBDweight/PVTbackbone_SC/II_epoch_best.pth')) 27 | 28 | # model.load_state_dict(torch.load('/media/zy/shuju/TMMweight/vgg16plus/TMM_epoch_60.pth')) 29 | model.cuda() 30 | model.eval() 31 | 32 | #test 33 | test_mae = [] 34 | test_datasets = ['NJU2K','STERE','DES','LFSD','NLPR','SIP'] 35 | 36 | for dataset in test_datasets: 37 | mae_sum = 0 38 | save_path = '/home/zy/PycharmProjects/SOD/rgbd/rgbd_test_maps/CCAFNet/' + dataset + '/' 39 | 40 | if not os.path.exists(save_path): 41 | os.makedirs(save_path) 42 | image_root = dataset_path + dataset + '/RGB/' 43 | gt_root = dataset_path + dataset + '/GT/' 44 | depth_root=dataset_path +dataset +'/depth/' 45 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize) 46 | for i in range(test_loader.size): 47 | image, gt, depth, name = test_loader.load_data() 48 | gt = gt.cuda() 49 | image = image.cuda() 50 | # print(image.shape) 51 | n, c, h, w = image.size() 52 | depth = depth.cuda() 53 | depth = depth.view(n, h, w, 1).repeat(1, 1, 1, c) 54 | depth = depth.transpose(3, 1) 55 | depth = depth.transpose(3, 2) 56 | res = model(image, depth) 57 | predict = torch.sigmoid(res) 58 | predict = (predict - predict.min()) / (predict.max() - predict.min() + 1e-8) 59 | mae = torch.sum(torch.abs(predict - gt)) / torch.numel(gt) 60 | mae_sum = mae.item() + mae_sum 61 | predict = predict.data.cpu().numpy().squeeze() 62 | print('save img to: ', save_path + name) 63 | 64 | plt.imsave(save_path + name, arr=predict, cmap='gray') 65 | 66 | test_mae.append(mae_sum / test_loader.size) 67 | print('Test_mae:', test_mae) 68 | print('Test Done!') 69 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | 8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 9 | decay = decay_rate ** (epoch // decay_epoch) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = decay*init_lr 12 | lr=param_group['lr'] 13 | return lr 14 | --------------------------------------------------------------------------------