├── .gitignore ├── README.md ├── dataset.py ├── imgs └── BiANet_logo.png ├── models ├── BiANet_res2_50.py ├── BiANet_res50.py ├── BiANet_vgg11.py ├── BiANet_vgg16.py ├── res2net_v1b.py └── resnet_conv1.py ├── test.py ├── test.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | /param 3 | /Testset 4 | /SalMaps 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 | 5 | Logo 6 | 7 | 8 |

Bilateral Attention Network for RGB-D Salient Object Detection

9 |

10 | Published in IEEE Transactions on Image Processing (TIP) 11 |
12 | [Paper 📄] 13 | [ArXiv 🌐] 14 | 15 | [Homepage 🏠] » 16 |
17 |

18 |

19 | 20 | *** 21 | 23 | 24 | 25 | ## Prerequisites 26 | #### Environments 27 | * PyTorch >= 1.0 28 | * Ubuntu 18.04 29 | 30 | 31 | 32 | ## Usage 33 | 1. Download the [model parameters](#download) and [datasets](http://dpfan.net/d3netbenchmark/) 34 | 2. Configure `test.sh` 35 | 36 | ``` 37 | --backbones vgg16+vgg11+res50+res2_50 (Multiple items are connected with '+') 38 | --datasets dataset1+dataset2+dataset3 39 | --param_root param (pretrained model path) 40 | --input_root your_data_root (categorize by subfolders) 41 | --save_root your_output_root 42 | ``` 43 | 44 | 3. Run by 45 | ``` 46 | sh test.sh 47 | ``` 48 | ## Model parameters and prediction results 49 | | | Model parameters | Prediction results | 50 | | ---- | ---- | ---- | 51 | | **VGG-16** | [[Google Drive]](https://drive.google.com/file/d/1yfE2-4GH-QJo5JvvJbKRwXgzaRQ5e8h_/view?usp=sharing) [[Baidu Pan (bfrn)]](https://pan.baidu.com/s/1gXkDYUU0wxzM2EjyBoO6Yg) | [[Google Drive]](https://drive.google.com/file/d/1BI43wDAT9lON-8mKK6X00j-AmcZnwoZG/view?usp=sharing) [[Baidu Pan (k01w)]](https://pan.baidu.com/s/1lFPPf9LynKlBx2tOyoP_2A) | 52 | | VGG-11 | [[Google Drive]](https://drive.google.com/file/d/1TdTvZmPIbPfaX_BYI7dNTUoMI7IVXvFe/view?usp=sharing) [[Baidu Pan (2a5c)]](https://pan.baidu.com/s/1Usr-SNCPZADyISaIXPEZxA) | [[Google Drive]](https://drive.google.com/file/d/14aP1634QFjc0wQu8Unjme0lsmaJtlnFp/view?usp=sharing) [[Baidu Pan (d0t7)]](https://pan.baidu.com/s/1U-7hkmvfN8Pjj0pnC8VLGQ) | 53 | | ResNet-50 | [[Google Drive]](https://drive.google.com/file/d/13vHFAR44v2bojEJppoB058QV0Vc9-Tm7/view?usp=sharing) [[Baidu Pan (o9l2)]](https://pan.baidu.com/s/1m0p7IN4GB2BWCcoj6kM_lw) | [[Google Drive]](https://drive.google.com/file/d/1CFgXVlB-jmHArTv6kdK-CZvQ6nuEpve3/view?usp=sharing) [[Baidu Pan (dqw1)]](https://pan.baidu.com/s/1KJUy4cu4dpVfdF5Nqw2uOw) | 54 | | Res2Net-50 | [[Google Drive]](https://drive.google.com/file/d/1DppyXLs_toFi6bM5ZbGWip35BxLGfw4y/view?usp=sharing) [[Baidu Pan (k761)]](https://pan.baidu.com/s/1ycs9SI5bmIKBUbcNsrR7qQ) | [[Google Drive]](https://drive.google.com/file/d/1at-K6DfKNP2Gnao9f0v9agmzADkgt0Ik/view?usp=sharing) [[Baidu Pan (h3t9)]](https://pan.baidu.com/s/1YHVrDEl1-dCHgS2Fuc1Qzw) | 55 | 56 | ## Citation 57 | ``` 58 | @article{zhang2020bianet, 59 | title={Bilateral attention network for rgb-d salient object detection}, 60 | author={Zhang, Zhao and Lin, Zheng and Xu, Jun and Jin, Wenda and Lu, Shao-Ping and Fan, Deng-Ping}, 61 | journal={IEEE Transactions on Image Processing (TIP)}, 62 | volume={30}, 63 | pages={1949-1961}, 64 | doi={10.1109/TIP.2021.3049959}, 65 | year={2021}, 66 | } 67 | ``` 68 | 69 | ## Contact 70 | If you have any questions, feel free to contact me via `zzhang🥳mail😲nankai😲edu😲cn` 71 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils import data 7 | from torchvision import transforms 8 | from torchvision.transforms import functional as F 9 | 10 | 11 | class ImageData(data.Dataset): 12 | def __init__(self, rgb_root, dep_root, transform): 13 | 14 | self.rgb_path = list( 15 | map(lambda x: os.path.join(rgb_root, x), os.listdir(rgb_root))) 16 | self.dep_path = list( 17 | map( 18 | lambda x: os.path.join(dep_root, 19 | x.split('/')[-1][:-3] + 'png'), 20 | self.rgb_path)) 21 | 22 | self.transform = transform 23 | 24 | def __getitem__(self, item): 25 | 26 | rgb = Image.open(self.rgb_path[item]).convert('RGB') 27 | dep = Image.open(self.dep_path[item]).convert('RGB') 28 | [h, w] = dep.size 29 | imsize = [w, h] 30 | 31 | [rgb, dep] = self.transform(rgb, dep) 32 | 33 | return rgb, dep, self.rgb_path[item].split('/')[-1], imsize 34 | 35 | def __len__(self): 36 | return len(self.rgb_path) 37 | 38 | 39 | class FixedResize(object): 40 | def __init__(self, size): 41 | self.size = (size, size) # size: (h, w) 42 | 43 | def __call__(self, rgb, dep): 44 | 45 | assert rgb.size == dep.size 46 | 47 | rgb = rgb.resize(self.size, Image.BILINEAR) 48 | dep = dep.resize(self.size, Image.BILINEAR) 49 | 50 | return rgb, dep 51 | 52 | 53 | class ToTensor(object): 54 | """Convert ndarrays in sample to Tensors.""" 55 | def __call__(self, rgb, dep): 56 | 57 | return F.to_tensor(rgb), F.to_tensor(dep) 58 | 59 | 60 | class Normalize(object): 61 | """Normalize a tensor image with mean and standard deviation. 62 | Args: 63 | mean (tuple): means for each channel. 64 | std (tuple): standard deviations for each channel. 65 | """ 66 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 67 | self.mean = mean 68 | self.std = std 69 | 70 | def __call__(self, rgb, dep): 71 | 72 | dep = F.normalize(dep, self.mean, self.std) 73 | 74 | rgb = F.normalize(rgb, self.mean, self.std) 75 | 76 | return rgb, dep 77 | 78 | 79 | class RandomHorizontalFlip(object): 80 | def __init__(self, p=0.5): 81 | self.p = p 82 | 83 | def __call__(self, rgb, dep): 84 | if random.random() < self.p: 85 | rgb = rgb.transpose(Image.FLIP_LEFT_RIGHT) 86 | dep = dep.transpose(Image.FLIP_LEFT_RIGHT) 87 | 88 | return rgb, dep 89 | 90 | 91 | class Compose(object): 92 | def __init__(self, transforms): 93 | self.transforms = transforms 94 | 95 | def __call__(self, rgb, dep): 96 | for t in self.transforms: 97 | rgb, dep = t(rgb, dep) 98 | return rgb, dep 99 | 100 | def __repr__(self): 101 | format_string = self.__class__.__name__ + '(' 102 | for t in self.transforms: 103 | format_string += '\n' 104 | format_string += ' {0}'.format(t) 105 | format_string += '\n)' 106 | return format_string 107 | 108 | 109 | def get_loader(rgb_root, 110 | dep_root, 111 | img_size, 112 | batch_size=1, 113 | num_thread=1, 114 | pin=False): 115 | test_transform = Compose([ 116 | FixedResize(img_size), 117 | ToTensor(), 118 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 119 | ]) 120 | 121 | dataset = ImageData(rgb_root, dep_root, transform=test_transform) 122 | data_loader = data.DataLoader(dataset=dataset, 123 | batch_size=batch_size, 124 | shuffle=False, 125 | num_workers=num_thread, 126 | pin_memory=pin) 127 | return data_loader 128 | -------------------------------------------------------------------------------- /imgs/BiANet_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzhanghub/bianet/0d557772b944ba2847a1bf83b0ef89752b2d6f7e/imgs/BiANet_logo.png -------------------------------------------------------------------------------- /models/BiANet_res2_50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from models.res2net_v1b import res2net50_v1b 6 | 7 | 8 | # RGB Stream (VGG16) 9 | class RGB_Stream(nn.Module): 10 | def __init__(self): 11 | super(RGB_Stream, self).__init__() 12 | self.backbone = res2net50_v1b(pretrained=True) 13 | self.toplayer = nn.Sequential( 14 | nn.MaxPool2d(2, stride=2), 15 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3), 16 | nn.BatchNorm2d(32), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3), 19 | nn.BatchNorm2d(32), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | def forward(self, rgb): 24 | rgb = self.backbone.conv1(rgb) 25 | rgb = self.backbone.bn1(rgb) 26 | rgb = self.backbone.relu(rgb) 27 | rgb1 = rgb 28 | rgb = self.backbone.maxpool(rgb) 29 | rgb2 = self.backbone.layer1(rgb) 30 | rgb3 = self.backbone.layer2(rgb2) 31 | rgb4 = self.backbone.layer3(rgb3) 32 | rgb5 = self.backbone.layer4(rgb4) 33 | rgb6 = self.toplayer(rgb5) 34 | 35 | return [rgb1, rgb2, rgb3, rgb4, rgb5, rgb6] 36 | 37 | 38 | # Depth Stream (VGG16) 39 | class Dep_Stream(nn.Module): 40 | def __init__(self): 41 | super(Dep_Stream, self).__init__() 42 | self.backbone = res2net50_v1b(pretrained=True) 43 | self.toplayer = nn.Sequential( 44 | nn.MaxPool2d(2, stride=2), 45 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3), 46 | nn.BatchNorm2d(32), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3), 49 | nn.BatchNorm2d(32), 50 | nn.ReLU(inplace=True), 51 | ) 52 | 53 | def forward(self, dep): 54 | dep = self.backbone.conv1(dep) 55 | dep = self.backbone.bn1(dep) 56 | dep = self.backbone.relu(dep) 57 | dep1 = dep 58 | dep = self.backbone.maxpool(dep) 59 | dep2 = self.backbone.layer1(dep) 60 | dep3 = self.backbone.layer2(dep2) 61 | dep4 = self.backbone.layer3(dep3) 62 | dep5 = self.backbone.layer4(dep4) 63 | dep6 = self.toplayer(dep5) 64 | return [dep1, dep2, dep3, dep4, dep5, dep6] 65 | 66 | 67 | class Pred_Layer(nn.Module): 68 | def __init__(self, in_c=32): 69 | super(Pred_Layer, self).__init__() 70 | self.enlayer = nn.Sequential( 71 | nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), 72 | nn.BatchNorm2d(32), 73 | nn.ReLU(inplace=True), 74 | ) 75 | self.outlayer = nn.Sequential( 76 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), ) 77 | 78 | def forward(self, x): 79 | x = self.enlayer(x) 80 | x = self.outlayer(x) 81 | return x 82 | 83 | 84 | # BAM 85 | class BAM(nn.Module): 86 | def __init__(self, in_c): 87 | super(BAM, self).__init__() 88 | self.reduce = nn.Conv2d(in_c * 2, 32, 1) 89 | self.ff_conv = nn.Sequential( 90 | nn.Conv2d(32, 32, 3, 1, 1), 91 | nn.BatchNorm2d(32), 92 | nn.ReLU(inplace=True), 93 | ) 94 | self.bf_conv = nn.Sequential( 95 | nn.Conv2d(32, 32, 3, 1, 1), 96 | nn.BatchNorm2d(32), 97 | nn.ReLU(inplace=True), 98 | ) 99 | self.rgbd_pred_layer = Pred_Layer(32 * 2) 100 | 101 | def forward(self, rgb_feat, dep_feat, pred): 102 | feat = torch.cat((rgb_feat, dep_feat), 1) 103 | feat = self.reduce(feat) 104 | [_, _, H, W] = feat.size() 105 | pred = torch.sigmoid( 106 | F.interpolate(pred, 107 | size=(H, W), 108 | mode='bilinear', 109 | align_corners=True)) 110 | ff_feat = self.ff_conv(feat * pred) 111 | bf_feat = self.bf_conv(feat * (1 - pred)) 112 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) 113 | return new_pred 114 | 115 | 116 | # FF 117 | class FF(nn.Module): 118 | def __init__(self, in_c): 119 | super(FF, self).__init__() 120 | self.reduce = nn.Conv2d(in_c, 32, 1) 121 | self.ff_conv = nn.Sequential( 122 | nn.Conv2d(32, 32, k, 1, k // 2), 123 | nn.BatchNorm2d(32), 124 | nn.ReLU(inplace=True), 125 | ) 126 | self.rgbd_pred_layer = Pred_Layer(32) 127 | 128 | def forward(self, rgb_feat, dep_feat, pred): 129 | feat = torch.cat((rgb_feat, dep_feat), 1) 130 | [_, _, H, W] = feat.size() 131 | pred = torch.sigmoid( 132 | F.interpolate(pred, 133 | size=(H, W), 134 | mode='bilinear', 135 | align_corners=True)) 136 | ff_feat = self.ff_conv(feat * pred) 137 | new_pred = self.rgbd_pred_layer(ff_feat) 138 | return new_pred 139 | 140 | 141 | # BF 142 | class BF(nn.Module): 143 | def __init__(self, in_c): 144 | super(BF, self).__init__() 145 | self.reduce = nn.Conv2d(in_c * 2, 32, 1) 146 | self.bf_conv = nn.Sequential( 147 | nn.Conv2d(32, 32, 3, 1, 1), 148 | nn.BatchNorm2d(32), 149 | nn.ReLU(inplace=True), 150 | ) 151 | self.rgbd_pred_layer = Pred_Layer(32) 152 | 153 | def forward(self, rgb_feat, dep_feat, pred): 154 | feat = torch.cat((rgb_feat, dep_feat), 1) 155 | [_, _, H, W] = feat.size() 156 | pred = torch.sigmoid( 157 | F.interpolate(pred, 158 | size=(H, W), 159 | mode='bilinear', 160 | align_corners=True)) 161 | bf_feat = self.bf_conv(feat * (1 - pred)) 162 | new_pred = self.rgbd_pred_layer(bf_feat) 163 | return new_pred 164 | 165 | 166 | # ASPP for MBAM 167 | class ASPP(nn.Module): 168 | def __init__(self, in_c): 169 | super(ASPP, self).__init__() 170 | 171 | self.aspp1 = nn.Sequential( 172 | nn.Conv2d(in_c * 2, 32, 1, 1), 173 | nn.BatchNorm2d(32), 174 | nn.ReLU(inplace=True), 175 | ) 176 | self.aspp2 = nn.Sequential( 177 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=3, dilation=3), 178 | nn.BatchNorm2d(32), 179 | nn.ReLU(inplace=True), 180 | ) 181 | 182 | self.aspp3 = nn.Sequential( 183 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=5, dilation=5), 184 | nn.BatchNorm2d(32), 185 | nn.ReLU(inplace=True), 186 | ) 187 | self.aspp4 = nn.Sequential( 188 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=7, dilation=7), 189 | nn.BatchNorm2d(32), 190 | nn.ReLU(inplace=True), 191 | ) 192 | 193 | def forward(self, x): 194 | x1 = self.aspp1(x) 195 | x2 = self.aspp2(x) 196 | x3 = self.aspp3(x) 197 | x4 = self.aspp4(x) 198 | x = torch.cat((x1, x2, x3, x4), dim=1) 199 | 200 | return x 201 | 202 | 203 | # MBAM 204 | class MBAM(nn.Module): 205 | def __init__(self, in_c): 206 | super(MBAM, self).__init__() 207 | self.ff_conv = ASPP(in_c) 208 | self.bf_conv = ASPP(in_c) 209 | self.rgbd_pred_layer = Pred_Layer(32 * 8) 210 | 211 | def forward(self, rgb_feat, dep_feat, pred): 212 | feat = torch.cat((rgb_feat, dep_feat), 1) 213 | [_, _, H, W] = feat.size() 214 | pred = torch.sigmoid( 215 | F.interpolate(pred, 216 | size=(H, W), 217 | mode='bilinear', 218 | align_corners=True)) 219 | 220 | ff_feat = self.ff_conv(feat * pred) 221 | bf_feat = self.bf_conv(feat * (1 - pred)) 222 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) 223 | return new_pred 224 | 225 | 226 | class BiANet(nn.Module): 227 | def __init__(self): 228 | super(BiANet, self).__init__() 229 | 230 | # two-streams 231 | self.rgb_stream = RGB_Stream() 232 | self.dep_stream = Dep_Stream() 233 | 234 | # Global Pred 235 | self.rgb_global = Pred_Layer(32) 236 | self.dep_global = Pred_Layer(32) 237 | self.rgbd_global = Pred_Layer(32 * 2) 238 | 239 | # Shor-Conection 240 | self.bams = nn.ModuleList([ 241 | BAM(64), 242 | BAM(256), 243 | MBAM(512), 244 | MBAM(1024), 245 | MBAM(2048), 246 | ]) 247 | 248 | def _upsample_add(self, x, y): 249 | [_, _, H, W] = y.size() 250 | return F.interpolate( 251 | x, size=(H, W), mode='bilinear', align_corners=True) + y 252 | 253 | def forward(self, rgb, dep): 254 | [_, _, H, W] = rgb.size() 255 | rgb_feats = self.rgb_stream(rgb) 256 | dep_feats = self.dep_stream(dep) 257 | 258 | # Gloabl Prediction 259 | rgb_pred = self.rgb_global(rgb_feats[5]) 260 | dep_pred = self.dep_global(dep_feats[5]) 261 | rgbd_pred = self.rgbd_global(torch.cat((rgb_feats[5], dep_feats[5]), 262 | 1)) 263 | preds = [ 264 | torch.sigmoid( 265 | F.interpolate(rgb_pred, 266 | size=(H, W), 267 | mode='bilinear', 268 | align_corners=True)), 269 | torch.sigmoid( 270 | F.interpolate(dep_pred, 271 | size=(H, W), 272 | mode='bilinear', 273 | align_corners=True)), 274 | torch.sigmoid( 275 | F.interpolate(rgbd_pred, 276 | size=(H, W), 277 | mode='bilinear', 278 | align_corners=True)), 279 | ] 280 | 281 | p = rgbd_pred 282 | for idx in [4, 3, 2, 1, 0]: 283 | _p = self.bams[idx](rgb_feats[idx], dep_feats[idx], p) 284 | p = self._upsample_add(p, _p) 285 | preds.append( 286 | torch.sigmoid( 287 | F.interpolate(p, 288 | size=(H, W), 289 | mode='bilinear', 290 | align_corners=True))) 291 | return preds 292 | -------------------------------------------------------------------------------- /models/BiANet_res50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from models.resnet_conv1 import resnet50 6 | 7 | 8 | # RGB Stream (VGG16) 9 | class RGB_Stream(nn.Module): 10 | def __init__(self): 11 | super(RGB_Stream, self).__init__() 12 | self.backbone = resnet50(pretrained=True) 13 | self.toplayer = nn.Sequential( 14 | nn.MaxPool2d(2, stride=2), 15 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3), 16 | nn.BatchNorm2d(32), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3), 19 | nn.BatchNorm2d(32), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | def forward(self, rgb): 24 | rgb = self.backbone.relu1(self.backbone.bn1(self.backbone.conv1(rgb))) 25 | rgb = self.backbone.relu2(self.backbone.bn2(self.backbone.conv2(rgb))) 26 | rgb = self.backbone.relu3(self.backbone.bn3(self.backbone.conv3(rgb))) 27 | rgb1 = rgb 28 | rgb = self.backbone.maxpool(rgb) 29 | rgb2 = self.backbone.layer1(rgb) 30 | rgb3 = self.backbone.layer2(rgb2) 31 | rgb4 = self.backbone.layer3(rgb3) 32 | rgb5 = self.backbone.layer4(rgb4) 33 | rgb6 = self.toplayer(rgb5) 34 | 35 | return [rgb1, rgb2, rgb3, rgb4, rgb5, rgb6] 36 | 37 | 38 | # Depth Stream (VGG16) 39 | class Dep_Stream(nn.Module): 40 | def __init__(self): 41 | super(Dep_Stream, self).__init__() 42 | self.backbone = resnet50(pretrained=True) 43 | self.toplayer = nn.Sequential( 44 | nn.MaxPool2d(2, stride=2), 45 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3), 46 | nn.BatchNorm2d(32), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3), 49 | nn.BatchNorm2d(32), 50 | nn.ReLU(inplace=True), 51 | ) 52 | 53 | def forward(self, dep): 54 | dep = self.backbone.relu1(self.backbone.bn1(self.backbone.conv1(dep))) 55 | dep = self.backbone.relu2(self.backbone.bn2(self.backbone.conv2(dep))) 56 | dep = self.backbone.relu3(self.backbone.bn3(self.backbone.conv3(dep))) 57 | dep1 = dep 58 | dep = self.backbone.maxpool(dep) 59 | dep2 = self.backbone.layer1(dep) 60 | dep3 = self.backbone.layer2(dep2) 61 | dep4 = self.backbone.layer3(dep3) 62 | dep5 = self.backbone.layer4(dep4) 63 | dep6 = self.toplayer(dep5) 64 | return [dep1, dep2, dep3, dep4, dep5, dep6] 65 | 66 | 67 | class Pred_Layer(nn.Module): 68 | def __init__(self, in_c=32): 69 | super(Pred_Layer, self).__init__() 70 | self.enlayer = nn.Sequential( 71 | nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1), 72 | nn.BatchNorm2d(32), 73 | nn.ReLU(inplace=True), 74 | ) 75 | self.outlayer = nn.Sequential( 76 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), ) 77 | 78 | def forward(self, x): 79 | x = self.enlayer(x) 80 | x = self.outlayer(x) 81 | return x 82 | 83 | 84 | # BAM 85 | class BAM(nn.Module): 86 | def __init__(self, in_c): 87 | super(BAM, self).__init__() 88 | self.reduce = nn.Conv2d(in_c * 2, 32, 1) 89 | self.ff_conv = nn.Sequential( 90 | nn.Conv2d(32, 32, 3, 1, 1), 91 | nn.BatchNorm2d(32), 92 | nn.ReLU(inplace=True), 93 | ) 94 | self.bf_conv = nn.Sequential( 95 | nn.Conv2d(32, 32, 3, 1, 1), 96 | nn.BatchNorm2d(32), 97 | nn.ReLU(inplace=True), 98 | ) 99 | self.rgbd_pred_layer = Pred_Layer(32 * 2) 100 | 101 | def forward(self, rgb_feat, dep_feat, pred): 102 | feat = torch.cat((rgb_feat, dep_feat), 1) 103 | feat = self.reduce(feat) 104 | [_, _, H, W] = feat.size() 105 | pred = torch.sigmoid( 106 | F.interpolate(pred, 107 | size=(H, W), 108 | mode='bilinear', 109 | align_corners=True)) 110 | ff_feat = self.ff_conv(feat * pred) 111 | bf_feat = self.bf_conv(feat * (1 - pred)) 112 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) 113 | return new_pred 114 | 115 | 116 | # FF 117 | class FF(nn.Module): 118 | def __init__(self, in_c): 119 | super(FF, self).__init__() 120 | self.reduce = nn.Conv2d(in_c, 32, 1) 121 | self.ff_conv = nn.Sequential( 122 | nn.Conv2d(32, 32, k, 1, k // 2), 123 | nn.BatchNorm2d(32), 124 | nn.ReLU(inplace=True), 125 | ) 126 | self.rgbd_pred_layer = Pred_Layer(32) 127 | 128 | def forward(self, rgb_feat, dep_feat, pred): 129 | feat = torch.cat((rgb_feat, dep_feat), 1) 130 | [_, _, H, W] = feat.size() 131 | pred = torch.sigmoid( 132 | F.interpolate(pred, 133 | size=(H, W), 134 | mode='bilinear', 135 | align_corners=True)) 136 | ff_feat = self.ff_conv(feat * pred) 137 | new_pred = self.rgbd_pred_layer(ff_feat) 138 | return new_pred 139 | 140 | 141 | # BF 142 | class BF(nn.Module): 143 | def __init__(self, in_c): 144 | super(BF, self).__init__() 145 | self.reduce = nn.Conv2d(in_c * 2, 32, 1) 146 | self.bf_conv = nn.Sequential( 147 | nn.Conv2d(32, 32, 3, 1, 1), 148 | nn.BatchNorm2d(32), 149 | nn.ReLU(inplace=True), 150 | ) 151 | self.rgbd_pred_layer = Pred_Layer(32) 152 | 153 | def forward(self, rgb_feat, dep_feat, pred): 154 | feat = torch.cat((rgb_feat, dep_feat), 1) 155 | [_, _, H, W] = feat.size() 156 | pred = torch.sigmoid( 157 | F.interpolate(pred, 158 | size=(H, W), 159 | mode='bilinear', 160 | align_corners=True)) 161 | bf_feat = self.bf_conv(feat * (1 - pred)) 162 | new_pred = self.rgbd_pred_layer(bf_feat) 163 | return new_pred 164 | 165 | 166 | # ASPP for MBAM 167 | class ASPP(nn.Module): 168 | def __init__(self, in_c): 169 | super(ASPP, self).__init__() 170 | 171 | self.aspp1 = nn.Sequential( 172 | nn.Conv2d(in_c * 2, 32, 1, 1), 173 | nn.BatchNorm2d(32), 174 | nn.ReLU(inplace=True), 175 | ) 176 | self.aspp2 = nn.Sequential( 177 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=3, dilation=3), 178 | nn.BatchNorm2d(32), 179 | nn.ReLU(inplace=True), 180 | ) 181 | 182 | self.aspp3 = nn.Sequential( 183 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=5, dilation=5), 184 | nn.BatchNorm2d(32), 185 | nn.ReLU(inplace=True), 186 | ) 187 | self.aspp4 = nn.Sequential( 188 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=7, dilation=7), 189 | nn.BatchNorm2d(32), 190 | nn.ReLU(inplace=True), 191 | ) 192 | 193 | def forward(self, x): 194 | x1 = self.aspp1(x) 195 | x2 = self.aspp2(x) 196 | x3 = self.aspp3(x) 197 | x4 = self.aspp4(x) 198 | x = torch.cat((x1, x2, x3, x4), dim=1) 199 | 200 | return x 201 | 202 | 203 | # MBAM 204 | class MBAM(nn.Module): 205 | def __init__(self, in_c): 206 | super(MBAM, self).__init__() 207 | self.ff_conv = ASPP(in_c) 208 | self.bf_conv = ASPP(in_c) 209 | self.rgbd_pred_layer = Pred_Layer(32 * 8) 210 | 211 | def forward(self, rgb_feat, dep_feat, pred): 212 | feat = torch.cat((rgb_feat, dep_feat), 1) 213 | [_, _, H, W] = feat.size() 214 | pred = torch.sigmoid( 215 | F.interpolate(pred, 216 | size=(H, W), 217 | mode='bilinear', 218 | align_corners=True)) 219 | 220 | ff_feat = self.ff_conv(feat * pred) 221 | bf_feat = self.bf_conv(feat * (1 - pred)) 222 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1)) 223 | return new_pred 224 | 225 | 226 | class BiANet(nn.Module): 227 | def __init__(self): 228 | super(BiANet, self).__init__() 229 | 230 | # two-streams 231 | self.rgb_stream = RGB_Stream() 232 | self.dep_stream = Dep_Stream() 233 | 234 | # Global Pred 235 | self.rgb_global = Pred_Layer(32) 236 | self.dep_global = Pred_Layer(32) 237 | self.rgbd_global = Pred_Layer(32 * 2) 238 | 239 | # Shor-Conection 240 | self.bams = nn.ModuleList([ 241 | BAM(128), 242 | BAM(256), 243 | MBAM(512), 244 | MBAM(1024), 245 | MBAM(2048), 246 | ]) 247 | 248 | def _upsample_add(self, x, y): 249 | [_, _, H, W] = y.size() 250 | return F.interpolate( 251 | x, size=(H, W), mode='bilinear', align_corners=True) + y 252 | 253 | def forward(self, rgb, dep): 254 | [_, _, H, W] = rgb.size() 255 | rgb_feats = self.rgb_stream(rgb) 256 | dep_feats = self.dep_stream(dep) 257 | 258 | # Gloabl Prediction 259 | rgb_pred = self.rgb_global(rgb_feats[5]) 260 | dep_pred = self.dep_global(dep_feats[5]) 261 | rgbd_pred = self.rgbd_global(torch.cat((rgb_feats[5], dep_feats[5]), 262 | 1)) 263 | preds = [ 264 | torch.sigmoid( 265 | F.interpolate(rgb_pred, 266 | size=(H, W), 267 | mode='bilinear', 268 | align_corners=True)), 269 | torch.sigmoid( 270 | F.interpolate(dep_pred, 271 | size=(H, W), 272 | mode='bilinear', 273 | align_corners=True)), 274 | torch.sigmoid( 275 | F.interpolate(rgbd_pred, 276 | size=(H, W), 277 | mode='bilinear', 278 | align_corners=True)), 279 | ] 280 | 281 | p = rgbd_pred 282 | for idx in [4, 3, 2, 1, 0]: 283 | _p = self.bams[idx](rgb_feats[idx], dep_feats[idx], p) 284 | p = self._upsample_add(p, _p) 285 | preds.append( 286 | torch.sigmoid( 287 | F.interpolate(p, 288 | size=(H, W), 289 | mode='bilinear', 290 | align_corners=True))) 291 | return preds 292 | -------------------------------------------------------------------------------- /models/BiANet_vgg11.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | backbone = { 7 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'] 8 | } 9 | 10 | 11 | # VGG16 12 | def vgg(cfg, i=3, batch_norm=False): 13 | layers = [] 14 | in_channels = i 15 | for v in cfg: 16 | if v == 'M': 17 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 18 | else: 19 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 20 | if batch_norm: 21 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 22 | else: 23 | layers += [conv2d, nn.ReLU(inplace=True)] 24 | in_channels = v 25 | return layers 26 | 27 | 28 | # VGG16 with Side Outputs 29 | class VGG_Sout(nn.Module): 30 | def __init__(self, extract=[1, 4, 9, 14, 19]): 31 | super(VGG_Sout, self).__init__() 32 | self.vgg = nn.ModuleList(vgg(cfg=backbone['vgg11'])) 33 | self.extract = extract 34 | 35 | def forward(self, x): 36 | souts = [] 37 | for idx in range(len(self.vgg)): 38 | x = self.vgg[idx](x) 39 | if idx in self.extract: 40 | souts.append(x) 41 | 42 | return souts, x 43 | 44 | 45 | # Global Sliency (A new block following VGG-16 for predict global saliency map) 46 | class GSLayer(nn.Module): 47 | def __init__(self, in_channel, channel, k): 48 | super(GSLayer, self).__init__() 49 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1) 50 | self.convs = nn.Sequential(nn.Conv2d(channel, channel, k, 1, k // 2), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(channel, channel, k, 1, k // 2), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(channel, channel, k, 1, k // 2), 55 | nn.ReLU(inplace=True)) 56 | self.out_layer = nn.Conv2d(channel, 1, 1) 57 | 58 | def forward(self, x): 59 | x = self.conv1x1(x) 60 | x = self.convs(x) 61 | out = self.out_layer(x) 62 | return out 63 | 64 | 65 | # Original Attention 66 | class OriAtt(nn.Module): 67 | def __init__(self): 68 | super(OriAtt, self).__init__() 69 | 70 | def forward(self, sout, pred): 71 | return sout.mul(torch.sigmoid(pred)) 72 | 73 | 74 | # Reverse Attention 75 | class RevAtt(nn.Module): 76 | def __init__(self): 77 | super(RevAtt, self).__init__() 78 | 79 | def forward(self, sout, pred): 80 | return sout.mul(1 - torch.sigmoid(pred)) 81 | 82 | 83 | # ASPP block 84 | class ASPP(nn.Module): 85 | def __init__(self, in_channel, channel): 86 | super(ASPP, self).__init__() 87 | 88 | self.aspp1 = nn.Sequential( 89 | nn.Conv2d(in_channel, channel, 1, 1), 90 | nn.ReLU(inplace=True), 91 | ) 92 | self.aspp2 = nn.Sequential( 93 | nn.Conv2d(in_channel, channel, 3, 1, padding=3, dilation=3), 94 | nn.ReLU(inplace=True), 95 | ) 96 | 97 | self.aspp3 = nn.Sequential( 98 | nn.Conv2d(in_channel, channel, 3, 1, padding=5, dilation=5), 99 | nn.ReLU(inplace=True), 100 | ) 101 | self.aspp4 = nn.Sequential( 102 | nn.Conv2d(in_channel, channel, 3, 1, padding=7, dilation=7), 103 | nn.ReLU(inplace=True), 104 | ) 105 | 106 | def forward(self, x): 107 | x1 = self.aspp1(x) 108 | x2 = self.aspp2(x) 109 | x3 = self.aspp3(x) 110 | x4 = self.aspp4(x) 111 | # x4 = F.interpolate(x4, size=x3.size()[2:], mode='bilinear', align_corners=True) 112 | x = torch.cat((x1, x2, x3, x4), dim=1) 113 | 114 | return x 115 | 116 | 117 | # Output residual (Dual-stream Attention) 118 | class ResiLayer(nn.Module): 119 | def __init__(self, in_channel, channel, k): 120 | super(ResiLayer, self).__init__() 121 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1) 122 | 123 | self.rev_att = RevAtt() 124 | self.rev_conv = nn.Sequential( 125 | nn.Conv2d(channel, channel, k, 1, k // 2), 126 | # nn.BatchNorm2d(channel), 127 | nn.ReLU(inplace=True), 128 | ) 129 | 130 | self.ori_att = OriAtt() 131 | self.ori_conv = nn.Sequential( 132 | nn.Conv2d(channel, channel, k, 1, k // 2), 133 | # nn.BatchNorm2d(channel), 134 | nn.ReLU(inplace=True), 135 | ) 136 | 137 | self.out_layer = nn.Sequential( 138 | nn.Conv2d(channel * 2, channel, k, 1, k // 2), 139 | nn.ReLU(inplace=True), 140 | nn.Conv2d(channel, 1, 3, 1, 1), 141 | ) 142 | 143 | def forward(self, sout, pred): 144 | sout = self.conv1x1(sout) 145 | 146 | sout_rev = self.rev_att(sout, pred) 147 | sout_rev = self.rev_conv(sout_rev) 148 | 149 | sout_ori = self.ori_att(sout, pred) 150 | sout_ori = self.ori_conv(sout_ori) 151 | 152 | return self.out_layer(torch.cat((sout_ori, sout_rev), 1)) 153 | 154 | 155 | # Multi-Scaled Attention Residual Prediction 156 | class PResiLayer(nn.Module): 157 | def __init__(self, in_channel, channel, k): 158 | super(PResiLayer, self).__init__() 159 | # self.conv1x1 = nn.Conv2d(in_channel, channel, 1) 160 | 161 | self.rev_att = RevAtt() 162 | self.ori_att = OriAtt() 163 | 164 | self.ori_aspp = ASPP(in_channel, channel) 165 | self.rev_aspp = ASPP(in_channel, channel) 166 | 167 | self.out_layer = nn.Sequential( 168 | nn.Conv2d(channel * 8, channel, k, 1, k // 2), 169 | nn.ReLU(inplace=True), 170 | # nn.Dropout(0.5), 171 | nn.Conv2d(channel, 1, 3, 1, 1), 172 | ) 173 | 174 | def forward(self, sout, pred): 175 | # sout = self.conv1x1(sout) 176 | 177 | sout_rev = self.rev_att(sout, pred) 178 | sout_rev = self.rev_aspp(sout_rev) 179 | 180 | sout_ori = self.ori_att(sout, pred) 181 | sout_ori = self.ori_aspp(sout_ori) 182 | 183 | sout_cat = torch.cat((sout_ori, sout_rev), 1) 184 | 185 | return self.out_layer(sout_cat) 186 | 187 | 188 | # Top-Down Stream for dual att 189 | class TDLayer(nn.Module): 190 | def __init__(self, in_channel, channel, k): 191 | super(TDLayer, self).__init__() 192 | self.resi_layer = ResiLayer(in_channel, channel, k) 193 | 194 | def forward(self, sout, pred): 195 | pred = nn.functional.interpolate(pred, 196 | size=sout.size()[2:], 197 | mode='bilinear', 198 | align_corners=True) 199 | residual = self.resi_layer(sout, pred) 200 | return pred + residual 201 | 202 | 203 | # Top-Down Stream for Multi-scaled Bi att 204 | class PTDLayer(nn.Module): 205 | def __init__(self, in_channel, channel, k): 206 | super(PTDLayer, self).__init__() 207 | self.resi_layer = PResiLayer(in_channel, channel, k) 208 | 209 | def forward(self, sout, pred): 210 | pred = nn.functional.interpolate(pred, 211 | size=sout.size()[2:], 212 | mode='bilinear', 213 | align_corners=True) 214 | residual = self.resi_layer(sout, pred) 215 | return pred + residual 216 | 217 | 218 | # CANet Modele 219 | class BiANet(nn.Module): 220 | def __init__(self): 221 | super(BiANet, self).__init__() 222 | self.rgb_sout = VGG_Sout() 223 | self.rgb_gs = GSLayer(512, 256, k=5) 224 | 225 | self.dep_sout = VGG_Sout() 226 | self.dep_gs = GSLayer(512, 256, k=5) 227 | 228 | self.rgbd_gs = GSLayer(1024, 256, k=5) 229 | 230 | self.td_layers = nn.ModuleList([ 231 | PTDLayer(1024, 32, 3), 232 | PTDLayer(1024, 32, 3), 233 | PTDLayer(512, 32, 3), 234 | TDLayer(256, 32, 3), 235 | TDLayer(128, 32, 3), 236 | ]) 237 | 238 | def forward(self, rgb, dep): 239 | [_, _, h, w] = rgb.size() 240 | 241 | rgb_souts, rgb_x = self.rgb_sout(rgb) 242 | dep_souts, dep_x = self.dep_sout(dep) 243 | 244 | rgb_pred = self.rgb_gs(rgb_x) # global saliency 245 | dep_pred = self.dep_gs(dep_x) # global saliency 246 | 247 | rgbd_souts = [] # cat rgb_souts and dep_souts 248 | for idx in range(len(rgb_souts)): 249 | rgbd_souts.append(torch.cat((rgb_souts[idx], dep_souts[idx]), 1)) 250 | 251 | rgbd_preds = [] 252 | rgbd_preds.append(self.rgbd_gs(torch.cat((rgb_x, dep_x), 253 | 1))) # global saliency 254 | 255 | for idx in range(len(rgbd_souts)): 256 | rgbd_preds.append(self.td_layers[idx](rgbd_souts[-(idx + 1)], 257 | rgbd_preds[idx])) 258 | 259 | scaled_preds = [] 260 | scaled_preds.append( 261 | torch.sigmoid( 262 | nn.functional.interpolate(rgb_pred, 263 | size=(h, w), 264 | mode='bilinear', 265 | align_corners=True))) 266 | scaled_preds.append( 267 | torch.sigmoid( 268 | nn.functional.interpolate(dep_pred, 269 | size=(h, w), 270 | mode='bilinear', 271 | align_corners=True))) 272 | 273 | for idx in range(len(rgbd_preds) - 1): 274 | scaled_preds.append( 275 | torch.sigmoid( 276 | nn.functional.interpolate(rgbd_preds[idx], 277 | size=(h, w), 278 | mode='bilinear', 279 | align_corners=True))) 280 | scaled_preds.append(torch.sigmoid(rgbd_preds[-1])) 281 | 282 | # rgb_gs, dep_gs, rgbd(from top to down), final pred is scaled_preds[-1] 283 | return scaled_preds 284 | 285 | 286 | # weight init 287 | def xavier(param): 288 | init.xavier_uniform_(param) 289 | 290 | 291 | def weights_init(m): 292 | if isinstance(m, nn.Conv2d): 293 | xavier(m.weight.data) 294 | elif isinstance(m, nn.BatchNorm2d): 295 | init.constant_(m.weight, 1) 296 | init.constant_(m.bias, 0) 297 | -------------------------------------------------------------------------------- /models/BiANet_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | backbone = { 7 | 'vgg16': [ 8 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 9 | 512, 512, 512, 'M' 10 | ] 11 | } 12 | 13 | 14 | # VGG16 backbone 15 | def vgg(cfg, i=3, batch_norm=False): 16 | layers = [] 17 | in_channels = i 18 | for v in cfg: 19 | if v == 'M': 20 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 21 | else: 22 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 23 | if batch_norm: 24 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 25 | else: 26 | layers += [conv2d, nn.ReLU(inplace=True)] 27 | in_channels = v 28 | return layers 29 | 30 | 31 | # VGG16 with Side Outputs 32 | class VGG_Sout(nn.Module): 33 | def __init__(self, extract=[3, 8, 15, 22, 29]): 34 | super(VGG_Sout, self).__init__() 35 | self.vgg = nn.ModuleList(vgg(cfg=backbone['vgg16'])) 36 | self.extract = extract 37 | 38 | def forward(self, x): 39 | souts = [] 40 | for idx in range(len(self.vgg)): 41 | x = self.vgg[idx](x) 42 | if idx in self.extract: 43 | souts.append(x) 44 | 45 | return souts, x 46 | 47 | 48 | # Global Sliency (A new block following VGG-16 for predict global saliency map) 49 | class GSLayer(nn.Module): 50 | def __init__(self, in_channel, channel, k): 51 | super(GSLayer, self).__init__() 52 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1) 53 | self.convs = nn.Sequential(nn.Conv2d(channel, channel, k, 1, k // 2), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(channel, channel, k, 1, k // 2), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(channel, channel, k, 1, k // 2), 58 | nn.ReLU(inplace=True)) 59 | self.out_layer = nn.Conv2d(channel, 1, 1) 60 | 61 | def forward(self, x): 62 | x = self.conv1x1(x) 63 | x = self.convs(x) 64 | out = self.out_layer(x) 65 | return out 66 | 67 | 68 | # Foreground Attention 69 | class OriAtt(nn.Module): 70 | def __init__(self): 71 | super(OriAtt, self).__init__() 72 | 73 | def forward(self, sout, pred): 74 | return sout.mul(torch.sigmoid(pred)) 75 | 76 | 77 | # Background Attention 78 | class RevAtt(nn.Module): 79 | def __init__(self): 80 | super(RevAtt, self).__init__() 81 | 82 | def forward(self, sout, pred): 83 | return sout.mul(1 - torch.sigmoid(pred)) 84 | 85 | 86 | # MBAM 87 | class ASPP(nn.Module): 88 | def __init__(self, in_channel, channel): 89 | super(ASPP, self).__init__() 90 | 91 | self.aspp1 = nn.Sequential( 92 | nn.Conv2d(in_channel, channel, 1, 1), 93 | nn.ReLU(inplace=True), 94 | ) 95 | self.aspp2 = nn.Sequential( 96 | nn.Conv2d(in_channel, channel, 3, 1, padding=3, dilation=3), 97 | nn.ReLU(inplace=True), 98 | ) 99 | 100 | self.aspp3 = nn.Sequential( 101 | nn.Conv2d(in_channel, channel, 3, 1, padding=5, dilation=5), 102 | nn.ReLU(inplace=True), 103 | ) 104 | self.aspp4 = nn.Sequential( 105 | nn.Conv2d(in_channel, channel, 3, 1, padding=7, dilation=7), 106 | nn.ReLU(inplace=True), 107 | ) 108 | 109 | def forward(self, x): 110 | x1 = self.aspp1(x) 111 | x2 = self.aspp2(x) 112 | x3 = self.aspp3(x) 113 | x4 = self.aspp4(x) 114 | x = torch.cat((x1, x2, x3, x4), dim=1) 115 | 116 | return x 117 | 118 | 119 | # Output residual 120 | class ResiLayer(nn.Module): 121 | def __init__(self, in_channel, channel, k): 122 | super(ResiLayer, self).__init__() 123 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1) 124 | 125 | self.rev_att = RevAtt() 126 | self.rev_conv = nn.Sequential( 127 | nn.Conv2d(channel, channel, k, 1, k // 2), 128 | nn.ReLU(inplace=True), 129 | ) 130 | 131 | self.ori_att = OriAtt() 132 | self.ori_conv = nn.Sequential( 133 | nn.Conv2d(channel, channel, k, 1, k // 2), 134 | nn.ReLU(inplace=True), 135 | ) 136 | 137 | self.out_layer = nn.Sequential( 138 | nn.Conv2d(channel * 2, channel, k, 1, k // 2), 139 | nn.ReLU(inplace=True), 140 | nn.Conv2d(channel, 1, 3, 1, 1), 141 | ) 142 | 143 | def forward(self, sout, pred): 144 | sout = self.conv1x1(sout) 145 | 146 | sout_rev = self.rev_att(sout, pred) 147 | sout_rev = self.rev_conv(sout_rev) 148 | 149 | sout_ori = self.ori_att(sout, pred) 150 | sout_ori = self.ori_conv(sout_ori) 151 | 152 | return self.out_layer(torch.cat((sout_ori, sout_rev), 1)) 153 | 154 | 155 | # Multi-Scaled Residual 156 | class PResiLayer(nn.Module): 157 | def __init__(self, in_channel, channel, k): 158 | super(PResiLayer, self).__init__() 159 | self.rev_att = RevAtt() 160 | self.ori_att = OriAtt() 161 | 162 | self.ori_aspp = ASPP(in_channel, channel) 163 | self.rev_aspp = ASPP(in_channel, channel) 164 | 165 | self.out_layer = nn.Sequential( 166 | nn.Conv2d(channel * 8, channel, k, 1, k // 2), 167 | nn.ReLU(inplace=True), 168 | # nn.Dropout(0.5), 169 | nn.Conv2d(channel, 1, 3, 1, 1), 170 | ) 171 | 172 | def forward(self, sout, pred): 173 | sout_rev = self.rev_att(sout, pred) 174 | sout_rev = self.rev_aspp(sout_rev) 175 | 176 | sout_ori = self.ori_att(sout, pred) 177 | sout_ori = self.ori_aspp(sout_ori) 178 | 179 | sout_cat = torch.cat((sout_ori, sout_rev), 1) 180 | 181 | return self.out_layer(sout_cat) 182 | 183 | 184 | # Top-Down Stream 185 | class TDLayer(nn.Module): 186 | def __init__(self, in_channel, channel, k): 187 | super(TDLayer, self).__init__() 188 | self.resi_layer = ResiLayer(in_channel, channel, k) 189 | 190 | def forward(self, sout, pred): 191 | pred = nn.functional.interpolate(pred, 192 | size=sout.size()[2:], 193 | mode='bilinear', 194 | align_corners=True) 195 | residual = self.resi_layer(sout, pred) 196 | return pred + residual 197 | 198 | 199 | # Top-Down Stream with MBAM 200 | class PTDLayer(nn.Module): 201 | def __init__(self, in_channel, channel, k): 202 | super(PTDLayer, self).__init__() 203 | self.resi_layer = PResiLayer(in_channel, channel, k) 204 | 205 | def forward(self, sout, pred): 206 | pred = nn.functional.interpolate(pred, 207 | size=sout.size()[2:], 208 | mode='bilinear', 209 | align_corners=True) 210 | residual = self.resi_layer(sout, pred) 211 | return pred + residual 212 | 213 | 214 | class BiANet(nn.Module): 215 | def __init__(self): 216 | super(BiANet, self).__init__() 217 | self.rgb_sout = VGG_Sout() 218 | self.rgb_gs = GSLayer(512, 256, k=5) 219 | 220 | self.dep_sout = VGG_Sout() 221 | self.dep_gs = GSLayer(512, 256, k=5) 222 | 223 | self.rgbd_gs = GSLayer(1024, 256, k=5) 224 | 225 | self.td_layers = nn.ModuleList([ 226 | PTDLayer(1024, 32, 3), 227 | PTDLayer(1024, 32, 3), 228 | PTDLayer(512, 32, 3), 229 | TDLayer(256, 32, 3), 230 | TDLayer(128, 32, 3), 231 | ]) 232 | 233 | def forward(self, rgb, dep): 234 | [_, _, h, w] = rgb.size() 235 | 236 | rgb_souts, rgb_x = self.rgb_sout(rgb) 237 | dep_souts, dep_x = self.dep_sout(dep) 238 | 239 | rgb_pred = self.rgb_gs(rgb_x) # global saliency 240 | dep_pred = self.dep_gs(dep_x) # global saliency 241 | 242 | rgbd_souts = [] # cat rgb_souts and dep_souts 243 | for idx in range(len(rgb_souts)): 244 | rgbd_souts.append(torch.cat((rgb_souts[idx], dep_souts[idx]), 1)) 245 | 246 | rgbd_preds = [] 247 | rgbd_preds.append(self.rgbd_gs(torch.cat((rgb_x, dep_x), 248 | 1))) # global saliency 249 | 250 | for idx in range(len(rgbd_souts)): 251 | rgbd_preds.append(self.td_layers[idx](rgbd_souts[-(idx + 1)], 252 | rgbd_preds[idx])) 253 | 254 | scaled_preds = [] 255 | scaled_preds.append( 256 | torch.sigmoid( 257 | nn.functional.interpolate(rgb_pred, 258 | size=(h, w), 259 | mode='bilinear', 260 | align_corners=True))) 261 | scaled_preds.append( 262 | torch.sigmoid( 263 | nn.functional.interpolate(dep_pred, 264 | size=(h, w), 265 | mode='bilinear', 266 | align_corners=True))) 267 | 268 | for idx in range(len(rgbd_preds) - 1): 269 | scaled_preds.append( 270 | torch.sigmoid( 271 | nn.functional.interpolate(rgbd_preds[idx], 272 | size=(h, w), 273 | mode='bilinear', 274 | align_corners=True))) 275 | scaled_preds.append(torch.sigmoid(rgbd_preds[-1])) 276 | 277 | return scaled_preds 278 | 279 | 280 | # weight init 281 | def xavier(param): 282 | init.xavier_uniform_(param) 283 | 284 | 285 | def weights_init(m): 286 | if isinstance(m, nn.Conv2d): 287 | xavier(m.weight.data) 288 | elif isinstance(m, nn.BatchNorm2d): 289 | init.constant_(m.weight, 1) 290 | init.constant_(m.bias, 0) 291 | -------------------------------------------------------------------------------- /models/res2net_v1b.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torch.nn.functional as F 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b'] 8 | 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth/64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width*scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale -1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i==0 or self.stype=='stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i==0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype=='normal': 79 | out = torch.cat((out, spx[self.nums]),1) 80 | elif self.scale != 1 and self.stype=='stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])),1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 1, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | # self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | # x = self.avgpool(x) 159 | # x = x.view(x.size(0), -1) 160 | # x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b model. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 172 | if pretrained: 173 | pretrained_dict = model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']) 174 | model_dict=model.state_dict() 175 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 176 | model_dict.update(pretrained_dict) 177 | model.load_state_dict(model_dict) 178 | return model 179 | 180 | def res2net101_v1b(pretrained=False, **kwargs): 181 | """Constructs a Res2Net-50_v1b_26w_4s model. 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 186 | if pretrained: 187 | pretrained_dict = model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']) 188 | model_dict=model.state_dict() 189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 190 | model_dict.update(pretrained_dict) 191 | model.load_state_dict(model_dict) 192 | return model 193 | 194 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 195 | """Constructs a Res2Net-50_v1b_26w_4s model. 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 202 | return model 203 | 204 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 205 | """Constructs a Res2Net-50_v1b_26w_4s model. 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 212 | return model 213 | 214 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 215 | """Constructs a Res2Net-50_v1b_26w_4s model. 216 | Args: 217 | pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | """ 219 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs) 220 | if pretrained: 221 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 222 | return model 223 | 224 | 225 | 226 | 227 | 228 | if __name__ == '__main__': 229 | images = torch.rand(1, 3, 224, 224).cuda(0) 230 | model = res2net50_v1b_26w_4s(pretrained=True) 231 | model = model.cuda(0) 232 | print(model(images).size()) 233 | -------------------------------------------------------------------------------- /models/resnet_conv1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 11 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 12 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | 95 | def __init__(self, block, layers, num_classes=1000): 96 | self.inplanes = 128 97 | super(ResNet, self).__init__() 98 | self.conv1 = conv3x3(3, 64, stride=1) 99 | self.bn1 = nn.BatchNorm2d(64) 100 | self.relu1 = nn.ReLU(inplace=True) 101 | self.conv2 = conv3x3(64, 64) 102 | self.bn2 = nn.BatchNorm2d(64) 103 | self.relu2 = nn.ReLU(inplace=True) 104 | self.conv3 = conv3x3(64, 128) 105 | self.bn3 = nn.BatchNorm2d(128) 106 | self.relu3 = nn.ReLU(inplace=True) 107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 108 | 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | return x 151 | 152 | def resnet18(pretrained=False, **kwargs): 153 | """Constructs a ResNet-18 model. 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 158 | if pretrained: 159 | model.load_state_dict(load_url(model_urls['resnet18'])) 160 | return model 161 | 162 | ''' 163 | def resnet34(pretrained=False, **kwargs): 164 | """Constructs a ResNet-34 model. 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | """ 168 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 169 | if pretrained: 170 | model.load_state_dict(load_url(model_urls['resnet34'])) 171 | return model 172 | ''' 173 | 174 | def resnet50(pretrained=False, **kwargs): 175 | """Constructs a ResNet-50 model. 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | """ 179 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 180 | if pretrained: 181 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 182 | model_dict=model.state_dict() 183 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 184 | model_dict.update(pretrained_dict) 185 | model.load_state_dict(model_dict) 186 | return model 187 | 188 | 189 | def resnet101(pretrained=False, **kwargs): 190 | """Constructs a ResNet-101 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 195 | if pretrained: 196 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 197 | return model 198 | 199 | # def resnet152(pretrained=False, **kwargs): 200 | # """Constructs a ResNet-152 model. 201 | # 202 | # Args: 203 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | # """ 205 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 206 | # if pretrained: 207 | # model.load_state_dict(load_url(model_urls['resnet152'])) 208 | # return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from dataset import get_loader 3 | import torch 4 | from torchvision import transforms 5 | from torch import nn 6 | import os 7 | import argparse 8 | 9 | 10 | def main(args): 11 | 12 | backbone_names = args.backbones.split('+') 13 | dataset_names = args.datasets.split('+') 14 | 15 | for dataset in dataset_names: 16 | for backbone in backbone_names: 17 | print("Working on [DATASET: %s] with [BACKBONE: %s]" % 18 | (dataset, backbone)) 19 | 20 | # Configure testset path 21 | test_rgb_path = os.path.join(args.input_root, dataset, 'RGB') 22 | test_dep_path = os.path.join(args.input_root, dataset, 'depth') 23 | 24 | res_path = os.path.join(args.save_root, 'BiANet_' + backbone, 25 | dataset) 26 | os.makedirs(res_path, exist_ok=True) 27 | test_loader = get_loader(test_rgb_path, 28 | test_dep_path, 29 | 224, 30 | 1, 31 | num_thread=8, 32 | pin=True) 33 | 34 | # Load model and parameters 35 | exec('from models import BiANet_' + backbone) 36 | model = eval('BiANet_' + backbone).BiANet() 37 | pre_dict = torch.load( 38 | os.path.join(args.param_root, 'BiANet_' + backbone + '.pth')) 39 | device = torch.device("cuda") 40 | model.to(device) 41 | if backbone == 'vgg16': 42 | model = torch.nn.DataParallel(model, device_ids=[0]) 43 | model.load_state_dict(pre_dict) 44 | model.eval() 45 | 46 | # Test Go! 47 | tensor2pil = transforms.ToPILImage() 48 | with torch.no_grad(): 49 | for batch in test_loader: 50 | rgbs = batch[0].to(device) 51 | deps = batch[1].to(device) 52 | name = batch[2][0] 53 | imsize = batch[3] 54 | 55 | scaled_preds = model(rgbs, deps) 56 | 57 | res = scaled_preds[-1] 58 | 59 | res = nn.functional.interpolate(res, 60 | size=imsize, 61 | mode='bilinear', 62 | align_corners=True).cpu() 63 | res = res.squeeze(0) 64 | res = tensor2pil(res) 65 | res.save(os.path.join(res_path, name[:-3] + 'png')) 66 | 67 | print('Outputs were saved at:' + args.save_root) 68 | 69 | 70 | if __name__ == '__main__': 71 | # Parameter from command line 72 | parser = argparse.ArgumentParser(description='') 73 | parser.add_argument('--backbones', 74 | default='vgg16', 75 | type=str, 76 | help="Options: 'vgg11','vgg16','res50', 'res2_50") 77 | parser.add_argument( 78 | '--datasets', 79 | default='NJU2K_Test', 80 | type=str, 81 | help="Options: 'NJU2K_TEST', 'NLPR_TEST','DES','SSD','STERE','SIP'") 82 | parser.add_argument('--size', default=224, type=int, help='input size') 83 | parser.add_argument('--param_root', 84 | default='param', 85 | type=str, 86 | help='folder for pre-trained model') 87 | parser.add_argument('--input_root', 88 | default='./Testset', 89 | type=str, 90 | help='dataset root') 91 | args = parser.parse_args() 92 | parser.add_argument('--save_root', 93 | default='./SalMap', 94 | type=str, 95 | help='Output folder') 96 | args = parser.parse_args() 97 | main(args) 98 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test.py --backbones vgg16+vgg11+res50+res2_50 --datasets NJU2K_TEST+NLPR_TEST+DES+SSD+STERE+SIP --param_root param --save_root ../SalMaps_Minor/pred -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | import shutil 5 | from torchvision import transforms 6 | import numpy as np 7 | import random 8 | 9 | 10 | class Logger(): 11 | def __init__(self, path="log.txt"): 12 | self.logger = logging.getLogger('DGNet') 13 | self.file_handler = logging.FileHandler(path, "w") 14 | self.stdout_handler = logging.StreamHandler() 15 | self.stdout_handler.setFormatter( 16 | logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 17 | self.file_handler.setFormatter( 18 | logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 19 | self.logger.addHandler(self.file_handler) 20 | self.logger.addHandler(self.stdout_handler) 21 | self.logger.setLevel(logging.INFO) 22 | self.logger.propagate = False 23 | 24 | def info(self, txt): 25 | self.logger.info(txt) 26 | 27 | def close(self): 28 | self.file_handler.close() 29 | self.stdout_handler.close() 30 | 31 | 32 | class AverageMeter(object): 33 | """Computes and stores the average and current value""" 34 | def __init__(self): 35 | self.reset() 36 | 37 | def reset(self): 38 | self.val = 0.0 39 | self.avg = 0.0 40 | self.sum = 0.0 41 | self.count = 0.0 42 | 43 | def update(self, val, n=1): 44 | self.val = val 45 | self.sum += val * n 46 | self.count += n 47 | self.avg = self.sum / self.count 48 | 49 | 50 | def save_tensor_img(tenor_im, path): 51 | im = tenor_im.cpu().clone() 52 | im = im.squeeze(0) 53 | tensor2pil = transforms.ToPILImage() 54 | im = tensor2pil(im) 55 | im.save(path) --------------------------------------------------------------------------------