├── LICENSE ├── README.md ├── data.py ├── dataset └── .gitkeep ├── depth.py ├── img ├── benchmark.png ├── benchmark_vis_IJCV.png ├── qualitative_results.png ├── quantitative_results.png ├── structure_diagram.png └── structure_diagram_IJCV.png ├── mobilenet.py ├── net.py ├── options.py ├── pretrain └── .gitkeep ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 zwbx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFM-Net (ACM MM 2021) 2 | Official repository for paper [Depth Quality-Inspired Feature Manipulation for Efficient RGB-D Salient Object Detection](https://arxiv.org/pdf/2107.01779.pdf) | [中文版](https://pan.baidu.com/s/1axKXAqBmMmQuPTvTTY_LNg?pwd=jsvr) 3 | 4 | ## News 5 | - 6/Jun/2022🔥[online demo](http://rgbdsod-krf.natapp4.cc/) is newly realeased! 6 | - 8/Aug/2022 we extend DFM-Net to Video Salient Object Detection task, which refers to [Depth Quality-Inspired Feature Manipulation for Efficient RGB-D and Video Salient Object Detection](https://arxiv.org/abs/2208.03918) 7 | 8 | 9 | 10 | 11 |

12 |
13 | 14 | Block diagram of DFM-Net. 15 | 16 |

17 | 18 | 19 | ## The most efficient RGB-D SOD method ⚡ 20 | - Low model size: Model size is only **8.5Mb**, being **6.7/3.1 smaller** than the latest lightest models A2dele and MobileSal. 21 | - High accuracy: SOTA performance on 9 datasets (NJU2K, NLPR, STERE, RGBD135, LFSD, SIP, DUT-RGBD, RedWeb-S, COME). 22 | - High Speed: Cost 50ms on CPU (Core i7-8700 CPU), being **2.9/2.4 faster** than the latest fastest models A2dele and MobileSal. 23 | 24 |

25 |
26 | 27 | Performance visualization. Performance visualization. The vertical axis indicates the average S-measure over six widely used datasets (NJU2K, NLPR, STERE, RGBD135, LFSD, SIP). The horizontal axis indicates CPU speed. The circle area is proportional to the model size. 28 | 29 |

30 | 31 | 32 | ## Extension :fire: 33 | [Depth Quality-Inspired Feature Manipulation for Efficient RGB-D and Video Salient Object Detection](https://arxiv.org/abs/2208.03918) 34 | - More comprehensive comparison: 35 | - Benchmark results on DUT-RGBD, RedWeb-S, COME are updated. 36 | - Metric of maximum-batch inference speed is added. 37 | - We re-test the inference speed of ours and compared methods on Ubuntu 16.04. 38 | - Working mechanism explanation 39 | - Further analyses verify the ability of DQFM in distinguishing depth maps of various qualities without any quality labels. 40 | - Application on efficient VSOD 41 | - One of the lightest VSOD methods! 42 | - Joint training strategy is proposed. 43 | 44 | 45 | 46 | ## Easy-to-use to boost your RGB-D SOD network 47 | If you use a depth branch as an affiliate to the RGB branch: 48 | - Use DQW/DHA to boost performance with extra 0.007/0.042Mb model size increased 49 | - Use our light-weight depth backbone to improve efficiency 50 | 51 | if you adopt parallel encoders for RGB and depth: 52 | - refer to our other work [BTS-Net](https://github.com/zwbx/BTS-Net) 53 | 54 | 55 | 56 | 57 | ## Test 58 | 59 | Directly run test.py 60 | 61 | The test maps will be saved to './resutls/'. 62 | 63 | data preparation 64 | - Classic benchmark: training on NJU2K and NLPR and test on NJU2K, NLPR, STERE, RGBD135, LFSD, SIP. 65 | - [test data](https://pan.baidu.com/s/1wI-bxarzdSrOY39UxZaomQ) [code: 940i] 66 | - [pretrained model for DFMNet](https://pan.baidu.com/s/1pTEByo0OngNJlKCJsTcx-A?pwd=skin) 67 | - Additional test datasets [RedWeb-S](https://github.com/nnizhang/SMAC) 🆕, updated in journal version. 68 | - DUT-RGBD benchmark 🆕 69 | - Download the training and test data in [official repository](https://pan.baidu.com/s/1mhHAXLgoqqLQIb6r-k-hbA). 70 | - [pretrained model for DFMNet](https://pan.baidu.com/s/1GJHvxh2gTLutpM1hfESDNg?pwd=nmw3). 71 | - COME benchmark 🆕 72 | - Download the training and test data in [official repository](https://github.com/JingZhang617/cascaded_rgbd_sod). 73 | - [pretrained model for DFMNet](https://pan.baidu.com/s/1fCYF5p9dCC8RXRCLaWUQlg?pwd=iqyf). 74 | 75 | ## Results 76 | 77 | - We provide testing results of 9 datasets (NJU2K, NLPR, STERE, RGBD135, LFSD, SIP, DUT-RGBD 🆕, RedWeb-S 🆕, COME 🆕). 78 | - [Results of DFM-Net](https://pan.baidu.com/s/1wZyYqYISpRGZATDgKYO4nA?pwd=4jqu). 79 | - [Results of DFM-Net*](https://pan.baidu.com/s/1vemT9nfaXoSc_tqSYakSCg?pwd=pax4). 80 | 81 | - Evaluate the result maps: 82 | You can evaluate the result maps using the tool in [Matlab Version](http://dpfan.net/d3netbenchmark/) or [Python_GPU Version](https://github.com/zyjwuyan/SOD_Evaluation_Metrics). 83 | 84 | - Note that the parameter file is 8.9Mb, which is 0.4Mb bigger than we report in the paper because keys denoting parameter names also occupy some space. Then put them under the following directory: 85 | 86 | -dataset\ 87 | -RGBD_train 88 | -NJU2K\ 89 | -NLPR\ 90 | ... 91 | -pretrain 92 | -DFMNet_300_epoch.pth 93 | ... 94 | 95 | 96 | ## Training 97 | - Download [training data](https://pan.baidu.com/s/1ckNlS0uEIPV-iCwVzjutsQ)(eb2z) 98 | - Modify setting in options.py and run train.py 99 | 100 | 101 | ## Application on VSOD 🆕 102 | - We provide testing results of 4 datasets (DAVIS, FBMS, MCL, DAVSOD). 103 | - [Results of DFM-Net](https://pan.baidu.com/s/1jLGP2kV_Z7esOkkY3jKFQw?pwd=58wc). 104 | - [Results of DFM-Net*](https://pan.baidu.com/s/1EV4_neyES7jAyo0op-XfTA?pwd=pp2w). 105 | 106 | ## Citation 107 | 108 | Please cite the following paper if you use this repository in your research 109 | 110 | @inproceedings{zhang2021depth, 111 | title={Depth quality-inspired feature manipulation for efficient RGB-D salient object detection}, 112 | author={Zhang, Wenbo and Ji, Ge-Peng and Wang, Zhuo and Fu, Keren and Zhao, Qijun}, 113 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia}, 114 | pages={731--740}, 115 | year={2021} 116 | } 117 | 118 | @artical{zhang2022depth, 119 | title={Depth Quality-Inspired Feature Manipulation for Efficient RGB-D and Video Salient Object Detection}, 120 | author={Zhang, Wenbo and Fu, Keren and Wang, Zhuo and Ji, Ge-Peng and Zhao, Qijun}, 121 | booktitle={arXiv:2208.03918}, 122 | year={2022} 123 | } 124 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | from PIL import Image 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import random 7 | import numpy as np 8 | from PIL import ImageEnhance 9 | from natsort import natsorted 10 | import torch 11 | 12 | #several data augumentation strategies 13 | def cv_random_flip(img, label,depth,edge): 14 | flip_flag = random.randint(0, 1) 15 | # flip_flag2= random.randint(0,1) 16 | #left right flip 17 | if flip_flag == 1: 18 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 19 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 20 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 21 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT) 22 | #top bottom flip 23 | # if flip_flag2==1: 24 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 25 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 26 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 27 | return img, label, depth, edge 28 | def randomCrop(image, label,depth,edge): 29 | border=30 30 | image_width = image.size[0] 31 | image_height = image.size[1] 32 | crop_win_width = np.random.randint(image_width-border , image_width) 33 | crop_win_height = np.random.randint(image_height-border , image_height) 34 | random_region = ( 35 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 36 | (image_height + crop_win_height) >> 1) 37 | return image.crop(random_region), label.crop(random_region),depth.crop(random_region),edge.crop(random_region) 38 | def randomRotation(image,label,depth,edge): 39 | mode=Image.BICUBIC 40 | if random.random()>0.8: 41 | random_angle = np.random.randint(-15, 15) 42 | image=image.rotate(random_angle, mode) 43 | label=label.rotate(random_angle, mode) 44 | depth=depth.rotate(random_angle, mode) 45 | edge = edge.rotate(random_angle, mode) 46 | return image,label,depth,edge 47 | def colorEnhance(image): 48 | bright_intensity=random.randint(5,15)/10.0 49 | image=ImageEnhance.Brightness(image).enhance(bright_intensity) 50 | contrast_intensity=random.randint(5,15)/10.0 51 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity) 52 | color_intensity=random.randint(0,20)/10.0 53 | image=ImageEnhance.Color(image).enhance(color_intensity) 54 | sharp_intensity=random.randint(0,30)/10.0 55 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity) 56 | return image 57 | def randomGaussian(image, mean=0.1, sigma=0.35): 58 | def gaussianNoisy(im, mean=mean, sigma=sigma): 59 | for _i in range(len(im)): 60 | im[_i] += random.gauss(mean, sigma) 61 | return im 62 | img = np.asarray(image) 63 | width, height = img.shape 64 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 65 | img = img.reshape([width, height]) 66 | return Image.fromarray(np.uint8(img)) 67 | def randomPeper(img): 68 | 69 | img=np.array(img) 70 | noiseNum=int(0.0015*img.shape[0]*img.shape[1]) 71 | for i in range(noiseNum): 72 | 73 | randX=random.randint(0,img.shape[0]-1) 74 | 75 | randY=random.randint(0,img.shape[1]-1) 76 | 77 | if random.randint(0,1)==0: 78 | 79 | img[randX,randY]=0 80 | 81 | else: 82 | 83 | img[randX,randY]=255 84 | return Image.fromarray(img) 85 | 86 | # dataset for training 87 | #The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 88 | #(e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 89 | class SalObjDataset(data.Dataset): 90 | def __init__(self, image_root, gt_root,depth_root,edge_root, trainsize): 91 | self.trainsize = trainsize 92 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 93 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 94 | or f.endswith('.png')] 95 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 96 | or f.endswith('.png')] 97 | self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.bmp') 98 | or f.endswith('.png')] 99 | self.images = natsorted(self.images) 100 | self.gts = natsorted(self.gts) 101 | self.depths= natsorted(self.depths) 102 | self.edges = natsorted(self.edges) 103 | # print(self.images) 104 | # print(self.depths) 105 | # print(self.gts) 106 | self.filter_files() 107 | self.size = len(self.images) 108 | self.img_transform = transforms.Compose([ 109 | transforms.Resize((self.trainsize, self.trainsize)), 110 | transforms.ToTensor(), 111 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 112 | self.gt_transform = transforms.Compose([ 113 | transforms.Resize((self.trainsize, self.trainsize)), 114 | transforms.ToTensor()]) 115 | self.depths_transform = transforms.Compose([ 116 | transforms.Resize((self.trainsize, self.trainsize)), 117 | transforms.ToTensor(),]) 118 | 119 | def __getitem__(self, index): 120 | image = self.rgb_loader(self.images[index]) 121 | gt = self.binary_loader(self.gts[index]) 122 | depth = self.binary_loader(self.depths[index]) 123 | edge_gt = self.binary_loader(self.edges[index]) 124 | depth = PIL.ImageOps.invert(depth) 125 | image,gt,depth,edge_gt =cv_random_flip(image,gt,depth,edge_gt) 126 | image,gt,depth,edge_gt=randomCrop(image, gt,depth,edge_gt) 127 | image,gt,depth,edge_gt=randomRotation(image, gt,depth,edge_gt) 128 | image=colorEnhance(image) 129 | # depth= colorEnhance(depth) 130 | #gt=randomGaussian(gt) 131 | gt=randomPeper(gt) 132 | image = self.img_transform(image) 133 | gt = self.gt_transform(gt) 134 | edge_gt = self.gt_transform(edge_gt) 135 | edge_gt = (edge_gt - edge_gt.min()) / (edge_gt.max() - edge_gt.min() + 1e-8) 136 | depth=self.depths_transform(depth) 137 | return image, gt, depth, edge_gt 138 | 139 | def filter_files(self): 140 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images) 141 | images = [] 142 | gts = [] 143 | depths=[] 144 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths): 145 | img = Image.open(img_path) 146 | gt = Image.open(gt_path) 147 | depth= Image.open(depth_path) 148 | if img.size == gt.size and gt.size==depth.size: 149 | images.append(img_path) 150 | gts.append(gt_path) 151 | depths.append(depth_path) 152 | self.images = images 153 | self.gts = gts 154 | self.depths=depths 155 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 156 | 157 | def rgb_loader(self, path): 158 | with open(path, 'rb') as f: 159 | img = Image.open(f) 160 | return img.convert('RGB') 161 | 162 | def binary_loader(self, path): 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | return img.convert('L') 166 | 167 | def rgb_loader_ops(self, path): 168 | with open(path, 'rb') as f: 169 | img = Image.open(f) 170 | return PIL.ImageOps.invert(img.convert('RGB')) 171 | 172 | def resize(self, img, gt, depth): 173 | assert img.size == gt.size and gt.size==depth.size 174 | w, h = img.size 175 | if h < self.trainsize or w < self.trainsize: 176 | h = max(h, self.trainsize) 177 | w = max(w, self.trainsize) 178 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST) 179 | else: 180 | return img, gt, depth 181 | 182 | def __len__(self): 183 | return self.size 184 | 185 | #dataloader for training 186 | def get_loader(image_root, gt_root,depth_root,edge_root, batchsize, trainsize, shuffle=True, num_workers=2, pin_memory=False): 187 | 188 | dataset = SalObjDataset(image_root, gt_root, depth_root,edge_root,trainsize) 189 | data_loader = data.DataLoader(dataset=dataset, 190 | batch_size=batchsize, 191 | shuffle=shuffle, 192 | num_workers=num_workers, 193 | pin_memory=pin_memory) 194 | return data_loader 195 | 196 | #test dataset and loader 197 | class test_dataset: 198 | def __init__(self, image_root, gt_root,depth_root, testsize): 199 | self.testsize = testsize 200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 201 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 202 | or f.endswith('.png')] 203 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 204 | or f.endswith('.png')] 205 | self.images = natsorted(self.images) 206 | self.gts = natsorted(self.gts) 207 | self.depths= natsorted(self.depths) 208 | # print(self.images) 209 | # print(self.depths) 210 | # print(self.gts) 211 | self.filter_files() 212 | self.transform = transforms.Compose([ 213 | transforms.Resize((self.testsize, self.testsize)), 214 | transforms.ToTensor(), 215 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 216 | self.gt_transform = transforms.ToTensor() 217 | # self.gt_transform = transforms.Compose([ 218 | # transforms.Resize((self.trainsize, self.trainsize)), 219 | # transforms.ToTensor()]) 220 | self.depths_transform = transforms.Compose([ 221 | transforms.Resize((self.testsize, self.testsize)), 222 | transforms.ToTensor(), 223 | ]) 224 | self.size = len(self.images) 225 | self.index = 0 226 | 227 | def load_data(self): 228 | image = self.rgb_loader(self.images[self.index]) 229 | image = self.transform(image).unsqueeze(0) 230 | gt = self.binary_loader(self.gts[self.index]) 231 | depth=self.binary_loader_ops(self.depths[self.index]) 232 | pesudo_depth = self.depths_transform(self.rgb_loader_ops(self.gts[self.index])).unsqueeze(0) 233 | depth=self.depths_transform(depth).unsqueeze(0) 234 | name = self.images[self.index].split('/')[-1] 235 | image_for_post=self.rgb_loader(self.images[self.index]) 236 | image_for_post=image_for_post.resize(gt.size) 237 | if name.endswith('.jpg'): 238 | name = name.split('.jpg')[0] + '.png' 239 | self.index += 1 240 | self.index = self.index % self.size 241 | return image, gt,depth, name,np.array(image_for_post) 242 | 243 | def rgb_loader(self, path): 244 | with open(path, 'rb') as f: 245 | img = Image.open(f) 246 | return img.convert('RGB') 247 | 248 | def binary_loader(self, path): 249 | with open(path, 'rb') as f: 250 | img = Image.open(f) 251 | return img.convert('L') 252 | 253 | def binary_loader_ops(self, path): 254 | with open(path, 'rb') as f: 255 | img = Image.open(f) 256 | return PIL.ImageOps.invert(img.convert('L')) 257 | 258 | def rgb_loader_ops(self, path): 259 | with open(path, 'rb') as f: 260 | img = Image.open(f) 261 | img = PIL.ImageOps.invert(img.convert('RGB')) 262 | return img 263 | 264 | def __len__(self): 265 | return self.size 266 | 267 | def filter_files(self): 268 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images) 269 | images = [] 270 | gts = [] 271 | depths=[] 272 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths): 273 | img = Image.open(img_path) 274 | gt = Image.open(gt_path) 275 | depth= Image.open(depth_path) 276 | if img.size == gt.size and gt.size==depth.size: 277 | images.append(img_path) 278 | gts.append(gt_path) 279 | depths.append(depth_path) 280 | # else: 281 | # print(img.size, depth.size, gt.size) 282 | self.images = images 283 | self.gts = gts 284 | self.depths = depths 285 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 286 | 287 | -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn import functional as F 5 | import time 6 | import timm 7 | from mobilenet import MobileNetV2Encoder 8 | 9 | 10 | import os 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | 17 | def upsample(x, size): 18 | return F.interpolate(x, size, mode='bilinear', align_corners=True) 19 | 20 | def initialize_weights(model): 21 | m = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True) 22 | pretrained_dict = m.state_dict() 23 | all_params = {} 24 | for k, v in model.state_dict().items(): 25 | if k in pretrained_dict.keys() and v.shape == pretrained_dict[k]: 26 | v = pretrained_dict[k] 27 | all_params[k] = v 28 | # assert len(all_params.keys()) == len(self.resnet.state_dict().keys()) 29 | model.load_state_dict(all_params,strict=False) 30 | 31 | class DepthBranch(nn.Module): 32 | def __init__(self, c1=8, c2=16, c3=32, c4=48, c5=320, **kwargs): 33 | super(DepthBranch, self).__init__() 34 | self.bottleneck1 = _make_layer(LinearBottleneck, 1, 16, blocks=1, t=3, stride=2) 35 | self.bottleneck2 = _make_layer(LinearBottleneck, 16, 24, blocks=3, t=3, stride=2) 36 | self.bottleneck3 = _make_layer(LinearBottleneck, 24, 32, blocks=7, t=3, stride=2) 37 | self.bottleneck4 = _make_layer(LinearBottleneck, 32, 96, blocks=3, t=2, stride=2) 38 | self.bottleneck5 = _make_layer(LinearBottleneck, 96, 320, blocks=1, t=2, stride=1) 39 | 40 | # self.conv_s_d = _ConvBNReLU(320,1,1,1) 41 | 42 | # nn.Sequential(_DSConv(c3, c3 // 4), 43 | # nn.Conv2d(c3 // 4, 1, 1), ) 44 | 45 | def forward(self, x): 46 | size = x.size()[2:] 47 | feat = [] 48 | 49 | x1 = self.bottleneck1(x) 50 | x2 = self.bottleneck2(x1) 51 | x3 = self.bottleneck3(x2) 52 | x4 = self.bottleneck4(x3) 53 | x5 = self.bottleneck5(x4) 54 | # s_d = self.conv_s_d(x5) 55 | 56 | feat.append(x1) 57 | feat.append(x2) 58 | feat.append(x3) 59 | feat.append(x4) 60 | feat.append(x5) 61 | return x1 ,feat 62 | 63 | class _ConvBNReLU(nn.Module): 64 | """Conv-BN-ReLU""" 65 | 66 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, **kwargs): 67 | super(_ConvBNReLU, self).__init__() 68 | self.conv = nn.Sequential( 69 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 70 | nn.BatchNorm2d(out_channels), 71 | nn.ReLU(True) 72 | ) 73 | 74 | def forward(self, x): 75 | return self.conv(x) 76 | 77 | 78 | class _DSConv(nn.Module): 79 | """Depthwise Separable Convolutions""" 80 | 81 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs): 82 | super(_DSConv, self).__init__() 83 | self.conv = nn.Sequential( 84 | nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False), 85 | nn.BatchNorm2d(dw_channels), 86 | nn.ReLU(True), 87 | nn.Conv2d(dw_channels, out_channels, 1, bias=False), 88 | nn.BatchNorm2d(out_channels), 89 | nn.ReLU(True) 90 | ) 91 | 92 | def forward(self, x): 93 | return self.conv(x) 94 | 95 | def _make_layer( block, inplanes, planes, blocks, t=6, stride=1): 96 | layers = [] 97 | layers.append(block(inplanes, planes, t, stride)) 98 | for i in range(1, blocks): 99 | layers.append(block(planes, planes, t, 1)) 100 | return nn.Sequential(*layers) 101 | 102 | class _DWConv(nn.Module): 103 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs): 104 | super(_DWConv, self).__init__() 105 | self.conv = nn.Sequential( 106 | nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False), 107 | nn.BatchNorm2d(out_channels), 108 | nn.ReLU(True) 109 | ) 110 | 111 | def forward(self, x): 112 | return self.conv(x) 113 | 114 | 115 | 116 | 117 | class LinearBottleneck(nn.Module): 118 | """LinearBottleneck used in MobileNetV2""" 119 | 120 | def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs): 121 | super(LinearBottleneck, self).__init__() 122 | self.use_shortcut = stride == 1 and in_channels == out_channels 123 | self.block = nn.Sequential( 124 | # pw 125 | _ConvBNReLU(in_channels, in_channels * t, 1), 126 | # dw 127 | _DWConv(in_channels * t, in_channels * t, stride), 128 | # pw-linear 129 | nn.Conv2d(in_channels * t, out_channels, 1, bias=False), 130 | nn.BatchNorm2d(out_channels) 131 | ) 132 | 133 | def forward(self, x): 134 | out = self.block(x) 135 | if self.use_shortcut: 136 | out = x + out 137 | return out 138 | 139 | 140 | class PyramidPooling(nn.Module): 141 | """Pyramid pooling module""" 142 | 143 | def __init__(self, in_channels, out_channels, **kwargs): 144 | super(PyramidPooling, self).__init__() 145 | inter_channels = int(in_channels / 4) 146 | self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 147 | self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 148 | self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 149 | self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 150 | self.out = _ConvBNReLU(in_channels * 2, out_channels, 1) 151 | 152 | def pool(self, x, size): 153 | avgpool = nn.AdaptiveAvgPool2d(size) 154 | return avgpool(x) 155 | 156 | def forward(self, x): 157 | size = x.size()[2:] 158 | feat1 = upsample(self.conv1(self.pool(x, 1)), size) 159 | feat2 = upsample(self.conv2(self.pool(x, 2)), size) 160 | feat3 = upsample(self.conv3(self.pool(x, 3)), size) 161 | feat4 = upsample(self.conv4(self.pool(x, 6)), size) 162 | x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1) 163 | x = self.out(x) 164 | return x 165 | 166 | 167 | 168 | 169 | class BasicConv2d(nn.Module): 170 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, activation='relu'): 171 | super(BasicConv2d, self).__init__() 172 | self.conv = nn.Conv2d(in_planes, out_planes, 173 | kernel_size=kernel_size, stride=stride, 174 | padding=padding, dilation=dilation, bias=False) 175 | self.bn = nn.BatchNorm2d(out_planes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.activation = activation 178 | self.sigmoid = nn.Sigmoid() 179 | 180 | def forward(self, x): 181 | x = self.conv(x) 182 | x = self.bn(x) 183 | return self.relu(x) if self.activation=='relu' \ 184 | else self.sigmoid(x) if self.activation=='sigmoid' \ 185 | else x 186 | -------------------------------------------------------------------------------- /img/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/benchmark.png -------------------------------------------------------------------------------- /img/benchmark_vis_IJCV.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/benchmark_vis_IJCV.png -------------------------------------------------------------------------------- /img/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/qualitative_results.png -------------------------------------------------------------------------------- /img/quantitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/quantitative_results.png -------------------------------------------------------------------------------- /img/structure_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/structure_diagram.png -------------------------------------------------------------------------------- /img/structure_diagram_IJCV.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwbx/DFM-Net/3e148cb14f48045af7c7dd0530f83a036625f847/img/structure_diagram_IJCV.png -------------------------------------------------------------------------------- /mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models import MobileNetV2 3 | import torch 4 | 5 | 6 | class MobileNetV2Encoder(MobileNetV2): 7 | """ 8 | MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to 9 | use dilation on the last block to maintain output stride 16, and deleted the 10 | classifier block that was originally used for classification. The forward method 11 | additionally returns the feature maps at all resolutions for decoder's use. 12 | """ 13 | 14 | def __init__(self, in_channels, norm_layer=None): 15 | super().__init__() 16 | 17 | # Replace first conv layer if in_channels doesn't match. 18 | if in_channels != 3: 19 | self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False) 20 | 21 | # Remove last block 22 | self.features = self.features[:-1] 23 | 24 | # Change to use dilation to maintain output stride = 16 25 | self.features[14].conv[1][0].stride = (1, 1) 26 | for feature in self.features[15:]: 27 | feature.conv[1][0].dilation = (2, 2) 28 | feature.conv[1][0].padding = (2, 2) 29 | 30 | # Delete classifier 31 | del self.classifier 32 | 33 | self.layer1 = nn.Sequential(self.features[0], self.features[1]) 34 | self.layer2 = nn.Sequential(self.features[2], self.features[3]) 35 | self.layer3 = nn.Sequential(self.features[4], self.features[5], self.features[6]) 36 | self.layer4 = nn.Sequential(self.features[7], self.features[8], self.features[9], self.features[10], 37 | self.features[11], self.features[12], self.features[13]) 38 | self.layer5 = nn.Sequential(self.features[14], self.features[15], self.features[16], self.features[17]) 39 | def forward(self, x): 40 | x0 = x # 1/1 41 | x = self.features[0](x) 42 | x = self.features[1](x) 43 | x = x 44 | x1 = x # 1/2 45 | x = self.features[2](x) 46 | x = self.features[3](x) 47 | x2 = x # 1/4 48 | x = self.features[4](x) 49 | x = self.features[5](x) 50 | x = self.features[6](x) 51 | x3 = x # 1/8 52 | x = self.features[7](x) 53 | x = self.features[8](x) 54 | x = self.features[9](x) 55 | x = self.features[10](x) 56 | x = self.features[11](x) 57 | x = self.features[12](x) 58 | x = self.features[13](x) 59 | x4 = x # 1/16 60 | x = self.features[14](x) 61 | x = self.features[15](x) 62 | x = self.features[16](x) 63 | x = self.features[17](x) 64 | x5 = x # 1/16 65 | return x1,x2,x3,x4,x5 66 | 67 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn import functional as F 5 | import time 6 | import timm 7 | import random 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from depth import DepthBranch 13 | from mobilenet import MobileNetV2Encoder 14 | 15 | 16 | def upsample(x, size): 17 | return F.interpolate(x, size, mode='bilinear', align_corners=True) 18 | 19 | class DFMNet(nn.Module): 20 | def __init__(self, **kwargs): 21 | super(DFMNet, self).__init__() 22 | self.rgb = RGBBranch() 23 | self.depth = DepthBranch() 24 | 25 | def forward(self, r, d): 26 | size = r.shape[2:] 27 | outputs = [] 28 | 29 | sal_d,feat = self.depth(d) 30 | sal_final= self.rgb(r,feat) 31 | 32 | sal_final = upsample(sal_final, size) 33 | sal_d = upsample(sal_d, size) 34 | 35 | outputs.append(sal_final) 36 | outputs.append(sal_d) 37 | 38 | return outputs 39 | 40 | class _ConvBNReLU(nn.Module): 41 | """Conv-BN-ReLU""" 42 | 43 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,dilation=1, **kwargs): 44 | super(_ConvBNReLU, self).__init__() 45 | self.conv = nn.Sequential( 46 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,dilation=dilation ,bias=False), 47 | nn.BatchNorm2d(out_channels), 48 | nn.ReLU(True) 49 | ) 50 | 51 | def forward(self, x): 52 | return self.conv(x) 53 | 54 | class _ConvBNSig(nn.Module): 55 | """Conv-BN-Sigmoid""" 56 | 57 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,dilation=1, **kwargs): 58 | super(_ConvBNSig, self).__init__() 59 | self.conv = nn.Sequential( 60 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,dilation=dilation ,bias=False), 61 | nn.BatchNorm2d(out_channels), 62 | nn.Sigmoid() 63 | ) 64 | 65 | def forward(self, x): 66 | return self.conv(x) 67 | 68 | 69 | class _DSConv(nn.Module): 70 | """Depthwise Separable Convolutions""" 71 | 72 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs): 73 | super(_DSConv, self).__init__() 74 | self.conv = nn.Sequential( 75 | nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False), 76 | nn.BatchNorm2d(dw_channels), 77 | nn.ReLU(True), 78 | nn.Conv2d(dw_channels, out_channels, 1, bias=False), 79 | nn.BatchNorm2d(out_channels), 80 | nn.ReLU(True) 81 | ) 82 | 83 | def forward(self, x): 84 | return self.conv(x) 85 | 86 | def _make_layer( block, inplanes, planes, blocks, t=6, stride=1): 87 | layers = [] 88 | layers.append(block(inplanes, planes, t, stride)) 89 | for i in range(1, blocks): 90 | layers.append(block(planes, planes, t, 1)) 91 | return nn.Sequential(*layers) 92 | 93 | class _DWConv(nn.Module): 94 | def __init__(self, dw_channels, out_channels, stride=1, **kwargs): 95 | super(_DWConv, self).__init__() 96 | self.conv = nn.Sequential( 97 | nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False), 98 | nn.BatchNorm2d(out_channels), 99 | nn.ReLU(True) 100 | ) 101 | 102 | def forward(self, x): 103 | return self.conv(x) 104 | 105 | 106 | class LinearBottleneck(nn.Module): 107 | """LinearBottleneck used in MobileNetV2""" 108 | 109 | def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs): 110 | super(LinearBottleneck, self).__init__() 111 | self.use_shortcut = stride == 1 and in_channels == out_channels 112 | self.block = nn.Sequential( 113 | # pw 114 | _ConvBNReLU(in_channels, in_channels * t, 1), 115 | # dw 116 | _DWConv(in_channels * t, in_channels * t, stride), 117 | # pw-linear 118 | nn.Conv2d(in_channels * t, out_channels, 1, bias=False), 119 | nn.BatchNorm2d(out_channels) 120 | ) 121 | 122 | def forward(self, x): 123 | out = self.block(x) 124 | if self.use_shortcut: 125 | out = x + out 126 | return out 127 | 128 | 129 | 130 | class PyramidPooling(nn.Module): 131 | """Pyramid pooling module""" 132 | 133 | def __init__(self, in_channels, out_channels, **kwargs): 134 | super(PyramidPooling, self).__init__() 135 | inter_channels = int(in_channels / 4) 136 | self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 137 | self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 138 | self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 139 | self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs) 140 | self.out = _ConvBNReLU(in_channels * 2, out_channels, 1) 141 | 142 | def pool(self, x, size): 143 | avgpool = nn.AdaptiveAvgPool2d(size) 144 | return avgpool(x) 145 | 146 | def forward(self, x): 147 | size = x.size()[2:] 148 | feat1 = upsample(self.conv1(self.pool(x, 1)), size) 149 | feat2 = upsample(self.conv2(self.pool(x, 2)), size) 150 | feat3 = upsample(self.conv3(self.pool(x, 3)), size) 151 | feat4 = upsample(self.conv4(self.pool(x, 6)), size) 152 | x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1) 153 | x = self.out(x) 154 | return x 155 | 156 | 157 | 158 | class RGBBranch(nn.Module): 159 | """RGBBranch for low-level RGB feature extract""" 160 | 161 | def __init__(self, c1=16, c2=24, c3=32, c4=96,c5=320,k=32 ,**kwargs): 162 | super(RGBBranch, self).__init__() 163 | self.base = MobileNetV2Encoder(3) 164 | initialize_weights(self.base) 165 | 166 | self.conv_cp1 = _DSConv(c1,k) 167 | self.conv_cp2 = _DSConv(c2, k) 168 | self.conv_cp3 = _DSConv(c3, k) 169 | self.conv_cp4 = _DSConv(c4, k) 170 | self.conv_cp5 = _DSConv(c5, k) 171 | self.conv_s_f = nn.Sequential(_DSConv(2 * k, k), 172 | _DSConv( k, k), 173 | nn.Conv2d(k, 1, 1), ) 174 | 175 | # self.focus = focus() 176 | self.ca1 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid()) 177 | self.ca2 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid()) 178 | self.ca3 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid()) 179 | self.ca4 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid()) 180 | self.ca5 = nn.Sequential(_ConvBNReLU(k, k, 1, 1), nn.Conv2d(k, k, 1, 1), nn.Sigmoid()) 181 | 182 | self.conv_r1_tran = _ConvBNReLU(16, 16, 1, 1) 183 | self.conv_d1_tran = _ConvBNReLU(16, 16, 1, 1) 184 | self.mlp = nn.Sequential(_ConvBNReLU(48, 24, 1, 1),_ConvBNSig(24,5,1,1)) 185 | 186 | self.conv_r1_tran2 = _ConvBNReLU(16, 16, 1, 1) 187 | self.conv_d1_tran2 = _ConvBNReLU(16, 16, 1, 1) 188 | self.conv_sgate1 = _ConvBNReLU(16, 16, 3, 1,2,2) 189 | self.conv_sgate2 = _ConvBNReLU(16, 16, 3, 1,2,2) 190 | self.conv_sgate3 = _ConvBNSig(16,5,3,1,1) 191 | 192 | self.ppm = PyramidPooling(320, 32) 193 | 194 | self.conv_guide = _ConvBNReLU(320, 16, 1, 1) 195 | 196 | 197 | 198 | def forward(self, x,feat): 199 | 200 | d1, d2, d3, d4, d5 = feat 201 | 202 | d5_guide = upsample(self.conv_guide(d5),d1.shape[2:]) 203 | 204 | r1 = self.base.layer1(x) 205 | 206 | r1t = self.conv_r1_tran(r1) 207 | d1t = self.conv_d1_tran(d1) 208 | r1t2 = self.conv_r1_tran2(r1) 209 | d1t2 = self.conv_d1_tran2(d1) 210 | 211 | # QDW 212 | iou = F.adaptive_avg_pool2d(r1t * d1t, 1) / \ 213 | (F.adaptive_avg_pool2d(r1t + d1t, 1)) 214 | 215 | e_rp = F.max_pool2d(r1t, 2, 2) 216 | e_dp = F.max_pool2d(d1t, 2, 2) 217 | 218 | e_rp2 = F.max_pool2d(e_rp, 2, 2) 219 | e_dp2 = F.max_pool2d(e_dp, 2, 2) 220 | 221 | iou_p1 = F.adaptive_avg_pool2d(e_rp * e_dp, 1) / \ 222 | (F.adaptive_avg_pool2d(e_rp + e_dp, 1)) 223 | 224 | iou_p2 = F.adaptive_avg_pool2d(e_rp2 * e_dp2, 1) / \ 225 | (F.adaptive_avg_pool2d(e_rp2 + e_dp2, 1)) 226 | 227 | gate = self.mlp(torch.cat((iou, iou_p1, iou_p2), dim=1)) 228 | 229 | 230 | # DHA 231 | mc = r1t2 * d1t2 232 | 233 | sgate = self.conv_sgate1(upsample(mc + d5_guide, d2.shape[2:])) 234 | d5_guide1 = mc + upsample(sgate, d1.shape[2:]) 235 | 236 | sgate = self.conv_sgate1(upsample(mc + d5_guide1, d2.shape[2:])) 237 | d5_guide2 = mc + upsample(sgate, d1.shape[2:]) 238 | 239 | sgate = self.conv_sgate3(d5_guide1 + d5_guide2 + mc) 240 | 241 | dqw1 = gate[:,0:1,...] 242 | dha1 = upsample(sgate[:, 0:1, ...], d1.shape[2:]) 243 | dqw2 = gate[:, 1:2, ...] 244 | dha2 = upsample(sgate[:, 1:2, ...], d2.shape[2:]) 245 | dqw3 = gate[:, 2:3, ...] 246 | dha3 = upsample(sgate[:, 2:3, ...], d3.shape[2:]) 247 | dqw4 = gate[:, 3:4, ...] 248 | dha4 = upsample(sgate[:, 3:4, ...], d4.shape[2:]) 249 | dqw5 = gate[:, 4:5, ...] 250 | dha5 = upsample(sgate[:, 4:5, ...], d5.shape[2:]) 251 | 252 | r1 = r1 + d1 * dqw1 * dha1 253 | r2 = self.base.layer2(r1) + d2 * dqw2 * dha2 254 | r3 = self.base.layer3(r2) + d3 * dqw3 * dha3 255 | r4 = self.base.layer4(r3) + d4 * dqw4 * dha4 256 | r5 = self.base.layer5(r4) + d5 * dqw5 * dha5 257 | r6 = self.ppm(r5) 258 | 259 | # Two stage decoder 260 | ## pre-fusion 261 | r5 = self.conv_cp5(r5) 262 | r4 = self.conv_cp4(r4) 263 | r3 = self.conv_cp3(r3) 264 | r2 = self.conv_cp2(r2) 265 | r1 = self.conv_cp1(r1) 266 | 267 | r5 = self.ca5(F.adaptive_avg_pool2d(r5, 1)) * r5 268 | r4 = self.ca4(F.adaptive_avg_pool2d(r4, 1)) * r4 269 | r3 = self.ca3(F.adaptive_avg_pool2d(r3, 1)) * r3 270 | r2 = self.ca2(F.adaptive_avg_pool2d(r2, 1)) * r2 271 | r1 = self.ca1(F.adaptive_avg_pool2d(r1, 1)) * r1 272 | 273 | r3 = upsample(r3, r1.shape[2:]) 274 | r2 = upsample(r2, r1.shape[2:]) 275 | rh = r4 + r5 + r6 276 | rl = r1 + r2 + r3 277 | 278 | ## full-fusion 279 | rh = upsample(rh, rl.shape[2:]) 280 | sal = self.conv_s_f (torch.cat((rh,rl),dim=1)) 281 | 282 | return sal 283 | 284 | def initialize_weights(model): 285 | m = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True) 286 | pretrained_dict = m.state_dict() 287 | all_params = {} 288 | for k, v in model.state_dict().items(): 289 | if k in pretrained_dict.keys(): 290 | v = pretrained_dict[k] 291 | all_params[k] = v 292 | model.load_state_dict(all_params,strict = False) 293 | 294 | if __name__ == '__main__': 295 | img = torch.randn(1, 3, 256, 256).cuda() 296 | depth = torch.randn(1, 1, 256, 256).cuda() 297 | model = DFMNet().cuda() 298 | model.eval() 299 | time1= time.time() 300 | outputs = model(img,depth) 301 | time2 = time.time() 302 | torch.cuda.synchronize() 303 | print(1000/(time2-time1)) 304 | num_params = 0 305 | for p in model.parameters(): 306 | num_params += p.numel() 307 | print(num_params) 308 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | from natsort import natsorted 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--local_rank', default=-1, type=int, 9 | help='node rank for distributed training') 10 | parser.add_argument('--epoch', type=int, default=301, help='epoch number') 11 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 12 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size') 13 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size') 14 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 15 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 16 | parser.add_argument('--decay_epoch', type=int, default=100, help='every n epochs decay learning rate') 17 | parser.add_argument('--load', type=str, default='./pre_train/resnet50-19c8e357.pth', help='train from checkpoints') 18 | parser.add_argument('--gpu_id', type=str, default='1', help='train use gpu') 19 | parser.add_argument('--rgb_root', type=str, default='E://pytorch/data/RGBDcollection_fast/RGB/', help='the training rgb images root') 20 | parser.add_argument('--depth_root', type=str, default='E://pytorch/data/RGBDcollection_fast/depth/', help='the training depth images root') 21 | parser.add_argument('--gt_root', type=str, default='E://pytorch/data/RGBDcollection_fast/GT/', help='the training gt images root') 22 | parser.add_argument('--edge_root', type=str, default='E://pytorch/data/RGBDcollection_fast/edge/', help='the training edge images root') 23 | parser.add_argument('--test_rgb_root', type=str, default='E://pytorch/data/test_in_train/RGB/', help='the test rgb images root') 24 | parser.add_argument('--test_depth_root', type=str, default='E://pytorch/data/test_in_train/depth/', help='the test depth images root') 25 | parser.add_argument('--test_gt_root', type=str, default='E://pytorch/data/test_in_train/GT/', help='the test gt images root') 26 | parser.add_argument('--save_path', type=str, default='./results/train', help='the path to save models and logs') 27 | opt = parser.parse_args() 28 | 29 | -------------------------------------------------------------------------------- /pretrain/.gitkeep: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | import numpy as np 5 | import os, argparse 6 | import cv2 7 | from net import DFMNet 8 | from data import test_dataset 9 | import time 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--testsize', type=int, default=256, help='testing size') 13 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 14 | parser.add_argument('--test_path',type=str,default='./dataset/',help='test dataset path') 15 | opt = parser.parse_args() 16 | 17 | dataset_path = opt.test_path 18 | 19 | #set device for test 20 | if opt.gpu_id=='0': 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 22 | print('USE GPU 0') 23 | elif opt.gpu_id=='1': 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 25 | print('USE GPU 1') 26 | elif opt.gpu_id == '3': 27 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 28 | print('USE GPU 3') 29 | elif opt.gpu_id=='all': 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 31 | print('USE GPU 0,1,2,3') 32 | 33 | #load the model 34 | model = DFMNet() 35 | model.load_state_dict(torch.load('./pretrain/DFMNet_epoch_300.pth')) 36 | model.cuda() 37 | model.eval() 38 | 39 | #test 40 | 41 | 42 | def save(res,gt,notation=None,sigmoid=True): 43 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 44 | res = res.sigmoid().data.cpu().numpy().squeeze() if sigmoid ==True else res.data.cpu().numpy().squeeze() 45 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 46 | print('save img to: ', os.path.join(save_path, name.replace('.png','_'+notation+'.png') if notation != None else name)) 47 | cv2.imwrite(os.path.join(save_path, name.replace('.png','_'+notation+'.png') if notation != None else name), res * 255) 48 | 49 | test_datasets = ['NJU2K','NLPR','STERE', 'RGBD135', 'LFSD','SIP'] 50 | for dataset in test_datasets: 51 | with torch.no_grad(): 52 | save_path = './results/benchmark/' + dataset 53 | if not os.path.exists(save_path): 54 | os.makedirs(save_path) 55 | image_root = dataset_path + dataset + '/RGB/' 56 | gt_root = dataset_path + dataset + '/GT/' 57 | depth_root=dataset_path +dataset +'/depth/' 58 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize) 59 | 60 | for i in range(test_loader.size): 61 | image, gt,depth, name, image_for_post = test_loader.load_data() 62 | gt = np.asarray(gt, np.float32) 63 | gt /= (gt.max() + 1e-8) 64 | image = image.cuda() 65 | depth = depth.cuda() 66 | torch.cuda.synchronize() 67 | time_s = time.time() 68 | out = model(image,depth) 69 | torch.cuda.synchronize() 70 | time_e = time.time() 71 | t = time_e - time_s 72 | print("time: {:.2f} ms".format(t*1000)) 73 | save(out[0],gt) 74 | print('Test Done!') 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append('./models') 6 | import numpy as np 7 | from datetime import datetime 8 | from net import DFMNet 9 | from data import get_loader,test_dataset 10 | from utils import clip_gradient, LR_Scheduler 11 | from torch.utils.tensorboard import SummaryWriter 12 | import logging 13 | import torch.backends.cudnn as cudnn 14 | from options import opt 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | def upsample(x, size): 20 | return F.interpolate(x, size, mode='bilinear', align_corners=True) 21 | 22 | #train function 23 | def train(train_loader, model, optimizer, epoch,save_path): 24 | 25 | global step 26 | model.train() 27 | loss_all=0 28 | epoch_step=0 29 | try: 30 | for i, (images, gts, depths) in enumerate(train_loader, start=1): 31 | optimizer.zero_grad() 32 | images = images.cuda() 33 | gts = gts.cuda() 34 | depths=depths.cuda() 35 | 36 | 37 | cur_lr = lr_scheduler(optimizer, i, epoch) 38 | writer.add_scalar('learning_rate', cur_lr, global_step=(epoch-1)*total_step + i) 39 | 40 | out,feature_r,feature_d = model(images,depths) 41 | loss_f = F.binary_cross_entropy_with_logits(out[0], gts) 42 | loss_d = F.binary_cross_entropy_with_logits(out[1], gts) 43 | 44 | 45 | loss = loss_f + loss_d 46 | loss.backward() 47 | 48 | clip_gradient(optimizer, opt.clip) 49 | optimizer.step() 50 | step+=1 51 | epoch_step+=1 52 | loss_all+=loss.data 53 | 54 | 55 | if i % 100 == 0 or i == total_step or i==1: 56 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], loss: {:.4f}, loss_final: {:.4f}, loss_d: {:.4f}'. 57 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss,loss_f.data,loss_d.data )) 58 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} '. 59 | format( epoch, opt.epoch, i, total_step, loss.data)) 60 | writer.add_scalar('Loss', loss.data, global_step=step) 61 | 62 | loss_all/=epoch_step 63 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, opt.epoch, loss_all)) 64 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 65 | if epoch == 300: 66 | torch.save(model.state_dict(), save_path+'/epoch_{}.pth'.format(epoch)) 67 | except KeyboardInterrupt: 68 | print('Keyboard Interrupt: save model and exit.') 69 | if not os.path.exists(save_path): 70 | os.makedirs(save_path) 71 | torch.save(model.state_dict(), save_path+'/epoch_{}.pth'.format(epoch+1)) 72 | print('save checkpoints successfully!') 73 | raise 74 | 75 | #test function 76 | def test(test_loader,model,epoch,save_path): 77 | global best_mae,best_epoch 78 | model.eval() 79 | with torch.no_grad(): 80 | mae_sum=0 81 | for i in range(test_loader.size): 82 | image, gt,depth, name,img_for_post = test_loader.load_data() 83 | gt = np.asarray(gt, np.float32) 84 | gt /= (gt.max() + 1e-8) 85 | image = image.cuda() 86 | depth = depth.cuda() 87 | res,_,_ = model(image,depth) 88 | res = F.upsample(res[0], size=gt.shape, mode='bilinear', align_corners=False) 89 | res = res.sigmoid().data.cpu().numpy().squeeze() 90 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 91 | mae_sum+=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1]) 92 | mae=mae_sum/test_loader.size 93 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch) 94 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch)) 95 | if epoch==1: 96 | best_mae=mae 97 | torch.save(model.state_dict(), save_path + '/epoch_best.pth') 98 | else: 99 | if mae 0 and T < self.warmup_iters: 57 | lr = lr * 1.0 * T / self.warmup_iters 58 | # if epoch > self.epoch: 59 | # print('\n=>Epoches %i, learning rate = %.4f, \ 60 | # previous best = %.4f' % (epoch, lr, best_pred)) 61 | # self.epoch = epoch 62 | assert lr >= 0 63 | self._adjust_learning_rate(optimizer, lr) 64 | return lr 65 | 66 | def _adjust_learning_rate(self, optimizer, lr): 67 | if len(optimizer.param_groups) == 1: 68 | optimizer.param_groups[0]['lr'] = lr 69 | else: 70 | # enlarge the lr at the head 71 | optimizer.param_groups[0]['lr'] = lr 72 | for i in range(1, len(optimizer.param_groups)): 73 | optimizer.param_groups[i]['lr'] = lr * 10 --------------------------------------------------------------------------------