├── README.md
├── dataloader_rgbdsod.py
├── dataset
└── SSD
│ ├── GT
│ ├── histrory10301.png
│ ├── histrory10401.png
│ └── histrory10501.png
│ ├── RGB
│ ├── histrory10301.jpg
│ ├── histrory10401.jpg
│ └── histrory10501.jpg
│ └── depth
│ ├── histrory10301.png
│ ├── histrory10401.png
│ └── histrory10501.png
├── eval.py
├── eval
└── pretrained_models
│ └── put_model_here.txt
├── figures
├── D3Net-Result.jpg
├── D3Net-TNNLS20.png
├── DDU2.png
├── RunTime.png
├── SIP.png
└── SIP2.png
├── model
├── DepthNet.py
├── RgbNet.py
├── RgbdNet.py
├── __init__.py
└── vgg_new.py
├── my_custom_transforms.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # RGB-D Salient Object Detection: Models, Data Sets, and Large-Scale Benchmarks (TNNLS2021)
2 | Rethinking RGB-D Salient Object Detection: Models, Data Sets, and Large-Scale Benchmarks, IEEE TNNLS 2021.
3 |
4 | ### :fire: NEWS :fire:
5 | - [2022/06/09] :boom: Update the related works.
6 | - [2020/08/02] : Release the training code.
7 |
8 |
9 |
10 |
11 | Figure 1: Illustration of the proposed D3Net. In the training stage (Left), the input RGB and depth images are processed with three parallel sub-networks, e.g., RgbNet, RgbdNet, and DepthNet. The three sub-networks are based on a same modified structure of Feature Pyramid Networks (FPN) (see § IV-A for details). We introduced these sub-networks to obtain three saliency maps (i.e., Srgb, Srgbd, and Sdepth) which considered both coarse and fine details of the input. In the test phase (Right), a novel depth depurator unit (DDU) (§ IV-B) is utilized for the first time in this work to explicitly discard (i.e., Srgbd) or keep (i.e., Srgbd) the saliency map introduced by the depth map. In the training/test phase, these components form a nested structure and are elaborately designed (e.g., gate connection in DDU) to automatically learn the salient object from the RGB image and Depth image jointly.
12 |
13 |
14 |
15 | ### Table of Contents
16 | - [RGB-D Salient Object Detection ](#Title)
17 | - [Table of Contents](#table-of-contents)
18 | - [Abstract](#abstract)
19 | - [Notion of Depth Depurator Unit](#Notion-of-Depth-Depurator-Unit)
20 | - [Related Works](#related-works)
21 | - [SIP dataset](#SIP-dataset)
22 | - [Train](#train)
23 | - [Evaluation](#evaluation)
24 | - [Results](#results)
25 | - [Citation](#citation)
26 |
27 | ## Abstract
28 | The use of RGB-D information for salient object detection has been explored in recent years. However, relatively few efforts have been spent in modeling salient object detection over real-world human activity scenes with RGB-D. In this work, we fill the gap by making the following contributions to RGB-D salient object detection. First, we carefully collect a new salient person (SIP) dataset, which consists of 1K high-resolution images that cover diverse real-world scenes from various viewpoints, poses, occlusion, illumination, and background. Second, we conduct a large-scale and so far the most comprehensive benchmark
29 | comparing contemporary methods, which has long been missing in the area and can serve as a baseline for future research. We systematically summarized 31 popular models, evaluated 17 state-of-the-art methods over seven datasets with totally about 91K images. Third, we propose a simple baseline architecture, called Deep Depth-Depurator Network (D3Net). It consists of a depth depurator unit and a feature learning module, performing initial low-quality depth map filtering and cross-modal feature learning respectively. These components form a nested structure and are elaborately designed to be learned jointly. D3Net exceeds the performance of any prior contenders across five metrics considered, thus serves as a strong baseline to advance the research frontier. We also demonstrate that D3Net can be used to efficiently extract salient person masks from the real scenes, enabling effective background changed book cover application with 20 fps on a single GPU. All the saliency maps, our new SIP dataset, baseline model, and evaluation tools are made publicly available at https://github.com/DengPingFan/D3NetBenchmark.
30 |
31 |
32 | ## Notion of Depth Depurator Unit
33 | The statistics of the depth maps in existing datasets (e.g., NJU2K, NLPR, RGBD135, STERE, and LFSD) suggest that — “high quality depth maps usually contain clear objects, but the elements in low-quality depth maps are cluttered (2nd row in Fig. 2)”
34 |
35 |
36 |
37 |
38 | Figure 2: The smoothed histogram (c) of high-quality (1st row), lowquality (2nd row) depth map, respectively.
39 |
40 |
41 |
42 | ## Related Works
43 | Please refer to our recent survey paper: https://github.com/taozh2017/RGBD-SODsurvey
44 |
45 | Paper with code: https://paperswithcode.com/task/rgb-d-salient-object-detection
46 |
47 | ## SIP dataset
48 |
49 |
50 |
51 | Figure 3: Representative subsets in our SIP. The images in SIP are grouped into eight subsets according to background objects (i.e., grass, car, barrier, road,
52 | sign, tree, flower, and other), different lighting conditions (i.e., low light and sunny with clear object boundary), and various number of objects (i.e., 1, 2, ≥3).
53 |
54 |
55 |
56 |
57 |
58 |
59 | Figure 4: Examples of images, depth maps and annotations (i.e., object level and instance level) in our SIP data set with different numbers of salient objects,
60 | object sizes, object positions, scene complexities, and lighting conditions. Note that the “RGB” and “Gray” images are captured by two different monocular
61 | cameras from short distances. Thus, the “Gray” images are slightly different from the grayscale images obtained from colorful (RGB) image. Our SIP data
62 | set provides a new direction, such as depth estimating from “RGB” and “Gray” images, and instance-level RGB-D SOD.
63 |
64 |
65 |
66 | RGB-D SOD Datasets:
67 | **No.** |**Dataset** | **Year** | **Pub.** |**Size** | **#Obj.** | **Types** | **Resolution** | **Download**
68 | :-: | :-: | :-: | :- | :- | :-:| :-: | :-: | :-:
69 | 1 | [**STERE**]() |2012 |CVPR | 1000 | ~One |Internet | [251-1200] * [222-900] | [Baidu: rcql](https://pan.baidu.com/s/1CzBX7dHW9UNzhMC2Z02qTw)/[Google (1.29G)](https://drive.google.com/file/d/1JYfSHsKXC3GLaxcZcZSHkluMFGX0bJza/view?usp=sharing)
70 | 2 | [**GIT**](http://www.bmva.org/bmvc/2013/Papers/paper0112/abstract0112.pdf) |2013 |BMVC | 80 | Multiple |Home environment | 640 * 480 | [Baidu](https://pan.baidu.com/s/15sG1xx93oqWZAxAaVKu4lg)/[Google (35.6M)](https://drive.google.com/open?id=13zis--Pg9--bqNCjTOJGpCThbOly8Epa)
71 | 3 | [**DES**]() |2014 |ICIMCS | 135 | One |Indoor | 640 * 480 | [Baidu: qhen](https://pan.baidu.com/s/1RRp8oV9FYMmPDU5sMXYH6g)/[Google (60.4M)](https://drive.google.com/open?id=15Th-xDeRjkcefS8eDYl-vSN967JVyjoR)
72 | 4 | [**NLPR**]() |2014 |ECCV | 1000 | Multiple |Indoor/outdoor | 640 * 480, 480 * 640 | [Baidu: n701](https://pan.baidu.com/s/1o9387dhf_J2sl-V_0NniFA)/[Google (546M)](https://drive.google.com/open?id=1CbgySAZxznbsN9uOG4pNDHwUPvQIQjCn)
73 | 5 | [**LFSD**]() |2014 |CVPR | 100 | One |Indoor/outdoor | 360 * 360 | [Baidu](https://pan.baidu.com/s/17EiZrnUc9vmx-zfVnP4iIQ)/[Google (32M)](https://drive.google.com/open?id=1cEeJpUukomdt_C4vUZlBlpc1UueuWWRU)
74 | 6 | [**NJUD**]() |2014 |ICIP | 1985 | ~One |Moive/internet/photo | [231-1213] * [274-828] | [Baidu: zjmf](https://pan.baidu.com/s/156oDr-jJij01XAtkqngF7Q)/[Google (1.54G)](https://drive.google.com/open?id=1R1O2dWr6HqpTOiDn6hZxUWTesOSJteQo)
75 | 7 | [**SSD**]() |2017 |ICCVW | 80 | Multiple |Movies | 960 * 1080 | [Baidu: e4qz](https://pan.baidu.com/s/1Yp5gSdLQlhcJclSrbr-LeA)/[Google (119M)](https://drive.google.com/open?id=1k8_TQTZbbYOpnTvc9n6jgLg4Ih4xNhCj)
76 | 8 | [**DUT-RGBD**](https://openaccess.thecvf.com/content_ICCV_2019/papers/Piao_Depth-Induced_Multi-Scale_Recurrent_Attention_Network_for_Saliency_Detection_ICCV_2019_paper.pdf) |2019 |ICCV | 1200 | Multiple |Indoor/outdoor | 400 * 600 | [Baidu: 6rt0](https://pan.baidu.com/s/1oMG7fWVAr1VUz75EcbyKVg)/[Google (100M)](https://drive.google.com/open?id=1DzkswvLo-3eYPtPoitWvFPJ8qd4EHPGv)
77 | 9 | [**SIP**]() |2020 |TNNLS | 929 | Multiple |Person in wild | 992 * 774 | [Baidu: 46w8](https://pan.baidu.com/s/1wMTDG8yhCNbioPwzq7t25w)/[Google (2.16G)](https://drive.google.com/open?id=1R91EEHzI1JwfqvQJLmyciAIWU-N8VR4A)
78 | 10 | Overall | | | | | | | [Baidu: 39un](https://pan.baidu.com/s/1DgO18k2B32lAt0naY323PA)/[Google (5.33G)](https://drive.google.com/open?id=16kgnv9NxeiPGwNNx8WoZQLl4qL0qtBZN)
79 |
80 | ## Train
81 | Put the three datasets 'NJU2K_TRAIN', 'NLPR_TRAIN','NJU2K_TEST' into the created folder "dataset".
82 |
83 | Put the vgg-pretrained model 'vgg16_feat.pth' ( [GoogleDrive](https://drive.google.com/file/d/1SXOV-DKnnqFD_b9yxJCIzdSkU7qiHh1X/view?usp=sharing) | [BaiduYun](https://pan.baidu.com/s/17qaLM3nbgR_eGehSK-SOrA) code: zsxh ) into the created folder "model".
84 | ```
85 | python train.py --net RgbNet
86 | python train.py --net RgbdNet
87 | python train.py --net DepthNet
88 | ```
89 | # Requirement
90 | - PyTorch>=0.4.1
91 | - Opencv
92 |
93 | # Pretrained models
94 | -RgbdNet,RgbNet,DepthNet pretrained models can be downloaded from ( [GoogleDrive](https://drive.google.com/drive/folders/1jbZzUbgOC0XzbBEsy-Bgf3b-pvr62aWK?usp=sharing) | [BaiduYun](https://pan.baidu.com/s/1sgi0KExOv5KOfGQgXpDdqw) code: xf1h )
95 |
96 | # Training and Testing Sets
97 | Our training dataset is:
98 |
99 | https://drive.google.com/open?id=1osdm_PRnupIkM82hFbz9u0EKJC_arlQI
100 |
101 | Our testing dataset is:
102 |
103 | https://drive.google.com/open?id=1ABYxq0mL4lPq2F0paNJ7-5T9ST6XVHl1
104 |
105 |
106 | ## Evaluation
107 | Put the three pretrained models into the created folder "eval/pretrained_model".
108 | ```
109 | python eval.py
110 | ```
111 |
112 | Toolbox (updated in 2022/06/09):
113 |
114 | [Baidu: i09j] (https://pan.baidu.com/s/1ArnPZ4OwP67NR71OWYjitg)
115 |
116 | [Google] (https://drive.google.com/file/d/1I4Z7rA3wefN7KeEQvkGA92u99uXS_aI_/view?usp=sharing)
117 |
118 |
119 |
120 |
121 | Table1. Running time comparison.
122 |
123 |
124 |
125 | ## Results
126 |
127 |
128 |
129 | Results of our model on seven benchmark datasets can be found:
130 |
131 | Baidu Pan(https://pan.baidu.com/s/13z0ZEptUfEU6hZ6yEEISuw) 提取码: r295
132 |
133 | Google Drive(https://drive.google.com/drive/folders/1T46FyPzi3XjsB18i3HnLEqkYQWXVbCnK?usp=sharing)
134 |
135 |
136 | ## Citation
137 | If you find this work or code is helpful in your research, please cite:
138 | ```
139 | @article{fan2019rethinking,
140 | title={{Rethinking RGB-D salient object detection: Models, datasets, and large-scale benchmarks}},
141 | author={Fan, Deng-Ping and Lin, Zheng and Zhang, Zhao and Zhu, Menglong and Cheng, Ming-Ming},
142 | journal={IEEE TNNLS},
143 | year={2021}
144 | }
145 | @article{zhou2021rgbd,
146 | title={RGB-D Salient Object Detection: A Survey},
147 | author={Zhou, Tao and Fan, Deng-Ping and Cheng, Ming-Ming and Shen, Jianbing and Shao, Ling},
148 | journal={CVMJ},
149 | year={2021}
150 | }
151 | ```
152 |
--------------------------------------------------------------------------------
/dataloader_rgbdsod.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | import glob
6 | import os
7 | from os.path import join
8 | class RgbdSodDataset(Dataset):
9 | def __init__(self, datasets , transform=None, max_num=0 , if_memory=False):
10 | super().__init__()
11 | if not isinstance(datasets,list) : datasets=[datasets]
12 | self.imgs_list, self.gts_list, self.depths_list = [], [], []
13 |
14 | for dataset in datasets:
15 | ids=sorted(glob.glob(os.path.join(dataset,'RGB','*.jpg')))
16 | ids=[os.path.splitext(os.path.split(id)[1])[0] for id in ids]
17 | for id in ids:
18 | self.imgs_list.append(os.path.join(dataset,'RGB',id+'.jpg'))
19 | self.gts_list.append(os.path.join(dataset,'GT',id+'.png'))
20 | self.depths_list.append(os.path.join(dataset,'depth',id+'.png'))
21 |
22 | if max_num!=0 and len(self.imgs_list)> abs(max_num):
23 | indices= random.sample(range(len(self.imgs_list)),max_num) if max_num>0 else range(abs(max_num))
24 | self.imgs_list= [self.imgs_list[i] for i in indices]
25 | self.gts_list = [self.gts_list[i] for i in indices]
26 | self.depths_list = [self.depths_list[i] for i in indices]
27 |
28 | self.transform, self.if_memory = transform, if_memory
29 |
30 | if if_memory:
31 | self.samples=[]
32 | for index in range(len(self.imgs_list)):
33 | self.samples.append(self.get_sample(index))
34 |
35 | def __len__(self):
36 | return len(self.imgs_list)
37 |
38 | def __getitem__(self, index):
39 | if self.if_memory:
40 | return self.transform(self.samples[index].copy()) if self.transform !=None else self.samples[index].copy()
41 | else:
42 | return self.transform(self.get_sample(index)) if self.transform !=None else self.get_sample(index)
43 |
44 | def get_sample(self,index):
45 | img = np.array(Image.open(self.imgs_list[index]).convert('RGB'))
46 | gt = np.array(Image.open(self.gts_list[index]).convert('L'))
47 | depth = np.array(Image.open(self.depths_list[index]).convert('L'))
48 | sample={'img':img , 'gt' : gt,'depth':depth}
49 |
50 | sample['meta'] = {'id': os.path.splitext(os.path.split(self.gts_list[index])[1])[0]}
51 | sample['meta']['source_size'] = np.array(gt.shape[::-1])
52 | sample['meta']['img_path'] = self.imgs_list[index]
53 | sample['meta']['gt_path'] = self.gts_list[index]
54 | sample['meta']['depth_path'] = self.depths_list[index]
55 | return sample
56 |
57 | if __name__=='__main__':
58 | pass
59 |
60 |
61 |
--------------------------------------------------------------------------------
/dataset/SSD/GT/histrory10301.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/GT/histrory10301.png
--------------------------------------------------------------------------------
/dataset/SSD/GT/histrory10401.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/GT/histrory10401.png
--------------------------------------------------------------------------------
/dataset/SSD/GT/histrory10501.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/GT/histrory10501.png
--------------------------------------------------------------------------------
/dataset/SSD/RGB/histrory10301.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/RGB/histrory10301.jpg
--------------------------------------------------------------------------------
/dataset/SSD/RGB/histrory10401.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/RGB/histrory10401.jpg
--------------------------------------------------------------------------------
/dataset/SSD/RGB/histrory10501.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/RGB/histrory10501.jpg
--------------------------------------------------------------------------------
/dataset/SSD/depth/histrory10301.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/depth/histrory10301.png
--------------------------------------------------------------------------------
/dataset/SSD/depth/histrory10401.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/depth/histrory10401.png
--------------------------------------------------------------------------------
/dataset/SSD/depth/histrory10501.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/dataset/SSD/depth/histrory10501.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #pytorch
2 | import torch
3 | import torchvision
4 | from torch.utils.data import DataLoader
5 |
6 | #general
7 | import os
8 | import cv2
9 | import shutil
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 | #mine
14 | import utils
15 | import my_custom_transforms as mtr
16 | from dataloader_rgbdsod import RgbdSodDataset
17 | from PIL import Image
18 | from model.RgbNet import MyNet as RgbNet
19 | from model.RgbdNet import MyNet as RgbdNet
20 | from model.DepthNet import MyNet as DepthNet
21 |
22 | size=(224, 224)
23 | datasets_path='./dataset/'
24 | test_datasets=['SSD']
25 | pretrained_models={'RgbNet':'./eval/pretrained_models/RgbNet.pth', 'RgbdNet':'eval/pretrained_models/RgbdNet.pth' , 'DepthNet':'eval/pretrained_models/DepthNet.pth' }
26 | result_path='./eval/result/'
27 | os.makedirs(result_path,exist_ok=True)
28 |
29 | for tmp in ['D3Net']:
30 | os.makedirs(os.path.join(result_path,tmp),exist_ok=True)
31 | for test_dataset in test_datasets:
32 | os.makedirs(os.path.join(result_path,tmp,test_dataset),exist_ok=True)
33 |
34 | model_rgb=RgbNet().cuda()
35 | model_rgbd=RgbdNet().cuda()
36 | model_depth=DepthNet().cuda()
37 |
38 | model_rgb.load_state_dict(torch.load(pretrained_models['RgbNet'])['model'])
39 | model_rgbd.load_state_dict(torch.load(pretrained_models['RgbdNet'])['model'])
40 | model_depth.load_state_dict(torch.load(pretrained_models['DepthNet'])['model'])
41 |
42 | model_rgb.eval()
43 | model_rgbd.eval()
44 | model_depth.eval()
45 |
46 | transform_test = torchvision.transforms.Compose([mtr.Resize(size),mtr.ToTensor(),mtr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],elems_do=['img'])])
47 |
48 | test_loaders=[]
49 | for test_dataset in test_datasets:
50 | val_set=RgbdSodDataset(datasets_path+test_dataset,transform=transform_test)
51 | test_loaders.append(DataLoader(val_set, batch_size=1, shuffle=False,pin_memory=True))
52 |
53 | for index, test_loader in enumerate(test_loaders):
54 | dataset=test_datasets[index]
55 | print('Test [{}]'.format(dataset))
56 |
57 | for i, sample_batched in enumerate(tqdm(test_loader)):
58 | input, gt = model_rgb.get_input(sample_batched),model_rgb.get_gt(sample_batched)
59 |
60 | with torch.no_grad():
61 | output_rgb = model_rgb(input)
62 | output_rgbd = model_rgbd(input)
63 | output_depth = model_depth(input)
64 |
65 | result_rgb = model_rgb.get_result(output_rgb)
66 | result_rgbd = model_rgbd.get_result(output_rgbd)
67 | result_depth = model_depth.get_result(output_depth)
68 |
69 | id=sample_batched['meta']['id'][0]
70 | gt_src=np.array(Image.open(sample_batched['meta']['gt_path'][0]).convert('L'))
71 |
72 | result_rgb=(cv2.resize(result_rgb, gt_src.shape[::-1], interpolation=cv2.INTER_LINEAR) *255).astype(np.uint8)
73 | result_rgbd=(cv2.resize(result_rgbd, gt_src.shape[::-1], interpolation=cv2.INTER_LINEAR) *255).astype(np.uint8)
74 | result_depth=(cv2.resize(result_depth, gt_src.shape[::-1], interpolation=cv2.INTER_LINEAR) *255).astype(np.uint8)
75 |
76 | ddu_mae=np.mean(np.abs(result_rgbd/255.0 - result_depth/255.0))
77 | result_d3net=result_rgbd if ddu_mae<0.15 else result_rgb
78 |
79 | Image.fromarray(result_d3net).save(os.path.join(result_path,'D3Net',dataset,id+'.png'))
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/eval/pretrained_models/put_model_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/eval/pretrained_models/put_model_here.txt
--------------------------------------------------------------------------------
/figures/D3Net-Result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/D3Net-Result.jpg
--------------------------------------------------------------------------------
/figures/D3Net-TNNLS20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/D3Net-TNNLS20.png
--------------------------------------------------------------------------------
/figures/DDU2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/DDU2.png
--------------------------------------------------------------------------------
/figures/RunTime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/RunTime.png
--------------------------------------------------------------------------------
/figures/SIP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/SIP.png
--------------------------------------------------------------------------------
/figures/SIP2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/figures/SIP2.png
--------------------------------------------------------------------------------
/model/DepthNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from model.vgg_new import VGG_backbone
6 |
7 | def init_weight(model):
8 | for m in model.modules():
9 | if isinstance(m, nn.Conv2d):
10 | torch.nn.init.kaiming_normal_(m.weight)
11 | if m.bias is not None:
12 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight)
13 | bound = 1 / math.sqrt(fan_in)
14 | torch.nn.init.uniform_(m.bias, -bound, bound)
15 | elif isinstance(m, nn.BatchNorm2d):
16 | m.weight.data.fill_(1)
17 | m.bias.data.zero_()
18 |
19 | class Decoder(nn.Module):
20 | def __init__(self,in_channel=32,side_channel=512):
21 | super(Decoder, self).__init__()
22 | self.reduce_conv=nn.Sequential(
23 | #nn.Conv2d(side_channel, in_channel, kernel_size=3, stride=1, padding=1),
24 | nn.Conv2d(side_channel, in_channel, kernel_size=1, stride=1, padding=0),
25 | nn.ReLU(inplace=True) ###
26 | )
27 | self.decoder = nn.Sequential(
28 | nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1),
31 | nn.ReLU(inplace=True) ###
32 | )
33 | init_weight(self)
34 |
35 | def forward(self, x, side):
36 | x=F.interpolate(x, size=side.size()[2:], mode='bilinear', align_corners=True)
37 | side=self.reduce_conv(side)
38 | x=torch.cat((x, side), 1)
39 | x = self.decoder(x)
40 | return x
41 |
42 |
43 | class Single_Stream(nn.Module):
44 | def __init__(self,in_channel=3):
45 | super(Single_Stream, self).__init__()
46 | self.backbone = VGG_backbone(in_channel=in_channel,pre_train_path='./model/vgg16_feat.pth')
47 | self.toplayer = nn.Sequential(
48 | nn.AvgPool2d(2, stride=2),
49 | nn.Conv2d(512, 32, kernel_size=3, stride=1, padding=1),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
52 | nn.ReLU(inplace=True) ###
53 | )
54 | channels = [64, 128, 256, 512, 512, 32]
55 | # Decoders
56 | decoders = []
57 | for idx in range(5):
58 | decoders.append(Decoder(in_channel=32,side_channel=channels[idx]))
59 | self.decoders = nn.ModuleList(decoders)
60 | init_weight(self.toplayer)
61 |
62 | def forward(self, input):
63 | l1 = self.backbone.conv1(input)
64 | l2 = self.backbone.conv2(l1)
65 | l3 = self.backbone.conv3(l2)
66 | l4 = self.backbone.conv4(l3)
67 | l5 = self.backbone.conv5(l4)
68 | l6 = self.toplayer(l5)
69 | feats=[l1, l2, l3, l4, l5, l6]
70 |
71 | x=feats[5]
72 | for idx in [4, 3, 2, 1, 0]:
73 | x=self.decoders[idx](x,feats[idx])
74 |
75 | return x
76 |
77 | class PredLayer(nn.Module):
78 | def __init__(self, in_channel=32):
79 | super(PredLayer, self).__init__()
80 | self.enlayer = nn.Sequential(
81 | nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1),
82 | nn.ReLU(inplace=True),
83 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
84 | nn.ReLU(inplace=True),
85 | )
86 | self.outlayer = nn.Sequential(
87 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
88 | nn.Sigmoid()
89 | )
90 | init_weight(self)
91 |
92 | def forward(self, x):
93 | x = self.enlayer(x)
94 | x = self.outlayer(x)
95 | return x
96 |
97 |
98 | class MyNet(nn.Module):
99 | def __init__(self):
100 | super(MyNet, self).__init__()
101 |
102 | # Main-streams
103 | self.main_stream = Single_Stream(in_channel=3)
104 |
105 | # Prediction
106 | self.pred_layer = PredLayer()
107 |
108 | def forward(self, input, if_return_feat=False):
109 | rgb, dep=input
110 | dep=dep.repeat(1,3,1,1)
111 | feat = self.main_stream(dep)
112 | result = self.pred_layer(feat)
113 |
114 | if if_return_feat:
115 | return feat
116 | else:
117 | return result
118 |
119 | def get_train_params(self, lr):
120 | train_params = [{'params': self.parameters(), 'lr': lr}]
121 | return train_params
122 |
123 | def get_input(self, sample_batched):
124 | rgb,dep = sample_batched['img'].cuda(),sample_batched['depth'].cuda()
125 | return rgb,dep
126 |
127 | def get_gt(self, sample_batched):
128 | gt = sample_batched['gt'].cuda()
129 | return gt
130 |
131 | def get_result(self, output, index=0):
132 | if isinstance(output, list):
133 | result = output[0].data.cpu().numpy()[index,0,:,:]
134 | else:
135 | result = output.data.cpu().numpy()[index,0,:,:]
136 |
137 | # if isinstance(output, list):
138 | # result = torch.sigmoid(output[0].data.cpu()).numpy()[index,0,:,:]
139 | # else:
140 | # result = torch.sigmoid(output.data.cpu()).numpy()[index,0,:,:]
141 | return result
142 |
143 | def get_loss(self, output, gt, if_mean=True):
144 | criterion = nn.BCELoss().cuda()
145 | #criterion = nn.BCEWithLogitsLoss().cuda()
146 | if isinstance(output, list):
147 | loss=0
148 | for i in range(len(output)):
149 | loss+=criterion(output[i], gt)
150 | if if_mean:loss/=len(output)
151 | else:
152 | loss = criterion(output, gt)
153 | return loss
154 |
155 |
156 | if __name__ == "__main__":
157 | pass
158 |
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/model/RgbNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from model.vgg_new import VGG_backbone
6 |
7 | def init_weight(model):
8 | for m in model.modules():
9 | if isinstance(m, nn.Conv2d):
10 | torch.nn.init.kaiming_normal_(m.weight)
11 | if m.bias is not None:
12 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight)
13 | bound = 1 / math.sqrt(fan_in)
14 | torch.nn.init.uniform_(m.bias, -bound, bound)
15 | elif isinstance(m, nn.BatchNorm2d):
16 | m.weight.data.fill_(1)
17 | m.bias.data.zero_()
18 |
19 | class Decoder(nn.Module):
20 | def __init__(self,in_channel=32,side_channel=512):
21 | super(Decoder, self).__init__()
22 | self.reduce_conv=nn.Sequential(
23 | #nn.Conv2d(side_channel, in_channel, kernel_size=3, stride=1, padding=1),
24 | nn.Conv2d(side_channel, in_channel, kernel_size=1, stride=1, padding=0),
25 | nn.ReLU(inplace=True) ###
26 | )
27 | self.decoder = nn.Sequential(
28 | nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1),
31 | nn.ReLU(inplace=True) ###
32 | )
33 | init_weight(self)
34 |
35 | def forward(self, x, side):
36 | x=F.interpolate(x, size=side.size()[2:], mode='bilinear', align_corners=True)
37 | side=self.reduce_conv(side)
38 | x=torch.cat((x, side), 1)
39 | x = self.decoder(x)
40 | return x
41 |
42 |
43 | class Single_Stream(nn.Module):
44 | def __init__(self,in_channel=3):
45 | super(Single_Stream, self).__init__()
46 | self.backbone = VGG_backbone(in_channel=in_channel,pre_train_path='./model/vgg16_feat.pth')
47 | self.toplayer = nn.Sequential(
48 | nn.AvgPool2d(2, stride=2),
49 | nn.Conv2d(512, 32, kernel_size=3, stride=1, padding=1),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
52 | nn.ReLU(inplace=True) ###
53 | )
54 | channels = [64, 128, 256, 512, 512, 32]
55 | # Decoders
56 | decoders = []
57 | for idx in range(5):
58 | decoders.append(Decoder(in_channel=32,side_channel=channels[idx]))
59 | self.decoders = nn.ModuleList(decoders)
60 | init_weight(self.toplayer)
61 |
62 | def forward(self, input):
63 | l1 = self.backbone.conv1(input)
64 | l2 = self.backbone.conv2(l1)
65 | l3 = self.backbone.conv3(l2)
66 | l4 = self.backbone.conv4(l3)
67 | l5 = self.backbone.conv5(l4)
68 | l6 = self.toplayer(l5)
69 | feats=[l1, l2, l3, l4, l5, l6]
70 |
71 | x=feats[5]
72 | for idx in [4, 3, 2, 1, 0]:
73 | x=self.decoders[idx](x,feats[idx])
74 |
75 | return x
76 |
77 | class PredLayer(nn.Module):
78 | def __init__(self, in_channel=32):
79 | super(PredLayer, self).__init__()
80 | self.enlayer = nn.Sequential(
81 | nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1),
82 | nn.ReLU(inplace=True),
83 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
84 | nn.ReLU(inplace=True),
85 | )
86 | self.outlayer = nn.Sequential(
87 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
88 | nn.Sigmoid()
89 | )
90 | init_weight(self)
91 |
92 | def forward(self, x):
93 | x = self.enlayer(x)
94 | x = self.outlayer(x)
95 | return x
96 |
97 |
98 | class MyNet(nn.Module):
99 | def __init__(self):
100 | super(MyNet, self).__init__()
101 |
102 | # Main-streams
103 | self.main_stream = Single_Stream(in_channel=3)
104 |
105 | # Prediction
106 | self.pred_layer = PredLayer()
107 |
108 | def forward(self, input, if_return_feat=False):
109 | rgb, dep=input
110 | feat = self.main_stream(rgb)
111 | result = self.pred_layer(feat)
112 |
113 | if if_return_feat:
114 | return feat
115 | else:
116 | return result
117 |
118 | def get_train_params(self, lr):
119 | train_params = [{'params': self.parameters(), 'lr': lr}]
120 | return train_params
121 |
122 | def get_input(self, sample_batched):
123 | rgb,dep = sample_batched['img'].cuda(),sample_batched['depth'].cuda()
124 | return rgb,dep
125 |
126 | def get_gt(self, sample_batched):
127 | gt = sample_batched['gt'].cuda()
128 | return gt
129 |
130 | def get_result(self, output, index=0):
131 | if isinstance(output, list):
132 | result = output[0].data.cpu().numpy()[index,0,:,:]
133 | else:
134 | result = output.data.cpu().numpy()[index,0,:,:]
135 |
136 | # if isinstance(output, list):
137 | # result = torch.sigmoid(output[0].data.cpu()).numpy()[index,0,:,:]
138 | # else:
139 | # result = torch.sigmoid(output.data.cpu()).numpy()[index,0,:,:]
140 | return result
141 |
142 | def get_loss(self, output, gt, if_mean=True):
143 | criterion = nn.BCELoss().cuda()
144 | #criterion = nn.BCEWithLogitsLoss().cuda()
145 | if isinstance(output, list):
146 | loss=0
147 | for i in range(len(output)):
148 | loss+=criterion(output[i], gt)
149 | if if_mean:loss/=len(output)
150 | else:
151 | loss = criterion(output, gt)
152 | return loss
153 |
154 |
155 | if __name__ == "__main__":
156 | pass
157 |
158 |
159 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/model/RgbdNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from model.vgg_new import VGG_backbone
6 |
7 | def init_weight(model):
8 | for m in model.modules():
9 | if isinstance(m, nn.Conv2d):
10 | torch.nn.init.kaiming_normal_(m.weight)
11 | if m.bias is not None:
12 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight)
13 | bound = 1 / math.sqrt(fan_in)
14 | torch.nn.init.uniform_(m.bias, -bound, bound)
15 | elif isinstance(m, nn.BatchNorm2d):
16 | m.weight.data.fill_(1)
17 | m.bias.data.zero_()
18 |
19 | class Decoder(nn.Module):
20 | def __init__(self,in_channel=32,side_channel=512):
21 | super(Decoder, self).__init__()
22 | self.reduce_conv=nn.Sequential(
23 | #nn.Conv2d(side_channel, in_channel, kernel_size=3, stride=1, padding=1),
24 | nn.Conv2d(side_channel, in_channel, kernel_size=1, stride=1, padding=0),
25 | nn.ReLU(inplace=True) ###
26 | )
27 | self.decoder = nn.Sequential(
28 | nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1),
31 | nn.ReLU(inplace=True) ###
32 | )
33 | init_weight(self)
34 |
35 | def forward(self, x, side):
36 | x=F.interpolate(x, size=side.size()[2:], mode='bilinear', align_corners=True)
37 | side=self.reduce_conv(side)
38 | x=torch.cat((x, side), 1)
39 | x = self.decoder(x)
40 | return x
41 |
42 |
43 | class Single_Stream(nn.Module):
44 | def __init__(self,in_channel=3):
45 | super(Single_Stream, self).__init__()
46 | self.backbone = VGG_backbone(in_channel=in_channel,pre_train_path='./model/vgg16_feat.pth')
47 | self.toplayer = nn.Sequential(
48 | nn.AvgPool2d(2, stride=2),
49 | nn.Conv2d(512, 32, kernel_size=3, stride=1, padding=1),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
52 | nn.ReLU(inplace=True) ###
53 | )
54 | channels = [64, 128, 256, 512, 512, 32]
55 | # Decoders
56 | decoders = []
57 | for idx in range(5):
58 | decoders.append(Decoder(in_channel=32,side_channel=channels[idx]))
59 | self.decoders = nn.ModuleList(decoders)
60 | init_weight(self.toplayer)
61 |
62 | def forward(self, input):
63 | l1 = self.backbone.conv1(input)
64 | l2 = self.backbone.conv2(l1)
65 | l3 = self.backbone.conv3(l2)
66 | l4 = self.backbone.conv4(l3)
67 | l5 = self.backbone.conv5(l4)
68 | l6 = self.toplayer(l5)
69 | feats=[l1, l2, l3, l4, l5, l6]
70 |
71 | x=feats[5]
72 | for idx in [4, 3, 2, 1, 0]:
73 | x=self.decoders[idx](x,feats[idx])
74 |
75 | return x
76 |
77 | class PredLayer(nn.Module):
78 | def __init__(self, in_channel=32):
79 | super(PredLayer, self).__init__()
80 | self.enlayer = nn.Sequential(
81 | nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1),
82 | nn.ReLU(inplace=True),
83 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
84 | nn.ReLU(inplace=True),
85 | )
86 | self.outlayer = nn.Sequential(
87 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
88 | nn.Sigmoid()
89 | )
90 | init_weight(self)
91 |
92 | def forward(self, x):
93 | x = self.enlayer(x)
94 | x = self.outlayer(x)
95 | return x
96 |
97 |
98 | class MyNet(nn.Module):
99 | def __init__(self):
100 | super(MyNet, self).__init__()
101 |
102 | # Main-streams
103 | self.main_stream = Single_Stream(in_channel=4)
104 |
105 | # Prediction
106 | self.pred_layer = PredLayer()
107 |
108 | def forward(self, input):
109 | rgb, dep=input
110 | x=torch.cat((rgb,dep),1)
111 | x = self.main_stream(x)
112 | x = self.pred_layer(x)
113 | return x
114 |
115 | def get_train_params(self, lr):
116 | train_params = [{'params': self.parameters(), 'lr': lr}]
117 | return train_params
118 |
119 | def get_input(self, sample_batched):
120 | rgb,dep = sample_batched['img'].cuda(),sample_batched['depth'].cuda()
121 | return rgb,dep
122 |
123 | def get_gt(self, sample_batched):
124 | gt = sample_batched['gt'].cuda()
125 | return gt
126 |
127 | def get_result(self, output, index=0):
128 | if isinstance(output, list):
129 | result = output[0].data.cpu().numpy()[index,0,:,:]
130 | else:
131 | result = output.data.cpu().numpy()[index,0,:,:]
132 |
133 | # if isinstance(output, list):
134 | # result = torch.sigmoid(output[0].data.cpu()).numpy()[index,0,:,:]
135 | # else:
136 | # result = torch.sigmoid(output.data.cpu()).numpy()[index,0,:,:]
137 | return result
138 |
139 | def get_loss(self, output, gt, if_mean=True):
140 | criterion = nn.BCELoss().cuda()
141 | #criterion = nn.BCEWithLogitsLoss().cuda()
142 | if isinstance(output, list):
143 | loss=0
144 | for i in range(len(output)):
145 | loss+=criterion(output[i], gt)
146 | if if_mean:loss/=len(output)
147 | else:
148 | loss = criterion(output, gt)
149 | return loss
150 |
151 |
152 | if __name__ == "__main__":
153 | pass
154 |
155 |
156 |
157 |
158 |
159 |
160 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DengPingFan/D3NetBenchmark/d312f944c8aed19430f8c0c628bfca3d62d9498e/model/__init__.py
--------------------------------------------------------------------------------
/model/vgg_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 |
5 | class VGG_backbone(nn.Module):
6 | # VGG16 with two branches
7 | # pooling layer at the front of block
8 | def __init__(self,in_channel=3,pre_train_path=None):
9 | super(VGG_backbone, self).__init__()
10 | self.in_channel=in_channel
11 | conv1 = nn.Sequential()
12 | conv1.add_module('conv1_1', nn.Conv2d(in_channel, 64, 3, 1, 1))
13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True))
14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))
15 | conv1.add_module('relu1_2', nn.ReLU(inplace=True))
16 |
17 | self.conv1 = conv1
18 | conv2 = nn.Sequential()
19 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2))
20 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))
21 | conv2.add_module('relu2_1', nn.ReLU())
22 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))
23 | conv2.add_module('relu2_2', nn.ReLU())
24 | self.conv2 = conv2
25 |
26 | conv3 = nn.Sequential()
27 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2))
28 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))
29 | conv3.add_module('relu3_1', nn.ReLU())
30 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))
31 | conv3.add_module('relu3_2', nn.ReLU())
32 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))
33 | conv3.add_module('relu3_3', nn.ReLU())
34 | self.conv3 = conv3
35 |
36 | conv4 = nn.Sequential()
37 | conv4.add_module('pool3', nn.MaxPool2d(2, stride=2))
38 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1))
39 | conv4.add_module('relu4_1', nn.ReLU())
40 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1))
41 | conv4.add_module('relu4_2', nn.ReLU())
42 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1))
43 | conv4.add_module('relu4_3', nn.ReLU())
44 | self.conv4 = conv4
45 |
46 | conv5 = nn.Sequential()
47 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2))
48 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1))
49 | conv5.add_module('relu5_1', nn.ReLU())
50 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1))
51 | conv5.add_module('relu5_2', nn.ReLU())
52 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1))
53 | conv5.add_module('relu5_3', nn.ReLU())
54 | self.conv5 = conv5
55 |
56 | if pre_train_path is not None and os.path.exists(pre_train_path):
57 | self._initialize_weights(torch.load(pre_train_path))
58 |
59 | def forward(self, x):
60 | x = self.conv1(x)
61 | x = self.conv2(x)
62 | x = self.conv3(x)
63 | x = self.conv4(x)
64 | x = self.conv5(x)
65 | return x
66 | def _initialize_weights(self, pre_train):
67 | keys = list(pre_train.keys())
68 |
69 | torch.nn.init.kaiming_normal_(self.conv1.conv1_1.weight)
70 | self.conv1.conv1_1.weight.data[:,:3,:,:].copy_(pre_train[keys[0]])
71 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]])
72 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]])
73 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]])
74 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]])
75 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]])
76 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]])
77 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]])
78 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]])
79 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]])
80 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]])
81 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]])
82 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]])
83 |
84 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]])
85 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]])
86 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]])
87 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]])
88 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]])
89 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]])
90 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]])
91 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]])
92 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]])
93 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]])
94 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]])
95 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]])
96 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]])
97 |
98 | if __name__ == "__main__":
99 | model=VGG_backbone(in_channel=6,pre_train_path='vgg16_feat.pth')
100 |
--------------------------------------------------------------------------------
/my_custom_transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import torch
4 | import random
5 | import numbers
6 | import numpy as np
7 | from PIL import Image
8 | ########################################[ function ]########################################
9 |
10 | def img_rotate(img, angle, center=None, if_expand=False, scale=1.0, mode=None):
11 | (h, w) = img.shape[:2]
12 | if center is None: center = (w // 2 ,h // 2)
13 | M = cv2.getRotationMatrix2D(center, angle, scale)
14 | if mode is None: mode=cv2.INTER_LINEAR if len(img.shape)==3 else cv2.INTER_NEAREST
15 | if if_expand:
16 | h_new=int(w*math.fabs(math.sin(math.radians(angle)))+h*math.fabs(math.cos(math.radians(angle))))
17 | w_new=int(h*math.fabs(math.sin(math.radians(angle)))+w*math.fabs(math.cos(math.radians(angle))))
18 | M[0,2] +=(w_new-w)/2
19 | M[1,2] +=(h_new-h)/2
20 | h, w =h_new, w_new
21 | rotated = cv2.warpAffine(img, M, (w, h),flags=mode)
22 | return rotated
23 |
24 |
25 | def img_rotate_point(img, angle, center=None, if_expand=False, scale=1.0):
26 | (h, w) = img.shape[:2]
27 | if center is None: center = (w // 2 ,h // 2)
28 | M = cv2.getRotationMatrix2D(center, angle, scale)
29 | if if_expand:
30 | h_new=int(w*math.fabs(math.sin(math.radians(angle)))+h*math.fabs(math.cos(math.radians(angle))))
31 | w_new=int(h*math.fabs(math.sin(math.radians(angle)))+w*math.fabs(math.cos(math.radians(angle))))
32 | M[0,2] +=(w_new-w)/2
33 | M[1,2] +=(h_new-h)/2
34 | h, w =h_new, w_new
35 |
36 |
37 | pts_y, pts_x= np.where(img==1)
38 | pts_xy=np.concatenate( (pts_x[:,np.newaxis], pts_y[:,np.newaxis]), axis=1 )
39 | pts_xy_new= np.rint(np.dot( np.insert(pts_xy,2,1,axis=1), M.T)).astype(np.int64)
40 |
41 | img_new=np.zeros((h,w),dtype=np.uint8)
42 | for pt in pts_xy_new:
43 | img_new[pt[1], pt[0]]=1
44 | return img_new
45 |
46 |
47 | def img_resize_point(img, size):
48 | (h, w) = img.shape
49 | if not isinstance(size, tuple): size=( int(w*size), int(h*size) )
50 | M=np.array([[size[0]/w,0,0],[0,size[1]/h,0]])
51 |
52 | pts_y, pts_x= np.where(img==1)
53 | pts_xy=np.concatenate( (pts_x[:,np.newaxis], pts_y[:,np.newaxis]), axis=1 )
54 | pts_xy_new= np.dot( np.insert(pts_xy,2,1,axis=1), M.T).astype(np.int64)
55 |
56 | img_new=np.zeros(size[::-1],dtype=np.uint8)
57 | for pt in pts_xy_new:
58 | img_new[pt[1], pt[0]]=1
59 | return img_new
60 |
61 | ########################################[ General ]########################################
62 |
63 | #Template for all same operation
64 | class Template(object):
65 | def __init__(self, elems_do=None, elems_undo=[]):
66 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
67 | def __call__(self, sample):
68 | for elem in sample.keys():
69 | if self.elems_do!= None and elem not in self.elems_do :continue
70 | if elem in self.elems_undo:continue
71 | pass
72 | return sample
73 |
74 |
75 | class Transform(object):
76 | def __init__(self, transform, if_numpy=True, elems_do=None, elems_undo=[]):
77 | self.transform, self.if_numpy = transform, if_numpy
78 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
79 | def __call__(self, sample):
80 | for elem in sample.keys():
81 | if self.elems_do!= None and elem not in self.elems_do :continue
82 | if elem in self.elems_undo:continue
83 | tmp = self.transform(Image.fromarray(sample[elem]))
84 | sample[elem] = np.array(tmp) if self.if_numpy else tmp
85 | return sample
86 |
87 |
88 | class ToPilImage(object):
89 | def __init__(self, elems_do=None, elems_undo=[]):
90 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
91 | def __call__(self, sample):
92 | for elem in sample.keys():
93 | if self.elems_do!= None and elem not in self.elems_do :continue
94 | if elem in self.elems_undo:continue
95 | sample[elem] = Image.fromarray(sample[elem])
96 | return sample
97 |
98 |
99 | class ToNumpyImage(object):
100 | def __init__(self, elems_do=None, elems_undo=[]):
101 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
102 | def __call__(self, sample):
103 | for elem in sample.keys():
104 | if self.elems_do!= None and elem not in self.elems_do :continue
105 | if elem in self.elems_undo:continue
106 | sample[elem] = np.array(sample[elem])
107 | return sample
108 |
109 |
110 | class ImageToOne(object):
111 | def __init__(self, elems_do=None, elems_undo=[]):
112 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
113 | def __call__(self, sample):
114 | for elem in sample.keys():
115 | if self.elems_do!= None and elem not in self.elems_do :continue
116 | if elem in self.elems_undo:continue
117 | sample[elem]=np.array(sample[elem])/255.0
118 | return sample
119 |
120 |
121 | class ToTensor(object):
122 | def __init__(self, if_div=True, elems_do=None, elems_undo=[]):
123 | self.if_div = if_div
124 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
125 | def __call__(self, sample):
126 | for elem in sample.keys():
127 | if self.elems_do!= None and elem not in self.elems_do :continue
128 | if elem in self.elems_undo:continue
129 | tmp = sample[elem]
130 | tmp = tmp[np.newaxis,:,:] if tmp.ndim == 2 else tmp.transpose((2, 0, 1))
131 | tmp = torch.from_numpy(tmp).float()
132 | tmp = tmp.float().div(255) if self.if_div else tmp
133 | sample[elem] = tmp
134 | return sample
135 |
136 |
137 | class Normalize(object):
138 | def __init__(self, mean, std, elems_do=None, elems_undo=[]):
139 | self.mean, self.std = mean, std
140 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
141 | def __call__(self, sample):
142 | for elem in sample.keys():
143 | if self.elems_do!= None and elem not in self.elems_do :continue
144 | if elem in self.elems_undo:continue
145 | tensor = sample[elem]
146 | #print(tensor.min(),tensor.max())
147 | for t, m, s in zip(tensor, self.mean, self.std):
148 | t.sub_(m).div_(s)
149 | #print(tensor.min(),tensor.max())
150 |
151 | return sample
152 |
153 |
154 | class Show(object):
155 | def __init__(self, elems_show=['img','gt'], elems_do=None, elems_undo=[]):
156 | self.elems_show = elems_show
157 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
158 | def __call__(self, sample):
159 | show_list=[ sample[elem] for elem in self.elems_show ]
160 | return sample
161 |
162 |
163 |
164 | class TestDebug(object):
165 | def __init__(self, elems_do=None, elems_undo=[]):
166 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
167 | def __call__(self, sample):
168 | #print(sample['depth'].min(),sample['depth'].max())
169 | return sample
170 |
171 |
172 | ########################################[ Basic Image Augmentation ]########################################
173 |
174 |
175 | class RandomFlip(object):
176 | def __init__(self, direction=Image.FLIP_LEFT_RIGHT, p=0.5, elems_do=None, elems_undo=[]):
177 | self.direction, self.p = direction, p
178 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
179 | def __call__(self, sample):
180 | if random.random() < self.p:
181 | for elem in sample.keys():
182 | if self.elems_do!= None and elem not in self.elems_do :continue
183 | if elem in self.elems_undo:continue
184 | sample[elem]= np.array(Image.fromarray(sample[elem]).transpose(self.direction))
185 | sample['meta']['flip']=1
186 | else:
187 | sample['meta']['flip']=0
188 | return sample
189 |
190 |
191 | class RandomRotation(object):
192 | def __init__(self, angle_range=30, if_expand=False, mode=None, elems_point=['pos_points_mask','neg_points_mask'], elems_do=None, elems_undo=[]):
193 | self.angle_range = (-angle_range, angle_range) if isinstance(angle_range, numbers.Number) else angle_range
194 | self.if_expand, self.mode = if_expand, mode
195 | self.elems_point = elems_point
196 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
197 |
198 | def __call__(self, sample):
199 | angle = random.uniform(self.angle_range[0], self.angle_range[1])
200 | for elem in sample.keys():
201 | if self.elems_do!= None and elem not in self.elems_do :continue
202 | if elem in self.elems_undo:continue
203 |
204 | if elem in self.elems_point:
205 | sample[elem]=img_rotate_point(sample[elem], angle, if_expand=self.if_expand)
206 | continue
207 |
208 | sample[elem]=img_rotate(sample[elem], angle, if_expand=self.if_expand, mode=self.mode)
209 | return sample
210 |
211 |
212 | class Resize(object):
213 | def __init__(self, size, mode=None, elems_point=['pos_points_mask','neg_points_mask'], elems_do=None, elems_undo=[]):
214 | self.size, self.mode = size, mode
215 | self.elems_point = elems_point
216 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
217 | def __call__(self, sample):
218 | for elem in sample.keys():
219 | if self.elems_do!= None and elem not in self.elems_do :continue
220 | if elem in self.elems_undo:continue
221 |
222 | if elem in self.elems_point:
223 | sample[elem]=img_resize_point(sample[elem],self.size)
224 | continue
225 |
226 | if self.mode is None:
227 | mode = cv2.INTER_LINEAR if len(sample[elem].shape)==3 else cv2.INTER_NEAREST
228 | sample[elem] = cv2.resize(sample[elem], self.size, interpolation=mode)
229 |
230 | return sample
231 |
232 |
233 | #扩充边界pad(上下左右)
234 | class Expand(object):
235 | def __init__(self, pad=(0,0,0,0), elems_do=None, elems_undo=[]):
236 | if isinstance(pad, int):
237 | self.pad=(pad, pad, pad, pad)
238 | elif len(pad)==2:
239 | self.pad=(pad[0],pad[0],pad[1],pad[1])
240 | elif len(pad)==4:
241 | self.pad= pad
242 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
243 | def __call__(self, sample):
244 | for elem in sample.keys():
245 | if self.elems_do!= None and elem not in self.elems_do :continue
246 | if elem in self.elems_undo:continue
247 | sample[elem]=cv2.copyMakeBorder(sample[elem],self.pad[0],self.pad[1],self.pad[2],self.pad[3],cv2.BORDER_CONSTANT)
248 | return sample
249 |
250 |
251 | class Crop(object):
252 | def __init__(self, x_range, y_range, elems_do=None, elems_undo=[]):
253 | self.x_range, self.y_range = x_range, y_range
254 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
255 | def __call__(self, sample):
256 | for elem in sample.keys():
257 | if self.elems_do!= None and elem not in self.elems_do :continue
258 | if elem in self.elems_undo:continue
259 | sample[elem]=sample[elem][self.y_range[0]:self.y_range[1], self.x_range[0]:self.x_range[1], ...]
260 |
261 | sample['meta']['crop_size'] = np.array((self.x_range[1]-self.x_range[0],self.y_range[1]-self.y_range[0]))
262 | sample['meta']['crop_lt'] = np.array((self.x_range[0],self.y_range[0]))
263 | return sample
264 |
265 |
266 | class RandomScale(object):
267 | def __init__(self, scale=(0.75, 1.25), elems_do=None, elems_undo=[]):
268 | self.scale = scale
269 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo)
270 | def __call__(self, sample):
271 | scale_tmp = random.uniform(self.scale[0], self.scale[1])
272 | src_size=sample['gt'].shape[::-1]
273 | dst_size= ( int(src_size[0]*scale_tmp), int(src_size[1]*scale_tmp))
274 | Resize(size=dst_size)(sample)
275 | return sample
276 |
277 |
278 | ########################################[ RGBD_SOD ]########################################
279 | class Depth2RGB(object):
280 | def __init__(self):
281 | pass
282 | def __call__(self, sample):
283 | #print('->old:',sample['depth'].size())
284 | sample['depth']=sample['depth'].repeat(3,1,1)
285 | #print('->new:',sample['depth'].size())
286 | return sample
287 |
288 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #pytorch
2 | import torch
3 | import torchvision
4 | from torch.utils.data import DataLoader
5 |
6 | #general
7 | import os
8 | import cv2
9 | import sys
10 | import time
11 | import math
12 | import random
13 | import shutil
14 | import argparse
15 | import numpy as np
16 | from tqdm import tqdm
17 |
18 | #mine
19 | import utils
20 | import my_custom_transforms as mtr
21 | from dataloader_rgbdsod import RgbdSodDataset
22 |
23 | #log_recorder
24 | class Logger(object):
25 | def __init__(self, filename='default.log', stream=sys.stdout):
26 | self.terminal = stream
27 | self.log = open(filename, 'w')
28 | def write(self, message):
29 | self.terminal.write(message)
30 | self.log.write(message)
31 | def flush(self):
32 | pass
33 |
34 | def SetLogFile(file_path='log'):
35 | sys.stdout = Logger(file_path, sys.stdout)
36 |
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--net', type=str, default='RgbNet',choices=['RgbNet','RgbdNet','DepthNet'],help='train net')
39 | args = parser.parse_args()
40 |
41 | utils.set_seed(10)
42 |
43 | p={}
44 | p['datasets_path']='./dataset/'
45 | p['train_datasets']=[p['datasets_path']+'NJU2K_TRAIN',p['datasets_path']+'NLPR_TRAIN']
46 | p['val_datasets']=[p['datasets_path']+'NJU2K_TEST']
47 |
48 | p['gpu_ids']=list(range(torch.cuda.device_count()))
49 | p['start_epoch']=0
50 | p['epochs']=30
51 | p['bs']=8*len(p['gpu_ids'])
52 | p['lr']=1.25e-5*(p['bs']/len(p['gpu_ids']))
53 | p['num_workers']=4*len(p['gpu_ids'])
54 |
55 | p['optimizer']=[ 'Adam' , {} ]
56 | p['scheduler']=['Constant',{}]
57 |
58 | p['if_memory']=False
59 | p['max_num']= 0
60 | p['size']=(224, 224)
61 | p['train_only_epochs']=0
62 | p['val_interval']=1
63 | p['resume']= None
64 | p['model']=args.net
65 |
66 | p['note']=''
67 | p['if_use_tensorboard']=False
68 | p['snapshot_path']='snapshot/[{}]_[{}]'.format(time.strftime('%Y-%m-%d-%H:%M:%S',time.localtime(time.time())),p['model'])
69 | if p['note']!='': p['snapshot_path']+='_[{}]'.format(p['note'])
70 |
71 | p['if_debug']=0
72 |
73 | p['if_only_val']=0 if p['resume'] is None else 1
74 | p['if_save_checkpoint']=False
75 |
76 | if p['if_only_val']:
77 | p['snapshot_path']+='[val]'
78 | p['if_use_tensorboard']=False
79 |
80 | if p['if_debug']:
81 | if os.path.exists('snapshot/debug'):shutil.rmtree('snapshot/debug')
82 | p['snapshot_path']='snapshot/debug'
83 | p['max_num']=32
84 |
85 | exec('from model.{} import MyNet'.format(p['model']))
86 |
87 | if p['if_use_tensorboard']:
88 | from torch.utils.tensorboard import SummaryWriter
89 |
90 | class Trainer(object):
91 | def __init__(self,p):
92 | self.p=p
93 | os.makedirs(p['snapshot_path'],exist_ok=True)
94 | shutil.copyfile(os.path.join('model',p['model']+'.py'), os.path.join(p['snapshot_path'],p['model']+'.py'))
95 | SetLogFile('{}/log.txt'.format(p['snapshot_path']))
96 | if p['if_use_tensorboard']:
97 | self.writer = SummaryWriter(p['snapshot_path'])
98 |
99 | transform_train = torchvision.transforms.Compose([
100 | mtr.RandomFlip(),
101 | mtr.Resize(p['size']),
102 | mtr.ToTensor(),
103 | mtr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],elems_do=['img']),
104 |
105 | ])
106 |
107 | transform_val = torchvision.transforms.Compose([
108 | mtr.Resize(p['size']),
109 | mtr.ToTensor(),
110 | mtr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],elems_do=['img']),
111 | ])
112 |
113 | self.train_set = RgbdSodDataset(datasets=p['train_datasets'],transform=transform_train,max_num=p['max_num'],if_memory=p['if_memory'])
114 | self.train_loader = DataLoader(self.train_set, batch_size=p['bs'], shuffle=True, num_workers=p['num_workers'],pin_memory=True)
115 |
116 | self.val_loaders=[]
117 | for val_dataset in p['val_datasets']:
118 | val_set=RgbdSodDataset(val_dataset,transform=transform_val,max_num=p['max_num'],if_memory=p['if_memory'])
119 | self.val_loaders.append(DataLoader(val_set, batch_size=1, shuffle=False,pin_memory=True))
120 |
121 | self.model=MyNet()
122 |
123 | self.model = self.model.cuda()
124 |
125 | self.optimizer = utils.get_optimizer(p['optimizer'][0], self.model.get_train_params(lr=p['lr']), p['optimizer'][1])
126 | self.scheduler = utils.get_scheduler(p['scheduler'][0], self.optimizer, p['scheduler'][1])
127 |
128 | self.best_metric=None
129 |
130 | if p['resume']!=None:
131 | print('Load checkpoint from [{}]'.format(p['resume']))
132 | checkpoint = torch.load(p['resume'])
133 | self.p['start_epoch']=checkpoint['current_epoch']+1
134 | self.best_metric=checkpoint['best_metric']
135 | self.model.load_state_dict(checkpoint['model'])
136 | self.optimizer.load_state_dict(checkpoint['optimizer'])
137 | self.scheduler.load_state_dict(checkpoint['scheduler'])
138 |
139 | def main(self):
140 | print('Start time : ',time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
141 | print('---[ NOTE: {} ]---'.format(self.p['note']))
142 | print('-'*79,'\ninfos : ' , self.p, '\n'+'-'*79)
143 |
144 | if self.p['if_only_val']:
145 | result_save_path=os.path.join(p['snapshot_path'],'result')
146 | os.makedirs(result_save_path,exist_ok=True)
147 | self.validation(self.p['start_epoch']-1,result_save_path)
148 | exit()
149 |
150 | for epoch in range(self.p['start_epoch'],self.p['epochs']):
151 | lr_str = ['{:.7f}'.format(i) for i in self.scheduler.get_lr()]
152 | print('-'*79+'\n'+'Epoch [{:03d}]=> |-lr:{}-| \n'.format(epoch, lr_str))
153 | #training
154 | if p['train_only_epochs']>=0:
155 | self.training(epoch)
156 | self.scheduler.step()
157 |
158 | if epoch mae:{:.4f} f_max:{:.4f}'.format(dataset,metric[0],metric[1]))
221 |
222 | metric_all+=metric
223 |
224 | metric_all=metric_all/len(self.val_loaders)
225 |
226 | is_best = utils.metric_better_than(metric_all, self.best_metric)
227 | self.best_metric = metric_all if is_best else self.best_metric
228 |
229 | print('Metric_Select[MAE]: {:.4f} ({:.4f})'.format(metric_all[0],self.best_metric[0]))
230 |
231 | pth_state={
232 | 'current_epoch': epoch,
233 | 'best_metric': self.best_metric,
234 | 'model': self.model.state_dict(),
235 | 'optimizer': self.optimizer.state_dict(),
236 | 'scheduler':self.scheduler.state_dict()
237 | }
238 |
239 | if self.p['if_save_checkpoint']:
240 | torch.save(pth_state, os.path.join(self.p['snapshot_path'], 'checkpoint.pth'))
241 | if is_best:
242 | torch.save(pth_state, os.path.join(self.p['snapshot_path'], 'best.pth'))
243 |
244 | if self.p['if_use_tensorboard']:
245 | self.writer.add_scalar('Loss/test', (loss_total / (i + 1)), epoch)
246 | self.writer.add_scalar('Metric/mae', metric_all[0], epoch)
247 | self.writer.add_scalar('Metric/f_max', metric_all[1], epoch)
248 |
249 |
250 | if __name__ == "__main__":
251 | mine =Trainer(p)
252 | mine.main()
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | import shutil
5 | import os
6 | ########################################[ Optimizer ]########################################
7 |
8 |
9 | def get_optimizer(mode, train_params, kwargs):
10 | opt_default = {'SGD' : {'momentum':0.9, 'weight_decay':5e-4, 'nesterov':False},
11 | #'Adam' : {'weight_decay':5e-4, 'betas':[0.9, 0.99]},
12 | 'Adam' : {'weight_decay':0, 'betas':[0.9, 0.99]},
13 | 'TBD' : {} }
14 |
15 | for k in opt_default[mode].keys():
16 | if k not in kwargs.keys():
17 | kwargs[k] = opt_default[mode][k]
18 |
19 | if mode=='SGD':
20 | return torch.optim.SGD(train_params,**kwargs)
21 | elif mode=='Adam':
22 | return torch.optim.Adam(train_params,**kwargs)
23 |
24 |
25 | ########################################[ Scheduler ]########################################
26 | from torch.optim.lr_scheduler import _LRScheduler
27 |
28 | class PolyLR(_LRScheduler):
29 | def __init__(self, optimizer, epoch_max, power=0.9, last_epoch=-1, cutoff_epoch=1000000):
30 | self.epoch_max = epoch_max
31 | self.power = power
32 | self.cutoff_epoch = cutoff_epoch
33 | super(PolyLR, self).__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self):
36 | if self.last_epoch= thlist[i]).float()
70 | tp = (y_temp * y).sum()
71 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
72 | return prec, recall
73 |
74 | def f_measure(pred,gt):
75 | beta2 = 0.3
76 | with torch.no_grad():
77 | pred = torch.from_numpy(pred).float().cuda()
78 | gt = torch.from_numpy(gt).float().cuda()
79 |
80 | prec, recall = eval_pr(pred, gt, 255)
81 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
82 | f_score[f_score != f_score] = 0 # for Nan
83 | return f_score
84 |
85 | def get_metric(sample_batched, result,result_save_path=None,if_recover=True):
86 | id=sample_batched['meta']['id'][0]
87 | gt=np.array(Image.open(sample_batched['meta']['gt_path'][0]).convert('L'))/255.0
88 |
89 | if if_recover:
90 | result=cv2.resize(result, gt.shape[::-1], interpolation=cv2.INTER_LINEAR)
91 | else:
92 | gt=cv2.resize(gt, result.shape[::-1], interpolation=cv2.INTER_NEAREST)
93 |
94 | result=(result*255).astype(np.uint8)
95 |
96 | if result_save_path is not None:
97 | Image.fromarray(result).save(os.path.join(result_save_path,id+'.png'))
98 |
99 | result=result.astype(np.float64)/255.0
100 |
101 | mae= np.mean(np.abs(result-gt))
102 | f_score=f_measure(result,gt)
103 | return mae,f_score
104 |
105 | def metric_better_than(metric_a, metric_b):
106 | if metric_b is None:
107 | return True
108 | if isinstance(metric_a,list) or isinstance(metric_a,np.ndarray):
109 | metric_a,metric_b=metric_a[0],metric_b[0]
110 | return metric_a < metric_b
111 |
112 |
113 | ########################################[ Loss ]########################################
114 |
115 |
116 |
117 |
118 | ########################################[ Random ]########################################
119 | # 固定随机种子!
120 | def set_seed(seed):
121 | torch.manual_seed(seed)
122 | torch.cuda.manual_seed_all(seed)
123 | np.random.seed(seed)
124 | random.seed(seed)
125 | torch.backends.cudnn.deterministic = True
126 |
127 | # #保存记录点
128 | # def save_checkpoint(state, is_best, path, filename, if_save_checkpoint=False):
129 | # if if_save_checkpoint:
130 | # torch.save(state, os.path.join(path, 'checkpoint.pth'))
131 | # if is_best:
132 | # torch.save(state, os.path.join(path, 'best.pth'))
133 |
--------------------------------------------------------------------------------