├── README.md
├── dataset
├── __init__.py
├── cityscapes.py
├── custom_transforms.py
└── kitti.py
├── img
├── KITTI.gif
└── usnet.png
├── loss.py
├── model
├── __init__.py
├── backbone.py
├── module.py
├── sne_model.py
└── usnet.py
├── requirements.txt
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # USNet
2 |
3 |
4 | ### Fast Road Segmentation via Uncertainty-aware Symmetric Network
5 |
6 | This repository provides the implementation of USNet [[arxiv]](https://arxiv.org/abs/2203.04537) in PyTorch.
7 |
8 | Road segmentation is significant in self-driving and mobile robot applications. USNet is proposed to achieve a trade-off between speed and accuracy in this task.
9 |
10 |
11 |
12 |
13 |
14 | Here shows the segmentation result and the uncertainty map:
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Data Preparation
22 |
23 |
24 | #### KITTI Road Dataset
25 |
26 | You may download the KITTI Road dataset from [Google Drive](https://drive.google.com/file/d/12BvNsVgSZ5cEqetNhkkbrmdUAlADphAm/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1Y0pg85lLeVQADqRCAFk92A) (Code: k5vp). Then please setup dataset according to the following directory structure:
27 | ```
28 | USNet
29 | |-- data
30 | | |-- KITTI
31 | | | |-- training
32 | | | | |-- calib
33 | | | | |-- depth_u16
34 | | | | |-- gt_image_2
35 | | | | |-- image_2
36 | | | |-- validating
37 | | | | |-- calib
38 | | | | |-- depth_u16
39 | | | | |-- gt_image_2
40 | | | | |-- image_2
41 | | | |-- testing
42 | | | | |-- calib
43 | | | | |-- depth_u16
44 | | | | |-- image_2
45 | |-- models
46 | ...
47 | ```
48 |
49 |
50 | ## Installation
51 | The code is developed using Python 3.7 with PyTorch 1.6.0. The code is tested using one NVIDIA 1080Ti GPU card.
52 | You can create a conda environment and install the required packages by running:
53 | ```
54 | $ conda create -n usnet python=3.7
55 | $ pip install -r requirements.txt
56 | ```
57 |
58 |
59 | ## Training
60 |
61 | For training USNet on KITTI Road dataset, you can run:
62 |
63 | ```
64 | $ cd $USNET_ROOT
65 | $ python train.py
66 | ```
67 | When training completed, the checkpoint will be saved to `./log/KITTI_model`.
68 |
69 |
70 | ## Testing
71 |
72 | **Note that** before testing you need to config the necessary paths or variables. Please ensure that the checkpoint exists in `checkpoint_path`.
73 |
74 | To run the test on KITTI Road dataset:
75 | ```
76 | $ python test.py
77 | ```
78 | You can download our trained model from [Google Drive](https://drive.google.com/file/d/1qB85Pt-jgnC5wf5-U2ExYBxzjvmaAZNb/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1prA2UsSr5keuCXqewKShCw) (Code: 9zgf). The BEV-results obtained from this released model can be found in [Google Drive](https://drive.google.com/file/d/1MFZwPz141Wgrhk7YW14lPbtwUrJw8LxX/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1c4hT2adWo9-8AIniMVLoXQ) (Code: csar).
79 |
80 | If you submit this result to the KITTI benchmark, you will get a MaxF score of **96.87** for **URBAN**, which is similar to the reported ones in our paper.
81 |
82 |
83 | ### Citation
84 |
85 | If you find USNet useful in your research, please consider citing:
86 | ```
87 | @inproceedings{Chang22Fast,
88 | title = {Fast Road Segmentation via Uncertainty-aware Symmetric Network},
89 | author = {Chang, Yicong and Xue, Feng and Sheng, Fei and Liang, Wenteng and Ming, Anlong},
90 | booktitle = {IEEE International Conference on Robotics and Automation (ICRA)},
91 | year = {2022}
92 | }
93 | ```
94 |
95 | ## Acknowledgement
96 | The source code of surface normal estimator in our method follows [SNE-RoadSeg](https://github.com/hlwang1124/SNE-RoadSeg), we do appreciate this great work. Besides, the code of acquiring uncertainty in our method is adapted from [TMC](https://github.com/hanmenghan/TMC).
97 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/cityscapes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from torch.utils import data
6 | from torchvision import transforms
7 | import cv2
8 | import glob
9 | import json
10 | from model.sne_model import SNE
11 | from dataset import custom_transforms as tr
12 | import matplotlib.pyplot as plt
13 |
14 |
15 | def Disp2depth(fx, baseline, disp):
16 | delta = 256
17 | disp_mask = disp > 0
18 | depth = disp.astype(np.float32)
19 | depth[disp_mask] = (depth[disp_mask] - 1) / delta
20 | disp_mask = depth > 0
21 | depth[disp_mask] = fx * baseline / depth[disp_mask]
22 | return depth
23 |
24 |
25 | def read_calib_file(filepath):
26 | with open(filepath, 'r') as f:
27 | calib_info = json.load(f)
28 | baseline = calib_info['extrinsic']['baseline']
29 | fx = calib_info['intrinsic']['fx']
30 | fy = calib_info['intrinsic']['fy']
31 | u0 = calib_info['intrinsic']['u0']
32 | v0 = calib_info['intrinsic']['v0']
33 | return baseline, fx, fy, u0, v0
34 |
35 |
36 | class Cityscapes_Dataset(data.Dataset):
37 | NUM_CLASSES = 2
38 |
39 | def __init__(self, args, root='./data/Cityscapes/', split='train'):
40 |
41 | self.root = root
42 | self.split = split
43 | self.args = args
44 | self.images = {}
45 | self.disparities = {}
46 | self.labels = {}
47 | self.calibs = {}
48 |
49 | self.image_base = os.path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit', self.split)
50 | self.disparity_base = os.path.join(self.root, 'disparity_trainvaltest/disparity', self.split)
51 | self.label_base = os.path.join(self.root, 'gtFine_trainvaltest/gtFine', self.split)
52 | self.calib_base = os.path.join(self.root, 'camera_trainvaltest/camera', self.split)
53 |
54 | self.images[split] = []
55 | self.images[split] = self.recursive_glob(rootdir=self.image_base, suffix='.png')
56 | self.images[split].sort()
57 |
58 | self.disparities[split] = []
59 | self.disparities[split] = self.recursive_glob(rootdir=self.disparity_base, suffix='.png')
60 | self.disparities[split].sort()
61 |
62 | self.labels[split] = []
63 | self.labels[split] = self.recursive_glob(rootdir=self.label_base, suffix='_labelIds.png')
64 | self.labels[split].sort()
65 |
66 | self.calibs[split] = []
67 | self.calibs[split] = self.recursive_glob(rootdir=self.calib_base, suffix='.json')
68 | self.calibs[split].sort()
69 |
70 | self.sne_model = SNE(crop_top=False)
71 |
72 | if not self.images[split]:
73 | raise Exception("No RGB images for split=[%s] found in %s" % (split, self.image_base))
74 | if not self.disparities[split]:
75 | raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparity_base))
76 |
77 | print("Found %d %s RGB images" % (len(self.images[split]), split))
78 | print("Found %d %s disparity images" % (len(self.disparities[split]), split))
79 |
80 | def __len__(self):
81 | return len(self.images[self.split])
82 |
83 | def __getitem__(self, index):
84 |
85 | img_path = self.images[self.split][index].rstrip()
86 | disp_path = self.disparities[self.split][index].rstrip()
87 | calib_path = self.calibs[self.split][index].rstrip()
88 | lbl_path = self.labels[self.split][index].rstrip()
89 |
90 | label_image = cv2.imread(lbl_path, cv2.IMREAD_GRAYSCALE)
91 | oriHeight, oriWidth = label_image.shape
92 | label = np.zeros((oriHeight, oriWidth), dtype=np.uint8)
93 | # reserve the 'road' class
94 | label[label_image == 7] = 1
95 |
96 | _img = Image.open(img_path).convert('RGB')
97 | disp_image = cv2.imread(disp_path, cv2.IMREAD_ANYDEPTH)
98 | baseline, fx, fy, u0, v0 = read_calib_file(calib_path)
99 | depth = Disp2depth(fx, baseline, disp_image)
100 | _depth = Image.fromarray(depth)
101 |
102 | _target = Image.fromarray(label)
103 |
104 | sample = {'image': _img, 'depth': _depth, 'label': _target}
105 |
106 | if self.split == 'train':
107 | sample = self.transform_tr(sample)
108 | elif self.split == 'val':
109 | sample = self.transform_val(sample)
110 | elif self.split == 'test':
111 | sample = self.transform_ts(sample)
112 | else:
113 | sample = self.transform_ts(sample)
114 |
115 | depth_image = np.array(sample['depth'])
116 |
117 | calib = np.array([[fx, 0, u0],
118 | [0, fy, v0],
119 | [0, 0, 1]])
120 | camParam = torch.tensor(calib, dtype=torch.float32)
121 | normal = self.sne_model(torch.tensor(depth_image.astype(np.float32)), camParam)
122 | normal = normal.cpu().numpy()
123 | normal = np.transpose(normal, [1, 2, 0])
124 | normal = cv2.resize(normal, (self.args.crop_width, self.args.crop_height))
125 |
126 | normal = transforms.ToTensor()(normal)
127 |
128 | sample['depth'] = normal
129 |
130 | sample['label'] = np.array(sample['label'])
131 | sample['label'] = torch.from_numpy(sample['label']).long()
132 |
133 | sample['oriHeight'] = oriHeight
134 | sample['oriWidth'] = oriWidth
135 |
136 | sample['img_path'] = img_path
137 | sample['depth_path'] = disp_path
138 | sample['calib_path'] = calib_path
139 | sample['lbl_path'] = lbl_path
140 |
141 | return sample
142 |
143 | def recursive_glob(self, rootdir='.', suffix=''):
144 | """Performs recursive glob with given suffix and rootdir
145 | :param rootdir is the root directory
146 | :param suffix is the suffix to be searched
147 | """
148 | return [os.path.join(looproot, filename)
149 | for looproot, _, filenames in os.walk(rootdir)
150 | for filename in filenames if filename.endswith(suffix)]
151 |
152 | def transform_tr(self, sample):
153 | composed_transforms = transforms.Compose([
154 | tr.RandomHorizontalFlip(),
155 | tr.RandomGaussianBlur(),
156 | tr.RandomGaussianNoise(),
157 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)),
158 | tr.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
159 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
160 | tr.ToTensor()])
161 | return composed_transforms(sample)
162 |
163 | def transform_val(self, sample):
164 | composed_transforms = transforms.Compose([
165 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)),
166 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
167 | tr.ToTensor()])
168 | return composed_transforms(sample)
169 |
170 | def transform_ts(self, sample):
171 | composed_transforms = transforms.Compose([
172 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)),
173 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
174 | tr.ToTensor()])
175 | return composed_transforms(sample)
176 |
--------------------------------------------------------------------------------
/dataset/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from torchvision import transforms
5 | from PIL import Image, ImageFilter, ImageEnhance
6 |
7 |
8 | class Normalize(object):
9 | """Normalize a tensor image with mean and standard deviation.
10 | Args:
11 | mean (tuple): means for each channel.
12 | std (tuple): standard deviations for each channel.
13 | """
14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
15 | self.mean = mean
16 | self.std = std
17 |
18 | def __call__(self, sample):
19 | img = sample['image']
20 | depth = sample['depth']
21 | mask = sample['label']
22 | img = np.array(img).astype(np.float32)
23 | depth = np.array(depth).astype(np.float32)
24 | mask = np.array(mask).astype(np.float32)
25 | img /= 255.0
26 | img -= self.mean
27 | img /= self.std
28 |
29 | return {'image': img,
30 | 'depth': depth,
31 | 'label': mask}
32 |
33 |
34 | class ToTensor(object):
35 | """Convert Image object in sample to Tensors."""
36 |
37 | def __call__(self, sample):
38 | # swap color axis because
39 | # numpy image: H x W x C
40 | # torch image: C X H X W
41 | img = sample['image']
42 | depth = sample['depth']
43 | mask = sample['label']
44 | img = np.array(img).astype(np.float32).transpose((2, 0, 1))
45 | depth = np.array(depth).astype(np.float32)
46 | mask = np.array(mask).astype(np.float32)
47 |
48 | img = torch.from_numpy(img).float()
49 | depth = torch.from_numpy(depth).float()
50 | mask = torch.from_numpy(mask).float()
51 |
52 | return {'image': img,
53 | 'depth': depth,
54 | 'label': mask}
55 |
56 |
57 | class ColorJitter(object):
58 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
59 | self.brightness = [max(1 - brightness, 0), 1 + brightness]
60 | self.contrast = [max(1 - contrast, 0), 1 + contrast]
61 | self.saturation = [max(1 - saturation, 0), 1 + saturation]
62 |
63 | def __call__(self, sample):
64 | img = sample['image']
65 | depth = sample['depth']
66 | mask = sample['label']
67 | r_brightness = random.uniform(self.brightness[0], self.brightness[1])
68 | r_contrast = random.uniform(self.contrast[0], self.contrast[1])
69 | r_saturation = random.uniform(self.saturation[0], self.saturation[1])
70 | img = ImageEnhance.Brightness(img).enhance(r_brightness)
71 | img = ImageEnhance.Contrast(img).enhance(r_contrast)
72 | img = ImageEnhance.Color(img).enhance(r_saturation)
73 | return {'image': img,
74 | 'depth': depth,
75 | 'label': mask}
76 |
77 |
78 | class RandomHorizontalFlip(object):
79 | def __call__(self, sample):
80 | img = sample['image']
81 | depth = sample['depth']
82 | mask = sample['label']
83 | if random.random() < 0.5:
84 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
85 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
86 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
87 |
88 | return {'image': img,
89 | 'depth': depth,
90 | 'label': mask}
91 |
92 |
93 | class HorizontalFlip(object):
94 | def __call__(self, sample):
95 | img = sample['image']
96 | depth = sample['depth']
97 | mask = sample['label']
98 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
99 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
100 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
101 |
102 | return {'image': img,
103 | 'depth': depth,
104 | 'label': mask}
105 |
106 |
107 | class RandomGaussianBlur(object):
108 | def __init__(self, radius=1):
109 | self.radius = radius
110 |
111 | def __call__(self, sample):
112 | img = sample['image']
113 | depth = sample['depth']
114 | mask = sample['label']
115 | if random.random() < 0.5:
116 | img = img.filter(ImageFilter.GaussianBlur(
117 | radius=self.radius*random.random()))
118 |
119 | return {'image': img,
120 | 'depth': depth,
121 | 'label': mask}
122 |
123 |
124 | class RandomGaussianNoise(object):
125 | def __init__(self, mean=0, sigma=10):
126 | self.mean = mean
127 | self.sigma = sigma
128 |
129 | def gaussianNoisy(self, im, mean=0, sigma=10):
130 | noise = np.random.normal(mean, sigma, len(im))
131 | im = im + noise
132 |
133 | im = np.clip(im, 0, 255)
134 | return im
135 |
136 | def __call__(self, sample):
137 | img = sample['image']
138 | depth = sample['depth']
139 | mask = sample['label']
140 | if random.random() < 0.5:
141 | img = np.asarray(img)
142 | img = img.astype(np.int)
143 | width, height = img.shape[:2]
144 | img_r = self.gaussianNoisy(img[:, :, 0].flatten(), self.mean, self.sigma)
145 | img_g = self.gaussianNoisy(img[:, :, 1].flatten(), self.mean, self.sigma)
146 | img_b = self.gaussianNoisy(img[:, :, 2].flatten(), self.mean, self.sigma)
147 | img[:, :, 0] = img_r.reshape([width, height])
148 | img[:, :, 1] = img_g.reshape([width, height])
149 | img[:, :, 2] = img_b.reshape([width, height])
150 | img = Image.fromarray(np.uint8(img))
151 | return {'image': img,
152 | 'depth': depth,
153 | 'label': mask}
154 |
155 |
156 | class Resize(object):
157 | """Resize rgb and label images, while keep depth image unchanged. """
158 | def __init__(self, size):
159 | self.size = size # size: (w, h)
160 |
161 | def __call__(self, sample):
162 | img = sample['image']
163 | depth = sample['depth']
164 | mask = sample['label']
165 |
166 | assert img.size == depth.size == mask.size
167 |
168 | # resize rgb and label
169 | img = img.resize(self.size, Image.BILINEAR)
170 | # depth = depth.resize(self.size, Image.BILINEAR)
171 | mask = mask.resize(self.size, Image.NEAREST)
172 |
173 | return {'image': img,
174 | 'depth': depth,
175 | 'label': mask}
176 |
--------------------------------------------------------------------------------
/dataset/kitti.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from torch.utils import data
6 | from torchvision import transforms
7 | import cv2
8 | import glob
9 | from model.sne_model import SNE
10 | from dataset import custom_transforms as tr
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | class kittiCalibInfo():
15 | """
16 | Read calibration files in the kitti dataset,
17 | we need to use the intrinsic parameter of the cam2
18 | """
19 | def __init__(self, filepath):
20 | """
21 | Args:
22 | filepath ([str]): calibration file path (AAA.txt)
23 | """
24 | self.data = self._load_calib(filepath)
25 |
26 | def get_cam_param(self):
27 | """
28 | Returns:
29 | [numpy.array]: intrinsic parameter of the cam2
30 | """
31 | return self.data['P2']
32 |
33 | def _load_calib(self, filepath):
34 | rawdata = self._read_calib_file(filepath)
35 | data = {}
36 | P0 = np.reshape(rawdata['P0'], (3,4))
37 | P1 = np.reshape(rawdata['P1'], (3,4))
38 | P2 = np.reshape(rawdata['P2'], (3,4))
39 | P3 = np.reshape(rawdata['P3'], (3,4))
40 | R0_rect = np.reshape(rawdata['R0_rect'], (3,3))
41 | Tr_velo_to_cam = np.reshape(rawdata['Tr_velo_to_cam'], (3,4))
42 |
43 | data['P0'] = P0
44 | data['P1'] = P1
45 | data['P2'] = P2
46 | data['P3'] = P3
47 | data['R0_rect'] = R0_rect
48 | data['Tr_velo_to_cam'] = Tr_velo_to_cam
49 |
50 | return data
51 |
52 | def _read_calib_file(self, filepath):
53 | """Read in a calibration file and parse into a dictionary."""
54 | data = {}
55 |
56 | with open(filepath, 'r') as f:
57 | for line in f.readlines():
58 | key, value = line.split(':', 1)
59 | # The only non-float values in these files are dates, which
60 | # we don't care about anyway
61 | try:
62 | data[key] = np.array([float(x) for x in value.split()])
63 | except ValueError:
64 | pass
65 | return data
66 |
67 |
68 | class Kitti_Dataset(data.Dataset):
69 | NUM_CLASSES = 2
70 |
71 | def __init__(self, args, root='./data/KITTI/', split='training'):
72 |
73 | self.root = root
74 | self.split = split
75 | self.args = args
76 | self.images = {}
77 | self.depths = {}
78 | self.labels = {}
79 | self.calibs = {}
80 |
81 | self.image_base = os.path.join(self.root, self.split, 'image_2')
82 | self.depth_base = os.path.join(self.root, self.split, 'depth_u16')
83 | self.label_base = os.path.join(self.root, self.split, 'gt_image_2')
84 | self.calib_base = os.path.join(self.root, self.split, 'calib')
85 |
86 | self.images[split] = []
87 | self.images[split].extend(glob.glob(os.path.join(self.image_base, '*.png')))
88 | self.images[split].sort()
89 |
90 | self.depths[split] = []
91 | self.depths[split].extend(glob.glob(os.path.join(self.depth_base, '*.png')))
92 | self.depths[split].sort()
93 |
94 | self.labels[split] = []
95 | self.labels[split].extend(glob.glob(os.path.join(self.label_base, '*.png')))
96 | self.labels[split].sort()
97 |
98 | self.calibs[split] = []
99 | self.calibs[split].extend(glob.glob(os.path.join(self.calib_base, '*.txt')))
100 | self.calibs[split].sort()
101 |
102 | self.sne_model = SNE(crop_top=True)
103 |
104 | if not self.images[split]:
105 | raise Exception("No RGB images for split=[%s] found in %s" % (split, self.image_base))
106 | if not self.depths[split]:
107 | raise Exception("No depth images for split=[%s] found in %s" % (split, self.depth_base))
108 |
109 | print("Found %d %s RGB images" % (len(self.images[split]), split))
110 | print("Found %d %s depth images" % (len(self.depths[split]), split))
111 |
112 | def __len__(self):
113 | return len(self.images[self.split])
114 |
115 | def __getitem__(self, index):
116 |
117 | img_path = self.images[self.split][index].rstrip()
118 | depth_path = self.depths[self.split][index].rstrip()
119 | calib_path = self.calibs[self.split][index].rstrip()
120 |
121 | useDir = "/".join(img_path.split('/')[:-2])
122 | name = img_path.split('/')[-1]
123 |
124 | _img = Image.open(img_path).convert('RGB')
125 | depth_image = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
126 | depth_image = depth_image.astype(np.float32)
127 | oriHeight, oriWidth = depth_image.shape
128 | _depth = Image.fromarray(depth_image)
129 |
130 | label = np.zeros((oriHeight, oriWidth), dtype=np.uint8)
131 | if not self.split == 'testing':
132 | lbl_path = os.path.join(useDir, 'gt_image_2', name[:-10] + 'road_' + name[-10:])
133 | label_image = cv2.cvtColor(cv2.imread(lbl_path), cv2.COLOR_BGR2RGB)
134 | label[label_image[:, :, 2] > 0] = 1
135 |
136 | _target = Image.fromarray(label)
137 |
138 | sample = {'image': _img, 'depth': _depth, 'label': _target}
139 |
140 | if self.split == 'training':
141 | sample = self.transform_tr(sample)
142 | elif self.split == 'validating':
143 | sample = self.transform_val(sample)
144 | elif self.split == 'testing':
145 | sample = self.transform_ts(sample)
146 | else:
147 | sample = self.transform_ts(sample)
148 |
149 | depth_image = np.array(sample['depth'])
150 |
151 | calib = kittiCalibInfo(calib_path)
152 | camParam = torch.tensor(calib.get_cam_param(), dtype=torch.float32)
153 | normal = self.sne_model(torch.tensor(depth_image.astype(np.float32) / 1000), camParam)
154 | normal = normal.cpu().numpy()
155 | normal = np.transpose(normal, [1, 2, 0])
156 | normal = cv2.resize(normal, (self.args.crop_width, self.args.crop_height))
157 |
158 | normal = transforms.ToTensor()(normal)
159 |
160 | sample['depth'] = normal
161 |
162 | sample['label'] = np.array(sample['label'])
163 | sample['label'] = torch.from_numpy(sample['label']).long()
164 |
165 | sample['oriHeight'] = oriHeight
166 | sample['oriWidth'] = oriWidth
167 |
168 | sample['img_path'] = img_path
169 | sample['depth_path'] = depth_path
170 | sample['calib_path'] = calib_path
171 |
172 | return sample
173 |
174 | def transform_tr(self, sample):
175 | composed_transforms = transforms.Compose([
176 | tr.RandomHorizontalFlip(),
177 | tr.RandomGaussianBlur(),
178 | tr.RandomGaussianNoise(),
179 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)),
180 | tr.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
181 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
182 | tr.ToTensor()])
183 | return composed_transforms(sample)
184 |
185 | def transform_val(self, sample):
186 | composed_transforms = transforms.Compose([
187 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)),
188 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
189 | tr.ToTensor()])
190 | return composed_transforms(sample)
191 |
192 | def transform_ts(self, sample):
193 | composed_transforms = transforms.Compose([
194 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)),
195 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
196 | tr.ToTensor()])
197 | return composed_transforms(sample)
198 |
--------------------------------------------------------------------------------
/img/KITTI.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/img/KITTI.gif
--------------------------------------------------------------------------------
/img/usnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/img/usnet.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | # loss function
6 | def KL(alpha, c):
7 | beta = torch.ones((1, c)).cuda()
8 | S_alpha = torch.sum(alpha, dim=1, keepdim=True)
9 | S_beta = torch.sum(beta, dim=1, keepdim=True)
10 | lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
11 | lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)
12 | dg0 = torch.digamma(S_alpha)
13 | dg1 = torch.digamma(alpha)
14 | kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
15 | return kl
16 |
17 |
18 | def ce_loss(p, alpha, c, global_step, annealing_step):
19 | S = torch.sum(alpha, dim=1, keepdim=True)
20 | E = alpha - 1
21 | label = F.one_hot(p, num_classes=c)
22 | A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
23 |
24 | annealing_coef = min(1, global_step / annealing_step)
25 |
26 | alp = E * (1 - label) + 1
27 | B = annealing_coef * KL(alp, c)
28 |
29 | return (A + B)
30 |
31 |
32 | def mse_loss(p, alpha, c, global_step, annealing_step=1):
33 | S = torch.sum(alpha, dim=1, keepdim=True)
34 | E = alpha - 1
35 | m = alpha / S
36 | label = F.one_hot(p, num_classes=c)
37 | A = torch.sum((label - m) ** 2, dim=1, keepdim=True)
38 | B = torch.sum(alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True)
39 | annealing_coef = min(1, global_step / annealing_step)
40 | alp = E * (1 - label) + 1
41 | C = annealing_coef * KL(alp, c)
42 | return (A + B) + C
43 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/model/__init__.py
--------------------------------------------------------------------------------
/model/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import models
3 |
4 |
5 | class symmetric_backbone(torch.nn.Module):
6 | def __init__(self, name, pretrained=True):
7 | super().__init__()
8 | if name == 'resnet18':
9 | features = models.resnet18(pretrained=pretrained)
10 | features_d = models.resnet18(pretrained=pretrained)
11 | elif name == 'resnet101':
12 | features = models.resnet101(pretrained=pretrained)
13 | features_d = models.resnet101(pretrained=pretrained)
14 | else:
15 | print('Error: unspported backbone \n')
16 |
17 | self.conv1 = features.conv1
18 | self.bn1 = features.bn1
19 | self.relu = features.relu
20 | self.maxpool1 = features.maxpool
21 | self.layer1 = features.layer1
22 | self.layer2 = features.layer2
23 | self.layer3 = features.layer3
24 | self.layer4 = features.layer4
25 |
26 | self.conv1_d = features_d.conv1
27 | self.bn1_d = features_d.bn1
28 | self.relu_d = features_d.relu
29 | self.maxpool1_d = features_d.maxpool
30 | self.layer1_d = features_d.layer1
31 | self.layer2_d = features_d.layer2
32 | self.layer3_d = features_d.layer3
33 | self.layer4_d = features_d.layer4
34 |
35 | def forward(self, input_rgb, input_depth):
36 | # Symmetric Network
37 | x = self.conv1(input_rgb)
38 | x = self.relu(self.bn1(x))
39 | feature0 = self.maxpool1(x)
40 |
41 | feature1 = self.layer1(feature0) # 1 / 4
42 | feature2 = self.layer2(feature1) # 1 / 8
43 | feature3 = self.layer3(feature2) # 1 / 16
44 | feature4 = self.layer4(feature3) # 1 / 32
45 |
46 | y = self.conv1_d(input_depth)
47 | y = self.relu_d(self.bn1_d(y))
48 | feature0_d = self.maxpool1_d(y)
49 |
50 | feature1_d = self.layer1_d(feature0_d) # 1 / 4
51 | feature2_d = self.layer2_d(feature1_d) # 1 / 8
52 | feature3_d = self.layer3_d(feature2_d) # 1 / 16
53 | feature4_d = self.layer4_d(feature3_d) # 1 / 32
54 |
55 | return feature1, feature2, feature3, feature4, feature1_d, feature2_d, feature3_d, feature4_d
56 |
--------------------------------------------------------------------------------
/model/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class _ASPPModule(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
8 | super(_ASPPModule, self).__init__()
9 | self.atrous_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
10 | stride=1, padding=padding, dilation=dilation, bias=False)
11 | self.bn = nn.BatchNorm2d(out_channels)
12 | self.relu = nn.ReLU()
13 |
14 | self._init_weight()
15 |
16 | def forward(self, x):
17 | x = self.atrous_conv(x)
18 | x = self.bn(x)
19 |
20 | return self.relu(x)
21 |
22 | def _init_weight(self):
23 | for m in self.modules():
24 | if isinstance(m, nn.Conv2d):
25 | torch.nn.init.kaiming_normal_(m.weight)
26 | elif isinstance(m, nn.BatchNorm2d):
27 | m.weight.data.fill_(1)
28 | m.bias.data.zero_()
29 |
30 |
31 | class ASPP(nn.Module):
32 | def __init__(self, in_channels, out_channels, dilations):
33 | super(ASPP, self).__init__()
34 | self.aspp1 = _ASPPModule(in_channels, out_channels, 1, padding=0, dilation=dilations[0])
35 | self.aspp2 = _ASPPModule(in_channels, out_channels, 3, padding=dilations[1], dilation=dilations[1])
36 | self.aspp3 = _ASPPModule(in_channels, out_channels, 3, padding=dilations[2], dilation=dilations[2])
37 | self.aspp4 = _ASPPModule(in_channels, out_channels, 3, padding=dilations[3], dilation=dilations[3])
38 |
39 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
40 | nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False),
41 | nn.BatchNorm2d(out_channels),
42 | nn.ReLU())
43 | self.conv1 = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False)
44 | self.bn1 = nn.BatchNorm2d(out_channels)
45 | self.relu = nn.ReLU()
46 | self.dropout = nn.Dropout(0.5)
47 | self._init_weight()
48 |
49 | def forward(self, x):
50 | x1 = self.aspp1(x)
51 | x2 = self.aspp2(x)
52 | x3 = self.aspp3(x)
53 | x4 = self.aspp4(x)
54 | x5 = self.global_avg_pool(x)
55 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
56 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
57 |
58 | x = self.conv1(x)
59 | x = self.bn1(x)
60 | x = self.relu(x)
61 |
62 | return self.dropout(x)
63 |
64 | def _init_weight(self):
65 | for m in self.modules():
66 | if isinstance(m, nn.Conv2d):
67 | torch.nn.init.kaiming_normal_(m.weight)
68 | elif isinstance(m, nn.BatchNorm2d):
69 | m.weight.data.fill_(1)
70 | m.bias.data.zero_()
71 |
--------------------------------------------------------------------------------
/model/sne_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SNE(nn.Module):
7 | """Our SNE takes depth and camera intrinsic parameters as input,
8 | and outputs normal estimations.
9 | """
10 | def __init__(self, crop_top=True):
11 | super(SNE, self).__init__()
12 | self.crop_top = crop_top
13 |
14 | def forward(self, depth, camParam):
15 | h,w = depth.size()
16 | v_map, u_map = torch.meshgrid(torch.arange(h), torch.arange(w))
17 | v_map = v_map.type(torch.float32)
18 | u_map = u_map.type(torch.float32)
19 |
20 | Z = depth # h, w
21 | Y = Z.mul((v_map - camParam[1,2])) / camParam[0,0] # h, w
22 | X = Z.mul((u_map - camParam[0,2])) / camParam[0,0] # h, w
23 |
24 | if self.crop_top:
25 | Z[Y <= 0] = 0
26 | Y[Y <= 0] = 0
27 |
28 | Z[torch.isnan(Z)] = 0
29 | D = torch.div(torch.ones(h, w), Z) # h, w
30 |
31 | Gx = torch.tensor([[0,0,0],[-1,0,1],[0,0,0]], dtype=torch.float32)
32 | Gy = torch.tensor([[0,-1,0],[0,0,0],[0,1,0]], dtype=torch.float32)
33 |
34 | Gu = F.conv2d(D.view(1,1,h,w), Gx.view(1,1,3,3), padding=1)
35 | Gv = F.conv2d(D.view(1,1,h,w), Gy.view(1,1,3,3), padding=1)
36 |
37 | nx_t = Gu * camParam[0,0] # 1, 1, h, w
38 | ny_t = Gv * camParam[1,1] # 1, 1, h, w
39 |
40 | phi = torch.atan(torch.div(ny_t, nx_t)) + torch.ones([1,1,h,w])*3.141592657
41 | a = torch.cos(phi)
42 | b = torch.sin(phi)
43 |
44 | diffKernelArray = torch.tensor([[-1, 0, 0, 0, 1, 0, 0, 0, 0],
45 | [ 0,-1, 0, 0, 1, 0, 0, 0, 0],
46 | [ 0, 0,-1, 0, 1, 0, 0, 0, 0],
47 | [ 0, 0, 0,-1, 1, 0, 0, 0, 0],
48 | [ 0, 0, 0, 0, 1,-1, 0, 0, 0],
49 | [ 0, 0, 0, 0, 1, 0,-1, 0, 0],
50 | [ 0, 0, 0, 0, 1, 0, 0,-1, 0],
51 | [ 0, 0, 0, 0, 1, 0, 0, 0,-1]], dtype=torch.float32)
52 |
53 | sum_nx = torch.zeros((1,1,h,w), dtype=torch.float32)
54 | sum_ny = torch.zeros((1,1,h,w), dtype=torch.float32)
55 | sum_nz = torch.zeros((1,1,h,w), dtype=torch.float32)
56 |
57 | for i in range(8):
58 | diffKernel = diffKernelArray[i].view(1,1,3,3)
59 | X_d = F.conv2d(X.view(1,1,h,w), diffKernel, padding=1)
60 | Y_d = F.conv2d(Y.view(1,1,h,w), diffKernel, padding=1)
61 | Z_d = F.conv2d(Z.view(1,1,h,w), diffKernel, padding=1)
62 |
63 | nz_i = torch.div((torch.mul(nx_t, X_d) + torch.mul(ny_t, Y_d)), Z_d)
64 | norm = torch.sqrt(torch.mul(nx_t, nx_t) + torch.mul(ny_t, ny_t) + torch.mul(nz_i, nz_i))
65 | nx_t_i = torch.div(nx_t, norm)
66 | ny_t_i = torch.div(ny_t, norm)
67 | nz_t_i = torch.div(nz_i, norm)
68 |
69 | nx_t_i[torch.isnan(nx_t_i)] = 0
70 | ny_t_i[torch.isnan(ny_t_i)] = 0
71 | nz_t_i[torch.isnan(nz_t_i)] = 0
72 |
73 | sum_nx = sum_nx + nx_t_i
74 | sum_ny = sum_ny + ny_t_i
75 | sum_nz = sum_nz + nz_t_i
76 |
77 | theta = -torch.atan(torch.div((torch.mul(sum_nx, a) + torch.mul(sum_ny, b)), sum_nz))
78 | nx = torch.mul(torch.sin(theta), torch.cos(phi))
79 | ny = torch.mul(torch.sin(theta), torch.sin(phi))
80 | nz = torch.cos(theta)
81 |
82 | nx[torch.isnan(nz)] = 0
83 | ny[torch.isnan(nz)] = 0
84 | nz[torch.isnan(nz)] = -1
85 |
86 | sign = torch.ones((1,1,h,w), dtype=torch.float32)
87 | sign[ny > 0] = -1
88 |
89 | nx = torch.mul(nx, sign).squeeze(dim=0)
90 | ny = torch.mul(ny, sign).squeeze(dim=0)
91 | nz = torch.mul(nz, sign).squeeze(dim=0)
92 |
93 | return torch.cat([nx, ny, nz], dim=0)
94 |
--------------------------------------------------------------------------------
/model/usnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from model.backbone import symmetric_backbone
5 | from model.module import ASPP
6 | from torchsummary import summary
7 |
8 |
9 | class ConvBlock(nn.Module):
10 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
11 | super().__init__()
12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
13 | bias=False)
14 | self.bn = nn.BatchNorm2d(out_channels)
15 | self.relu = nn.ReLU()
16 |
17 | def forward(self, input):
18 | x = self.conv1(input)
19 | return self.relu(self.bn(x))
20 |
21 |
22 | # Feature Compression and Adaptation block
23 | class FCA(nn.Module):
24 | def __init__(self, in_channels, out_channels):
25 | super().__init__()
26 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
27 | self.bn = nn.BatchNorm2d(out_channels)
28 | self.sigmoid = nn.Sigmoid()
29 | self.in_channels = in_channels
30 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
31 |
32 | def forward(self, input):
33 | # global average pooling
34 | x = self.avgpool(input)
35 | assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1))
36 | x = self.conv(x)
37 | x = self.sigmoid(x)
38 | x = torch.mul(input, x)
39 | return x
40 |
41 |
42 | class Up_layer(nn.Module):
43 | def __init__(self, in_channels, out_channels):
44 | super().__init__()
45 | self.conv = ConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
46 | self.fca = FCA(in_channels=out_channels, out_channels=out_channels)
47 |
48 | def forward(self, x, x_aux):
49 | x = F.interpolate(x, size=x_aux.size()[-2:], mode='bilinear')
50 | x_aux = self.fca(self.conv(x_aux))
51 | x = x + x_aux
52 | return x
53 |
54 |
55 | # Multi-scale Evidence Collection Module
56 | class MEC(nn.Module):
57 | def __init__(self, in_channels, out_channels):
58 | super().__init__()
59 | self.scale_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
60 | self.scale_2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=3, dilation=3)
61 | self.scale_3 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=6, dilation=6)
62 | self.softplus = nn.Softplus()
63 |
64 | def forward(self, x, input):
65 | e_1 = self.scale_1(x)
66 | e_2 = self.scale_2(x)
67 | e_3 = self.scale_3(x)
68 |
69 | e_1 = F.interpolate(e_1, size=input.size()[-2:], mode='bilinear')
70 | e_2 = F.interpolate(e_2, size=input.size()[-2:], mode='bilinear')
71 | e_3 = F.interpolate(e_3, size=input.size()[-2:], mode='bilinear')
72 |
73 | e_1 = self.softplus(e_1)
74 | e_2 = self.softplus(e_2)
75 | e_3 = self.softplus(e_3)
76 |
77 | e = (e_1 + e_2 + e_3) / 3
78 | return e_1, e_2, e_3, e
79 |
80 |
81 | class USNet(nn.Module):
82 | def __init__(self, num_classes, backbone_name):
83 | super().__init__()
84 | # build backbone
85 | self.backbone = symmetric_backbone(name=backbone_name)
86 |
87 | if backbone_name == 'resnet101':
88 | # ASPP
89 | dilations = [1, 6, 12, 18]
90 | self.aspp_r = ASPP(2048, 256, dilations)
91 | self.aspp_d = ASPP(2048, 256, dilations)
92 |
93 | self.conv_r = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0)
94 | self.conv_d = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0)
95 |
96 | self.up_r1 = Up_layer(in_channels=1024, out_channels=64)
97 | self.up_r2 = Up_layer(in_channels=512, out_channels=64)
98 | self.up_r3 = Up_layer(in_channels=256, out_channels=64)
99 |
100 | self.up_d1 = Up_layer(in_channels=1024, out_channels=64)
101 | self.up_d2 = Up_layer(in_channels=512, out_channels=64)
102 | self.up_d3 = Up_layer(in_channels=256, out_channels=64)
103 |
104 | self.mec_r = MEC(in_channels=64, out_channels=num_classes)
105 | self.mec_d = MEC(in_channels=64, out_channels=num_classes)
106 |
107 | elif backbone_name == 'resnet18':
108 | # ASPP
109 | dilations = [1, 6, 12, 18]
110 | self.aspp_r = ASPP(512, 256, dilations)
111 | self.aspp_d = ASPP(512, 256, dilations)
112 |
113 | self.conv_r = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0)
114 | self.conv_d = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0)
115 |
116 | self.up_r1 = Up_layer(in_channels=256, out_channels=64)
117 | self.up_r2 = Up_layer(in_channels=128, out_channels=64)
118 | self.up_r3 = Up_layer(in_channels=64, out_channels=64)
119 |
120 | self.up_d1 = Up_layer(in_channels=256, out_channels=64)
121 | self.up_d2 = Up_layer(in_channels=128, out_channels=64)
122 | self.up_d3 = Up_layer(in_channels=64, out_channels=64)
123 |
124 | self.mec_r = MEC(in_channels=64, out_channels=num_classes)
125 | self.mec_d = MEC(in_channels=64, out_channels=num_classes)
126 | else:
127 | print('Error: unspported backbone \n')
128 |
129 | self.num_classes = num_classes
130 | self.softplus = nn.Softplus()
131 | self.init_weight()
132 |
133 | def init_weight(self):
134 | for name, m in self.named_modules():
135 | if 'backbone' not in name:
136 | if isinstance(m, nn.Conv2d):
137 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
138 | elif isinstance(m, nn.BatchNorm2d):
139 | m.eps = 1e-5
140 | m.momentum = 0.1
141 | nn.init.constant_(m.weight, 1)
142 | nn.init.constant_(m.bias, 0)
143 |
144 | def DS_Combin(self, alpha):
145 | """
146 | :param alpha: All Dirichlet distribution parameters.
147 | :return: Combined Dirichlet distribution parameters.
148 | """
149 |
150 | def DS_Combin_two(alpha1, alpha2):
151 | """
152 | :param alpha1: Dirichlet distribution parameters of view 1
153 | :param alpha2: Dirichlet distribution parameters of view 2
154 | :return: Combined Dirichlet distribution parameters
155 | """
156 | alpha = dict()
157 | alpha[0], alpha[1] = alpha1, alpha2
158 | b, S, E, u = dict(), dict(), dict(), dict()
159 | for v in range(2):
160 | S[v] = torch.sum(alpha[v], dim=1, keepdim=True)
161 | E[v] = alpha[v] - 1
162 | b[v] = E[v] / (S[v].expand(E[v].shape))
163 | u[v] = self.num_classes / S[v]
164 |
165 | # b^0 @ b^(0+1)
166 | bb = torch.bmm(b[0].view(-1, self.num_classes, 1), b[1].view(-1, 1, self.num_classes))
167 | # b^0 * u^1
168 | uv1_expand = u[1].expand(b[0].shape)
169 | bu = torch.mul(b[0], uv1_expand)
170 | # b^1 * u^0
171 | uv_expand = u[0].expand(b[0].shape)
172 | ub = torch.mul(b[1], uv_expand)
173 | # calculate C
174 | bb_sum = torch.sum(bb, dim=(1, 2), out=None)
175 | bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1)
176 | # bb_diag1 = torch.diag(torch.mm(b[v], torch.transpose(b[v+1], 0, 1)))
177 | C = bb_sum - bb_diag
178 |
179 | # calculate b^a
180 | b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - C).view(-1, 1).expand(b[0].shape))
181 | # calculate u^a
182 | u_a = torch.mul(u[0], u[1]) / ((1 - C).view(-1, 1).expand(u[0].shape))
183 |
184 | # calculate new S
185 | S_a = self.num_classes / u_a
186 | # calculate new e_k
187 | e_a = torch.mul(b_a, S_a.expand(b_a.shape))
188 | alpha_a = e_a + 1
189 | return alpha_a
190 |
191 | for v in range(len(alpha) - 1):
192 | if v == 0:
193 | alpha_a = DS_Combin_two(alpha[0], alpha[1])
194 | else:
195 | alpha_a = DS_Combin_two(alpha_a, alpha[v + 1])
196 | return alpha_a
197 |
198 | def forward(self, input_rgb, input_depth):
199 | # output of backbone
200 | x_r1, x_r2, x_r3, x_r4, x_d1, x_d2, x_d3, x_d4 = self.backbone(input_rgb, input_depth)
201 |
202 | # ASPP
203 | x_r4 = self.aspp_r(x_r4)
204 | x_d4 = self.aspp_d(x_d4)
205 |
206 | # compress channel to 64
207 | x_r4 = self.conv_r(x_r4)
208 | x_d4 = self.conv_d(x_d4)
209 |
210 | # decoder
211 | x_r3 = self.up_r1(x_r4, x_r3)
212 | x_r2 = self.up_r2(x_r3, x_r2)
213 | x_r1 = self.up_r3(x_r2, x_r1)
214 |
215 | x_d3 = self.up_d1(x_d4, x_d3)
216 | x_d2 = self.up_d2(x_d3, x_d2)
217 | x_d1 = self.up_d3(x_d2, x_d1)
218 |
219 | # MEC module
220 | e_r1, e_r2, e_r3, e_r = self.mec_r(x_r1, input_rgb)
221 | e_d1, e_d2, e_d3, e_d = self.mec_d(x_d1, input_depth)
222 |
223 | # compute evidence, alpha
224 | evidence = dict()
225 | evidence[0] = e_r.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
226 | evidence[1] = e_d.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
227 |
228 | alpha = dict()
229 | for v_num in range(len(evidence)):
230 | alpha[v_num] = evidence[v_num] + 1
231 |
232 | alpha_a = self.DS_Combin(alpha)
233 | evidence_a = alpha_a - 1
234 |
235 | if self.training == True:
236 | evidence_sup = dict()
237 | evidence_sup[0] = e_r1.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
238 | evidence_sup[1] = e_d1.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
239 | evidence_sup[2] = e_r2.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
240 | evidence_sup[3] = e_d2.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
241 | evidence_sup[4] = e_r3.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
242 | evidence_sup[5] = e_d3.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2)
243 |
244 | alpha_sup = dict()
245 | for v_num in range(len(evidence_sup)):
246 | alpha_sup[v_num] = evidence_sup[v_num] + 1
247 |
248 | return evidence_sup, alpha_sup, evidence, evidence_a, alpha, alpha_a
249 |
250 | return evidence, evidence_a, alpha, alpha_a
251 |
252 |
253 | if __name__ == '__main__':
254 |
255 | model = USNet(2, 'resnet18')
256 | model = model.cuda()
257 | summary(model, [(3, 1248, 384),(3, 1248, 384)])
258 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.6.0
2 | torchvision>=0.7.0
3 | torchsummary
4 | matplotlib
5 | scikit-image
6 | scipy
7 | tqdm
8 | timm
9 | opencv-python
10 | protobuf
11 | tensorboardX
12 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import argparse
3 | import os
4 | import torch
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | from dataset.kitti import Kitti_Dataset
8 | from model.usnet import USNet
9 | from utils import fast_hist, getScores
10 |
11 |
12 | def test(args, model, dataloader):
13 | print('start test!')
14 | with torch.no_grad():
15 | model.eval()
16 | hist = np.zeros((args.num_classes, args.num_classes))
17 | for i, sample in enumerate(dataloader):
18 | image, depth, label = sample['image'], sample['depth'], sample['label']
19 | oriHeight, oriWidth = sample['oriHeight'], sample['oriWidth']
20 | oriWidth = oriWidth.cpu().numpy()[0]
21 | oriHeight = oriHeight.cpu().numpy()[0]
22 |
23 | if torch.cuda.is_available() and args.use_gpu:
24 | image = image.cuda()
25 | depth = depth.cuda()
26 | label = label.cuda()
27 |
28 | # get predict image
29 | evidence, evidence_a, alpha, alpha_a = model(image, depth)
30 |
31 | s = torch.sum(alpha_a, dim=1, keepdim=True)
32 | p = alpha_a / (s.expand(alpha_a.shape))
33 |
34 | pred = p[:,1]
35 | pred = pred.view(args.crop_height, args.crop_width)
36 | pred = pred.detach().cpu().numpy()
37 | pred = np.array(pred)
38 |
39 | # save predict image
40 | visualize = cv2.resize(pred, (oriWidth, oriHeight))
41 | visualize = np.floor(255*(visualize - visualize.min()) / (visualize.max()-visualize.min()))
42 | img_path = sample['img_path'][0]
43 | img_name = img_path.split('/')[-1]
44 | save_name = img_name.split('_')[0]+'_road_'+img_name.split('_')[1]
45 | cv2.imwrite(os.path.join(args.save_path, save_name), np.uint8(visualize))
46 |
47 | pred = np.uint8(pred > 0.5)
48 | pred = cv2.resize(pred, (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST)
49 |
50 | # get label image
51 | label = label.squeeze()
52 | label = label.cpu().numpy()
53 | label = np.array(label)
54 | label = cv2.resize(np.uint8(label), (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST)
55 |
56 | hist += fast_hist(label.flatten(), pred.flatten(), args.num_classes)
57 | F_score, pre, recall, fpr, fnr = getScores(hist)
58 | print('F_score: %.3f' % F_score)
59 | print('pre : %.3f' % pre)
60 | print('recall: %.3f' % recall)
61 | print('fpr: %.3f' % fpr)
62 | print('fnr: %.3f' % fpr)
63 | return F_score, pre, recall, fpr, fnr
64 |
65 |
66 | def main(params):
67 | # basic parameters
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('--checkpoint_path', type=str, default=None, required=True, help='The path to the pretrained weights of model')
70 | parser.add_argument('--crop_height', type=int, default=384, help='Height of cropped/resized input image to network')
71 | parser.add_argument('--crop_width', type=int, default=1248, help='Width of cropped/resized input image to network')
72 | parser.add_argument('--data', type=str, default='', help='Path of testing data')
73 | parser.add_argument('--dataset', type=str, default="KITTI", help='Dataset you are using.')
74 | parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
75 | parser.add_argument('--backbone_name', type=str, default="resnet18", help='The backbone model you are using.')
76 | parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
77 | parser.add_argument('--use_gpu', type=bool, default=True, help='Whether to user gpu for training')
78 | parser.add_argument('--num_classes', type=int, default=2, help='num of object classes (with void)')
79 | parser.add_argument('--save_path', type=str, default=None, required=True, help='Path to save predict image')
80 | args = parser.parse_args(params)
81 |
82 | dataset = Kitti_Dataset(args, root=args.data, split='validating')
83 | dataloader = DataLoader(
84 | dataset,
85 | batch_size=args.batch_size,
86 | shuffle=False,
87 | num_workers=1,
88 | )
89 |
90 | # build model
91 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
92 | model = USNet(args.num_classes, args.backbone_name)
93 | if torch.cuda.is_available() and args.use_gpu:
94 | model = torch.nn.DataParallel(model).cuda()
95 |
96 | # load trained model
97 | print('load model from %s ...' % args.checkpoint_path)
98 | model.module.load_state_dict(torch.load(args.checkpoint_path))
99 | print('Done!')
100 |
101 | # make save folder
102 | if not os.path.exists(args.save_path):
103 | os.makedirs(args.save_path)
104 |
105 | test(args, model, dataloader)
106 |
107 |
108 | if __name__ == '__main__':
109 | params = [
110 | '--checkpoint_path', './log/KITTI_model/usnet_best.pth',
111 | '--data', './data/KITTI',
112 | '--batch_size', '1',
113 | '--backbone_name', 'resnet18',
114 | '--cuda', '0',
115 | '--num_classes', '2',
116 | '--save_path', './result/kitti_test',
117 | ]
118 | main(params)
119 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import cv2
4 | import torch
5 | import tqdm
6 | import argparse
7 | import numpy as np
8 | from torch.utils.data import DataLoader
9 | from dataset.kitti import Kitti_Dataset
10 | from dataset.cityscapes import Cityscapes_Dataset
11 | from model.usnet import USNet
12 | from tensorboardX import SummaryWriter
13 | from utils import poly_lr_scheduler, fast_hist, getScores
14 | from loss import KL, ce_loss, mse_loss
15 | import time
16 |
17 |
18 | # validation
19 | def val(args, model, dataloader):
20 | print('start val!')
21 | with torch.no_grad():
22 | model.eval()
23 | hist = np.zeros((args.num_classes, args.num_classes))
24 | for i, sample in enumerate(dataloader):
25 | image, depth, label = sample['image'], sample['depth'], sample['label']
26 | oriHeight, oriWidth = sample['oriHeight'], sample['oriWidth']
27 | oriWidth = oriWidth.cpu().numpy()[0]
28 | oriHeight = oriHeight.cpu().numpy()[0]
29 |
30 | if torch.cuda.is_available() and args.use_gpu:
31 | image = image.cuda()
32 | depth = depth.cuda()
33 | label = label.cuda()
34 |
35 | # get predict
36 | evidence, evidence_a, alpha, alpha_a = model(image, depth)
37 |
38 | s = torch.sum(alpha_a, dim=1, keepdim=True)
39 | p = alpha_a / (s.expand(alpha_a.shape))
40 |
41 | pred = p[:,1]
42 | pred = pred.view(args.crop_height, args.crop_width)
43 | pred = pred.detach().cpu().numpy()
44 | pred = np.array(pred)
45 | pred = np.uint8(pred > 0.5)
46 | pred = cv2.resize(pred, (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST)
47 |
48 | # get label
49 | label = label.squeeze()
50 | label = label.cpu().numpy()
51 | label = np.array(label)
52 | label = cv2.resize(np.uint8(label), (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST)
53 |
54 | hist += fast_hist(label.flatten(), pred.flatten(), args.num_classes)
55 | F_score, pre, recall, fpr, fnr = getScores(hist)
56 | print('F_score: %.3f' % F_score)
57 | print('pre : %.3f' % pre)
58 | print('recall: %.3f' % recall)
59 | print('fpr: %.3f' % fpr)
60 | print('fnr: %.3f' % fpr)
61 | return F_score, pre, recall, fpr, fnr
62 |
63 |
64 | def train(args, model, optimizer, dataloader_train, dataloader_val):
65 | writer = SummaryWriter(comment=''.format(args.backbone_name))
66 | max_F_score = 0
67 | step = 0
68 | lambda_epochs = 50
69 | for epoch in range(args.num_epochs):
70 | lr = poly_lr_scheduler(optimizer, args.learning_rate, iter=epoch, max_iter=args.num_epochs)
71 | model.train()
72 | tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
73 | tq.set_description('epoch %d, lr %f' % (epoch, lr))
74 | loss_record = []
75 | for i, sample in enumerate(dataloader_train):
76 | image, depth, label = sample['image'], sample['depth'], sample['label']
77 | if torch.cuda.is_available() and args.use_gpu:
78 | image = image.cuda()
79 | depth = depth.cuda()
80 | label = label.cuda()
81 |
82 | # network output
83 | evidence_sup, alpha_sup, evidence, evidence_a, alpha, alpha_a = model(image, depth)
84 |
85 | # compute loss
86 | label = label.flatten()
87 | loss = 0
88 | for v_num in range(len(alpha_sup)):
89 | loss += ce_loss(label, alpha_sup[v_num], args.num_classes, epoch, lambda_epochs)
90 | for v_num in range(len(alpha)):
91 | loss += ce_loss(label, alpha[v_num], args.num_classes, epoch, lambda_epochs)
92 | loss += 2 * ce_loss(label, alpha_a, args.num_classes, epoch, lambda_epochs)
93 | loss = torch.mean(loss)
94 |
95 | tq.update(args.batch_size)
96 | tq.set_postfix(loss='%.6f' % loss)
97 | optimizer.zero_grad()
98 | loss.backward()
99 | optimizer.step()
100 | step += 1
101 | writer.add_scalar('loss_step', loss, step)
102 | loss_record.append(loss.item())
103 | tq.close()
104 | loss_train_mean = np.mean(loss_record)
105 | writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean), epoch)
106 | print('loss for train : %f' % (loss_train_mean))
107 |
108 | # save checkpoints
109 | if epoch % args.checkpoint_step == 0:
110 | if not os.path.isdir(args.save_model_path):
111 | os.mkdir(args.save_model_path)
112 | torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'usnet_latest.pth'))
113 |
114 | if epoch % args.validation_step == 0:
115 | F_score, pre, recall, fpr, fnr = val(args, model, dataloader_val)
116 | file = open(os.path.join(args.save_model_path, 'F_score.txt'), mode='a+')
117 | file.write('epoch = %d, F_score = %f\n' % (epoch, F_score))
118 | file.close()
119 | if F_score > max_F_score:
120 | max_F_score = F_score
121 | torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'usnet_best.pth'))
122 | writer.add_scalar('epoch/F_score', F_score, epoch)
123 | writer.add_scalar('epoch/pre', pre, epoch)
124 | writer.add_scalar('epoch/recall', recall, epoch)
125 | writer.add_scalar('epoch/fpr', fpr, epoch)
126 | writer.add_scalar('epoch/fnr', fnr, epoch)
127 |
128 |
129 | def main(params):
130 | # set initialization seed
131 | torch.manual_seed(0)
132 | torch.cuda.manual_seed(0)
133 | torch.cuda.manual_seed_all(0)
134 | np.random.seed(0)
135 | random.seed(0)
136 | torch.backends.cudnn.benchmark = True
137 | torch.backends.cudnn.deterministic = True
138 |
139 | # basic parameters
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument('--num_epochs', type=int, default=500, help='Number of epochs to train for')
142 | parser.add_argument('--checkpoint_step', type=int, default=1, help='How often to save checkpoints (epochs)')
143 | parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)')
144 | parser.add_argument('--data', type=str, default='', help='path of training data')
145 | parser.add_argument('--dataset', type=str, default="Kitti", help='Dataset you are using.')
146 | parser.add_argument('--crop_height', type=int, default=384, help='Height of cropped/resized input image to network')
147 | parser.add_argument('--crop_width', type=int, default=1248, help='Width of cropped/resized input image to network')
148 | parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
149 | parser.add_argument('--backbone_name', type=str, default="resnet18",
150 | help='The backbone model you are using, resnet18, resnet101.')
151 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate used for train')
152 | parser.add_argument('--num_workers', type=int, default=4, help='num of workers')
153 | parser.add_argument('--num_classes', type=int, default=32, help='num of object classes')
154 | parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
155 | parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training')
156 | parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model')
157 | parser.add_argument('--save_model_path', type=str, default=None, help='path to save model')
158 |
159 | args = parser.parse_args(params)
160 |
161 | if args.dataset == 'Kitti':
162 | train_set = Kitti_Dataset(args, root=args.data, split='training')
163 | val_set = Kitti_Dataset(args, root=args.data, split='validating')
164 | elif args.dataset == 'Cityscapes':
165 | train_set = Cityscapes_Dataset(args, root=args.data, split='train')
166 | val_set = Cityscapes_Dataset(args, root=args.data, split='val')
167 |
168 | dataloader_train = DataLoader(
169 | train_set,
170 | batch_size=args.batch_size,
171 | shuffle=True,
172 | num_workers=args.num_workers,
173 | drop_last=True
174 | )
175 | dataloader_val = DataLoader(
176 | val_set,
177 | batch_size=1,
178 | shuffle=False,
179 | num_workers=args.num_workers
180 | )
181 |
182 | # build model
183 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
184 |
185 | model = USNet(args.num_classes, args.backbone_name)
186 | if torch.cuda.is_available() and args.use_gpu:
187 | model = torch.nn.DataParallel(model).cuda()
188 |
189 | encoder_params = list(map(id, model.module.backbone.parameters()))
190 | base_params = filter(lambda p: id(p) not in encoder_params, model.parameters())
191 |
192 | optimizer = torch.optim.AdamW([{'params': base_params},
193 | {'params':model.module.backbone.parameters(), 'lr': args.learning_rate*0.1}],
194 | lr=args.learning_rate, betas=(0.9,0.999), weight_decay=0.01)
195 |
196 | # load pretrained model if exists
197 | if args.pretrained_model_path is not None:
198 | print('load model from %s ...' % args.pretrained_model_path)
199 | model.module.load_state_dict(torch.load(args.pretrained_model_path))
200 | print('Done!')
201 |
202 | # train
203 | train(args, model, optimizer, dataloader_train, dataloader_val)
204 |
205 | if __name__ == '__main__':
206 | params = [
207 | '--num_epochs', '500',
208 | '--learning_rate', '1e-3',
209 | '--data', './data/KITTI',
210 | '--dataset', 'Kitti',
211 | '--num_workers', '8',
212 | '--num_classes', '2',
213 | '--cuda', '0',
214 | '--batch_size', '2',
215 | '--save_model_path', './log/KITTI_model',
216 | '--backbone_name', 'resnet18', # only support resnet18 and resnet101
217 | '--checkpoint_step', '1',
218 | '--validation_step', '20',
219 | ]
220 | main(params)
221 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=500, power=0.9):
5 | """Polynomial decay of learning rate
6 | :param init_lr is base learning rate
7 | :param iter is a current iteration
8 | :param lr_decay_iter how frequently decay occurs, default is 1
9 | :param max_iter is number of maximum iterations
10 | :param power is a polymomial power
11 |
12 | """
13 | # if iter % lr_decay_iter or iter > max_iter:
14 | # return optimizer
15 |
16 | lr = init_lr * (1 - iter / max_iter) ** power
17 | optimizer.param_groups[0]['lr'] = lr
18 | optimizer.param_groups[1]['lr'] = lr * 0.1
19 | return lr
20 |
21 |
22 | def fast_hist(a, b, n):
23 | '''
24 | a and b are predict and mask respectively
25 | n is the number of classes
26 | '''
27 | k = (a >= 0) & (a < n)
28 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
29 |
30 |
31 | def getScores(conf_matrix):
32 | if conf_matrix.sum() == 0:
33 | return 0, 0, 0, 0, 0
34 | with np.errstate(divide='ignore',invalid='ignore'):
35 | classpre = np.diag(conf_matrix) / conf_matrix.sum(0).astype(np.float)
36 | classrecall = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float)
37 | pre = classpre[1]
38 | recall = classrecall[1]
39 | F_score = 2 * (recall * pre) / (recall + pre)
40 | fpr = conf_matrix[0, 1] / np.float(conf_matrix[0, 0] + conf_matrix[0, 1])
41 | fnr = conf_matrix[1, 0] / np.float(conf_matrix[1, 0] + conf_matrix[1, 1])
42 | return F_score, pre, recall, fpr, fnr
43 |
--------------------------------------------------------------------------------