├── .gitignore
├── README.md
├── dataset.py
├── imgs
└── BiANet_logo.png
├── models
├── BiANet_res2_50.py
├── BiANet_res50.py
├── BiANet_vgg11.py
├── BiANet_vgg16.py
├── res2net_v1b.py
└── resnet_conv1.py
├── test.py
├── test.sh
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
2 | /param
3 | /Testset
4 | /SalMaps
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
Bilateral Attention Network for RGB-D Salient Object Detection
9 |
10 | Published in IEEE Transactions on Image Processing (TIP)
11 |
12 | [Paper 📄]
13 | [ArXiv 🌐]
14 |
15 | [Homepage 🏠] »
16 |
17 |
18 |
19 |
20 | ***
21 |
23 |
24 |
25 | ## Prerequisites
26 | #### Environments
27 | * PyTorch >= 1.0
28 | * Ubuntu 18.04
29 |
30 |
31 |
32 | ## Usage
33 | 1. Download the [model parameters](#download) and [datasets](http://dpfan.net/d3netbenchmark/)
34 | 2. Configure `test.sh`
35 |
36 | ```
37 | --backbones vgg16+vgg11+res50+res2_50 (Multiple items are connected with '+')
38 | --datasets dataset1+dataset2+dataset3
39 | --param_root param (pretrained model path)
40 | --input_root your_data_root (categorize by subfolders)
41 | --save_root your_output_root
42 | ```
43 |
44 | 3. Run by
45 | ```
46 | sh test.sh
47 | ```
48 | ## Model parameters and prediction results
49 | | | Model parameters | Prediction results |
50 | | ---- | ---- | ---- |
51 | | **VGG-16** | [[Google Drive]](https://drive.google.com/file/d/1yfE2-4GH-QJo5JvvJbKRwXgzaRQ5e8h_/view?usp=sharing) [[Baidu Pan (bfrn)]](https://pan.baidu.com/s/1gXkDYUU0wxzM2EjyBoO6Yg) | [[Google Drive]](https://drive.google.com/file/d/1BI43wDAT9lON-8mKK6X00j-AmcZnwoZG/view?usp=sharing) [[Baidu Pan (k01w)]](https://pan.baidu.com/s/1lFPPf9LynKlBx2tOyoP_2A) |
52 | | VGG-11 | [[Google Drive]](https://drive.google.com/file/d/1TdTvZmPIbPfaX_BYI7dNTUoMI7IVXvFe/view?usp=sharing) [[Baidu Pan (2a5c)]](https://pan.baidu.com/s/1Usr-SNCPZADyISaIXPEZxA) | [[Google Drive]](https://drive.google.com/file/d/14aP1634QFjc0wQu8Unjme0lsmaJtlnFp/view?usp=sharing) [[Baidu Pan (d0t7)]](https://pan.baidu.com/s/1U-7hkmvfN8Pjj0pnC8VLGQ) |
53 | | ResNet-50 | [[Google Drive]](https://drive.google.com/file/d/13vHFAR44v2bojEJppoB058QV0Vc9-Tm7/view?usp=sharing) [[Baidu Pan (o9l2)]](https://pan.baidu.com/s/1m0p7IN4GB2BWCcoj6kM_lw) | [[Google Drive]](https://drive.google.com/file/d/1CFgXVlB-jmHArTv6kdK-CZvQ6nuEpve3/view?usp=sharing) [[Baidu Pan (dqw1)]](https://pan.baidu.com/s/1KJUy4cu4dpVfdF5Nqw2uOw) |
54 | | Res2Net-50 | [[Google Drive]](https://drive.google.com/file/d/1DppyXLs_toFi6bM5ZbGWip35BxLGfw4y/view?usp=sharing) [[Baidu Pan (k761)]](https://pan.baidu.com/s/1ycs9SI5bmIKBUbcNsrR7qQ) | [[Google Drive]](https://drive.google.com/file/d/1at-K6DfKNP2Gnao9f0v9agmzADkgt0Ik/view?usp=sharing) [[Baidu Pan (h3t9)]](https://pan.baidu.com/s/1YHVrDEl1-dCHgS2Fuc1Qzw) |
55 |
56 | ## Citation
57 | ```
58 | @article{zhang2020bianet,
59 | title={Bilateral attention network for rgb-d salient object detection},
60 | author={Zhang, Zhao and Lin, Zheng and Xu, Jun and Jin, Wenda and Lu, Shao-Ping and Fan, Deng-Ping},
61 | journal={IEEE Transactions on Image Processing (TIP)},
62 | volume={30},
63 | pages={1949-1961},
64 | doi={10.1109/TIP.2021.3049959},
65 | year={2021},
66 | }
67 | ```
68 |
69 | ## Contact
70 | If you have any questions, feel free to contact me via `zzhang🥳mail😲nankai😲edu😲cn`
71 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch
4 | import random
5 | import numpy as np
6 | from torch.utils import data
7 | from torchvision import transforms
8 | from torchvision.transforms import functional as F
9 |
10 |
11 | class ImageData(data.Dataset):
12 | def __init__(self, rgb_root, dep_root, transform):
13 |
14 | self.rgb_path = list(
15 | map(lambda x: os.path.join(rgb_root, x), os.listdir(rgb_root)))
16 | self.dep_path = list(
17 | map(
18 | lambda x: os.path.join(dep_root,
19 | x.split('/')[-1][:-3] + 'png'),
20 | self.rgb_path))
21 |
22 | self.transform = transform
23 |
24 | def __getitem__(self, item):
25 |
26 | rgb = Image.open(self.rgb_path[item]).convert('RGB')
27 | dep = Image.open(self.dep_path[item]).convert('RGB')
28 | [h, w] = dep.size
29 | imsize = [w, h]
30 |
31 | [rgb, dep] = self.transform(rgb, dep)
32 |
33 | return rgb, dep, self.rgb_path[item].split('/')[-1], imsize
34 |
35 | def __len__(self):
36 | return len(self.rgb_path)
37 |
38 |
39 | class FixedResize(object):
40 | def __init__(self, size):
41 | self.size = (size, size) # size: (h, w)
42 |
43 | def __call__(self, rgb, dep):
44 |
45 | assert rgb.size == dep.size
46 |
47 | rgb = rgb.resize(self.size, Image.BILINEAR)
48 | dep = dep.resize(self.size, Image.BILINEAR)
49 |
50 | return rgb, dep
51 |
52 |
53 | class ToTensor(object):
54 | """Convert ndarrays in sample to Tensors."""
55 | def __call__(self, rgb, dep):
56 |
57 | return F.to_tensor(rgb), F.to_tensor(dep)
58 |
59 |
60 | class Normalize(object):
61 | """Normalize a tensor image with mean and standard deviation.
62 | Args:
63 | mean (tuple): means for each channel.
64 | std (tuple): standard deviations for each channel.
65 | """
66 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
67 | self.mean = mean
68 | self.std = std
69 |
70 | def __call__(self, rgb, dep):
71 |
72 | dep = F.normalize(dep, self.mean, self.std)
73 |
74 | rgb = F.normalize(rgb, self.mean, self.std)
75 |
76 | return rgb, dep
77 |
78 |
79 | class RandomHorizontalFlip(object):
80 | def __init__(self, p=0.5):
81 | self.p = p
82 |
83 | def __call__(self, rgb, dep):
84 | if random.random() < self.p:
85 | rgb = rgb.transpose(Image.FLIP_LEFT_RIGHT)
86 | dep = dep.transpose(Image.FLIP_LEFT_RIGHT)
87 |
88 | return rgb, dep
89 |
90 |
91 | class Compose(object):
92 | def __init__(self, transforms):
93 | self.transforms = transforms
94 |
95 | def __call__(self, rgb, dep):
96 | for t in self.transforms:
97 | rgb, dep = t(rgb, dep)
98 | return rgb, dep
99 |
100 | def __repr__(self):
101 | format_string = self.__class__.__name__ + '('
102 | for t in self.transforms:
103 | format_string += '\n'
104 | format_string += ' {0}'.format(t)
105 | format_string += '\n)'
106 | return format_string
107 |
108 |
109 | def get_loader(rgb_root,
110 | dep_root,
111 | img_size,
112 | batch_size=1,
113 | num_thread=1,
114 | pin=False):
115 | test_transform = Compose([
116 | FixedResize(img_size),
117 | ToTensor(),
118 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
119 | ])
120 |
121 | dataset = ImageData(rgb_root, dep_root, transform=test_transform)
122 | data_loader = data.DataLoader(dataset=dataset,
123 | batch_size=batch_size,
124 | shuffle=False,
125 | num_workers=num_thread,
126 | pin_memory=pin)
127 | return data_loader
128 |
--------------------------------------------------------------------------------
/imgs/BiANet_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zzhanghub/bianet/0d557772b944ba2847a1bf83b0ef89752b2d6f7e/imgs/BiANet_logo.png
--------------------------------------------------------------------------------
/models/BiANet_res2_50.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 | from models.res2net_v1b import res2net50_v1b
6 |
7 |
8 | # RGB Stream (VGG16)
9 | class RGB_Stream(nn.Module):
10 | def __init__(self):
11 | super(RGB_Stream, self).__init__()
12 | self.backbone = res2net50_v1b(pretrained=True)
13 | self.toplayer = nn.Sequential(
14 | nn.MaxPool2d(2, stride=2),
15 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3),
16 | nn.BatchNorm2d(32),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3),
19 | nn.BatchNorm2d(32),
20 | nn.ReLU(inplace=True),
21 | )
22 |
23 | def forward(self, rgb):
24 | rgb = self.backbone.conv1(rgb)
25 | rgb = self.backbone.bn1(rgb)
26 | rgb = self.backbone.relu(rgb)
27 | rgb1 = rgb
28 | rgb = self.backbone.maxpool(rgb)
29 | rgb2 = self.backbone.layer1(rgb)
30 | rgb3 = self.backbone.layer2(rgb2)
31 | rgb4 = self.backbone.layer3(rgb3)
32 | rgb5 = self.backbone.layer4(rgb4)
33 | rgb6 = self.toplayer(rgb5)
34 |
35 | return [rgb1, rgb2, rgb3, rgb4, rgb5, rgb6]
36 |
37 |
38 | # Depth Stream (VGG16)
39 | class Dep_Stream(nn.Module):
40 | def __init__(self):
41 | super(Dep_Stream, self).__init__()
42 | self.backbone = res2net50_v1b(pretrained=True)
43 | self.toplayer = nn.Sequential(
44 | nn.MaxPool2d(2, stride=2),
45 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3),
46 | nn.BatchNorm2d(32),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3),
49 | nn.BatchNorm2d(32),
50 | nn.ReLU(inplace=True),
51 | )
52 |
53 | def forward(self, dep):
54 | dep = self.backbone.conv1(dep)
55 | dep = self.backbone.bn1(dep)
56 | dep = self.backbone.relu(dep)
57 | dep1 = dep
58 | dep = self.backbone.maxpool(dep)
59 | dep2 = self.backbone.layer1(dep)
60 | dep3 = self.backbone.layer2(dep2)
61 | dep4 = self.backbone.layer3(dep3)
62 | dep5 = self.backbone.layer4(dep4)
63 | dep6 = self.toplayer(dep5)
64 | return [dep1, dep2, dep3, dep4, dep5, dep6]
65 |
66 |
67 | class Pred_Layer(nn.Module):
68 | def __init__(self, in_c=32):
69 | super(Pred_Layer, self).__init__()
70 | self.enlayer = nn.Sequential(
71 | nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1),
72 | nn.BatchNorm2d(32),
73 | nn.ReLU(inplace=True),
74 | )
75 | self.outlayer = nn.Sequential(
76 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), )
77 |
78 | def forward(self, x):
79 | x = self.enlayer(x)
80 | x = self.outlayer(x)
81 | return x
82 |
83 |
84 | # BAM
85 | class BAM(nn.Module):
86 | def __init__(self, in_c):
87 | super(BAM, self).__init__()
88 | self.reduce = nn.Conv2d(in_c * 2, 32, 1)
89 | self.ff_conv = nn.Sequential(
90 | nn.Conv2d(32, 32, 3, 1, 1),
91 | nn.BatchNorm2d(32),
92 | nn.ReLU(inplace=True),
93 | )
94 | self.bf_conv = nn.Sequential(
95 | nn.Conv2d(32, 32, 3, 1, 1),
96 | nn.BatchNorm2d(32),
97 | nn.ReLU(inplace=True),
98 | )
99 | self.rgbd_pred_layer = Pred_Layer(32 * 2)
100 |
101 | def forward(self, rgb_feat, dep_feat, pred):
102 | feat = torch.cat((rgb_feat, dep_feat), 1)
103 | feat = self.reduce(feat)
104 | [_, _, H, W] = feat.size()
105 | pred = torch.sigmoid(
106 | F.interpolate(pred,
107 | size=(H, W),
108 | mode='bilinear',
109 | align_corners=True))
110 | ff_feat = self.ff_conv(feat * pred)
111 | bf_feat = self.bf_conv(feat * (1 - pred))
112 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1))
113 | return new_pred
114 |
115 |
116 | # FF
117 | class FF(nn.Module):
118 | def __init__(self, in_c):
119 | super(FF, self).__init__()
120 | self.reduce = nn.Conv2d(in_c, 32, 1)
121 | self.ff_conv = nn.Sequential(
122 | nn.Conv2d(32, 32, k, 1, k // 2),
123 | nn.BatchNorm2d(32),
124 | nn.ReLU(inplace=True),
125 | )
126 | self.rgbd_pred_layer = Pred_Layer(32)
127 |
128 | def forward(self, rgb_feat, dep_feat, pred):
129 | feat = torch.cat((rgb_feat, dep_feat), 1)
130 | [_, _, H, W] = feat.size()
131 | pred = torch.sigmoid(
132 | F.interpolate(pred,
133 | size=(H, W),
134 | mode='bilinear',
135 | align_corners=True))
136 | ff_feat = self.ff_conv(feat * pred)
137 | new_pred = self.rgbd_pred_layer(ff_feat)
138 | return new_pred
139 |
140 |
141 | # BF
142 | class BF(nn.Module):
143 | def __init__(self, in_c):
144 | super(BF, self).__init__()
145 | self.reduce = nn.Conv2d(in_c * 2, 32, 1)
146 | self.bf_conv = nn.Sequential(
147 | nn.Conv2d(32, 32, 3, 1, 1),
148 | nn.BatchNorm2d(32),
149 | nn.ReLU(inplace=True),
150 | )
151 | self.rgbd_pred_layer = Pred_Layer(32)
152 |
153 | def forward(self, rgb_feat, dep_feat, pred):
154 | feat = torch.cat((rgb_feat, dep_feat), 1)
155 | [_, _, H, W] = feat.size()
156 | pred = torch.sigmoid(
157 | F.interpolate(pred,
158 | size=(H, W),
159 | mode='bilinear',
160 | align_corners=True))
161 | bf_feat = self.bf_conv(feat * (1 - pred))
162 | new_pred = self.rgbd_pred_layer(bf_feat)
163 | return new_pred
164 |
165 |
166 | # ASPP for MBAM
167 | class ASPP(nn.Module):
168 | def __init__(self, in_c):
169 | super(ASPP, self).__init__()
170 |
171 | self.aspp1 = nn.Sequential(
172 | nn.Conv2d(in_c * 2, 32, 1, 1),
173 | nn.BatchNorm2d(32),
174 | nn.ReLU(inplace=True),
175 | )
176 | self.aspp2 = nn.Sequential(
177 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=3, dilation=3),
178 | nn.BatchNorm2d(32),
179 | nn.ReLU(inplace=True),
180 | )
181 |
182 | self.aspp3 = nn.Sequential(
183 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=5, dilation=5),
184 | nn.BatchNorm2d(32),
185 | nn.ReLU(inplace=True),
186 | )
187 | self.aspp4 = nn.Sequential(
188 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=7, dilation=7),
189 | nn.BatchNorm2d(32),
190 | nn.ReLU(inplace=True),
191 | )
192 |
193 | def forward(self, x):
194 | x1 = self.aspp1(x)
195 | x2 = self.aspp2(x)
196 | x3 = self.aspp3(x)
197 | x4 = self.aspp4(x)
198 | x = torch.cat((x1, x2, x3, x4), dim=1)
199 |
200 | return x
201 |
202 |
203 | # MBAM
204 | class MBAM(nn.Module):
205 | def __init__(self, in_c):
206 | super(MBAM, self).__init__()
207 | self.ff_conv = ASPP(in_c)
208 | self.bf_conv = ASPP(in_c)
209 | self.rgbd_pred_layer = Pred_Layer(32 * 8)
210 |
211 | def forward(self, rgb_feat, dep_feat, pred):
212 | feat = torch.cat((rgb_feat, dep_feat), 1)
213 | [_, _, H, W] = feat.size()
214 | pred = torch.sigmoid(
215 | F.interpolate(pred,
216 | size=(H, W),
217 | mode='bilinear',
218 | align_corners=True))
219 |
220 | ff_feat = self.ff_conv(feat * pred)
221 | bf_feat = self.bf_conv(feat * (1 - pred))
222 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1))
223 | return new_pred
224 |
225 |
226 | class BiANet(nn.Module):
227 | def __init__(self):
228 | super(BiANet, self).__init__()
229 |
230 | # two-streams
231 | self.rgb_stream = RGB_Stream()
232 | self.dep_stream = Dep_Stream()
233 |
234 | # Global Pred
235 | self.rgb_global = Pred_Layer(32)
236 | self.dep_global = Pred_Layer(32)
237 | self.rgbd_global = Pred_Layer(32 * 2)
238 |
239 | # Shor-Conection
240 | self.bams = nn.ModuleList([
241 | BAM(64),
242 | BAM(256),
243 | MBAM(512),
244 | MBAM(1024),
245 | MBAM(2048),
246 | ])
247 |
248 | def _upsample_add(self, x, y):
249 | [_, _, H, W] = y.size()
250 | return F.interpolate(
251 | x, size=(H, W), mode='bilinear', align_corners=True) + y
252 |
253 | def forward(self, rgb, dep):
254 | [_, _, H, W] = rgb.size()
255 | rgb_feats = self.rgb_stream(rgb)
256 | dep_feats = self.dep_stream(dep)
257 |
258 | # Gloabl Prediction
259 | rgb_pred = self.rgb_global(rgb_feats[5])
260 | dep_pred = self.dep_global(dep_feats[5])
261 | rgbd_pred = self.rgbd_global(torch.cat((rgb_feats[5], dep_feats[5]),
262 | 1))
263 | preds = [
264 | torch.sigmoid(
265 | F.interpolate(rgb_pred,
266 | size=(H, W),
267 | mode='bilinear',
268 | align_corners=True)),
269 | torch.sigmoid(
270 | F.interpolate(dep_pred,
271 | size=(H, W),
272 | mode='bilinear',
273 | align_corners=True)),
274 | torch.sigmoid(
275 | F.interpolate(rgbd_pred,
276 | size=(H, W),
277 | mode='bilinear',
278 | align_corners=True)),
279 | ]
280 |
281 | p = rgbd_pred
282 | for idx in [4, 3, 2, 1, 0]:
283 | _p = self.bams[idx](rgb_feats[idx], dep_feats[idx], p)
284 | p = self._upsample_add(p, _p)
285 | preds.append(
286 | torch.sigmoid(
287 | F.interpolate(p,
288 | size=(H, W),
289 | mode='bilinear',
290 | align_corners=True)))
291 | return preds
292 |
--------------------------------------------------------------------------------
/models/BiANet_res50.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 | from models.resnet_conv1 import resnet50
6 |
7 |
8 | # RGB Stream (VGG16)
9 | class RGB_Stream(nn.Module):
10 | def __init__(self):
11 | super(RGB_Stream, self).__init__()
12 | self.backbone = resnet50(pretrained=True)
13 | self.toplayer = nn.Sequential(
14 | nn.MaxPool2d(2, stride=2),
15 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3),
16 | nn.BatchNorm2d(32),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3),
19 | nn.BatchNorm2d(32),
20 | nn.ReLU(inplace=True),
21 | )
22 |
23 | def forward(self, rgb):
24 | rgb = self.backbone.relu1(self.backbone.bn1(self.backbone.conv1(rgb)))
25 | rgb = self.backbone.relu2(self.backbone.bn2(self.backbone.conv2(rgb)))
26 | rgb = self.backbone.relu3(self.backbone.bn3(self.backbone.conv3(rgb)))
27 | rgb1 = rgb
28 | rgb = self.backbone.maxpool(rgb)
29 | rgb2 = self.backbone.layer1(rgb)
30 | rgb3 = self.backbone.layer2(rgb2)
31 | rgb4 = self.backbone.layer3(rgb3)
32 | rgb5 = self.backbone.layer4(rgb4)
33 | rgb6 = self.toplayer(rgb5)
34 |
35 | return [rgb1, rgb2, rgb3, rgb4, rgb5, rgb6]
36 |
37 |
38 | # Depth Stream (VGG16)
39 | class Dep_Stream(nn.Module):
40 | def __init__(self):
41 | super(Dep_Stream, self).__init__()
42 | self.backbone = resnet50(pretrained=True)
43 | self.toplayer = nn.Sequential(
44 | nn.MaxPool2d(2, stride=2),
45 | nn.Conv2d(2048, 32, kernel_size=5, stride=1, padding=3),
46 | nn.BatchNorm2d(32),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=3),
49 | nn.BatchNorm2d(32),
50 | nn.ReLU(inplace=True),
51 | )
52 |
53 | def forward(self, dep):
54 | dep = self.backbone.relu1(self.backbone.bn1(self.backbone.conv1(dep)))
55 | dep = self.backbone.relu2(self.backbone.bn2(self.backbone.conv2(dep)))
56 | dep = self.backbone.relu3(self.backbone.bn3(self.backbone.conv3(dep)))
57 | dep1 = dep
58 | dep = self.backbone.maxpool(dep)
59 | dep2 = self.backbone.layer1(dep)
60 | dep3 = self.backbone.layer2(dep2)
61 | dep4 = self.backbone.layer3(dep3)
62 | dep5 = self.backbone.layer4(dep4)
63 | dep6 = self.toplayer(dep5)
64 | return [dep1, dep2, dep3, dep4, dep5, dep6]
65 |
66 |
67 | class Pred_Layer(nn.Module):
68 | def __init__(self, in_c=32):
69 | super(Pred_Layer, self).__init__()
70 | self.enlayer = nn.Sequential(
71 | nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1),
72 | nn.BatchNorm2d(32),
73 | nn.ReLU(inplace=True),
74 | )
75 | self.outlayer = nn.Sequential(
76 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), )
77 |
78 | def forward(self, x):
79 | x = self.enlayer(x)
80 | x = self.outlayer(x)
81 | return x
82 |
83 |
84 | # BAM
85 | class BAM(nn.Module):
86 | def __init__(self, in_c):
87 | super(BAM, self).__init__()
88 | self.reduce = nn.Conv2d(in_c * 2, 32, 1)
89 | self.ff_conv = nn.Sequential(
90 | nn.Conv2d(32, 32, 3, 1, 1),
91 | nn.BatchNorm2d(32),
92 | nn.ReLU(inplace=True),
93 | )
94 | self.bf_conv = nn.Sequential(
95 | nn.Conv2d(32, 32, 3, 1, 1),
96 | nn.BatchNorm2d(32),
97 | nn.ReLU(inplace=True),
98 | )
99 | self.rgbd_pred_layer = Pred_Layer(32 * 2)
100 |
101 | def forward(self, rgb_feat, dep_feat, pred):
102 | feat = torch.cat((rgb_feat, dep_feat), 1)
103 | feat = self.reduce(feat)
104 | [_, _, H, W] = feat.size()
105 | pred = torch.sigmoid(
106 | F.interpolate(pred,
107 | size=(H, W),
108 | mode='bilinear',
109 | align_corners=True))
110 | ff_feat = self.ff_conv(feat * pred)
111 | bf_feat = self.bf_conv(feat * (1 - pred))
112 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1))
113 | return new_pred
114 |
115 |
116 | # FF
117 | class FF(nn.Module):
118 | def __init__(self, in_c):
119 | super(FF, self).__init__()
120 | self.reduce = nn.Conv2d(in_c, 32, 1)
121 | self.ff_conv = nn.Sequential(
122 | nn.Conv2d(32, 32, k, 1, k // 2),
123 | nn.BatchNorm2d(32),
124 | nn.ReLU(inplace=True),
125 | )
126 | self.rgbd_pred_layer = Pred_Layer(32)
127 |
128 | def forward(self, rgb_feat, dep_feat, pred):
129 | feat = torch.cat((rgb_feat, dep_feat), 1)
130 | [_, _, H, W] = feat.size()
131 | pred = torch.sigmoid(
132 | F.interpolate(pred,
133 | size=(H, W),
134 | mode='bilinear',
135 | align_corners=True))
136 | ff_feat = self.ff_conv(feat * pred)
137 | new_pred = self.rgbd_pred_layer(ff_feat)
138 | return new_pred
139 |
140 |
141 | # BF
142 | class BF(nn.Module):
143 | def __init__(self, in_c):
144 | super(BF, self).__init__()
145 | self.reduce = nn.Conv2d(in_c * 2, 32, 1)
146 | self.bf_conv = nn.Sequential(
147 | nn.Conv2d(32, 32, 3, 1, 1),
148 | nn.BatchNorm2d(32),
149 | nn.ReLU(inplace=True),
150 | )
151 | self.rgbd_pred_layer = Pred_Layer(32)
152 |
153 | def forward(self, rgb_feat, dep_feat, pred):
154 | feat = torch.cat((rgb_feat, dep_feat), 1)
155 | [_, _, H, W] = feat.size()
156 | pred = torch.sigmoid(
157 | F.interpolate(pred,
158 | size=(H, W),
159 | mode='bilinear',
160 | align_corners=True))
161 | bf_feat = self.bf_conv(feat * (1 - pred))
162 | new_pred = self.rgbd_pred_layer(bf_feat)
163 | return new_pred
164 |
165 |
166 | # ASPP for MBAM
167 | class ASPP(nn.Module):
168 | def __init__(self, in_c):
169 | super(ASPP, self).__init__()
170 |
171 | self.aspp1 = nn.Sequential(
172 | nn.Conv2d(in_c * 2, 32, 1, 1),
173 | nn.BatchNorm2d(32),
174 | nn.ReLU(inplace=True),
175 | )
176 | self.aspp2 = nn.Sequential(
177 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=3, dilation=3),
178 | nn.BatchNorm2d(32),
179 | nn.ReLU(inplace=True),
180 | )
181 |
182 | self.aspp3 = nn.Sequential(
183 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=5, dilation=5),
184 | nn.BatchNorm2d(32),
185 | nn.ReLU(inplace=True),
186 | )
187 | self.aspp4 = nn.Sequential(
188 | nn.Conv2d(in_c * 2, 32, 3, 1, padding=7, dilation=7),
189 | nn.BatchNorm2d(32),
190 | nn.ReLU(inplace=True),
191 | )
192 |
193 | def forward(self, x):
194 | x1 = self.aspp1(x)
195 | x2 = self.aspp2(x)
196 | x3 = self.aspp3(x)
197 | x4 = self.aspp4(x)
198 | x = torch.cat((x1, x2, x3, x4), dim=1)
199 |
200 | return x
201 |
202 |
203 | # MBAM
204 | class MBAM(nn.Module):
205 | def __init__(self, in_c):
206 | super(MBAM, self).__init__()
207 | self.ff_conv = ASPP(in_c)
208 | self.bf_conv = ASPP(in_c)
209 | self.rgbd_pred_layer = Pred_Layer(32 * 8)
210 |
211 | def forward(self, rgb_feat, dep_feat, pred):
212 | feat = torch.cat((rgb_feat, dep_feat), 1)
213 | [_, _, H, W] = feat.size()
214 | pred = torch.sigmoid(
215 | F.interpolate(pred,
216 | size=(H, W),
217 | mode='bilinear',
218 | align_corners=True))
219 |
220 | ff_feat = self.ff_conv(feat * pred)
221 | bf_feat = self.bf_conv(feat * (1 - pred))
222 | new_pred = self.rgbd_pred_layer(torch.cat((ff_feat, bf_feat), 1))
223 | return new_pred
224 |
225 |
226 | class BiANet(nn.Module):
227 | def __init__(self):
228 | super(BiANet, self).__init__()
229 |
230 | # two-streams
231 | self.rgb_stream = RGB_Stream()
232 | self.dep_stream = Dep_Stream()
233 |
234 | # Global Pred
235 | self.rgb_global = Pred_Layer(32)
236 | self.dep_global = Pred_Layer(32)
237 | self.rgbd_global = Pred_Layer(32 * 2)
238 |
239 | # Shor-Conection
240 | self.bams = nn.ModuleList([
241 | BAM(128),
242 | BAM(256),
243 | MBAM(512),
244 | MBAM(1024),
245 | MBAM(2048),
246 | ])
247 |
248 | def _upsample_add(self, x, y):
249 | [_, _, H, W] = y.size()
250 | return F.interpolate(
251 | x, size=(H, W), mode='bilinear', align_corners=True) + y
252 |
253 | def forward(self, rgb, dep):
254 | [_, _, H, W] = rgb.size()
255 | rgb_feats = self.rgb_stream(rgb)
256 | dep_feats = self.dep_stream(dep)
257 |
258 | # Gloabl Prediction
259 | rgb_pred = self.rgb_global(rgb_feats[5])
260 | dep_pred = self.dep_global(dep_feats[5])
261 | rgbd_pred = self.rgbd_global(torch.cat((rgb_feats[5], dep_feats[5]),
262 | 1))
263 | preds = [
264 | torch.sigmoid(
265 | F.interpolate(rgb_pred,
266 | size=(H, W),
267 | mode='bilinear',
268 | align_corners=True)),
269 | torch.sigmoid(
270 | F.interpolate(dep_pred,
271 | size=(H, W),
272 | mode='bilinear',
273 | align_corners=True)),
274 | torch.sigmoid(
275 | F.interpolate(rgbd_pred,
276 | size=(H, W),
277 | mode='bilinear',
278 | align_corners=True)),
279 | ]
280 |
281 | p = rgbd_pred
282 | for idx in [4, 3, 2, 1, 0]:
283 | _p = self.bams[idx](rgb_feats[idx], dep_feats[idx], p)
284 | p = self._upsample_add(p, _p)
285 | preds.append(
286 | torch.sigmoid(
287 | F.interpolate(p,
288 | size=(H, W),
289 | mode='bilinear',
290 | align_corners=True)))
291 | return preds
292 |
--------------------------------------------------------------------------------
/models/BiANet_vgg11.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 |
6 | backbone = {
7 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
8 | }
9 |
10 |
11 | # VGG16
12 | def vgg(cfg, i=3, batch_norm=False):
13 | layers = []
14 | in_channels = i
15 | for v in cfg:
16 | if v == 'M':
17 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
18 | else:
19 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
20 | if batch_norm:
21 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
22 | else:
23 | layers += [conv2d, nn.ReLU(inplace=True)]
24 | in_channels = v
25 | return layers
26 |
27 |
28 | # VGG16 with Side Outputs
29 | class VGG_Sout(nn.Module):
30 | def __init__(self, extract=[1, 4, 9, 14, 19]):
31 | super(VGG_Sout, self).__init__()
32 | self.vgg = nn.ModuleList(vgg(cfg=backbone['vgg11']))
33 | self.extract = extract
34 |
35 | def forward(self, x):
36 | souts = []
37 | for idx in range(len(self.vgg)):
38 | x = self.vgg[idx](x)
39 | if idx in self.extract:
40 | souts.append(x)
41 |
42 | return souts, x
43 |
44 |
45 | # Global Sliency (A new block following VGG-16 for predict global saliency map)
46 | class GSLayer(nn.Module):
47 | def __init__(self, in_channel, channel, k):
48 | super(GSLayer, self).__init__()
49 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1)
50 | self.convs = nn.Sequential(nn.Conv2d(channel, channel, k, 1, k // 2),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(channel, channel, k, 1, k // 2),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(channel, channel, k, 1, k // 2),
55 | nn.ReLU(inplace=True))
56 | self.out_layer = nn.Conv2d(channel, 1, 1)
57 |
58 | def forward(self, x):
59 | x = self.conv1x1(x)
60 | x = self.convs(x)
61 | out = self.out_layer(x)
62 | return out
63 |
64 |
65 | # Original Attention
66 | class OriAtt(nn.Module):
67 | def __init__(self):
68 | super(OriAtt, self).__init__()
69 |
70 | def forward(self, sout, pred):
71 | return sout.mul(torch.sigmoid(pred))
72 |
73 |
74 | # Reverse Attention
75 | class RevAtt(nn.Module):
76 | def __init__(self):
77 | super(RevAtt, self).__init__()
78 |
79 | def forward(self, sout, pred):
80 | return sout.mul(1 - torch.sigmoid(pred))
81 |
82 |
83 | # ASPP block
84 | class ASPP(nn.Module):
85 | def __init__(self, in_channel, channel):
86 | super(ASPP, self).__init__()
87 |
88 | self.aspp1 = nn.Sequential(
89 | nn.Conv2d(in_channel, channel, 1, 1),
90 | nn.ReLU(inplace=True),
91 | )
92 | self.aspp2 = nn.Sequential(
93 | nn.Conv2d(in_channel, channel, 3, 1, padding=3, dilation=3),
94 | nn.ReLU(inplace=True),
95 | )
96 |
97 | self.aspp3 = nn.Sequential(
98 | nn.Conv2d(in_channel, channel, 3, 1, padding=5, dilation=5),
99 | nn.ReLU(inplace=True),
100 | )
101 | self.aspp4 = nn.Sequential(
102 | nn.Conv2d(in_channel, channel, 3, 1, padding=7, dilation=7),
103 | nn.ReLU(inplace=True),
104 | )
105 |
106 | def forward(self, x):
107 | x1 = self.aspp1(x)
108 | x2 = self.aspp2(x)
109 | x3 = self.aspp3(x)
110 | x4 = self.aspp4(x)
111 | # x4 = F.interpolate(x4, size=x3.size()[2:], mode='bilinear', align_corners=True)
112 | x = torch.cat((x1, x2, x3, x4), dim=1)
113 |
114 | return x
115 |
116 |
117 | # Output residual (Dual-stream Attention)
118 | class ResiLayer(nn.Module):
119 | def __init__(self, in_channel, channel, k):
120 | super(ResiLayer, self).__init__()
121 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1)
122 |
123 | self.rev_att = RevAtt()
124 | self.rev_conv = nn.Sequential(
125 | nn.Conv2d(channel, channel, k, 1, k // 2),
126 | # nn.BatchNorm2d(channel),
127 | nn.ReLU(inplace=True),
128 | )
129 |
130 | self.ori_att = OriAtt()
131 | self.ori_conv = nn.Sequential(
132 | nn.Conv2d(channel, channel, k, 1, k // 2),
133 | # nn.BatchNorm2d(channel),
134 | nn.ReLU(inplace=True),
135 | )
136 |
137 | self.out_layer = nn.Sequential(
138 | nn.Conv2d(channel * 2, channel, k, 1, k // 2),
139 | nn.ReLU(inplace=True),
140 | nn.Conv2d(channel, 1, 3, 1, 1),
141 | )
142 |
143 | def forward(self, sout, pred):
144 | sout = self.conv1x1(sout)
145 |
146 | sout_rev = self.rev_att(sout, pred)
147 | sout_rev = self.rev_conv(sout_rev)
148 |
149 | sout_ori = self.ori_att(sout, pred)
150 | sout_ori = self.ori_conv(sout_ori)
151 |
152 | return self.out_layer(torch.cat((sout_ori, sout_rev), 1))
153 |
154 |
155 | # Multi-Scaled Attention Residual Prediction
156 | class PResiLayer(nn.Module):
157 | def __init__(self, in_channel, channel, k):
158 | super(PResiLayer, self).__init__()
159 | # self.conv1x1 = nn.Conv2d(in_channel, channel, 1)
160 |
161 | self.rev_att = RevAtt()
162 | self.ori_att = OriAtt()
163 |
164 | self.ori_aspp = ASPP(in_channel, channel)
165 | self.rev_aspp = ASPP(in_channel, channel)
166 |
167 | self.out_layer = nn.Sequential(
168 | nn.Conv2d(channel * 8, channel, k, 1, k // 2),
169 | nn.ReLU(inplace=True),
170 | # nn.Dropout(0.5),
171 | nn.Conv2d(channel, 1, 3, 1, 1),
172 | )
173 |
174 | def forward(self, sout, pred):
175 | # sout = self.conv1x1(sout)
176 |
177 | sout_rev = self.rev_att(sout, pred)
178 | sout_rev = self.rev_aspp(sout_rev)
179 |
180 | sout_ori = self.ori_att(sout, pred)
181 | sout_ori = self.ori_aspp(sout_ori)
182 |
183 | sout_cat = torch.cat((sout_ori, sout_rev), 1)
184 |
185 | return self.out_layer(sout_cat)
186 |
187 |
188 | # Top-Down Stream for dual att
189 | class TDLayer(nn.Module):
190 | def __init__(self, in_channel, channel, k):
191 | super(TDLayer, self).__init__()
192 | self.resi_layer = ResiLayer(in_channel, channel, k)
193 |
194 | def forward(self, sout, pred):
195 | pred = nn.functional.interpolate(pred,
196 | size=sout.size()[2:],
197 | mode='bilinear',
198 | align_corners=True)
199 | residual = self.resi_layer(sout, pred)
200 | return pred + residual
201 |
202 |
203 | # Top-Down Stream for Multi-scaled Bi att
204 | class PTDLayer(nn.Module):
205 | def __init__(self, in_channel, channel, k):
206 | super(PTDLayer, self).__init__()
207 | self.resi_layer = PResiLayer(in_channel, channel, k)
208 |
209 | def forward(self, sout, pred):
210 | pred = nn.functional.interpolate(pred,
211 | size=sout.size()[2:],
212 | mode='bilinear',
213 | align_corners=True)
214 | residual = self.resi_layer(sout, pred)
215 | return pred + residual
216 |
217 |
218 | # CANet Modele
219 | class BiANet(nn.Module):
220 | def __init__(self):
221 | super(BiANet, self).__init__()
222 | self.rgb_sout = VGG_Sout()
223 | self.rgb_gs = GSLayer(512, 256, k=5)
224 |
225 | self.dep_sout = VGG_Sout()
226 | self.dep_gs = GSLayer(512, 256, k=5)
227 |
228 | self.rgbd_gs = GSLayer(1024, 256, k=5)
229 |
230 | self.td_layers = nn.ModuleList([
231 | PTDLayer(1024, 32, 3),
232 | PTDLayer(1024, 32, 3),
233 | PTDLayer(512, 32, 3),
234 | TDLayer(256, 32, 3),
235 | TDLayer(128, 32, 3),
236 | ])
237 |
238 | def forward(self, rgb, dep):
239 | [_, _, h, w] = rgb.size()
240 |
241 | rgb_souts, rgb_x = self.rgb_sout(rgb)
242 | dep_souts, dep_x = self.dep_sout(dep)
243 |
244 | rgb_pred = self.rgb_gs(rgb_x) # global saliency
245 | dep_pred = self.dep_gs(dep_x) # global saliency
246 |
247 | rgbd_souts = [] # cat rgb_souts and dep_souts
248 | for idx in range(len(rgb_souts)):
249 | rgbd_souts.append(torch.cat((rgb_souts[idx], dep_souts[idx]), 1))
250 |
251 | rgbd_preds = []
252 | rgbd_preds.append(self.rgbd_gs(torch.cat((rgb_x, dep_x),
253 | 1))) # global saliency
254 |
255 | for idx in range(len(rgbd_souts)):
256 | rgbd_preds.append(self.td_layers[idx](rgbd_souts[-(idx + 1)],
257 | rgbd_preds[idx]))
258 |
259 | scaled_preds = []
260 | scaled_preds.append(
261 | torch.sigmoid(
262 | nn.functional.interpolate(rgb_pred,
263 | size=(h, w),
264 | mode='bilinear',
265 | align_corners=True)))
266 | scaled_preds.append(
267 | torch.sigmoid(
268 | nn.functional.interpolate(dep_pred,
269 | size=(h, w),
270 | mode='bilinear',
271 | align_corners=True)))
272 |
273 | for idx in range(len(rgbd_preds) - 1):
274 | scaled_preds.append(
275 | torch.sigmoid(
276 | nn.functional.interpolate(rgbd_preds[idx],
277 | size=(h, w),
278 | mode='bilinear',
279 | align_corners=True)))
280 | scaled_preds.append(torch.sigmoid(rgbd_preds[-1]))
281 |
282 | # rgb_gs, dep_gs, rgbd(from top to down), final pred is scaled_preds[-1]
283 | return scaled_preds
284 |
285 |
286 | # weight init
287 | def xavier(param):
288 | init.xavier_uniform_(param)
289 |
290 |
291 | def weights_init(m):
292 | if isinstance(m, nn.Conv2d):
293 | xavier(m.weight.data)
294 | elif isinstance(m, nn.BatchNorm2d):
295 | init.constant_(m.weight, 1)
296 | init.constant_(m.bias, 0)
297 |
--------------------------------------------------------------------------------
/models/BiANet_vgg16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 |
6 | backbone = {
7 | 'vgg16': [
8 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
9 | 512, 512, 512, 'M'
10 | ]
11 | }
12 |
13 |
14 | # VGG16 backbone
15 | def vgg(cfg, i=3, batch_norm=False):
16 | layers = []
17 | in_channels = i
18 | for v in cfg:
19 | if v == 'M':
20 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
21 | else:
22 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
23 | if batch_norm:
24 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
25 | else:
26 | layers += [conv2d, nn.ReLU(inplace=True)]
27 | in_channels = v
28 | return layers
29 |
30 |
31 | # VGG16 with Side Outputs
32 | class VGG_Sout(nn.Module):
33 | def __init__(self, extract=[3, 8, 15, 22, 29]):
34 | super(VGG_Sout, self).__init__()
35 | self.vgg = nn.ModuleList(vgg(cfg=backbone['vgg16']))
36 | self.extract = extract
37 |
38 | def forward(self, x):
39 | souts = []
40 | for idx in range(len(self.vgg)):
41 | x = self.vgg[idx](x)
42 | if idx in self.extract:
43 | souts.append(x)
44 |
45 | return souts, x
46 |
47 |
48 | # Global Sliency (A new block following VGG-16 for predict global saliency map)
49 | class GSLayer(nn.Module):
50 | def __init__(self, in_channel, channel, k):
51 | super(GSLayer, self).__init__()
52 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1)
53 | self.convs = nn.Sequential(nn.Conv2d(channel, channel, k, 1, k // 2),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(channel, channel, k, 1, k // 2),
56 | nn.ReLU(inplace=True),
57 | nn.Conv2d(channel, channel, k, 1, k // 2),
58 | nn.ReLU(inplace=True))
59 | self.out_layer = nn.Conv2d(channel, 1, 1)
60 |
61 | def forward(self, x):
62 | x = self.conv1x1(x)
63 | x = self.convs(x)
64 | out = self.out_layer(x)
65 | return out
66 |
67 |
68 | # Foreground Attention
69 | class OriAtt(nn.Module):
70 | def __init__(self):
71 | super(OriAtt, self).__init__()
72 |
73 | def forward(self, sout, pred):
74 | return sout.mul(torch.sigmoid(pred))
75 |
76 |
77 | # Background Attention
78 | class RevAtt(nn.Module):
79 | def __init__(self):
80 | super(RevAtt, self).__init__()
81 |
82 | def forward(self, sout, pred):
83 | return sout.mul(1 - torch.sigmoid(pred))
84 |
85 |
86 | # MBAM
87 | class ASPP(nn.Module):
88 | def __init__(self, in_channel, channel):
89 | super(ASPP, self).__init__()
90 |
91 | self.aspp1 = nn.Sequential(
92 | nn.Conv2d(in_channel, channel, 1, 1),
93 | nn.ReLU(inplace=True),
94 | )
95 | self.aspp2 = nn.Sequential(
96 | nn.Conv2d(in_channel, channel, 3, 1, padding=3, dilation=3),
97 | nn.ReLU(inplace=True),
98 | )
99 |
100 | self.aspp3 = nn.Sequential(
101 | nn.Conv2d(in_channel, channel, 3, 1, padding=5, dilation=5),
102 | nn.ReLU(inplace=True),
103 | )
104 | self.aspp4 = nn.Sequential(
105 | nn.Conv2d(in_channel, channel, 3, 1, padding=7, dilation=7),
106 | nn.ReLU(inplace=True),
107 | )
108 |
109 | def forward(self, x):
110 | x1 = self.aspp1(x)
111 | x2 = self.aspp2(x)
112 | x3 = self.aspp3(x)
113 | x4 = self.aspp4(x)
114 | x = torch.cat((x1, x2, x3, x4), dim=1)
115 |
116 | return x
117 |
118 |
119 | # Output residual
120 | class ResiLayer(nn.Module):
121 | def __init__(self, in_channel, channel, k):
122 | super(ResiLayer, self).__init__()
123 | self.conv1x1 = nn.Conv2d(in_channel, channel, 1)
124 |
125 | self.rev_att = RevAtt()
126 | self.rev_conv = nn.Sequential(
127 | nn.Conv2d(channel, channel, k, 1, k // 2),
128 | nn.ReLU(inplace=True),
129 | )
130 |
131 | self.ori_att = OriAtt()
132 | self.ori_conv = nn.Sequential(
133 | nn.Conv2d(channel, channel, k, 1, k // 2),
134 | nn.ReLU(inplace=True),
135 | )
136 |
137 | self.out_layer = nn.Sequential(
138 | nn.Conv2d(channel * 2, channel, k, 1, k // 2),
139 | nn.ReLU(inplace=True),
140 | nn.Conv2d(channel, 1, 3, 1, 1),
141 | )
142 |
143 | def forward(self, sout, pred):
144 | sout = self.conv1x1(sout)
145 |
146 | sout_rev = self.rev_att(sout, pred)
147 | sout_rev = self.rev_conv(sout_rev)
148 |
149 | sout_ori = self.ori_att(sout, pred)
150 | sout_ori = self.ori_conv(sout_ori)
151 |
152 | return self.out_layer(torch.cat((sout_ori, sout_rev), 1))
153 |
154 |
155 | # Multi-Scaled Residual
156 | class PResiLayer(nn.Module):
157 | def __init__(self, in_channel, channel, k):
158 | super(PResiLayer, self).__init__()
159 | self.rev_att = RevAtt()
160 | self.ori_att = OriAtt()
161 |
162 | self.ori_aspp = ASPP(in_channel, channel)
163 | self.rev_aspp = ASPP(in_channel, channel)
164 |
165 | self.out_layer = nn.Sequential(
166 | nn.Conv2d(channel * 8, channel, k, 1, k // 2),
167 | nn.ReLU(inplace=True),
168 | # nn.Dropout(0.5),
169 | nn.Conv2d(channel, 1, 3, 1, 1),
170 | )
171 |
172 | def forward(self, sout, pred):
173 | sout_rev = self.rev_att(sout, pred)
174 | sout_rev = self.rev_aspp(sout_rev)
175 |
176 | sout_ori = self.ori_att(sout, pred)
177 | sout_ori = self.ori_aspp(sout_ori)
178 |
179 | sout_cat = torch.cat((sout_ori, sout_rev), 1)
180 |
181 | return self.out_layer(sout_cat)
182 |
183 |
184 | # Top-Down Stream
185 | class TDLayer(nn.Module):
186 | def __init__(self, in_channel, channel, k):
187 | super(TDLayer, self).__init__()
188 | self.resi_layer = ResiLayer(in_channel, channel, k)
189 |
190 | def forward(self, sout, pred):
191 | pred = nn.functional.interpolate(pred,
192 | size=sout.size()[2:],
193 | mode='bilinear',
194 | align_corners=True)
195 | residual = self.resi_layer(sout, pred)
196 | return pred + residual
197 |
198 |
199 | # Top-Down Stream with MBAM
200 | class PTDLayer(nn.Module):
201 | def __init__(self, in_channel, channel, k):
202 | super(PTDLayer, self).__init__()
203 | self.resi_layer = PResiLayer(in_channel, channel, k)
204 |
205 | def forward(self, sout, pred):
206 | pred = nn.functional.interpolate(pred,
207 | size=sout.size()[2:],
208 | mode='bilinear',
209 | align_corners=True)
210 | residual = self.resi_layer(sout, pred)
211 | return pred + residual
212 |
213 |
214 | class BiANet(nn.Module):
215 | def __init__(self):
216 | super(BiANet, self).__init__()
217 | self.rgb_sout = VGG_Sout()
218 | self.rgb_gs = GSLayer(512, 256, k=5)
219 |
220 | self.dep_sout = VGG_Sout()
221 | self.dep_gs = GSLayer(512, 256, k=5)
222 |
223 | self.rgbd_gs = GSLayer(1024, 256, k=5)
224 |
225 | self.td_layers = nn.ModuleList([
226 | PTDLayer(1024, 32, 3),
227 | PTDLayer(1024, 32, 3),
228 | PTDLayer(512, 32, 3),
229 | TDLayer(256, 32, 3),
230 | TDLayer(128, 32, 3),
231 | ])
232 |
233 | def forward(self, rgb, dep):
234 | [_, _, h, w] = rgb.size()
235 |
236 | rgb_souts, rgb_x = self.rgb_sout(rgb)
237 | dep_souts, dep_x = self.dep_sout(dep)
238 |
239 | rgb_pred = self.rgb_gs(rgb_x) # global saliency
240 | dep_pred = self.dep_gs(dep_x) # global saliency
241 |
242 | rgbd_souts = [] # cat rgb_souts and dep_souts
243 | for idx in range(len(rgb_souts)):
244 | rgbd_souts.append(torch.cat((rgb_souts[idx], dep_souts[idx]), 1))
245 |
246 | rgbd_preds = []
247 | rgbd_preds.append(self.rgbd_gs(torch.cat((rgb_x, dep_x),
248 | 1))) # global saliency
249 |
250 | for idx in range(len(rgbd_souts)):
251 | rgbd_preds.append(self.td_layers[idx](rgbd_souts[-(idx + 1)],
252 | rgbd_preds[idx]))
253 |
254 | scaled_preds = []
255 | scaled_preds.append(
256 | torch.sigmoid(
257 | nn.functional.interpolate(rgb_pred,
258 | size=(h, w),
259 | mode='bilinear',
260 | align_corners=True)))
261 | scaled_preds.append(
262 | torch.sigmoid(
263 | nn.functional.interpolate(dep_pred,
264 | size=(h, w),
265 | mode='bilinear',
266 | align_corners=True)))
267 |
268 | for idx in range(len(rgbd_preds) - 1):
269 | scaled_preds.append(
270 | torch.sigmoid(
271 | nn.functional.interpolate(rgbd_preds[idx],
272 | size=(h, w),
273 | mode='bilinear',
274 | align_corners=True)))
275 | scaled_preds.append(torch.sigmoid(rgbd_preds[-1]))
276 |
277 | return scaled_preds
278 |
279 |
280 | # weight init
281 | def xavier(param):
282 | init.xavier_uniform_(param)
283 |
284 |
285 | def weights_init(m):
286 | if isinstance(m, nn.Conv2d):
287 | xavier(m.weight.data)
288 | elif isinstance(m, nn.BatchNorm2d):
289 | init.constant_(m.weight, 1)
290 | init.constant_(m.bias, 0)
291 |
--------------------------------------------------------------------------------
/models/res2net_v1b.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 | import torch
6 | import torch.nn.functional as F
7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b']
8 |
9 |
10 | model_urls = {
11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
13 | }
14 |
15 |
16 | class Bottle2neck(nn.Module):
17 | expansion = 4
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
20 | """ Constructor
21 | Args:
22 | inplanes: input channel dimensionality
23 | planes: output channel dimensionality
24 | stride: conv stride. Replaces pooling layer.
25 | downsample: None when stride = 1
26 | baseWidth: basic width of conv3x3
27 | scale: number of scale.
28 | type: 'normal': normal set. 'stage': first block of a new stage.
29 | """
30 | super(Bottle2neck, self).__init__()
31 |
32 | width = int(math.floor(planes * (baseWidth/64.0)))
33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
34 | self.bn1 = nn.BatchNorm2d(width*scale)
35 |
36 | if scale == 1:
37 | self.nums = 1
38 | else:
39 | self.nums = scale -1
40 | if stype == 'stage':
41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
42 | convs = []
43 | bns = []
44 | for i in range(self.nums):
45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
46 | bns.append(nn.BatchNorm2d(width))
47 | self.convs = nn.ModuleList(convs)
48 | self.bns = nn.ModuleList(bns)
49 |
50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
52 |
53 | self.relu = nn.ReLU(inplace=True)
54 | self.downsample = downsample
55 | self.stype = stype
56 | self.scale = scale
57 | self.width = width
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | spx = torch.split(out, self.width, 1)
67 | for i in range(self.nums):
68 | if i==0 or self.stype=='stage':
69 | sp = spx[i]
70 | else:
71 | sp = sp + spx[i]
72 | sp = self.convs[i](sp)
73 | sp = self.relu(self.bns[i](sp))
74 | if i==0:
75 | out = sp
76 | else:
77 | out = torch.cat((out, sp), 1)
78 | if self.scale != 1 and self.stype=='normal':
79 | out = torch.cat((out, spx[self.nums]),1)
80 | elif self.scale != 1 and self.stype=='stage':
81 | out = torch.cat((out, self.pool(spx[self.nums])),1)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 | class Res2Net(nn.Module):
95 |
96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
97 | self.inplanes = 64
98 | super(Res2Net, self).__init__()
99 | self.baseWidth = baseWidth
100 | self.scale = scale
101 | self.conv1 = nn.Sequential(
102 | nn.Conv2d(3, 32, 3, 1, 1, bias=False),
103 | nn.BatchNorm2d(32),
104 | nn.ReLU(inplace=True),
105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False),
106 | nn.BatchNorm2d(32),
107 | nn.ReLU(inplace=True),
108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False)
109 | )
110 | self.bn1 = nn.BatchNorm2d(64)
111 | self.relu = nn.ReLU()
112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
113 | self.layer1 = self._make_layer(block, 64, layers[0])
114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
117 | # self.avgpool = nn.AdaptiveAvgPool2d(1)
118 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
123 | elif isinstance(m, nn.BatchNorm2d):
124 | nn.init.constant_(m.weight, 1)
125 | nn.init.constant_(m.bias, 0)
126 |
127 | def _make_layer(self, block, planes, blocks, stride=1):
128 | downsample = None
129 | if stride != 1 or self.inplanes != planes * block.expansion:
130 | downsample = nn.Sequential(
131 | nn.AvgPool2d(kernel_size=stride, stride=stride,
132 | ceil_mode=True, count_include_pad=False),
133 | nn.Conv2d(self.inplanes, planes * block.expansion,
134 | kernel_size=1, stride=1, bias=False),
135 | nn.BatchNorm2d(planes * block.expansion),
136 | )
137 |
138 | layers = []
139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
140 | stype='stage', baseWidth = self.baseWidth, scale=self.scale))
141 | self.inplanes = planes * block.expansion
142 | for i in range(1, blocks):
143 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))
144 |
145 | return nn.Sequential(*layers)
146 |
147 | def forward(self, x):
148 | x = self.conv1(x)
149 | x = self.bn1(x)
150 | x = self.relu(x)
151 | x = self.maxpool(x)
152 |
153 | x = self.layer1(x)
154 | x = self.layer2(x)
155 | x = self.layer3(x)
156 | x = self.layer4(x)
157 |
158 | # x = self.avgpool(x)
159 | # x = x.view(x.size(0), -1)
160 | # x = self.fc(x)
161 |
162 | return x
163 |
164 |
165 | def res2net50_v1b(pretrained=False, **kwargs):
166 | """Constructs a Res2Net-50_v1b model.
167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
168 | Args:
169 | pretrained (bool): If True, returns a model pre-trained on ImageNet
170 | """
171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
172 | if pretrained:
173 | pretrained_dict = model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])
174 | model_dict=model.state_dict()
175 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
176 | model_dict.update(pretrained_dict)
177 | model.load_state_dict(model_dict)
178 | return model
179 |
180 | def res2net101_v1b(pretrained=False, **kwargs):
181 | """Constructs a Res2Net-50_v1b_26w_4s model.
182 | Args:
183 | pretrained (bool): If True, returns a model pre-trained on ImageNet
184 | """
185 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
186 | if pretrained:
187 | pretrained_dict = model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])
188 | model_dict=model.state_dict()
189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
190 | model_dict.update(pretrained_dict)
191 | model.load_state_dict(model_dict)
192 | return model
193 |
194 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs):
195 | """Constructs a Res2Net-50_v1b_26w_4s model.
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | """
199 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
200 | if pretrained:
201 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
202 | return model
203 |
204 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs):
205 | """Constructs a Res2Net-50_v1b_26w_4s model.
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
210 | if pretrained:
211 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
212 | return model
213 |
214 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
215 | """Constructs a Res2Net-50_v1b_26w_4s model.
216 | Args:
217 | pretrained (bool): If True, returns a model pre-trained on ImageNet
218 | """
219 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs)
220 | if pretrained:
221 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s']))
222 | return model
223 |
224 |
225 |
226 |
227 |
228 | if __name__ == '__main__':
229 | images = torch.rand(1, 3, 224, 224).cuda(0)
230 | model = res2net50_v1b_26w_4s(pretrained=True)
231 | model = model.cuda(0)
232 | print(model(images).size())
233 |
--------------------------------------------------------------------------------
/models/resnet_conv1.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
11 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
12 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
13 | }
14 |
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | "3x3 convolution with padding"
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(BasicBlock, self).__init__()
27 | self.conv1 = conv3x3(inplanes, planes, stride)
28 | self.bn1 = nn.BatchNorm2d(planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.conv2 = conv3x3(planes, planes)
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super(Bottleneck, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
62 | padding=1, bias=False)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
65 | self.bn3 = nn.BatchNorm2d(planes * 4)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv3(out)
82 | out = self.bn3(out)
83 |
84 | if self.downsample is not None:
85 | residual = self.downsample(x)
86 |
87 | out += residual
88 | out = self.relu(out)
89 |
90 | return out
91 |
92 |
93 | class ResNet(nn.Module):
94 |
95 | def __init__(self, block, layers, num_classes=1000):
96 | self.inplanes = 128
97 | super(ResNet, self).__init__()
98 | self.conv1 = conv3x3(3, 64, stride=1)
99 | self.bn1 = nn.BatchNorm2d(64)
100 | self.relu1 = nn.ReLU(inplace=True)
101 | self.conv2 = conv3x3(64, 64)
102 | self.bn2 = nn.BatchNorm2d(64)
103 | self.relu2 = nn.ReLU(inplace=True)
104 | self.conv3 = conv3x3(64, 128)
105 | self.bn3 = nn.BatchNorm2d(128)
106 | self.relu3 = nn.ReLU(inplace=True)
107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108 |
109 | self.layer1 = self._make_layer(block, 64, layers[0])
110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113 |
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 |
122 | def _make_layer(self, block, planes, blocks, stride=1):
123 | downsample = None
124 | if stride != 1 or self.inplanes != planes * block.expansion:
125 | downsample = nn.Sequential(
126 | nn.Conv2d(self.inplanes, planes * block.expansion,
127 | kernel_size=1, stride=stride, bias=False),
128 | nn.BatchNorm2d(planes * block.expansion),
129 | )
130 |
131 | layers = []
132 | layers.append(block(self.inplanes, planes, stride, downsample))
133 | self.inplanes = planes * block.expansion
134 | for i in range(1, blocks):
135 | layers.append(block(self.inplanes, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def forward(self, x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.maxpool(x)
144 |
145 | x = self.layer1(x)
146 | x = self.layer2(x)
147 | x = self.layer3(x)
148 | x = self.layer4(x)
149 |
150 | return x
151 |
152 | def resnet18(pretrained=False, **kwargs):
153 | """Constructs a ResNet-18 model.
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | """
157 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
158 | if pretrained:
159 | model.load_state_dict(load_url(model_urls['resnet18']))
160 | return model
161 |
162 | '''
163 | def resnet34(pretrained=False, **kwargs):
164 | """Constructs a ResNet-34 model.
165 | Args:
166 | pretrained (bool): If True, returns a model pre-trained on ImageNet
167 | """
168 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
169 | if pretrained:
170 | model.load_state_dict(load_url(model_urls['resnet34']))
171 | return model
172 | '''
173 |
174 | def resnet50(pretrained=False, **kwargs):
175 | """Constructs a ResNet-50 model.
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | """
179 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
180 | if pretrained:
181 | pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
182 | model_dict=model.state_dict()
183 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
184 | model_dict.update(pretrained_dict)
185 | model.load_state_dict(model_dict)
186 | return model
187 |
188 |
189 | def resnet101(pretrained=False, **kwargs):
190 | """Constructs a ResNet-101 model.
191 | Args:
192 | pretrained (bool): If True, returns a model pre-trained on ImageNet
193 | """
194 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
195 | if pretrained:
196 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
197 | return model
198 |
199 | # def resnet152(pretrained=False, **kwargs):
200 | # """Constructs a ResNet-152 model.
201 | #
202 | # Args:
203 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | # """
205 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
206 | # if pretrained:
207 | # model.load_state_dict(load_url(model_urls['resnet152']))
208 | # return model
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from dataset import get_loader
3 | import torch
4 | from torchvision import transforms
5 | from torch import nn
6 | import os
7 | import argparse
8 |
9 |
10 | def main(args):
11 |
12 | backbone_names = args.backbones.split('+')
13 | dataset_names = args.datasets.split('+')
14 |
15 | for dataset in dataset_names:
16 | for backbone in backbone_names:
17 | print("Working on [DATASET: %s] with [BACKBONE: %s]" %
18 | (dataset, backbone))
19 |
20 | # Configure testset path
21 | test_rgb_path = os.path.join(args.input_root, dataset, 'RGB')
22 | test_dep_path = os.path.join(args.input_root, dataset, 'depth')
23 |
24 | res_path = os.path.join(args.save_root, 'BiANet_' + backbone,
25 | dataset)
26 | os.makedirs(res_path, exist_ok=True)
27 | test_loader = get_loader(test_rgb_path,
28 | test_dep_path,
29 | 224,
30 | 1,
31 | num_thread=8,
32 | pin=True)
33 |
34 | # Load model and parameters
35 | exec('from models import BiANet_' + backbone)
36 | model = eval('BiANet_' + backbone).BiANet()
37 | pre_dict = torch.load(
38 | os.path.join(args.param_root, 'BiANet_' + backbone + '.pth'))
39 | device = torch.device("cuda")
40 | model.to(device)
41 | if backbone == 'vgg16':
42 | model = torch.nn.DataParallel(model, device_ids=[0])
43 | model.load_state_dict(pre_dict)
44 | model.eval()
45 |
46 | # Test Go!
47 | tensor2pil = transforms.ToPILImage()
48 | with torch.no_grad():
49 | for batch in test_loader:
50 | rgbs = batch[0].to(device)
51 | deps = batch[1].to(device)
52 | name = batch[2][0]
53 | imsize = batch[3]
54 |
55 | scaled_preds = model(rgbs, deps)
56 |
57 | res = scaled_preds[-1]
58 |
59 | res = nn.functional.interpolate(res,
60 | size=imsize,
61 | mode='bilinear',
62 | align_corners=True).cpu()
63 | res = res.squeeze(0)
64 | res = tensor2pil(res)
65 | res.save(os.path.join(res_path, name[:-3] + 'png'))
66 |
67 | print('Outputs were saved at:' + args.save_root)
68 |
69 |
70 | if __name__ == '__main__':
71 | # Parameter from command line
72 | parser = argparse.ArgumentParser(description='')
73 | parser.add_argument('--backbones',
74 | default='vgg16',
75 | type=str,
76 | help="Options: 'vgg11','vgg16','res50', 'res2_50")
77 | parser.add_argument(
78 | '--datasets',
79 | default='NJU2K_Test',
80 | type=str,
81 | help="Options: 'NJU2K_TEST', 'NLPR_TEST','DES','SSD','STERE','SIP'")
82 | parser.add_argument('--size', default=224, type=int, help='input size')
83 | parser.add_argument('--param_root',
84 | default='param',
85 | type=str,
86 | help='folder for pre-trained model')
87 | parser.add_argument('--input_root',
88 | default='./Testset',
89 | type=str,
90 | help='dataset root')
91 | args = parser.parse_args()
92 | parser.add_argument('--save_root',
93 | default='./SalMap',
94 | type=str,
95 | help='Output folder')
96 | args = parser.parse_args()
97 | main(args)
98 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test.py --backbones vgg16+vgg11+res50+res2_50 --datasets NJU2K_TEST+NLPR_TEST+DES+SSD+STERE+SIP --param_root param --save_root ../SalMaps_Minor/pred
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | import shutil
5 | from torchvision import transforms
6 | import numpy as np
7 | import random
8 |
9 |
10 | class Logger():
11 | def __init__(self, path="log.txt"):
12 | self.logger = logging.getLogger('DGNet')
13 | self.file_handler = logging.FileHandler(path, "w")
14 | self.stdout_handler = logging.StreamHandler()
15 | self.stdout_handler.setFormatter(
16 | logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
17 | self.file_handler.setFormatter(
18 | logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
19 | self.logger.addHandler(self.file_handler)
20 | self.logger.addHandler(self.stdout_handler)
21 | self.logger.setLevel(logging.INFO)
22 | self.logger.propagate = False
23 |
24 | def info(self, txt):
25 | self.logger.info(txt)
26 |
27 | def close(self):
28 | self.file_handler.close()
29 | self.stdout_handler.close()
30 |
31 |
32 | class AverageMeter(object):
33 | """Computes and stores the average and current value"""
34 | def __init__(self):
35 | self.reset()
36 |
37 | def reset(self):
38 | self.val = 0.0
39 | self.avg = 0.0
40 | self.sum = 0.0
41 | self.count = 0.0
42 |
43 | def update(self, val, n=1):
44 | self.val = val
45 | self.sum += val * n
46 | self.count += n
47 | self.avg = self.sum / self.count
48 |
49 |
50 | def save_tensor_img(tenor_im, path):
51 | im = tenor_im.cpu().clone()
52 | im = im.squeeze(0)
53 | tensor2pil = transforms.ToPILImage()
54 | im = tensor2pil(im)
55 | im.save(path)
--------------------------------------------------------------------------------