├── LICENSE
├── README.md
├── data.py
├── dataset
└── .gitkeep
├── depth.py
├── img
├── benchmark.png
├── benchmark_vis_IJCV.png
├── qualitative_results.png
├── quantitative_results.png
├── structure_diagram.png
└── structure_diagram_IJCV.png
├── mobilenet.py
├── net.py
├── options.py
├── pretrain
└── .gitkeep
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 zwbx
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DFM-Net (ACM MM 2021)
2 | Official repository for paper [Depth Quality-Inspired Feature Manipulation for Efficient RGB-D Salient Object Detection](https://arxiv.org/pdf/2107.01779.pdf) | [中文版](https://pan.baidu.com/s/1axKXAqBmMmQuPTvTTY_LNg?pwd=jsvr)
3 |
4 | ## News
5 | - 6/Jun/2022🔥[online demo](http://rgbdsod-krf.natapp4.cc/) is newly realeased!
6 | - 8/Aug/2022 we extend DFM-Net to Video Salient Object Detection task, which refers to [Depth Quality-Inspired Feature Manipulation for Efficient RGB-D and Video Salient Object Detection](https://arxiv.org/abs/2208.03918)
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | Block diagram of DFM-Net.
15 |
16 |
17 |
18 |
19 | ## The most efficient RGB-D SOD method ⚡
20 | - Low model size: Model size is only **8.5Mb**, being **6.7/3.1 smaller** than the latest lightest models A2dele and MobileSal.
21 | - High accuracy: SOTA performance on 9 datasets (NJU2K, NLPR, STERE, RGBD135, LFSD, SIP, DUT-RGBD, RedWeb-S, COME).
22 | - High Speed: Cost 50ms on CPU (Core i7-8700 CPU), being **2.9/2.4 faster** than the latest fastest models A2dele and MobileSal.
23 |
24 |
25 |
26 |
27 | Performance visualization. Performance visualization. The vertical axis indicates the average S-measure over six widely used datasets (NJU2K, NLPR, STERE, RGBD135, LFSD, SIP). The horizontal axis indicates CPU speed. The circle area is proportional to the model size.
28 |
29 |
30 |
31 |
32 | ## Extension :fire:
33 | [Depth Quality-Inspired Feature Manipulation for Efficient RGB-D and Video Salient Object Detection](https://arxiv.org/abs/2208.03918)
34 | - More comprehensive comparison:
35 | - Benchmark results on DUT-RGBD, RedWeb-S, COME are updated.
36 | - Metric of maximum-batch inference speed is added.
37 | - We re-test the inference speed of ours and compared methods on Ubuntu 16.04.
38 | - Working mechanism explanation
39 | - Further analyses verify the ability of DQFM in distinguishing depth maps of various qualities without any quality labels.
40 | - Application on efficient VSOD
41 | - One of the lightest VSOD methods!
42 | - Joint training strategy is proposed.
43 |
44 |
45 |
46 | ## Easy-to-use to boost your RGB-D SOD network
47 | If you use a depth branch as an affiliate to the RGB branch:
48 | - Use DQW/DHA to boost performance with extra 0.007/0.042Mb model size increased
49 | - Use our light-weight depth backbone to improve efficiency
50 |
51 | if you adopt parallel encoders for RGB and depth:
52 | - refer to our other work [BTS-Net](https://github.com/zwbx/BTS-Net)
53 |
54 |
55 |
56 |
57 | ## Test
58 |
59 | Directly run test.py
60 |
61 | The test maps will be saved to './resutls/'.
62 |
63 | data preparation
64 | - Classic benchmark: training on NJU2K and NLPR and test on NJU2K, NLPR, STERE, RGBD135, LFSD, SIP.
65 | - [test data](https://pan.baidu.com/s/1wI-bxarzdSrOY39UxZaomQ) [code: 940i]
66 | - [pretrained model for DFMNet](https://pan.baidu.com/s/1pTEByo0OngNJlKCJsTcx-A?pwd=skin)
67 | - Additional test datasets [RedWeb-S](https://github.com/nnizhang/SMAC) 🆕, updated in journal version.
68 | - DUT-RGBD benchmark 🆕
69 | - Download the training and test data in [official repository](https://pan.baidu.com/s/1mhHAXLgoqqLQIb6r-k-hbA).
70 | - [pretrained model for DFMNet](https://pan.baidu.com/s/1GJHvxh2gTLutpM1hfESDNg?pwd=nmw3).
71 | - COME benchmark 🆕
72 | - Download the training and test data in [official repository](https://github.com/JingZhang617/cascaded_rgbd_sod).
73 | - [pretrained model for DFMNet](https://pan.baidu.com/s/1fCYF5p9dCC8RXRCLaWUQlg?pwd=iqyf).
74 |
75 | ## Results
76 |
77 | - We provide testing results of 9 datasets (NJU2K, NLPR, STERE, RGBD135, LFSD, SIP, DUT-RGBD 🆕, RedWeb-S 🆕, COME 🆕).
78 | - [Results of DFM-Net](https://pan.baidu.com/s/1wZyYqYISpRGZATDgKYO4nA?pwd=4jqu).
79 | - [Results of DFM-Net*](https://pan.baidu.com/s/1vemT9nfaXoSc_tqSYakSCg?pwd=pax4).
80 |
81 | - Evaluate the result maps:
82 | You can evaluate the result maps using the tool in [Matlab Version](http://dpfan.net/d3netbenchmark/) or [Python_GPU Version](https://github.com/zyjwuyan/SOD_Evaluation_Metrics).
83 |
84 | - Note that the parameter file is 8.9Mb, which is 0.4Mb bigger than we report in the paper because keys denoting parameter names also occupy some space. Then put them under the following directory:
85 |
86 | -dataset\
87 | -RGBD_train
88 | -NJU2K\
89 | -NLPR\
90 | ...
91 | -pretrain
92 | -DFMNet_300_epoch.pth
93 | ...
94 |
95 |
96 | ## Training
97 | - Download [training data](https://pan.baidu.com/s/1ckNlS0uEIPV-iCwVzjutsQ)(eb2z)
98 | - Modify setting in options.py and run train.py
99 |
100 |
101 | ## Application on VSOD 🆕
102 | - We provide testing results of 4 datasets (DAVIS, FBMS, MCL, DAVSOD).
103 | - [Results of DFM-Net](https://pan.baidu.com/s/1jLGP2kV_Z7esOkkY3jKFQw?pwd=58wc).
104 | - [Results of DFM-Net*](https://pan.baidu.com/s/1EV4_neyES7jAyo0op-XfTA?pwd=pp2w).
105 |
106 | ## Citation
107 |
108 | Please cite the following paper if you use this repository in your research
109 |
110 | @inproceedings{zhang2021depth,
111 | title={Depth quality-inspired feature manipulation for efficient RGB-D salient object detection},
112 | author={Zhang, Wenbo and Ji, Ge-Peng and Wang, Zhuo and Fu, Keren and Zhao, Qijun},
113 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
114 | pages={731--740},
115 | year={2021}
116 | }
117 |
118 | @artical{zhang2022depth,
119 | title={Depth Quality-Inspired Feature Manipulation for Efficient RGB-D and Video Salient Object Detection},
120 | author={Zhang, Wenbo and Fu, Keren and Wang, Zhuo and Ji, Ge-Peng and Zhao, Qijun},
121 | booktitle={arXiv:2208.03918},
122 | year={2022}
123 | }
124 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | from PIL import Image
4 | import torch.utils.data as data
5 | import torchvision.transforms as transforms
6 | import random
7 | import numpy as np
8 | from PIL import ImageEnhance
9 | from natsort import natsorted
10 | import torch
11 |
12 | #several data augumentation strategies
13 | def cv_random_flip(img, label,depth,edge):
14 | flip_flag = random.randint(0, 1)
15 | # flip_flag2= random.randint(0,1)
16 | #left right flip
17 | if flip_flag == 1:
18 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
19 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
20 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
21 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT)
22 | #top bottom flip
23 | # if flip_flag2==1:
24 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
25 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
26 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
27 | return img, label, depth, edge
28 | def randomCrop(image, label,depth,edge):
29 | border=30
30 | image_width = image.size[0]
31 | image_height = image.size[1]
32 | crop_win_width = np.random.randint(image_width-border , image_width)
33 | crop_win_height = np.random.randint(image_height-border , image_height)
34 | random_region = (
35 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
36 | (image_height + crop_win_height) >> 1)
37 | return image.crop(random_region), label.crop(random_region),depth.crop(random_region),edge.crop(random_region)
38 | def randomRotation(image,label,depth,edge):
39 | mode=Image.BICUBIC
40 | if random.random()>0.8:
41 | random_angle = np.random.randint(-15, 15)
42 | image=image.rotate(random_angle, mode)
43 | label=label.rotate(random_angle, mode)
44 | depth=depth.rotate(random_angle, mode)
45 | edge = edge.rotate(random_angle, mode)
46 | return image,label,depth,edge
47 | def colorEnhance(image):
48 | bright_intensity=random.randint(5,15)/10.0
49 | image=ImageEnhance.Brightness(image).enhance(bright_intensity)
50 | contrast_intensity=random.randint(5,15)/10.0
51 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity)
52 | color_intensity=random.randint(0,20)/10.0
53 | image=ImageEnhance.Color(image).enhance(color_intensity)
54 | sharp_intensity=random.randint(0,30)/10.0
55 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity)
56 | return image
57 | def randomGaussian(image, mean=0.1, sigma=0.35):
58 | def gaussianNoisy(im, mean=mean, sigma=sigma):
59 | for _i in range(len(im)):
60 | im[_i] += random.gauss(mean, sigma)
61 | return im
62 | img = np.asarray(image)
63 | width, height = img.shape
64 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
65 | img = img.reshape([width, height])
66 | return Image.fromarray(np.uint8(img))
67 | def randomPeper(img):
68 |
69 | img=np.array(img)
70 | noiseNum=int(0.0015*img.shape[0]*img.shape[1])
71 | for i in range(noiseNum):
72 |
73 | randX=random.randint(0,img.shape[0]-1)
74 |
75 | randY=random.randint(0,img.shape[1]-1)
76 |
77 | if random.randint(0,1)==0:
78 |
79 | img[randX,randY]=0
80 |
81 | else:
82 |
83 | img[randX,randY]=255
84 | return Image.fromarray(img)
85 |
86 | # dataset for training
87 | #The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
88 | #(e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
89 | class SalObjDataset(data.Dataset):
90 | def __init__(self, image_root, gt_root,depth_root,edge_root, trainsize):
91 | self.trainsize = trainsize
92 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
93 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
94 | or f.endswith('.png')]
95 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
96 | or f.endswith('.png')]
97 | self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.bmp')
98 | or f.endswith('.png')]
99 | self.images = natsorted(self.images)
100 | self.gts = natsorted(self.gts)
101 | self.depths= natsorted(self.depths)
102 | self.edges = natsorted(self.edges)
103 | # print(self.images)
104 | # print(self.depths)
105 | # print(self.gts)
106 | self.filter_files()
107 | self.size = len(self.images)
108 | self.img_transform = transforms.Compose([
109 | transforms.Resize((self.trainsize, self.trainsize)),
110 | transforms.ToTensor(),
111 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
112 | self.gt_transform = transforms.Compose([
113 | transforms.Resize((self.trainsize, self.trainsize)),
114 | transforms.ToTensor()])
115 | self.depths_transform = transforms.Compose([
116 | transforms.Resize((self.trainsize, self.trainsize)),
117 | transforms.ToTensor(),])
118 |
119 | def __getitem__(self, index):
120 | image = self.rgb_loader(self.images[index])
121 | gt = self.binary_loader(self.gts[index])
122 | depth = self.binary_loader(self.depths[index])
123 | edge_gt = self.binary_loader(self.edges[index])
124 | depth = PIL.ImageOps.invert(depth)
125 | image,gt,depth,edge_gt =cv_random_flip(image,gt,depth,edge_gt)
126 | image,gt,depth,edge_gt=randomCrop(image, gt,depth,edge_gt)
127 | image,gt,depth,edge_gt=randomRotation(image, gt,depth,edge_gt)
128 | image=colorEnhance(image)
129 | # depth= colorEnhance(depth)
130 | #gt=randomGaussian(gt)
131 | gt=randomPeper(gt)
132 | image = self.img_transform(image)
133 | gt = self.gt_transform(gt)
134 | edge_gt = self.gt_transform(edge_gt)
135 | edge_gt = (edge_gt - edge_gt.min()) / (edge_gt.max() - edge_gt.min() + 1e-8)
136 | depth=self.depths_transform(depth)
137 | return image, gt, depth, edge_gt
138 |
139 | def filter_files(self):
140 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images)
141 | images = []
142 | gts = []
143 | depths=[]
144 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths):
145 | img = Image.open(img_path)
146 | gt = Image.open(gt_path)
147 | depth= Image.open(depth_path)
148 | if img.size == gt.size and gt.size==depth.size:
149 | images.append(img_path)
150 | gts.append(gt_path)
151 | depths.append(depth_path)
152 | self.images = images
153 | self.gts = gts
154 | self.depths=depths
155 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
156 |
157 | def rgb_loader(self, path):
158 | with open(path, 'rb') as f:
159 | img = Image.open(f)
160 | return img.convert('RGB')
161 |
162 | def binary_loader(self, path):
163 | with open(path, 'rb') as f:
164 | img = Image.open(f)
165 | return img.convert('L')
166 |
167 | def rgb_loader_ops(self, path):
168 | with open(path, 'rb') as f:
169 | img = Image.open(f)
170 | return PIL.ImageOps.invert(img.convert('RGB'))
171 |
172 | def resize(self, img, gt, depth):
173 | assert img.size == gt.size and gt.size==depth.size
174 | w, h = img.size
175 | if h < self.trainsize or w < self.trainsize:
176 | h = max(h, self.trainsize)
177 | w = max(w, self.trainsize)
178 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST)
179 | else:
180 | return img, gt, depth
181 |
182 | def __len__(self):
183 | return self.size
184 |
185 | #dataloader for training
186 | def get_loader(image_root, gt_root,depth_root,edge_root, batchsize, trainsize, shuffle=True, num_workers=2, pin_memory=False):
187 |
188 | dataset = SalObjDataset(image_root, gt_root, depth_root,edge_root,trainsize)
189 | data_loader = data.DataLoader(dataset=dataset,
190 | batch_size=batchsize,
191 | shuffle=shuffle,
192 | num_workers=num_workers,
193 | pin_memory=pin_memory)
194 | return data_loader
195 |
196 | #test dataset and loader
197 | class test_dataset:
198 | def __init__(self, image_root, gt_root,depth_root, testsize):
199 | self.testsize = testsize
200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
201 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
202 | or f.endswith('.png')]
203 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
204 | or f.endswith('.png')]
205 | self.images = natsorted(self.images)
206 | self.gts = natsorted(self.gts)
207 | self.depths= natsorted(self.depths)
208 | # print(self.images)
209 | # print(self.depths)
210 | # print(self.gts)
211 | self.filter_files()
212 | self.transform = transforms.Compose([
213 | transforms.Resize((self.testsize, self.testsize)),
214 | transforms.ToTensor(),
215 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
216 | self.gt_transform = transforms.ToTensor()
217 | # self.gt_transform = transforms.Compose([
218 | # transforms.Resize((self.trainsize, self.trainsize)),
219 | # transforms.ToTensor()])
220 | self.depths_transform = transforms.Compose([
221 | transforms.Resize((self.testsize, self.testsize)),
222 | transforms.ToTensor(),
223 | ])
224 | self.size = len(self.images)
225 | self.index = 0
226 |
227 | def load_data(self):
228 | image = self.rgb_loader(self.images[self.index])
229 | image = self.transform(image).unsqueeze(0)
230 | gt = self.binary_loader(self.gts[self.index])
231 | depth=self.binary_loader_ops(self.depths[self.index])
232 | pesudo_depth = self.depths_transform(self.rgb_loader_ops(self.gts[self.index])).unsqueeze(0)
233 | depth=self.depths_transform(depth).unsqueeze(0)
234 | name = self.images[self.index].split('/')[-1]
235 | image_for_post=self.rgb_loader(self.images[self.index])
236 | image_for_post=image_for_post.resize(gt.size)
237 | if name.endswith('.jpg'):
238 | name = name.split('.jpg')[0] + '.png'
239 | self.index += 1
240 | self.index = self.index % self.size
241 | return image, gt,depth, name,np.array(image_for_post)
242 |
243 | def rgb_loader(self, path):
244 | with open(path, 'rb') as f:
245 | img = Image.open(f)
246 | return img.convert('RGB')
247 |
248 | def binary_loader(self, path):
249 | with open(path, 'rb') as f:
250 | img = Image.open(f)
251 | return img.convert('L')
252 |
253 | def binary_loader_ops(self, path):
254 | with open(path, 'rb') as f:
255 | img = Image.open(f)
256 | return PIL.ImageOps.invert(img.convert('L'))
257 |
258 | def rgb_loader_ops(self, path):
259 | with open(path, 'rb') as f:
260 | img = Image.open(f)
261 | img = PIL.ImageOps.invert(img.convert('RGB'))
262 | return img
263 |
264 | def __len__(self):
265 | return self.size
266 |
267 | def filter_files(self):
268 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images)
269 | images = []
270 | gts = []
271 | depths=[]
272 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths):
273 | img = Image.open(img_path)
274 | gt = Image.open(gt_path)
275 | depth= Image.open(depth_path)
276 | if img.size == gt.size and gt.size==depth.size:
277 | images.append(img_path)
278 | gts.append(gt_path)
279 | depths.append(depth_path)
280 | # else:
281 | # print(img.size, depth.size, gt.size)
282 | self.images = images
283 | self.gts = gts
284 | self.depths = depths
285 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
286 |
287 |
--------------------------------------------------------------------------------
/dataset/.gitkeep:
--------------------------------------------------------------------------------
1 | #
2 |
--------------------------------------------------------------------------------
/depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | from torch.nn import functional as F
5 | import time
6 | import timm
7 | from mobilenet import MobileNetV2Encoder
8 |
9 |
10 | import os
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 |
17 | def upsample(x, size):
18 | return F.interpolate(x, size, mode='bilinear', align_corners=True)
19 |
20 | def initialize_weights(model):
21 | m = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True)
22 | pretrained_dict = m.state_dict()
23 | all_params = {}
24 | for k, v in model.state_dict().items():
25 | if k in pretrained_dict.keys() and v.shape == pretrained_dict[k]:
26 | v = pretrained_dict[k]
27 | all_params[k] = v
28 | # assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
29 | model.load_state_dict(all_params,strict=False)
30 |
31 | class DepthBranch(nn.Module):
32 | def __init__(self, c1=8, c2=16, c3=32, c4=48, c5=320, **kwargs):
33 | super(DepthBranch, self).__init__()
34 | self.bottleneck1 = _make_layer(LinearBottleneck, 1, 16, blocks=1, t=3, stride=2)
35 | self.bottleneck2 = _make_layer(LinearBottleneck, 16, 24, blocks=3, t=3, stride=2)
36 | self.bottleneck3 = _make_layer(LinearBottleneck, 24, 32, blocks=7, t=3, stride=2)
37 | self.bottleneck4 = _make_layer(LinearBottleneck, 32, 96, blocks=3, t=2, stride=2)
38 | self.bottleneck5 = _make_layer(LinearBottleneck, 96, 320, blocks=1, t=2, stride=1)
39 |
40 | # self.conv_s_d = _ConvBNReLU(320,1,1,1)
41 |
42 | # nn.Sequential(_DSConv(c3, c3 // 4),
43 | # nn.Conv2d(c3 // 4, 1, 1), )
44 |
45 | def forward(self, x):
46 | size = x.size()[2:]
47 | feat = []
48 |
49 | x1 = self.bottleneck1(x)
50 | x2 = self.bottleneck2(x1)
51 | x3 = self.bottleneck3(x2)
52 | x4 = self.bottleneck4(x3)
53 | x5 = self.bottleneck5(x4)
54 | # s_d = self.conv_s_d(x5)
55 |
56 | feat.append(x1)
57 | feat.append(x2)
58 | feat.append(x3)
59 | feat.append(x4)
60 | feat.append(x5)
61 | return x1 ,feat
62 |
63 | class _ConvBNReLU(nn.Module):
64 | """Conv-BN-ReLU"""
65 |
66 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, **kwargs):
67 | super(_ConvBNReLU, self).__init__()
68 | self.conv = nn.Sequential(
69 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
70 | nn.BatchNorm2d(out_channels),
71 | nn.ReLU(True)
72 | )
73 |
74 | def forward(self, x):
75 | return self.conv(x)
76 |
77 |
78 | class _DSConv(nn.Module):
79 | """Depthwise Separable Convolutions"""
80 |
81 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
82 | super(_DSConv, self).__init__()
83 | self.conv = nn.Sequential(
84 | nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False),
85 | nn.BatchNorm2d(dw_channels),
86 | nn.ReLU(True),
87 | nn.Conv2d(dw_channels, out_channels, 1, bias=False),
88 | nn.BatchNorm2d(out_channels),
89 | nn.ReLU(True)
90 | )
91 |
92 | def forward(self, x):
93 | return self.conv(x)
94 |
95 | def _make_layer( block, inplanes, planes, blocks, t=6, stride=1):
96 | layers = []
97 | layers.append(block(inplanes, planes, t, stride))
98 | for i in range(1, blocks):
99 | layers.append(block(planes, planes, t, 1))
100 | return nn.Sequential(*layers)
101 |
102 | class _DWConv(nn.Module):
103 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
104 | super(_DWConv, self).__init__()
105 | self.conv = nn.Sequential(
106 | nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False),
107 | nn.BatchNorm2d(out_channels),
108 | nn.ReLU(True)
109 | )
110 |
111 | def forward(self, x):
112 | return self.conv(x)
113 |
114 |
115 |
116 |
117 | class LinearBottleneck(nn.Module):
118 | """LinearBottleneck used in MobileNetV2"""
119 |
120 | def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs):
121 | super(LinearBottleneck, self).__init__()
122 | self.use_shortcut = stride == 1 and in_channels == out_channels
123 | self.block = nn.Sequential(
124 | # pw
125 | _ConvBNReLU(in_channels, in_channels * t, 1),
126 | # dw
127 | _DWConv(in_channels * t, in_channels * t, stride),
128 | # pw-linear
129 | nn.Conv2d(in_channels * t, out_channels, 1, bias=False),
130 | nn.BatchNorm2d(out_channels)
131 | )
132 |
133 | def forward(self, x):
134 | out = self.block(x)
135 | if self.use_shortcut:
136 | out = x + out
137 | return out
138 |
139 |
140 | class PyramidPooling(nn.Module):
141 | """Pyramid pooling module"""
142 |
143 | def __init__(self, in_channels, out_channels, **kwargs):
144 | super(PyramidPooling, self).__init__()
145 | inter_channels = int(in_channels / 4)
146 | self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
147 | self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
148 | self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
149 | self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
150 | self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
151 |
152 | def pool(self, x, size):
153 | avgpool = nn.AdaptiveAvgPool2d(size)
154 | return avgpool(x)
155 |
156 | def forward(self, x):
157 | size = x.size()[2:]
158 | feat1 = upsample(self.conv1(self.pool(x, 1)), size)
159 | feat2 = upsample(self.conv2(self.pool(x, 2)), size)
160 | feat3 = upsample(self.conv3(self.pool(x, 3)), size)
161 | feat4 = upsample(self.conv4(self.pool(x, 6)), size)
162 | x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
163 | x = self.out(x)
164 | return x
165 |
166 |
167 |
168 |
169 | class BasicConv2d(nn.Module):
170 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, activation='relu'):
171 | super(BasicConv2d, self).__init__()
172 | self.conv = nn.Conv2d(in_planes, out_planes,
173 | kernel_size=kernel_size, stride=stride,
174 | padding=padding, dilation=dilation, bias=False)
175 | self.bn = nn.BatchNorm2d(out_planes)
176 | self.relu = nn.ReLU(inplace=True)
177 | self.activation = activation
178 | self.sigmoid = nn.Sigmoid()
179 |
180 | def forward(self, x):
181 | x = self.conv(x)
182 | x = self.bn(x)
183 | return self.relu(x) if self.activation=='relu' \
184 | else self.sigmoid(x) if self.activation=='sigmoid' \
185 | else x
186 |
--------------------------------------------------------------------------------
/img/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/benchmark.png
--------------------------------------------------------------------------------
/img/benchmark_vis_IJCV.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/benchmark_vis_IJCV.png
--------------------------------------------------------------------------------
/img/qualitative_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/qualitative_results.png
--------------------------------------------------------------------------------
/img/quantitative_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/quantitative_results.png
--------------------------------------------------------------------------------
/img/structure_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/structure_diagram.png
--------------------------------------------------------------------------------
/img/structure_diagram_IJCV.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/structure_diagram_IJCV.png
--------------------------------------------------------------------------------
/mobilenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models import MobileNetV2
3 | import torch
4 |
5 |
6 | class MobileNetV2Encoder(MobileNetV2):
7 | """
8 | MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
9 | use dilation on the last block to maintain output stride 16, and deleted the
10 | classifier block that was originally used for classification. The forward method
11 | additionally returns the feature maps at all resolutions for decoder's use.
12 | """
13 |
14 | def __init__(self, in_channels, norm_layer=None):
15 | super().__init__()
16 |
17 | # Replace first conv layer if in_channels doesn't match.
18 | if in_channels != 3:
19 | self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
20 |
21 | # Remove last block
22 | self.features = self.features[:-1]
23 |
24 | # Change to use dilation to maintain output stride = 16
25 | self.features[14].conv[1][0].stride = (1, 1)
26 | for feature in self.features[15:]:
27 | feature.conv[1][0].dilation = (2, 2)
28 | feature.conv[1][0].padding = (2, 2)
29 |
30 | # Delete classifier
31 | del self.classifier
32 |
33 | self.layer1 = nn.Sequential(self.features[0], self.features[1])
34 | self.layer2 = nn.Sequential(self.features[2], self.features[3])
35 | self.layer3 = nn.Sequential(self.features[4], self.features[5], self.features[6])
36 | self.layer4 = nn.Sequential(self.features[7], self.features[8], self.features[9], self.features[10],
37 | self.features[11], self.features[12], self.features[13])
38 | self.layer5 = nn.Sequential(self.features[14], self.features[15], self.features[16], self.features[17])
39 | def forward(self, x):
40 | x0 = x # 1/1
41 | x = self.features[0](x)
42 | x = self.features[1](x)
43 | x = x
44 | x1 = x # 1/2
45 | x = self.features[2](x)
46 | x = self.features[3](x)
47 | x2 = x # 1/4
48 | x = self.features[4](x)
49 | x = self.features[5](x)
50 | x = self.features[6](x)
51 | x3 = x # 1/8
52 | x = self.features[7](x)
53 | x = self.features[8](x)
54 | x = self.features[9](x)
55 | x = self.features[10](x)
56 | x = self.features[11](x)
57 | x = self.features[12](x)
58 | x = self.features[13](x)
59 | x4 = x # 1/16
60 | x = self.features[14](x)
61 | x = self.features[15](x)
62 | x = self.features[16](x)
63 | x = self.features[17](x)
64 | x5 = x # 1/16
65 | return x1,x2,x3,x4,x5
66 |
67 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | from torch.nn import functional as F
5 | import time
6 | import timm
7 | import random
8 | import os
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from depth import DepthBranch
13 | from mobilenet import MobileNetV2Encoder
14 |
15 |
16 | def upsample(x, size):
17 | return F.interpolate(x, size, mode='bilinear', align_corners=True)
18 |
19 | class DFMNet(nn.Module):
20 | def __init__(self, **kwargs):
21 | super(DFMNet, self).__init__()
22 | self.rgb = RGBBranch()
23 | self.depth = DepthBranch()
24 |
25 | def forward(self, r, d):
26 | size = r.shape[2:]
27 | outputs = []
28 |
29 | sal_d,feat = self.depth(d)
30 | sal_final= self.rgb(r,feat)
31 |
32 | sal_final = upsample(sal_final, size)
33 | sal_d = upsample(sal_d, size)
34 |
35 | outputs.append(sal_final)
36 | outputs.append(sal_d)
37 |
38 | return outputs
39 |
40 | class _ConvBNReLU(nn.Module):
41 | """Conv-BN-ReLU"""
42 |
43 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,dilation=1, **kwargs):
44 | super(_ConvBNReLU, self).__init__()
45 | self.conv = nn.Sequential(
46 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,dilation=dilation ,bias=False),
47 | nn.BatchNorm2d(out_channels),
48 | nn.ReLU(True)
49 | )
50 |
51 | def forward(self, x):
52 | return self.conv(x)
53 |
54 | class _ConvBNSig(nn.Module):
55 | """Conv-BN-Sigmoid"""
56 |
57 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,dilation=1, **kwargs):
58 | super(_ConvBNSig, self).__init__()
59 | self.conv = nn.Sequential(
60 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,dilation=dilation ,bias=False),
61 | nn.BatchNorm2d(out_channels),
62 | nn.Sigmoid()
63 | )
64 |
65 | def forward(self, x):
66 | return self.conv(x)
67 |
68 |
69 | class _DSConv(nn.Module):
70 | """Depthwise Separable Convolutions"""
71 |
72 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
73 | super(_DSConv, self).__init__()
74 | self.conv = nn.Sequential(
75 | nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False),
76 | nn.BatchNorm2d(dw_channels),
77 | nn.ReLU(True),
78 | nn.Conv2d(dw_channels, out_channels, 1, bias=False),
79 | nn.BatchNorm2d(out_channels),
80 | nn.ReLU(True)
81 | )
82 |
83 | def forward(self, x):
84 | return self.conv(x)
85 |
86 | def _make_layer( block, inplanes, planes, blocks, t=6, stride=1):
87 | layers = []
88 | layers.append(block(inplanes, planes, t, stride))
89 | for i in range(1, blocks):
90 | layers.append(block(planes, planes, t, 1))
91 | return nn.Sequential(*layers)
92 |
93 | class _DWConv(nn.Module):
94 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
95 | super(_DWConv, self).__init__()
96 | self.conv = nn.Sequential(
97 | nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False),
98 | nn.BatchNorm2d(out_channels),
99 | nn.ReLU(True)
100 | )
101 |
102 | def forward(self, x):
103 | return self.conv(x)
104 |
105 |
106 | class LinearBottleneck(nn.Module):
107 | """LinearBottleneck used in MobileNetV2"""
108 |
109 | def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs):
110 | super(LinearBottleneck, self).__init__()
111 | self.use_shortcut = stride == 1 and in_channels == out_channels
112 | self.block = nn.Sequential(
113 | # pw
114 | _ConvBNReLU(in_channels, in_channels * t, 1),
115 | # dw
116 | _DWConv(in_channels * t, in_channels * t, stride),
117 | # pw-linear
118 | nn.Conv2d(in_channels * t, out_channels, 1, bias=False),
119 | nn.BatchNorm2d(out_channels)
120 | )
121 |
122 | def forward(self, x):
123 | out = self.block(x)
124 | if self.use_shortcut:
125 | out = x + out
126 | return out
127 |
128 |
129 |
130 | class PyramidPooling(nn.Module):
131 | """Pyramid pooling module"""
132 |
133 | def __init__(self, in_channels, out_channels, **kwargs):
134 | super(PyramidPooling, self).__init__()
135 | inter_channels = int(in_channels / 4)
136 | self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
137 | self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
138 | self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
139 | self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
140 | self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
141 |
142 | def pool(self, x, size):
143 | avgpool = nn.AdaptiveAvgPool2d(size)
144 | return avgpool(x)
145 |
146 | def forward(self, x):
147 | size = x.size()[2:]
148 | feat1 = upsample(self.conv1(self.pool(x, 1)), size)
149 | feat2 = upsample(self.conv2(self.pool(x, 2)), size)
150 | feat3 = upsample(self.conv3(self.pool(x, 3)), size)
151 | feat4 = upsample(self.conv4(self.pool(x, 6)), size)
152 | x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
153 | x = self.out(x)
154 | return x
155 |
156 |
157 |
158 | class RGBBranch(nn.Module):
159 | """RGBBranch for low-level RGB feature extract"""
160 |
161 | def __init__(self, c1=16, c2=24, c3=32, c4=96,c5=320,k=32 ,**kwargs):
162 | super(RGBBranch, self).__init__()
163 | self.base = MobileNetV2Encoder(3)
164 | initialize_weights(self.base)
165 |
166 | self.conv_cp1 = _DSConv(c1,k)
167 | self.conv_cp2 = _DSConv(c2, k)
168 | self.conv_cp3 = _DSConv(c3, k)
169 | self.conv_cp4 = _DSConv(c4, k)
170 | self.conv_cp5 = _DSConv(c5, k)
171 | self.conv_s_f = nn.Sequential(_DSConv(2 * k, k),
172 | _DSConv( k, k),
173 | nn.Conv2d(k, 1, 1), )
174 |
175 | # self.focus = focus()
176 | self.ca1 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid())
177 | self.ca2 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid())
178 | self.ca3 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid())
179 | self.ca4 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid())
180 | self.ca5 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid())
181 |
182 | self.conv_r1_tran = _ConvBNReLU(16, 16, 1, 1)
183 | self.conv_d1_tran = _ConvBNReLU(16, 16, 1, 1)
184 | self.mlp = nn.Sequential(_ConvBNReLU(48, 24, 1, 1),_ConvBNSig(24,5,1,1))
185 |
186 | self.conv_r1_tran2 = _ConvBNReLU(16, 16, 1, 1)
187 | self.conv_d1_tran2 = _ConvBNReLU(16, 16, 1, 1)
188 | self.conv_sgate1 = _ConvBNReLU(16, 16, 3, 1,2,2)
189 | self.conv_sgate2 = _ConvBNReLU(16, 16, 3, 1,2,2)
190 | self.conv_sgate3 = _ConvBNSig(16,5,3,1,1)
191 |
192 | self.ppm = PyramidPooling(320, 32)
193 |
194 | self.conv_guide = _ConvBNReLU(320, 16, 1, 1)
195 |
196 |
197 |
198 | def forward(self, x,feat):
199 |
200 | d1, d2, d3, d4, d5 = feat
201 |
202 | d5_guide = upsample(self.conv_guide(d5),d1.shape[2:])
203 |
204 | r1 = self.base.layer1(x)
205 |
206 | r1t = self.conv_r1_tran(r1)
207 | d1t = self.conv_d1_tran(d1)
208 | r1t2 = self.conv_r1_tran2(r1)
209 | d1t2 = self.conv_d1_tran2(d1)
210 |
211 | # QDW
212 | iou = F.adaptive_avg_pool2d(r1t * d1t, 1) / \
213 | (F.adaptive_avg_pool2d(r1t + d1t, 1))
214 |
215 | e_rp = F.max_pool2d(r1t, 2, 2)
216 | e_dp = F.max_pool2d(d1t, 2, 2)
217 |
218 | e_rp2 = F.max_pool2d(e_rp, 2, 2)
219 | e_dp2 = F.max_pool2d(e_dp, 2, 2)
220 |
221 | iou_p1 = F.adaptive_avg_pool2d(e_rp * e_dp, 1) / \
222 | (F.adaptive_avg_pool2d(e_rp + e_dp, 1))
223 |
224 | iou_p2 = F.adaptive_avg_pool2d(e_rp2 * e_dp2, 1) / \
225 | (F.adaptive_avg_pool2d(e_rp2 + e_dp2, 1))
226 |
227 | gate = self.mlp(torch.cat((iou, iou_p1, iou_p2), dim=1))
228 |
229 |
230 | # DHA
231 | mc = r1t2 * d1t2
232 |
233 | sgate = self.conv_sgate1(upsample(mc + d5_guide, d2.shape[2:]))
234 | d5_guide1 = mc + upsample(sgate, d1.shape[2:])
235 |
236 | sgate = self.conv_sgate1(upsample(mc + d5_guide1, d2.shape[2:]))
237 | d5_guide2 = mc + upsample(sgate, d1.shape[2:])
238 |
239 | sgate = self.conv_sgate3(d5_guide1 + d5_guide2 + mc)
240 |
241 | dqw1 = gate[:,0:1,...]
242 | dha1 = upsample(sgate[:, 0:1, ...], d1.shape[2:])
243 | dqw2 = gate[:, 1:2, ...]
244 | dha2 = upsample(sgate[:, 1:2, ...], d2.shape[2:])
245 | dqw3 = gate[:, 2:3, ...]
246 | dha3 = upsample(sgate[:, 2:3, ...], d3.shape[2:])
247 | dqw4 = gate[:, 3:4, ...]
248 | dha4 = upsample(sgate[:, 3:4, ...], d4.shape[2:])
249 | dqw5 = gate[:, 4:5, ...]
250 | dha5 = upsample(sgate[:, 4:5, ...], d5.shape[2:])
251 |
252 | r1 = r1 + d1 * dqw1 * dha1
253 | r2 = self.base.layer2(r1) + d2 * dqw2 * dha2
254 | r3 = self.base.layer3(r2) + d3 * dqw3 * dha3
255 | r4 = self.base.layer4(r3) + d4 * dqw4 * dha4
256 | r5 = self.base.layer5(r4) + d5 * dqw5 * dha5
257 | r6 = self.ppm(r5)
258 |
259 | # Two stage decoder
260 | ## pre-fusion
261 | r5 = self.conv_cp5(r5)
262 | r4 = self.conv_cp4(r4)
263 | r3 = self.conv_cp3(r3)
264 | r2 = self.conv_cp2(r2)
265 | r1 = self.conv_cp1(r1)
266 |
267 | r5 = self.ca5(F.adaptive_avg_pool2d(r5, 1)) * r5
268 | r4 = self.ca4(F.adaptive_avg_pool2d(r4, 1)) * r4
269 | r3 = self.ca3(F.adaptive_avg_pool2d(r3, 1)) * r3
270 | r2 = self.ca2(F.adaptive_avg_pool2d(r2, 1)) * r2
271 | r1 = self.ca1(F.adaptive_avg_pool2d(r1, 1)) * r1
272 |
273 | r3 = upsample(r3, r1.shape[2:])
274 | r2 = upsample(r2, r1.shape[2:])
275 | rh = r4 + r5 + r6
276 | rl = r1 + r2 + r3
277 |
278 | ## full-fusion
279 | rh = upsample(rh, rl.shape[2:])
280 | sal = self.conv_s_f (torch.cat((rh,rl),dim=1))
281 |
282 | return sal
283 |
284 | def initialize_weights(model):
285 | m = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True)
286 | pretrained_dict = m.state_dict()
287 | all_params = {}
288 | for k, v in model.state_dict().items():
289 | if k in pretrained_dict.keys():
290 | v = pretrained_dict[k]
291 | all_params[k] = v
292 | model.load_state_dict(all_params,strict = False)
293 |
294 | if __name__ == '__main__':
295 | img = torch.randn(1, 3, 256, 256).cuda()
296 | depth = torch.randn(1, 1, 256, 256).cuda()
297 | model = DFMNet().cuda()
298 | model.eval()
299 | time1= time.time()
300 | outputs = model(img,depth)
301 | time2 = time.time()
302 | torch.cuda.synchronize()
303 | print(1000/(time2-time1))
304 | num_params = 0
305 | for p in model.parameters():
306 | num_params += p.numel()
307 | print(num_params)
308 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | from natsort import natsorted
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--local_rank', default=-1, type=int,
9 | help='node rank for distributed training')
10 | parser.add_argument('--epoch', type=int, default=301, help='epoch number')
11 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
12 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
13 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size')
14 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
15 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
16 | parser.add_argument('--decay_epoch', type=int, default=100, help='every n epochs decay learning rate')
17 | parser.add_argument('--load', type=str, default='./pre_train/resnet50-19c8e357.pth', help='train from checkpoints')
18 | parser.add_argument('--gpu_id', type=str, default='1', help='train use gpu')
19 | parser.add_argument('--rgb_root', type=str, default='E://pytorch/data/RGBDcollection_fast/RGB/', help='the training rgb images root')
20 | parser.add_argument('--depth_root', type=str, default='E://pytorch/data/RGBDcollection_fast/depth/', help='the training depth images root')
21 | parser.add_argument('--gt_root', type=str, default='E://pytorch/data/RGBDcollection_fast/GT/', help='the training gt images root')
22 | parser.add_argument('--edge_root', type=str, default='E://pytorch/data/RGBDcollection_fast/edge/', help='the training edge images root')
23 | parser.add_argument('--test_rgb_root', type=str, default='E://pytorch/data/test_in_train/RGB/', help='the test rgb images root')
24 | parser.add_argument('--test_depth_root', type=str, default='E://pytorch/data/test_in_train/depth/', help='the test depth images root')
25 | parser.add_argument('--test_gt_root', type=str, default='E://pytorch/data/test_in_train/GT/', help='the test gt images root')
26 | parser.add_argument('--save_path', type=str, default='./results/train', help='the path to save models and logs')
27 | opt = parser.parse_args()
28 |
29 |
--------------------------------------------------------------------------------
/pretrain/.gitkeep:
--------------------------------------------------------------------------------
1 | #
2 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import sys
4 | import numpy as np
5 | import os, argparse
6 | import cv2
7 | from net import DFMNet
8 | from data import test_dataset
9 | import time
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--testsize', type=int, default=256, help='testing size')
13 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
14 | parser.add_argument('--test_path',type=str,default='./dataset/',help='test dataset path')
15 | opt = parser.parse_args()
16 |
17 | dataset_path = opt.test_path
18 |
19 | #set device for test
20 | if opt.gpu_id=='0':
21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
22 | print('USE GPU 0')
23 | elif opt.gpu_id=='1':
24 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
25 | print('USE GPU 1')
26 | elif opt.gpu_id == '3':
27 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
28 | print('USE GPU 3')
29 | elif opt.gpu_id=='all':
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
31 | print('USE GPU 0,1,2,3')
32 |
33 | #load the model
34 | model = DFMNet()
35 | model.load_state_dict(torch.load('./pretrain/DFMNet_epoch_300.pth'))
36 | model.cuda()
37 | model.eval()
38 |
39 | #test
40 |
41 |
42 | def save(res,gt,notation=None,sigmoid=True):
43 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
44 | res = res.sigmoid().data.cpu().numpy().squeeze() if sigmoid ==True else res.data.cpu().numpy().squeeze()
45 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
46 | print('save img to: ', os.path.join(save_path, name.replace('.png','_'+notation+'.png') if notation != None else name))
47 | cv2.imwrite(os.path.join(save_path, name.replace('.png','_'+notation+'.png') if notation != None else name), res * 255)
48 |
49 | test_datasets = ['NJU2K','NLPR','STERE', 'RGBD135', 'LFSD','SIP']
50 | for dataset in test_datasets:
51 | with torch.no_grad():
52 | save_path = './results/benchmark/' + dataset
53 | if not os.path.exists(save_path):
54 | os.makedirs(save_path)
55 | image_root = dataset_path + dataset + '/RGB/'
56 | gt_root = dataset_path + dataset + '/GT/'
57 | depth_root=dataset_path +dataset +'/depth/'
58 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize)
59 |
60 | for i in range(test_loader.size):
61 | image, gt,depth, name, image_for_post = test_loader.load_data()
62 | gt = np.asarray(gt, np.float32)
63 | gt /= (gt.max() + 1e-8)
64 | image = image.cuda()
65 | depth = depth.cuda()
66 | torch.cuda.synchronize()
67 | time_s = time.time()
68 | out = model(image,depth)
69 | torch.cuda.synchronize()
70 | time_e = time.time()
71 | t = time_e - time_s
72 | print("time: {:.2f} ms".format(t*1000))
73 | save(out[0],gt)
74 | print('Test Done!')
75 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import sys
5 | sys.path.append('./models')
6 | import numpy as np
7 | from datetime import datetime
8 | from net import DFMNet
9 | from data import get_loader,test_dataset
10 | from utils import clip_gradient, LR_Scheduler
11 | from torch.utils.tensorboard import SummaryWriter
12 | import logging
13 | import torch.backends.cudnn as cudnn
14 | from options import opt
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 |
19 | def upsample(x, size):
20 | return F.interpolate(x, size, mode='bilinear', align_corners=True)
21 |
22 | #train function
23 | def train(train_loader, model, optimizer, epoch,save_path):
24 |
25 | global step
26 | model.train()
27 | loss_all=0
28 | epoch_step=0
29 | try:
30 | for i, (images, gts, depths) in enumerate(train_loader, start=1):
31 | optimizer.zero_grad()
32 | images = images.cuda()
33 | gts = gts.cuda()
34 | depths=depths.cuda()
35 |
36 |
37 | cur_lr = lr_scheduler(optimizer, i, epoch)
38 | writer.add_scalar('learning_rate', cur_lr, global_step=(epoch-1)*total_step + i)
39 |
40 | out,feature_r,feature_d = model(images,depths)
41 | loss_f = F.binary_cross_entropy_with_logits(out[0], gts)
42 | loss_d = F.binary_cross_entropy_with_logits(out[1], gts)
43 |
44 |
45 | loss = loss_f + loss_d
46 | loss.backward()
47 |
48 | clip_gradient(optimizer, opt.clip)
49 | optimizer.step()
50 | step+=1
51 | epoch_step+=1
52 | loss_all+=loss.data
53 |
54 |
55 | if i % 100 == 0 or i == total_step or i==1:
56 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], loss: {:.4f}, loss_final: {:.4f}, loss_d: {:.4f}'.
57 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss,loss_f.data,loss_d.data ))
58 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} '.
59 | format( epoch, opt.epoch, i, total_step, loss.data))
60 | writer.add_scalar('Loss', loss.data, global_step=step)
61 |
62 | loss_all/=epoch_step
63 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, opt.epoch, loss_all))
64 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
65 | if epoch == 300:
66 | torch.save(model.state_dict(), save_path+'/epoch_{}.pth'.format(epoch))
67 | except KeyboardInterrupt:
68 | print('Keyboard Interrupt: save model and exit.')
69 | if not os.path.exists(save_path):
70 | os.makedirs(save_path)
71 | torch.save(model.state_dict(), save_path+'/epoch_{}.pth'.format(epoch+1))
72 | print('save checkpoints successfully!')
73 | raise
74 |
75 | #test function
76 | def test(test_loader,model,epoch,save_path):
77 | global best_mae,best_epoch
78 | model.eval()
79 | with torch.no_grad():
80 | mae_sum=0
81 | for i in range(test_loader.size):
82 | image, gt,depth, name,img_for_post = test_loader.load_data()
83 | gt = np.asarray(gt, np.float32)
84 | gt /= (gt.max() + 1e-8)
85 | image = image.cuda()
86 | depth = depth.cuda()
87 | res,_,_ = model(image,depth)
88 | res = F.upsample(res[0], size=gt.shape, mode='bilinear', align_corners=False)
89 | res = res.sigmoid().data.cpu().numpy().squeeze()
90 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
91 | mae_sum+=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1])
92 | mae=mae_sum/test_loader.size
93 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch)
94 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch))
95 | if epoch==1:
96 | best_mae=mae
97 | torch.save(model.state_dict(), save_path + '/epoch_best.pth')
98 | else:
99 | if mae 0 and T < self.warmup_iters:
57 | lr = lr * 1.0 * T / self.warmup_iters
58 | # if epoch > self.epoch:
59 | # print('\n=>Epoches %i, learning rate = %.4f, \
60 | # previous best = %.4f' % (epoch, lr, best_pred))
61 | # self.epoch = epoch
62 | assert lr >= 0
63 | self._adjust_learning_rate(optimizer, lr)
64 | return lr
65 |
66 | def _adjust_learning_rate(self, optimizer, lr):
67 | if len(optimizer.param_groups) == 1:
68 | optimizer.param_groups[0]['lr'] = lr
69 | else:
70 | # enlarge the lr at the head
71 | optimizer.param_groups[0]['lr'] = lr
72 | for i in range(1, len(optimizer.param_groups)):
73 | optimizer.param_groups[i]['lr'] = lr * 10
--------------------------------------------------------------------------------