├── README.md ├── dataloader_rgbdsod.py ├── dataset └── SSD │ ├── GT │ ├── histrory10301.png │ ├── histrory10401.png │ └── histrory10501.png │ ├── RGB │ ├── histrory10301.jpg │ ├── histrory10401.jpg │ └── histrory10501.jpg │ └── depth │ ├── histrory10301.png │ ├── histrory10401.png │ └── histrory10501.png ├── eval.py ├── eval └── pretrained_models │ └── put_model_here.txt ├── figures ├── D3Net-Result.jpg ├── D3Net-TNNLS20.png ├── DDU2.png ├── RunTime.png ├── SIP.png └── SIP2.png ├── model ├── DepthNet.py ├── RgbNet.py ├── RgbdNet.py ├── __init__.py └── vgg_new.py ├── my_custom_transforms.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RGB-D Salient Object Detection: Models, Data Sets, and Large-Scale Benchmarks (TNNLS2021) 2 | Rethinking RGB-D Salient Object Detection: Models, Data Sets, and Large-Scale Benchmarks, IEEE TNNLS 2021. 3 | 4 | ### :fire: NEWS :fire: 5 | - [2022/06/09] :boom: Update the related works. 6 | - [2020/08/02] : Release the training code. 7 | 8 |

9 |
10 | 11 | Figure 1: Illustration of the proposed D3Net. In the training stage (Left), the input RGB and depth images are processed with three parallel sub-networks, e.g., RgbNet, RgbdNet, and DepthNet. The three sub-networks are based on a same modified structure of Feature Pyramid Networks (FPN) (see § IV-A for details). We introduced these sub-networks to obtain three saliency maps (i.e., Srgb, Srgbd, and Sdepth) which considered both coarse and fine details of the input. In the test phase (Right), a novel depth depurator unit (DDU) (§ IV-B) is utilized for the first time in this work to explicitly discard (i.e., Srgbd) or keep (i.e., Srgbd) the saliency map introduced by the depth map. In the training/test phase, these components form a nested structure and are elaborately designed (e.g., gate connection in DDU) to automatically learn the salient object from the RGB image and Depth image jointly. 12 | 13 |

14 | 15 | ### Table of Contents 16 | - [RGB-D Salient Object Detection ](#Title) 17 | - [Table of Contents](#table-of-contents) 18 | - [Abstract](#abstract) 19 | - [Notion of Depth Depurator Unit](#Notion-of-Depth-Depurator-Unit) 20 | - [Related Works](#related-works) 21 | - [SIP dataset](#SIP-dataset) 22 | - [Train](#train) 23 | - [Evaluation](#evaluation) 24 | - [Results](#results) 25 | - [Citation](#citation) 26 | 27 | ## Abstract 28 | The use of RGB-D information for salient object detection has been explored in recent years. However, relatively few efforts have been spent in modeling salient object detection over real-world human activity scenes with RGB-D. In this work, we fill the gap by making the following contributions to RGB-D salient object detection. First, we carefully collect a new salient person (SIP) dataset, which consists of 1K high-resolution images that cover diverse real-world scenes from various viewpoints, poses, occlusion, illumination, and background. Second, we conduct a large-scale and so far the most comprehensive benchmark 29 | comparing contemporary methods, which has long been missing in the area and can serve as a baseline for future research. We systematically summarized 31 popular models, evaluated 17 state-of-the-art methods over seven datasets with totally about 91K images. Third, we propose a simple baseline architecture, called Deep Depth-Depurator Network (D3Net). It consists of a depth depurator unit and a feature learning module, performing initial low-quality depth map filtering and cross-modal feature learning respectively. These components form a nested structure and are elaborately designed to be learned jointly. D3Net exceeds the performance of any prior contenders across five metrics considered, thus serves as a strong baseline to advance the research frontier. We also demonstrate that D3Net can be used to efficiently extract salient person masks from the real scenes, enabling effective background changed book cover application with 20 fps on a single GPU. All the saliency maps, our new SIP dataset, baseline model, and evaluation tools are made publicly available at https://github.com/DengPingFan/D3NetBenchmark. 30 | 31 | 32 | ## Notion of Depth Depurator Unit 33 | The statistics of the depth maps in existing datasets (e.g., NJU2K, NLPR, RGBD135, STERE, and LFSD) suggest that — “high quality depth maps usually contain clear objects, but the elements in low-quality depth maps are cluttered (2nd row in Fig. 2)” 34 | 35 |

36 |
37 | 38 | Figure 2: The smoothed histogram (c) of high-quality (1st row), lowquality (2nd row) depth map, respectively. 39 | 40 |

41 | 42 | ## Related Works 43 | Please refer to our recent survey paper: https://github.com/taozh2017/RGBD-SODsurvey 44 | 45 | Paper with code: https://paperswithcode.com/task/rgb-d-salient-object-detection 46 | 47 | ## SIP dataset 48 |

49 |
50 | 51 | Figure 3: Representative subsets in our SIP. The images in SIP are grouped into eight subsets according to background objects (i.e., grass, car, barrier, road, 52 | sign, tree, flower, and other), different lighting conditions (i.e., low light and sunny with clear object boundary), and various number of objects (i.e., 1, 2, ≥3). 53 | 54 |

55 | 56 |

57 |
58 | 59 | Figure 4: Examples of images, depth maps and annotations (i.e., object level and instance level) in our SIP data set with different numbers of salient objects, 60 | object sizes, object positions, scene complexities, and lighting conditions. Note that the “RGB” and “Gray” images are captured by two different monocular 61 | cameras from short distances. Thus, the “Gray” images are slightly different from the grayscale images obtained from colorful (RGB) image. Our SIP data 62 | set provides a new direction, such as depth estimating from “RGB” and “Gray” images, and instance-level RGB-D SOD. 63 | 64 |

65 | 66 | RGB-D SOD Datasets: 67 | **No.** |**Dataset** | **Year** | **Pub.** |**Size** | **#Obj.** | **Types** | **Resolution** | **Download** 68 | :-: | :-: | :-: | :- | :- | :-:| :-: | :-: | :-: 69 | 1 | [**STERE**]() |2012 |CVPR | 1000 | ~One |Internet | [251-1200] * [222-900] | [Baidu: rcql](https://pan.baidu.com/s/1CzBX7dHW9UNzhMC2Z02qTw)/[Google (1.29G)](https://drive.google.com/file/d/1JYfSHsKXC3GLaxcZcZSHkluMFGX0bJza/view?usp=sharing) 70 | 2 | [**GIT**](http://www.bmva.org/bmvc/2013/Papers/paper0112/abstract0112.pdf) |2013 |BMVC | 80 | Multiple |Home environment | 640 * 480 | [Baidu](https://pan.baidu.com/s/15sG1xx93oqWZAxAaVKu4lg)/[Google (35.6M)](https://drive.google.com/open?id=13zis--Pg9--bqNCjTOJGpCThbOly8Epa) 71 | 3 | [**DES**]() |2014 |ICIMCS | 135 | One |Indoor | 640 * 480 | [Baidu: qhen](https://pan.baidu.com/s/1RRp8oV9FYMmPDU5sMXYH6g)/[Google (60.4M)](https://drive.google.com/open?id=15Th-xDeRjkcefS8eDYl-vSN967JVyjoR) 72 | 4 | [**NLPR**]() |2014 |ECCV | 1000 | Multiple |Indoor/outdoor | 640 * 480, 480 * 640 | [Baidu: n701](https://pan.baidu.com/s/1o9387dhf_J2sl-V_0NniFA)/[Google (546M)](https://drive.google.com/open?id=1CbgySAZxznbsN9uOG4pNDHwUPvQIQjCn) 73 | 5 | [**LFSD**]() |2014 |CVPR | 100 | One |Indoor/outdoor | 360 * 360 | [Baidu](https://pan.baidu.com/s/17EiZrnUc9vmx-zfVnP4iIQ)/[Google (32M)](https://drive.google.com/open?id=1cEeJpUukomdt_C4vUZlBlpc1UueuWWRU) 74 | 6 | [**NJUD**]() |2014 |ICIP | 1985 | ~One |Moive/internet/photo | [231-1213] * [274-828] | [Baidu: zjmf](https://pan.baidu.com/s/156oDr-jJij01XAtkqngF7Q)/[Google (1.54G)](https://drive.google.com/open?id=1R1O2dWr6HqpTOiDn6hZxUWTesOSJteQo) 75 | 7 | [**SSD**]() |2017 |ICCVW | 80 | Multiple |Movies | 960 * 1080 | [Baidu: e4qz](https://pan.baidu.com/s/1Yp5gSdLQlhcJclSrbr-LeA)/[Google (119M)](https://drive.google.com/open?id=1k8_TQTZbbYOpnTvc9n6jgLg4Ih4xNhCj) 76 | 8 | [**DUT-RGBD**](https://openaccess.thecvf.com/content_ICCV_2019/papers/Piao_Depth-Induced_Multi-Scale_Recurrent_Attention_Network_for_Saliency_Detection_ICCV_2019_paper.pdf) |2019 |ICCV | 1200 | Multiple |Indoor/outdoor | 400 * 600 | [Baidu: 6rt0](https://pan.baidu.com/s/1oMG7fWVAr1VUz75EcbyKVg)/[Google (100M)](https://drive.google.com/open?id=1DzkswvLo-3eYPtPoitWvFPJ8qd4EHPGv) 77 | 9 | [**SIP**]() |2020 |TNNLS | 929 | Multiple |Person in wild | 992 * 774 | [Baidu: 46w8](https://pan.baidu.com/s/1wMTDG8yhCNbioPwzq7t25w)/[Google (2.16G)](https://drive.google.com/open?id=1R91EEHzI1JwfqvQJLmyciAIWU-N8VR4A) 78 | 10 | Overall | | | | | | | [Baidu: 39un](https://pan.baidu.com/s/1DgO18k2B32lAt0naY323PA)/[Google (5.33G)](https://drive.google.com/open?id=16kgnv9NxeiPGwNNx8WoZQLl4qL0qtBZN) 79 | 80 | ## Train 81 | Put the three datasets 'NJU2K_TRAIN', 'NLPR_TRAIN','NJU2K_TEST' into the created folder "dataset". 82 | 83 | Put the vgg-pretrained model 'vgg16_feat.pth' ( [GoogleDrive](https://drive.google.com/file/d/1SXOV-DKnnqFD_b9yxJCIzdSkU7qiHh1X/view?usp=sharing) | [BaiduYun](https://pan.baidu.com/s/17qaLM3nbgR_eGehSK-SOrA) code: zsxh ) into the created folder "model". 84 | ``` 85 | python train.py --net RgbNet 86 | python train.py --net RgbdNet 87 | python train.py --net DepthNet 88 | ``` 89 | # Requirement 90 | - PyTorch>=0.4.1 91 | - Opencv 92 | 93 | # Pretrained models 94 | -RgbdNet,RgbNet,DepthNet pretrained models can be downloaded from ( [GoogleDrive](https://drive.google.com/drive/folders/1jbZzUbgOC0XzbBEsy-Bgf3b-pvr62aWK?usp=sharing) | [BaiduYun](https://pan.baidu.com/s/1sgi0KExOv5KOfGQgXpDdqw) code: xf1h ) 95 | 96 | # Training and Testing Sets 97 | Our training dataset is: 98 | 99 | https://drive.google.com/open?id=1osdm_PRnupIkM82hFbz9u0EKJC_arlQI 100 | 101 | Our testing dataset is: 102 | 103 | https://drive.google.com/open?id=1ABYxq0mL4lPq2F0paNJ7-5T9ST6XVHl1 104 | 105 | 106 | ## Evaluation 107 | Put the three pretrained models into the created folder "eval/pretrained_model". 108 | ``` 109 | python eval.py 110 | ``` 111 | 112 | Toolbox (updated in 2022/06/09): 113 | 114 | [Baidu: i09j] (https://pan.baidu.com/s/1ArnPZ4OwP67NR71OWYjitg) 115 | 116 | [Google] (https://drive.google.com/file/d/1I4Z7rA3wefN7KeEQvkGA92u99uXS_aI_/view?usp=sharing) 117 | 118 |

119 |
120 | 121 | Table1. Running time comparison. 122 | 123 |

124 | 125 | ## Results 126 |

127 |
128 |

129 | Results of our model on seven benchmark datasets can be found: 130 | 131 | Baidu Pan(https://pan.baidu.com/s/13z0ZEptUfEU6hZ6yEEISuw) 提取码: r295 132 | 133 | Google Drive(https://drive.google.com/drive/folders/1T46FyPzi3XjsB18i3HnLEqkYQWXVbCnK?usp=sharing) 134 | 135 | 136 | ## Citation 137 | If you find this work or code is helpful in your research, please cite: 138 | ``` 139 | @article{fan2019rethinking, 140 | title={{Rethinking RGB-D salient object detection: Models, datasets, and large-scale benchmarks}}, 141 | author={Fan, Deng-Ping and Lin, Zheng and Zhang, Zhao and Zhu, Menglong and Cheng, Ming-Ming}, 142 | journal={IEEE TNNLS}, 143 | year={2021} 144 | } 145 | @article{zhou2021rgbd, 146 | title={RGB-D Salient Object Detection: A Survey}, 147 | author={Zhou, Tao and Fan, Deng-Ping and Cheng, Ming-Ming and Shen, Jianbing and Shao, Ling}, 148 | journal={CVMJ}, 149 | year={2021} 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /dataloader_rgbdsod.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import glob 6 | import os 7 | from os.path import join 8 | class RgbdSodDataset(Dataset): 9 | def __init__(self, datasets , transform=None, max_num=0 , if_memory=False): 10 | super().__init__() 11 | if not isinstance(datasets,list) : datasets=[datasets] 12 | self.imgs_list, self.gts_list, self.depths_list = [], [], [] 13 | 14 | for dataset in datasets: 15 | ids=sorted(glob.glob(os.path.join(dataset,'RGB','*.jpg'))) 16 | ids=[os.path.splitext(os.path.split(id)[1])[0] for id in ids] 17 | for id in ids: 18 | self.imgs_list.append(os.path.join(dataset,'RGB',id+'.jpg')) 19 | self.gts_list.append(os.path.join(dataset,'GT',id+'.png')) 20 | self.depths_list.append(os.path.join(dataset,'depth',id+'.png')) 21 | 22 | if max_num!=0 and len(self.imgs_list)> abs(max_num): 23 | indices= random.sample(range(len(self.imgs_list)),max_num) if max_num>0 else range(abs(max_num)) 24 | self.imgs_list= [self.imgs_list[i] for i in indices] 25 | self.gts_list = [self.gts_list[i] for i in indices] 26 | self.depths_list = [self.depths_list[i] for i in indices] 27 | 28 | self.transform, self.if_memory = transform, if_memory 29 | 30 | if if_memory: 31 | self.samples=[] 32 | for index in range(len(self.imgs_list)): 33 | self.samples.append(self.get_sample(index)) 34 | 35 | def __len__(self): 36 | return len(self.imgs_list) 37 | 38 | def __getitem__(self, index): 39 | if self.if_memory: 40 | return self.transform(self.samples[index].copy()) if self.transform !=None else self.samples[index].copy() 41 | else: 42 | return self.transform(self.get_sample(index)) if self.transform !=None else self.get_sample(index) 43 | 44 | def get_sample(self,index): 45 | img = np.array(Image.open(self.imgs_list[index]).convert('RGB')) 46 | gt = np.array(Image.open(self.gts_list[index]).convert('L')) 47 | depth = np.array(Image.open(self.depths_list[index]).convert('L')) 48 | sample={'img':img , 'gt' : gt,'depth':depth} 49 | 50 | sample['meta'] = {'id': os.path.splitext(os.path.split(self.gts_list[index])[1])[0]} 51 | sample['meta']['source_size'] = np.array(gt.shape[::-1]) 52 | sample['meta']['img_path'] = self.imgs_list[index] 53 | sample['meta']['gt_path'] = self.gts_list[index] 54 | sample['meta']['depth_path'] = self.depths_list[index] 55 | return sample 56 | 57 | if __name__=='__main__': 58 | pass 59 | 60 | 61 | -------------------------------------------------------------------------------- /dataset/SSD/GT/histrory10301.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/GT/histrory10301.png -------------------------------------------------------------------------------- /dataset/SSD/GT/histrory10401.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/GT/histrory10401.png -------------------------------------------------------------------------------- /dataset/SSD/GT/histrory10501.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/GT/histrory10501.png -------------------------------------------------------------------------------- /dataset/SSD/RGB/histrory10301.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/RGB/histrory10301.jpg -------------------------------------------------------------------------------- /dataset/SSD/RGB/histrory10401.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/RGB/histrory10401.jpg -------------------------------------------------------------------------------- /dataset/SSD/RGB/histrory10501.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/RGB/histrory10501.jpg -------------------------------------------------------------------------------- /dataset/SSD/depth/histrory10301.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/depth/histrory10301.png -------------------------------------------------------------------------------- /dataset/SSD/depth/histrory10401.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/depth/histrory10401.png -------------------------------------------------------------------------------- /dataset/SSD/depth/histrory10501.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/depth/histrory10501.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #pytorch 2 | import torch 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | 6 | #general 7 | import os 8 | import cv2 9 | import shutil 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | #mine 14 | import utils 15 | import my_custom_transforms as mtr 16 | from dataloader_rgbdsod import RgbdSodDataset 17 | from PIL import Image 18 | from model.RgbNet import MyNet as RgbNet 19 | from model.RgbdNet import MyNet as RgbdNet 20 | from model.DepthNet import MyNet as DepthNet 21 | 22 | size=(224, 224) 23 | datasets_path='./dataset/' 24 | test_datasets=['SSD'] 25 | pretrained_models={'RgbNet':'./eval/pretrained_models/RgbNet.pth', 'RgbdNet':'eval/pretrained_models/RgbdNet.pth' , 'DepthNet':'eval/pretrained_models/DepthNet.pth' } 26 | result_path='./eval/result/' 27 | os.makedirs(result_path,exist_ok=True) 28 | 29 | for tmp in ['D3Net']: 30 | os.makedirs(os.path.join(result_path,tmp),exist_ok=True) 31 | for test_dataset in test_datasets: 32 | os.makedirs(os.path.join(result_path,tmp,test_dataset),exist_ok=True) 33 | 34 | model_rgb=RgbNet().cuda() 35 | model_rgbd=RgbdNet().cuda() 36 | model_depth=DepthNet().cuda() 37 | 38 | model_rgb.load_state_dict(torch.load(pretrained_models['RgbNet'])['model']) 39 | model_rgbd.load_state_dict(torch.load(pretrained_models['RgbdNet'])['model']) 40 | model_depth.load_state_dict(torch.load(pretrained_models['DepthNet'])['model']) 41 | 42 | model_rgb.eval() 43 | model_rgbd.eval() 44 | model_depth.eval() 45 | 46 | transform_test = torchvision.transforms.Compose([mtr.Resize(size),mtr.ToTensor(),mtr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],elems_do=['img'])]) 47 | 48 | test_loaders=[] 49 | for test_dataset in test_datasets: 50 | val_set=RgbdSodDataset(datasets_path+test_dataset,transform=transform_test) 51 | test_loaders.append(DataLoader(val_set, batch_size=1, shuffle=False,pin_memory=True)) 52 | 53 | for index, test_loader in enumerate(test_loaders): 54 | dataset=test_datasets[index] 55 | print('Test [{}]'.format(dataset)) 56 | 57 | for i, sample_batched in enumerate(tqdm(test_loader)): 58 | input, gt = model_rgb.get_input(sample_batched),model_rgb.get_gt(sample_batched) 59 | 60 | with torch.no_grad(): 61 | output_rgb = model_rgb(input) 62 | output_rgbd = model_rgbd(input) 63 | output_depth = model_depth(input) 64 | 65 | result_rgb = model_rgb.get_result(output_rgb) 66 | result_rgbd = model_rgbd.get_result(output_rgbd) 67 | result_depth = model_depth.get_result(output_depth) 68 | 69 | id=sample_batched['meta']['id'][0] 70 | gt_src=np.array(Image.open(sample_batched['meta']['gt_path'][0]).convert('L')) 71 | 72 | result_rgb=(cv2.resize(result_rgb, gt_src.shape[::-1], interpolation=cv2.INTER_LINEAR) *255).astype(np.uint8) 73 | result_rgbd=(cv2.resize(result_rgbd, gt_src.shape[::-1], interpolation=cv2.INTER_LINEAR) *255).astype(np.uint8) 74 | result_depth=(cv2.resize(result_depth, gt_src.shape[::-1], interpolation=cv2.INTER_LINEAR) *255).astype(np.uint8) 75 | 76 | ddu_mae=np.mean(np.abs(result_rgbd/255.0 - result_depth/255.0)) 77 | result_d3net=result_rgbd if ddu_mae<0.15 else result_rgb 78 | 79 | Image.fromarray(result_d3net).save(os.path.join(result_path,'D3Net',dataset,id+'.png')) 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /eval/pretrained_models/put_model_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/eval/pretrained_models/put_model_here.txt -------------------------------------------------------------------------------- /figures/D3Net-Result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/D3Net-Result.jpg -------------------------------------------------------------------------------- /figures/D3Net-TNNLS20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/D3Net-TNNLS20.png -------------------------------------------------------------------------------- /figures/DDU2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/DDU2.png -------------------------------------------------------------------------------- /figures/RunTime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/RunTime.png -------------------------------------------------------------------------------- /figures/SIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/SIP.png -------------------------------------------------------------------------------- /figures/SIP2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/SIP2.png -------------------------------------------------------------------------------- /model/DepthNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.vgg_new import VGG_backbone 6 | 7 | def init_weight(model): 8 | for m in model.modules(): 9 | if isinstance(m, nn.Conv2d): 10 | torch.nn.init.kaiming_normal_(m.weight) 11 | if m.bias is not None: 12 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight) 13 | bound = 1 / math.sqrt(fan_in) 14 | torch.nn.init.uniform_(m.bias, -bound, bound) 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | 19 | class Decoder(nn.Module): 20 | def __init__(self,in_channel=32,side_channel=512): 21 | super(Decoder, self).__init__() 22 | self.reduce_conv=nn.Sequential( 23 | #nn.Conv2d(side_channel, in_channel, kernel_size=3, stride=1, padding=1), 24 | nn.Conv2d(side_channel, in_channel, kernel_size=1, stride=1, padding=0), 25 | nn.ReLU(inplace=True) ### 26 | ) 27 | self.decoder = nn.Sequential( 28 | nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1), 31 | nn.ReLU(inplace=True) ### 32 | ) 33 | init_weight(self) 34 | 35 | def forward(self, x, side): 36 | x=F.interpolate(x, size=side.size()[2:], mode='bilinear', align_corners=True) 37 | side=self.reduce_conv(side) 38 | x=torch.cat((x, side), 1) 39 | x = self.decoder(x) 40 | return x 41 | 42 | 43 | class Single_Stream(nn.Module): 44 | def __init__(self,in_channel=3): 45 | super(Single_Stream, self).__init__() 46 | self.backbone = VGG_backbone(in_channel=in_channel,pre_train_path='./model/vgg16_feat.pth') 47 | self.toplayer = nn.Sequential( 48 | nn.AvgPool2d(2, stride=2), 49 | nn.Conv2d(512, 32, kernel_size=3, stride=1, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 52 | nn.ReLU(inplace=True) ### 53 | ) 54 | channels = [64, 128, 256, 512, 512, 32] 55 | # Decoders 56 | decoders = [] 57 | for idx in range(5): 58 | decoders.append(Decoder(in_channel=32,side_channel=channels[idx])) 59 | self.decoders = nn.ModuleList(decoders) 60 | init_weight(self.toplayer) 61 | 62 | def forward(self, input): 63 | l1 = self.backbone.conv1(input) 64 | l2 = self.backbone.conv2(l1) 65 | l3 = self.backbone.conv3(l2) 66 | l4 = self.backbone.conv4(l3) 67 | l5 = self.backbone.conv5(l4) 68 | l6 = self.toplayer(l5) 69 | feats=[l1, l2, l3, l4, l5, l6] 70 | 71 | x=feats[5] 72 | for idx in [4, 3, 2, 1, 0]: 73 | x=self.decoders[idx](x,feats[idx]) 74 | 75 | return x 76 | 77 | class PredLayer(nn.Module): 78 | def __init__(self, in_channel=32): 79 | super(PredLayer, self).__init__() 80 | self.enlayer = nn.Sequential( 81 | nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 84 | nn.ReLU(inplace=True), 85 | ) 86 | self.outlayer = nn.Sequential( 87 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 88 | nn.Sigmoid() 89 | ) 90 | init_weight(self) 91 | 92 | def forward(self, x): 93 | x = self.enlayer(x) 94 | x = self.outlayer(x) 95 | return x 96 | 97 | 98 | class MyNet(nn.Module): 99 | def __init__(self): 100 | super(MyNet, self).__init__() 101 | 102 | # Main-streams 103 | self.main_stream = Single_Stream(in_channel=3) 104 | 105 | # Prediction 106 | self.pred_layer = PredLayer() 107 | 108 | def forward(self, input, if_return_feat=False): 109 | rgb, dep=input 110 | dep=dep.repeat(1,3,1,1) 111 | feat = self.main_stream(dep) 112 | result = self.pred_layer(feat) 113 | 114 | if if_return_feat: 115 | return feat 116 | else: 117 | return result 118 | 119 | def get_train_params(self, lr): 120 | train_params = [{'params': self.parameters(), 'lr': lr}] 121 | return train_params 122 | 123 | def get_input(self, sample_batched): 124 | rgb,dep = sample_batched['img'].cuda(),sample_batched['depth'].cuda() 125 | return rgb,dep 126 | 127 | def get_gt(self, sample_batched): 128 | gt = sample_batched['gt'].cuda() 129 | return gt 130 | 131 | def get_result(self, output, index=0): 132 | if isinstance(output, list): 133 | result = output[0].data.cpu().numpy()[index,0,:,:] 134 | else: 135 | result = output.data.cpu().numpy()[index,0,:,:] 136 | 137 | # if isinstance(output, list): 138 | # result = torch.sigmoid(output[0].data.cpu()).numpy()[index,0,:,:] 139 | # else: 140 | # result = torch.sigmoid(output.data.cpu()).numpy()[index,0,:,:] 141 | return result 142 | 143 | def get_loss(self, output, gt, if_mean=True): 144 | criterion = nn.BCELoss().cuda() 145 | #criterion = nn.BCEWithLogitsLoss().cuda() 146 | if isinstance(output, list): 147 | loss=0 148 | for i in range(len(output)): 149 | loss+=criterion(output[i], gt) 150 | if if_mean:loss/=len(output) 151 | else: 152 | loss = criterion(output, gt) 153 | return loss 154 | 155 | 156 | if __name__ == "__main__": 157 | pass 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /model/RgbNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.vgg_new import VGG_backbone 6 | 7 | def init_weight(model): 8 | for m in model.modules(): 9 | if isinstance(m, nn.Conv2d): 10 | torch.nn.init.kaiming_normal_(m.weight) 11 | if m.bias is not None: 12 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight) 13 | bound = 1 / math.sqrt(fan_in) 14 | torch.nn.init.uniform_(m.bias, -bound, bound) 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | 19 | class Decoder(nn.Module): 20 | def __init__(self,in_channel=32,side_channel=512): 21 | super(Decoder, self).__init__() 22 | self.reduce_conv=nn.Sequential( 23 | #nn.Conv2d(side_channel, in_channel, kernel_size=3, stride=1, padding=1), 24 | nn.Conv2d(side_channel, in_channel, kernel_size=1, stride=1, padding=0), 25 | nn.ReLU(inplace=True) ### 26 | ) 27 | self.decoder = nn.Sequential( 28 | nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1), 31 | nn.ReLU(inplace=True) ### 32 | ) 33 | init_weight(self) 34 | 35 | def forward(self, x, side): 36 | x=F.interpolate(x, size=side.size()[2:], mode='bilinear', align_corners=True) 37 | side=self.reduce_conv(side) 38 | x=torch.cat((x, side), 1) 39 | x = self.decoder(x) 40 | return x 41 | 42 | 43 | class Single_Stream(nn.Module): 44 | def __init__(self,in_channel=3): 45 | super(Single_Stream, self).__init__() 46 | self.backbone = VGG_backbone(in_channel=in_channel,pre_train_path='./model/vgg16_feat.pth') 47 | self.toplayer = nn.Sequential( 48 | nn.AvgPool2d(2, stride=2), 49 | nn.Conv2d(512, 32, kernel_size=3, stride=1, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 52 | nn.ReLU(inplace=True) ### 53 | ) 54 | channels = [64, 128, 256, 512, 512, 32] 55 | # Decoders 56 | decoders = [] 57 | for idx in range(5): 58 | decoders.append(Decoder(in_channel=32,side_channel=channels[idx])) 59 | self.decoders = nn.ModuleList(decoders) 60 | init_weight(self.toplayer) 61 | 62 | def forward(self, input): 63 | l1 = self.backbone.conv1(input) 64 | l2 = self.backbone.conv2(l1) 65 | l3 = self.backbone.conv3(l2) 66 | l4 = self.backbone.conv4(l3) 67 | l5 = self.backbone.conv5(l4) 68 | l6 = self.toplayer(l5) 69 | feats=[l1, l2, l3, l4, l5, l6] 70 | 71 | x=feats[5] 72 | for idx in [4, 3, 2, 1, 0]: 73 | x=self.decoders[idx](x,feats[idx]) 74 | 75 | return x 76 | 77 | class PredLayer(nn.Module): 78 | def __init__(self, in_channel=32): 79 | super(PredLayer, self).__init__() 80 | self.enlayer = nn.Sequential( 81 | nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 84 | nn.ReLU(inplace=True), 85 | ) 86 | self.outlayer = nn.Sequential( 87 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 88 | nn.Sigmoid() 89 | ) 90 | init_weight(self) 91 | 92 | def forward(self, x): 93 | x = self.enlayer(x) 94 | x = self.outlayer(x) 95 | return x 96 | 97 | 98 | class MyNet(nn.Module): 99 | def __init__(self): 100 | super(MyNet, self).__init__() 101 | 102 | # Main-streams 103 | self.main_stream = Single_Stream(in_channel=3) 104 | 105 | # Prediction 106 | self.pred_layer = PredLayer() 107 | 108 | def forward(self, input, if_return_feat=False): 109 | rgb, dep=input 110 | feat = self.main_stream(rgb) 111 | result = self.pred_layer(feat) 112 | 113 | if if_return_feat: 114 | return feat 115 | else: 116 | return result 117 | 118 | def get_train_params(self, lr): 119 | train_params = [{'params': self.parameters(), 'lr': lr}] 120 | return train_params 121 | 122 | def get_input(self, sample_batched): 123 | rgb,dep = sample_batched['img'].cuda(),sample_batched['depth'].cuda() 124 | return rgb,dep 125 | 126 | def get_gt(self, sample_batched): 127 | gt = sample_batched['gt'].cuda() 128 | return gt 129 | 130 | def get_result(self, output, index=0): 131 | if isinstance(output, list): 132 | result = output[0].data.cpu().numpy()[index,0,:,:] 133 | else: 134 | result = output.data.cpu().numpy()[index,0,:,:] 135 | 136 | # if isinstance(output, list): 137 | # result = torch.sigmoid(output[0].data.cpu()).numpy()[index,0,:,:] 138 | # else: 139 | # result = torch.sigmoid(output.data.cpu()).numpy()[index,0,:,:] 140 | return result 141 | 142 | def get_loss(self, output, gt, if_mean=True): 143 | criterion = nn.BCELoss().cuda() 144 | #criterion = nn.BCEWithLogitsLoss().cuda() 145 | if isinstance(output, list): 146 | loss=0 147 | for i in range(len(output)): 148 | loss+=criterion(output[i], gt) 149 | if if_mean:loss/=len(output) 150 | else: 151 | loss = criterion(output, gt) 152 | return loss 153 | 154 | 155 | if __name__ == "__main__": 156 | pass 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /model/RgbdNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.vgg_new import VGG_backbone 6 | 7 | def init_weight(model): 8 | for m in model.modules(): 9 | if isinstance(m, nn.Conv2d): 10 | torch.nn.init.kaiming_normal_(m.weight) 11 | if m.bias is not None: 12 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight) 13 | bound = 1 / math.sqrt(fan_in) 14 | torch.nn.init.uniform_(m.bias, -bound, bound) 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | 19 | class Decoder(nn.Module): 20 | def __init__(self,in_channel=32,side_channel=512): 21 | super(Decoder, self).__init__() 22 | self.reduce_conv=nn.Sequential( 23 | #nn.Conv2d(side_channel, in_channel, kernel_size=3, stride=1, padding=1), 24 | nn.Conv2d(side_channel, in_channel, kernel_size=1, stride=1, padding=0), 25 | nn.ReLU(inplace=True) ### 26 | ) 27 | self.decoder = nn.Sequential( 28 | nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1), 31 | nn.ReLU(inplace=True) ### 32 | ) 33 | init_weight(self) 34 | 35 | def forward(self, x, side): 36 | x=F.interpolate(x, size=side.size()[2:], mode='bilinear', align_corners=True) 37 | side=self.reduce_conv(side) 38 | x=torch.cat((x, side), 1) 39 | x = self.decoder(x) 40 | return x 41 | 42 | 43 | class Single_Stream(nn.Module): 44 | def __init__(self,in_channel=3): 45 | super(Single_Stream, self).__init__() 46 | self.backbone = VGG_backbone(in_channel=in_channel,pre_train_path='./model/vgg16_feat.pth') 47 | self.toplayer = nn.Sequential( 48 | nn.AvgPool2d(2, stride=2), 49 | nn.Conv2d(512, 32, kernel_size=3, stride=1, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 52 | nn.ReLU(inplace=True) ### 53 | ) 54 | channels = [64, 128, 256, 512, 512, 32] 55 | # Decoders 56 | decoders = [] 57 | for idx in range(5): 58 | decoders.append(Decoder(in_channel=32,side_channel=channels[idx])) 59 | self.decoders = nn.ModuleList(decoders) 60 | init_weight(self.toplayer) 61 | 62 | def forward(self, input): 63 | l1 = self.backbone.conv1(input) 64 | l2 = self.backbone.conv2(l1) 65 | l3 = self.backbone.conv3(l2) 66 | l4 = self.backbone.conv4(l3) 67 | l5 = self.backbone.conv5(l4) 68 | l6 = self.toplayer(l5) 69 | feats=[l1, l2, l3, l4, l5, l6] 70 | 71 | x=feats[5] 72 | for idx in [4, 3, 2, 1, 0]: 73 | x=self.decoders[idx](x,feats[idx]) 74 | 75 | return x 76 | 77 | class PredLayer(nn.Module): 78 | def __init__(self, in_channel=32): 79 | super(PredLayer, self).__init__() 80 | self.enlayer = nn.Sequential( 81 | nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 84 | nn.ReLU(inplace=True), 85 | ) 86 | self.outlayer = nn.Sequential( 87 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 88 | nn.Sigmoid() 89 | ) 90 | init_weight(self) 91 | 92 | def forward(self, x): 93 | x = self.enlayer(x) 94 | x = self.outlayer(x) 95 | return x 96 | 97 | 98 | class MyNet(nn.Module): 99 | def __init__(self): 100 | super(MyNet, self).__init__() 101 | 102 | # Main-streams 103 | self.main_stream = Single_Stream(in_channel=4) 104 | 105 | # Prediction 106 | self.pred_layer = PredLayer() 107 | 108 | def forward(self, input): 109 | rgb, dep=input 110 | x=torch.cat((rgb,dep),1) 111 | x = self.main_stream(x) 112 | x = self.pred_layer(x) 113 | return x 114 | 115 | def get_train_params(self, lr): 116 | train_params = [{'params': self.parameters(), 'lr': lr}] 117 | return train_params 118 | 119 | def get_input(self, sample_batched): 120 | rgb,dep = sample_batched['img'].cuda(),sample_batched['depth'].cuda() 121 | return rgb,dep 122 | 123 | def get_gt(self, sample_batched): 124 | gt = sample_batched['gt'].cuda() 125 | return gt 126 | 127 | def get_result(self, output, index=0): 128 | if isinstance(output, list): 129 | result = output[0].data.cpu().numpy()[index,0,:,:] 130 | else: 131 | result = output.data.cpu().numpy()[index,0,:,:] 132 | 133 | # if isinstance(output, list): 134 | # result = torch.sigmoid(output[0].data.cpu()).numpy()[index,0,:,:] 135 | # else: 136 | # result = torch.sigmoid(output.data.cpu()).numpy()[index,0,:,:] 137 | return result 138 | 139 | def get_loss(self, output, gt, if_mean=True): 140 | criterion = nn.BCELoss().cuda() 141 | #criterion = nn.BCEWithLogitsLoss().cuda() 142 | if isinstance(output, list): 143 | loss=0 144 | for i in range(len(output)): 145 | loss+=criterion(output[i], gt) 146 | if if_mean:loss/=len(output) 147 | else: 148 | loss = criterion(output, gt) 149 | return loss 150 | 151 | 152 | if __name__ == "__main__": 153 | pass 154 | 155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/model/__init__.py -------------------------------------------------------------------------------- /model/vgg_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | class VGG_backbone(nn.Module): 6 | # VGG16 with two branches 7 | # pooling layer at the front of block 8 | def __init__(self,in_channel=3,pre_train_path=None): 9 | super(VGG_backbone, self).__init__() 10 | self.in_channel=in_channel 11 | conv1 = nn.Sequential() 12 | conv1.add_module('conv1_1', nn.Conv2d(in_channel, 64, 3, 1, 1)) 13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True)) 14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1)) 15 | conv1.add_module('relu1_2', nn.ReLU(inplace=True)) 16 | 17 | self.conv1 = conv1 18 | conv2 = nn.Sequential() 19 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2)) 20 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1)) 21 | conv2.add_module('relu2_1', nn.ReLU()) 22 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1)) 23 | conv2.add_module('relu2_2', nn.ReLU()) 24 | self.conv2 = conv2 25 | 26 | conv3 = nn.Sequential() 27 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2)) 28 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1)) 29 | conv3.add_module('relu3_1', nn.ReLU()) 30 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1)) 31 | conv3.add_module('relu3_2', nn.ReLU()) 32 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1)) 33 | conv3.add_module('relu3_3', nn.ReLU()) 34 | self.conv3 = conv3 35 | 36 | conv4 = nn.Sequential() 37 | conv4.add_module('pool3', nn.MaxPool2d(2, stride=2)) 38 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1)) 39 | conv4.add_module('relu4_1', nn.ReLU()) 40 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1)) 41 | conv4.add_module('relu4_2', nn.ReLU()) 42 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1)) 43 | conv4.add_module('relu4_3', nn.ReLU()) 44 | self.conv4 = conv4 45 | 46 | conv5 = nn.Sequential() 47 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2)) 48 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1)) 49 | conv5.add_module('relu5_1', nn.ReLU()) 50 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1)) 51 | conv5.add_module('relu5_2', nn.ReLU()) 52 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1)) 53 | conv5.add_module('relu5_3', nn.ReLU()) 54 | self.conv5 = conv5 55 | 56 | if pre_train_path is not None and os.path.exists(pre_train_path): 57 | self._initialize_weights(torch.load(pre_train_path)) 58 | 59 | def forward(self, x): 60 | x = self.conv1(x) 61 | x = self.conv2(x) 62 | x = self.conv3(x) 63 | x = self.conv4(x) 64 | x = self.conv5(x) 65 | return x 66 | def _initialize_weights(self, pre_train): 67 | keys = list(pre_train.keys()) 68 | 69 | torch.nn.init.kaiming_normal_(self.conv1.conv1_1.weight) 70 | self.conv1.conv1_1.weight.data[:,:3,:,:].copy_(pre_train[keys[0]]) 71 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]]) 72 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]]) 73 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]]) 74 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]]) 75 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]]) 76 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]]) 77 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]]) 78 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]]) 79 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]]) 80 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]]) 81 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]]) 82 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]]) 83 | 84 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]]) 85 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]]) 86 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]]) 87 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]]) 88 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]]) 89 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]]) 90 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]]) 91 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]]) 92 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]]) 93 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]]) 94 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]]) 95 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]]) 96 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]]) 97 | 98 | if __name__ == "__main__": 99 | model=VGG_backbone(in_channel=6,pre_train_path='vgg16_feat.pth') 100 | -------------------------------------------------------------------------------- /my_custom_transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import torch 4 | import random 5 | import numbers 6 | import numpy as np 7 | from PIL import Image 8 | ########################################[ function ]######################################## 9 | 10 | def img_rotate(img, angle, center=None, if_expand=False, scale=1.0, mode=None): 11 | (h, w) = img.shape[:2] 12 | if center is None: center = (w // 2 ,h // 2) 13 | M = cv2.getRotationMatrix2D(center, angle, scale) 14 | if mode is None: mode=cv2.INTER_LINEAR if len(img.shape)==3 else cv2.INTER_NEAREST 15 | if if_expand: 16 | h_new=int(w*math.fabs(math.sin(math.radians(angle)))+h*math.fabs(math.cos(math.radians(angle)))) 17 | w_new=int(h*math.fabs(math.sin(math.radians(angle)))+w*math.fabs(math.cos(math.radians(angle)))) 18 | M[0,2] +=(w_new-w)/2 19 | M[1,2] +=(h_new-h)/2 20 | h, w =h_new, w_new 21 | rotated = cv2.warpAffine(img, M, (w, h),flags=mode) 22 | return rotated 23 | 24 | 25 | def img_rotate_point(img, angle, center=None, if_expand=False, scale=1.0): 26 | (h, w) = img.shape[:2] 27 | if center is None: center = (w // 2 ,h // 2) 28 | M = cv2.getRotationMatrix2D(center, angle, scale) 29 | if if_expand: 30 | h_new=int(w*math.fabs(math.sin(math.radians(angle)))+h*math.fabs(math.cos(math.radians(angle)))) 31 | w_new=int(h*math.fabs(math.sin(math.radians(angle)))+w*math.fabs(math.cos(math.radians(angle)))) 32 | M[0,2] +=(w_new-w)/2 33 | M[1,2] +=(h_new-h)/2 34 | h, w =h_new, w_new 35 | 36 | 37 | pts_y, pts_x= np.where(img==1) 38 | pts_xy=np.concatenate( (pts_x[:,np.newaxis], pts_y[:,np.newaxis]), axis=1 ) 39 | pts_xy_new= np.rint(np.dot( np.insert(pts_xy,2,1,axis=1), M.T)).astype(np.int64) 40 | 41 | img_new=np.zeros((h,w),dtype=np.uint8) 42 | for pt in pts_xy_new: 43 | img_new[pt[1], pt[0]]=1 44 | return img_new 45 | 46 | 47 | def img_resize_point(img, size): 48 | (h, w) = img.shape 49 | if not isinstance(size, tuple): size=( int(w*size), int(h*size) ) 50 | M=np.array([[size[0]/w,0,0],[0,size[1]/h,0]]) 51 | 52 | pts_y, pts_x= np.where(img==1) 53 | pts_xy=np.concatenate( (pts_x[:,np.newaxis], pts_y[:,np.newaxis]), axis=1 ) 54 | pts_xy_new= np.dot( np.insert(pts_xy,2,1,axis=1), M.T).astype(np.int64) 55 | 56 | img_new=np.zeros(size[::-1],dtype=np.uint8) 57 | for pt in pts_xy_new: 58 | img_new[pt[1], pt[0]]=1 59 | return img_new 60 | 61 | ########################################[ General ]######################################## 62 | 63 | #Template for all same operation 64 | class Template(object): 65 | def __init__(self, elems_do=None, elems_undo=[]): 66 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 67 | def __call__(self, sample): 68 | for elem in sample.keys(): 69 | if self.elems_do!= None and elem not in self.elems_do :continue 70 | if elem in self.elems_undo:continue 71 | pass 72 | return sample 73 | 74 | 75 | class Transform(object): 76 | def __init__(self, transform, if_numpy=True, elems_do=None, elems_undo=[]): 77 | self.transform, self.if_numpy = transform, if_numpy 78 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 79 | def __call__(self, sample): 80 | for elem in sample.keys(): 81 | if self.elems_do!= None and elem not in self.elems_do :continue 82 | if elem in self.elems_undo:continue 83 | tmp = self.transform(Image.fromarray(sample[elem])) 84 | sample[elem] = np.array(tmp) if self.if_numpy else tmp 85 | return sample 86 | 87 | 88 | class ToPilImage(object): 89 | def __init__(self, elems_do=None, elems_undo=[]): 90 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 91 | def __call__(self, sample): 92 | for elem in sample.keys(): 93 | if self.elems_do!= None and elem not in self.elems_do :continue 94 | if elem in self.elems_undo:continue 95 | sample[elem] = Image.fromarray(sample[elem]) 96 | return sample 97 | 98 | 99 | class ToNumpyImage(object): 100 | def __init__(self, elems_do=None, elems_undo=[]): 101 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 102 | def __call__(self, sample): 103 | for elem in sample.keys(): 104 | if self.elems_do!= None and elem not in self.elems_do :continue 105 | if elem in self.elems_undo:continue 106 | sample[elem] = np.array(sample[elem]) 107 | return sample 108 | 109 | 110 | class ImageToOne(object): 111 | def __init__(self, elems_do=None, elems_undo=[]): 112 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 113 | def __call__(self, sample): 114 | for elem in sample.keys(): 115 | if self.elems_do!= None and elem not in self.elems_do :continue 116 | if elem in self.elems_undo:continue 117 | sample[elem]=np.array(sample[elem])/255.0 118 | return sample 119 | 120 | 121 | class ToTensor(object): 122 | def __init__(self, if_div=True, elems_do=None, elems_undo=[]): 123 | self.if_div = if_div 124 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 125 | def __call__(self, sample): 126 | for elem in sample.keys(): 127 | if self.elems_do!= None and elem not in self.elems_do :continue 128 | if elem in self.elems_undo:continue 129 | tmp = sample[elem] 130 | tmp = tmp[np.newaxis,:,:] if tmp.ndim == 2 else tmp.transpose((2, 0, 1)) 131 | tmp = torch.from_numpy(tmp).float() 132 | tmp = tmp.float().div(255) if self.if_div else tmp 133 | sample[elem] = tmp 134 | return sample 135 | 136 | 137 | class Normalize(object): 138 | def __init__(self, mean, std, elems_do=None, elems_undo=[]): 139 | self.mean, self.std = mean, std 140 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 141 | def __call__(self, sample): 142 | for elem in sample.keys(): 143 | if self.elems_do!= None and elem not in self.elems_do :continue 144 | if elem in self.elems_undo:continue 145 | tensor = sample[elem] 146 | #print(tensor.min(),tensor.max()) 147 | for t, m, s in zip(tensor, self.mean, self.std): 148 | t.sub_(m).div_(s) 149 | #print(tensor.min(),tensor.max()) 150 | 151 | return sample 152 | 153 | 154 | class Show(object): 155 | def __init__(self, elems_show=['img','gt'], elems_do=None, elems_undo=[]): 156 | self.elems_show = elems_show 157 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 158 | def __call__(self, sample): 159 | show_list=[ sample[elem] for elem in self.elems_show ] 160 | return sample 161 | 162 | 163 | 164 | class TestDebug(object): 165 | def __init__(self, elems_do=None, elems_undo=[]): 166 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 167 | def __call__(self, sample): 168 | #print(sample['depth'].min(),sample['depth'].max()) 169 | return sample 170 | 171 | 172 | ########################################[ Basic Image Augmentation ]######################################## 173 | 174 | 175 | class RandomFlip(object): 176 | def __init__(self, direction=Image.FLIP_LEFT_RIGHT, p=0.5, elems_do=None, elems_undo=[]): 177 | self.direction, self.p = direction, p 178 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 179 | def __call__(self, sample): 180 | if random.random() < self.p: 181 | for elem in sample.keys(): 182 | if self.elems_do!= None and elem not in self.elems_do :continue 183 | if elem in self.elems_undo:continue 184 | sample[elem]= np.array(Image.fromarray(sample[elem]).transpose(self.direction)) 185 | sample['meta']['flip']=1 186 | else: 187 | sample['meta']['flip']=0 188 | return sample 189 | 190 | 191 | class RandomRotation(object): 192 | def __init__(self, angle_range=30, if_expand=False, mode=None, elems_point=['pos_points_mask','neg_points_mask'], elems_do=None, elems_undo=[]): 193 | self.angle_range = (-angle_range, angle_range) if isinstance(angle_range, numbers.Number) else angle_range 194 | self.if_expand, self.mode = if_expand, mode 195 | self.elems_point = elems_point 196 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 197 | 198 | def __call__(self, sample): 199 | angle = random.uniform(self.angle_range[0], self.angle_range[1]) 200 | for elem in sample.keys(): 201 | if self.elems_do!= None and elem not in self.elems_do :continue 202 | if elem in self.elems_undo:continue 203 | 204 | if elem in self.elems_point: 205 | sample[elem]=img_rotate_point(sample[elem], angle, if_expand=self.if_expand) 206 | continue 207 | 208 | sample[elem]=img_rotate(sample[elem], angle, if_expand=self.if_expand, mode=self.mode) 209 | return sample 210 | 211 | 212 | class Resize(object): 213 | def __init__(self, size, mode=None, elems_point=['pos_points_mask','neg_points_mask'], elems_do=None, elems_undo=[]): 214 | self.size, self.mode = size, mode 215 | self.elems_point = elems_point 216 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 217 | def __call__(self, sample): 218 | for elem in sample.keys(): 219 | if self.elems_do!= None and elem not in self.elems_do :continue 220 | if elem in self.elems_undo:continue 221 | 222 | if elem in self.elems_point: 223 | sample[elem]=img_resize_point(sample[elem],self.size) 224 | continue 225 | 226 | if self.mode is None: 227 | mode = cv2.INTER_LINEAR if len(sample[elem].shape)==3 else cv2.INTER_NEAREST 228 | sample[elem] = cv2.resize(sample[elem], self.size, interpolation=mode) 229 | 230 | return sample 231 | 232 | 233 | #扩充边界pad(上下左右) 234 | class Expand(object): 235 | def __init__(self, pad=(0,0,0,0), elems_do=None, elems_undo=[]): 236 | if isinstance(pad, int): 237 | self.pad=(pad, pad, pad, pad) 238 | elif len(pad)==2: 239 | self.pad=(pad[0],pad[0],pad[1],pad[1]) 240 | elif len(pad)==4: 241 | self.pad= pad 242 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 243 | def __call__(self, sample): 244 | for elem in sample.keys(): 245 | if self.elems_do!= None and elem not in self.elems_do :continue 246 | if elem in self.elems_undo:continue 247 | sample[elem]=cv2.copyMakeBorder(sample[elem],self.pad[0],self.pad[1],self.pad[2],self.pad[3],cv2.BORDER_CONSTANT) 248 | return sample 249 | 250 | 251 | class Crop(object): 252 | def __init__(self, x_range, y_range, elems_do=None, elems_undo=[]): 253 | self.x_range, self.y_range = x_range, y_range 254 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 255 | def __call__(self, sample): 256 | for elem in sample.keys(): 257 | if self.elems_do!= None and elem not in self.elems_do :continue 258 | if elem in self.elems_undo:continue 259 | sample[elem]=sample[elem][self.y_range[0]:self.y_range[1], self.x_range[0]:self.x_range[1], ...] 260 | 261 | sample['meta']['crop_size'] = np.array((self.x_range[1]-self.x_range[0],self.y_range[1]-self.y_range[0])) 262 | sample['meta']['crop_lt'] = np.array((self.x_range[0],self.y_range[0])) 263 | return sample 264 | 265 | 266 | class RandomScale(object): 267 | def __init__(self, scale=(0.75, 1.25), elems_do=None, elems_undo=[]): 268 | self.scale = scale 269 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 270 | def __call__(self, sample): 271 | scale_tmp = random.uniform(self.scale[0], self.scale[1]) 272 | src_size=sample['gt'].shape[::-1] 273 | dst_size= ( int(src_size[0]*scale_tmp), int(src_size[1]*scale_tmp)) 274 | Resize(size=dst_size)(sample) 275 | return sample 276 | 277 | 278 | ########################################[ RGBD_SOD ]######################################## 279 | class Depth2RGB(object): 280 | def __init__(self): 281 | pass 282 | def __call__(self, sample): 283 | #print('->old:',sample['depth'].size()) 284 | sample['depth']=sample['depth'].repeat(3,1,1) 285 | #print('->new:',sample['depth'].size()) 286 | return sample 287 | 288 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #pytorch 2 | import torch 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | 6 | #general 7 | import os 8 | import cv2 9 | import sys 10 | import time 11 | import math 12 | import random 13 | import shutil 14 | import argparse 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | #mine 19 | import utils 20 | import my_custom_transforms as mtr 21 | from dataloader_rgbdsod import RgbdSodDataset 22 | 23 | #log_recorder 24 | class Logger(object): 25 | def __init__(self, filename='default.log', stream=sys.stdout): 26 | self.terminal = stream 27 | self.log = open(filename, 'w') 28 | def write(self, message): 29 | self.terminal.write(message) 30 | self.log.write(message) 31 | def flush(self): 32 | pass 33 | 34 | def SetLogFile(file_path='log'): 35 | sys.stdout = Logger(file_path, sys.stdout) 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--net', type=str, default='RgbNet',choices=['RgbNet','RgbdNet','DepthNet'],help='train net') 39 | args = parser.parse_args() 40 | 41 | utils.set_seed(10) 42 | 43 | p={} 44 | p['datasets_path']='./dataset/' 45 | p['train_datasets']=[p['datasets_path']+'NJU2K_TRAIN',p['datasets_path']+'NLPR_TRAIN'] 46 | p['val_datasets']=[p['datasets_path']+'NJU2K_TEST'] 47 | 48 | p['gpu_ids']=list(range(torch.cuda.device_count())) 49 | p['start_epoch']=0 50 | p['epochs']=30 51 | p['bs']=8*len(p['gpu_ids']) 52 | p['lr']=1.25e-5*(p['bs']/len(p['gpu_ids'])) 53 | p['num_workers']=4*len(p['gpu_ids']) 54 | 55 | p['optimizer']=[ 'Adam' , {} ] 56 | p['scheduler']=['Constant',{}] 57 | 58 | p['if_memory']=False 59 | p['max_num']= 0 60 | p['size']=(224, 224) 61 | p['train_only_epochs']=0 62 | p['val_interval']=1 63 | p['resume']= None 64 | p['model']=args.net 65 | 66 | p['note']='' 67 | p['if_use_tensorboard']=False 68 | p['snapshot_path']='snapshot/[{}]_[{}]'.format(time.strftime('%Y-%m-%d-%H:%M:%S',time.localtime(time.time())),p['model']) 69 | if p['note']!='': p['snapshot_path']+='_[{}]'.format(p['note']) 70 | 71 | p['if_debug']=0 72 | 73 | p['if_only_val']=0 if p['resume'] is None else 1 74 | p['if_save_checkpoint']=False 75 | 76 | if p['if_only_val']: 77 | p['snapshot_path']+='[val]' 78 | p['if_use_tensorboard']=False 79 | 80 | if p['if_debug']: 81 | if os.path.exists('snapshot/debug'):shutil.rmtree('snapshot/debug') 82 | p['snapshot_path']='snapshot/debug' 83 | p['max_num']=32 84 | 85 | exec('from model.{} import MyNet'.format(p['model'])) 86 | 87 | if p['if_use_tensorboard']: 88 | from torch.utils.tensorboard import SummaryWriter 89 | 90 | class Trainer(object): 91 | def __init__(self,p): 92 | self.p=p 93 | os.makedirs(p['snapshot_path'],exist_ok=True) 94 | shutil.copyfile(os.path.join('model',p['model']+'.py'), os.path.join(p['snapshot_path'],p['model']+'.py')) 95 | SetLogFile('{}/log.txt'.format(p['snapshot_path'])) 96 | if p['if_use_tensorboard']: 97 | self.writer = SummaryWriter(p['snapshot_path']) 98 | 99 | transform_train = torchvision.transforms.Compose([ 100 | mtr.RandomFlip(), 101 | mtr.Resize(p['size']), 102 | mtr.ToTensor(), 103 | mtr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],elems_do=['img']), 104 | 105 | ]) 106 | 107 | transform_val = torchvision.transforms.Compose([ 108 | mtr.Resize(p['size']), 109 | mtr.ToTensor(), 110 | mtr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],elems_do=['img']), 111 | ]) 112 | 113 | self.train_set = RgbdSodDataset(datasets=p['train_datasets'],transform=transform_train,max_num=p['max_num'],if_memory=p['if_memory']) 114 | self.train_loader = DataLoader(self.train_set, batch_size=p['bs'], shuffle=True, num_workers=p['num_workers'],pin_memory=True) 115 | 116 | self.val_loaders=[] 117 | for val_dataset in p['val_datasets']: 118 | val_set=RgbdSodDataset(val_dataset,transform=transform_val,max_num=p['max_num'],if_memory=p['if_memory']) 119 | self.val_loaders.append(DataLoader(val_set, batch_size=1, shuffle=False,pin_memory=True)) 120 | 121 | self.model=MyNet() 122 | 123 | self.model = self.model.cuda() 124 | 125 | self.optimizer = utils.get_optimizer(p['optimizer'][0], self.model.get_train_params(lr=p['lr']), p['optimizer'][1]) 126 | self.scheduler = utils.get_scheduler(p['scheduler'][0], self.optimizer, p['scheduler'][1]) 127 | 128 | self.best_metric=None 129 | 130 | if p['resume']!=None: 131 | print('Load checkpoint from [{}]'.format(p['resume'])) 132 | checkpoint = torch.load(p['resume']) 133 | self.p['start_epoch']=checkpoint['current_epoch']+1 134 | self.best_metric=checkpoint['best_metric'] 135 | self.model.load_state_dict(checkpoint['model']) 136 | self.optimizer.load_state_dict(checkpoint['optimizer']) 137 | self.scheduler.load_state_dict(checkpoint['scheduler']) 138 | 139 | def main(self): 140 | print('Start time : ',time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 141 | print('---[ NOTE: {} ]---'.format(self.p['note'])) 142 | print('-'*79,'\ninfos : ' , self.p, '\n'+'-'*79) 143 | 144 | if self.p['if_only_val']: 145 | result_save_path=os.path.join(p['snapshot_path'],'result') 146 | os.makedirs(result_save_path,exist_ok=True) 147 | self.validation(self.p['start_epoch']-1,result_save_path) 148 | exit() 149 | 150 | for epoch in range(self.p['start_epoch'],self.p['epochs']): 151 | lr_str = ['{:.7f}'.format(i) for i in self.scheduler.get_lr()] 152 | print('-'*79+'\n'+'Epoch [{:03d}]=> |-lr:{}-| \n'.format(epoch, lr_str)) 153 | #training 154 | if p['train_only_epochs']>=0: 155 | self.training(epoch) 156 | self.scheduler.step() 157 | 158 | if epoch mae:{:.4f} f_max:{:.4f}'.format(dataset,metric[0],metric[1])) 221 | 222 | metric_all+=metric 223 | 224 | metric_all=metric_all/len(self.val_loaders) 225 | 226 | is_best = utils.metric_better_than(metric_all, self.best_metric) 227 | self.best_metric = metric_all if is_best else self.best_metric 228 | 229 | print('Metric_Select[MAE]: {:.4f} ({:.4f})'.format(metric_all[0],self.best_metric[0])) 230 | 231 | pth_state={ 232 | 'current_epoch': epoch, 233 | 'best_metric': self.best_metric, 234 | 'model': self.model.state_dict(), 235 | 'optimizer': self.optimizer.state_dict(), 236 | 'scheduler':self.scheduler.state_dict() 237 | } 238 | 239 | if self.p['if_save_checkpoint']: 240 | torch.save(pth_state, os.path.join(self.p['snapshot_path'], 'checkpoint.pth')) 241 | if is_best: 242 | torch.save(pth_state, os.path.join(self.p['snapshot_path'], 'best.pth')) 243 | 244 | if self.p['if_use_tensorboard']: 245 | self.writer.add_scalar('Loss/test', (loss_total / (i + 1)), epoch) 246 | self.writer.add_scalar('Metric/mae', metric_all[0], epoch) 247 | self.writer.add_scalar('Metric/f_max', metric_all[1], epoch) 248 | 249 | 250 | if __name__ == "__main__": 251 | mine =Trainer(p) 252 | mine.main() 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import shutil 5 | import os 6 | ########################################[ Optimizer ]######################################## 7 | 8 | 9 | def get_optimizer(mode, train_params, kwargs): 10 | opt_default = {'SGD' : {'momentum':0.9, 'weight_decay':5e-4, 'nesterov':False}, 11 | #'Adam' : {'weight_decay':5e-4, 'betas':[0.9, 0.99]}, 12 | 'Adam' : {'weight_decay':0, 'betas':[0.9, 0.99]}, 13 | 'TBD' : {} } 14 | 15 | for k in opt_default[mode].keys(): 16 | if k not in kwargs.keys(): 17 | kwargs[k] = opt_default[mode][k] 18 | 19 | if mode=='SGD': 20 | return torch.optim.SGD(train_params,**kwargs) 21 | elif mode=='Adam': 22 | return torch.optim.Adam(train_params,**kwargs) 23 | 24 | 25 | ########################################[ Scheduler ]######################################## 26 | from torch.optim.lr_scheduler import _LRScheduler 27 | 28 | class PolyLR(_LRScheduler): 29 | def __init__(self, optimizer, epoch_max, power=0.9, last_epoch=-1, cutoff_epoch=1000000): 30 | self.epoch_max = epoch_max 31 | self.power = power 32 | self.cutoff_epoch = cutoff_epoch 33 | super(PolyLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch= thlist[i]).float() 70 | tp = (y_temp * y).sum() 71 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 72 | return prec, recall 73 | 74 | def f_measure(pred,gt): 75 | beta2 = 0.3 76 | with torch.no_grad(): 77 | pred = torch.from_numpy(pred).float().cuda() 78 | gt = torch.from_numpy(gt).float().cuda() 79 | 80 | prec, recall = eval_pr(pred, gt, 255) 81 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 82 | f_score[f_score != f_score] = 0 # for Nan 83 | return f_score 84 | 85 | def get_metric(sample_batched, result,result_save_path=None,if_recover=True): 86 | id=sample_batched['meta']['id'][0] 87 | gt=np.array(Image.open(sample_batched['meta']['gt_path'][0]).convert('L'))/255.0 88 | 89 | if if_recover: 90 | result=cv2.resize(result, gt.shape[::-1], interpolation=cv2.INTER_LINEAR) 91 | else: 92 | gt=cv2.resize(gt, result.shape[::-1], interpolation=cv2.INTER_NEAREST) 93 | 94 | result=(result*255).astype(np.uint8) 95 | 96 | if result_save_path is not None: 97 | Image.fromarray(result).save(os.path.join(result_save_path,id+'.png')) 98 | 99 | result=result.astype(np.float64)/255.0 100 | 101 | mae= np.mean(np.abs(result-gt)) 102 | f_score=f_measure(result,gt) 103 | return mae,f_score 104 | 105 | def metric_better_than(metric_a, metric_b): 106 | if metric_b is None: 107 | return True 108 | if isinstance(metric_a,list) or isinstance(metric_a,np.ndarray): 109 | metric_a,metric_b=metric_a[0],metric_b[0] 110 | return metric_a < metric_b 111 | 112 | 113 | ########################################[ Loss ]######################################## 114 | 115 | 116 | 117 | 118 | ########################################[ Random ]######################################## 119 | # 固定随机种子! 120 | def set_seed(seed): 121 | torch.manual_seed(seed) 122 | torch.cuda.manual_seed_all(seed) 123 | np.random.seed(seed) 124 | random.seed(seed) 125 | torch.backends.cudnn.deterministic = True 126 | 127 | # #保存记录点 128 | # def save_checkpoint(state, is_best, path, filename, if_save_checkpoint=False): 129 | # if if_save_checkpoint: 130 | # torch.save(state, os.path.join(path, 'checkpoint.pth')) 131 | # if is_best: 132 | # torch.save(state, os.path.join(path, 'best.pth')) 133 | --------------------------------------------------------------------------------