├── README.md ├── dataset ├── __init__.py ├── cityscapes.py ├── custom_transforms.py └── kitti.py ├── img ├── KITTI.gif └── usnet.png ├── loss.py ├── model ├── __init__.py ├── backbone.py ├── module.py ├── sne_model.py └── usnet.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # USNet 2 | 3 | 4 | ### Fast Road Segmentation via Uncertainty-aware Symmetric Network 5 | 6 | This repository provides the implementation of USNet [[arxiv]](https://arxiv.org/abs/2203.04537) in PyTorch. 7 | 8 | Road segmentation is significant in self-driving and mobile robot applications. USNet is proposed to achieve a trade-off between speed and accuracy in this task. 9 | 10 |

11 | 12 |

13 | 14 | Here shows the segmentation result and the uncertainty map: 15 | 16 |

17 | 18 |

19 | 20 | 21 | ## Data Preparation 22 | 23 | 24 | #### KITTI Road Dataset 25 | 26 | You may download the KITTI Road dataset from [Google Drive](https://drive.google.com/file/d/12BvNsVgSZ5cEqetNhkkbrmdUAlADphAm/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1Y0pg85lLeVQADqRCAFk92A) (Code: k5vp). Then please setup dataset according to the following directory structure: 27 | ``` 28 | USNet 29 | |-- data 30 | | |-- KITTI 31 | | | |-- training 32 | | | | |-- calib 33 | | | | |-- depth_u16 34 | | | | |-- gt_image_2 35 | | | | |-- image_2 36 | | | |-- validating 37 | | | | |-- calib 38 | | | | |-- depth_u16 39 | | | | |-- gt_image_2 40 | | | | |-- image_2 41 | | | |-- testing 42 | | | | |-- calib 43 | | | | |-- depth_u16 44 | | | | |-- image_2 45 | |-- models 46 | ... 47 | ``` 48 | 49 | 50 | ## Installation 51 | The code is developed using Python 3.7 with PyTorch 1.6.0. The code is tested using one NVIDIA 1080Ti GPU card. 52 | You can create a conda environment and install the required packages by running: 53 | ``` 54 | $ conda create -n usnet python=3.7 55 | $ pip install -r requirements.txt 56 | ``` 57 | 58 | 59 | ## Training 60 | 61 | For training USNet on KITTI Road dataset, you can run: 62 | 63 | ``` 64 | $ cd $USNET_ROOT 65 | $ python train.py 66 | ``` 67 | When training completed, the checkpoint will be saved to `./log/KITTI_model`. 68 | 69 | 70 | ## Testing 71 | 72 | **Note that** before testing you need to config the necessary paths or variables. Please ensure that the checkpoint exists in `checkpoint_path`. 73 | 74 | To run the test on KITTI Road dataset: 75 | ``` 76 | $ python test.py 77 | ``` 78 | You can download our trained model from [Google Drive](https://drive.google.com/file/d/1qB85Pt-jgnC5wf5-U2ExYBxzjvmaAZNb/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1prA2UsSr5keuCXqewKShCw) (Code: 9zgf). The BEV-results obtained from this released model can be found in [Google Drive](https://drive.google.com/file/d/1MFZwPz141Wgrhk7YW14lPbtwUrJw8LxX/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1c4hT2adWo9-8AIniMVLoXQ) (Code: csar). 79 | 80 | If you submit this result to the KITTI benchmark, you will get a MaxF score of **96.87** for **URBAN**, which is similar to the reported ones in our paper. 81 | 82 | 83 | ### Citation 84 | 85 | If you find USNet useful in your research, please consider citing: 86 | ``` 87 | @inproceedings{Chang22Fast, 88 | title = {Fast Road Segmentation via Uncertainty-aware Symmetric Network}, 89 | author = {Chang, Yicong and Xue, Feng and Sheng, Fei and Liang, Wenteng and Ming, Anlong}, 90 | booktitle = {IEEE International Conference on Robotics and Automation (ICRA)}, 91 | year = {2022} 92 | } 93 | ``` 94 | 95 | ## Acknowledgement 96 | The source code of surface normal estimator in our method follows [SNE-RoadSeg](https://github.com/hlwang1124/SNE-RoadSeg), we do appreciate this great work. Besides, the code of acquiring uncertainty in our method is adapted from [TMC](https://github.com/hanmenghan/TMC). 97 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | import cv2 8 | import glob 9 | import json 10 | from model.sne_model import SNE 11 | from dataset import custom_transforms as tr 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | def Disp2depth(fx, baseline, disp): 16 | delta = 256 17 | disp_mask = disp > 0 18 | depth = disp.astype(np.float32) 19 | depth[disp_mask] = (depth[disp_mask] - 1) / delta 20 | disp_mask = depth > 0 21 | depth[disp_mask] = fx * baseline / depth[disp_mask] 22 | return depth 23 | 24 | 25 | def read_calib_file(filepath): 26 | with open(filepath, 'r') as f: 27 | calib_info = json.load(f) 28 | baseline = calib_info['extrinsic']['baseline'] 29 | fx = calib_info['intrinsic']['fx'] 30 | fy = calib_info['intrinsic']['fy'] 31 | u0 = calib_info['intrinsic']['u0'] 32 | v0 = calib_info['intrinsic']['v0'] 33 | return baseline, fx, fy, u0, v0 34 | 35 | 36 | class Cityscapes_Dataset(data.Dataset): 37 | NUM_CLASSES = 2 38 | 39 | def __init__(self, args, root='./data/Cityscapes/', split='train'): 40 | 41 | self.root = root 42 | self.split = split 43 | self.args = args 44 | self.images = {} 45 | self.disparities = {} 46 | self.labels = {} 47 | self.calibs = {} 48 | 49 | self.image_base = os.path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit', self.split) 50 | self.disparity_base = os.path.join(self.root, 'disparity_trainvaltest/disparity', self.split) 51 | self.label_base = os.path.join(self.root, 'gtFine_trainvaltest/gtFine', self.split) 52 | self.calib_base = os.path.join(self.root, 'camera_trainvaltest/camera', self.split) 53 | 54 | self.images[split] = [] 55 | self.images[split] = self.recursive_glob(rootdir=self.image_base, suffix='.png') 56 | self.images[split].sort() 57 | 58 | self.disparities[split] = [] 59 | self.disparities[split] = self.recursive_glob(rootdir=self.disparity_base, suffix='.png') 60 | self.disparities[split].sort() 61 | 62 | self.labels[split] = [] 63 | self.labels[split] = self.recursive_glob(rootdir=self.label_base, suffix='_labelIds.png') 64 | self.labels[split].sort() 65 | 66 | self.calibs[split] = [] 67 | self.calibs[split] = self.recursive_glob(rootdir=self.calib_base, suffix='.json') 68 | self.calibs[split].sort() 69 | 70 | self.sne_model = SNE(crop_top=False) 71 | 72 | if not self.images[split]: 73 | raise Exception("No RGB images for split=[%s] found in %s" % (split, self.image_base)) 74 | if not self.disparities[split]: 75 | raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparity_base)) 76 | 77 | print("Found %d %s RGB images" % (len(self.images[split]), split)) 78 | print("Found %d %s disparity images" % (len(self.disparities[split]), split)) 79 | 80 | def __len__(self): 81 | return len(self.images[self.split]) 82 | 83 | def __getitem__(self, index): 84 | 85 | img_path = self.images[self.split][index].rstrip() 86 | disp_path = self.disparities[self.split][index].rstrip() 87 | calib_path = self.calibs[self.split][index].rstrip() 88 | lbl_path = self.labels[self.split][index].rstrip() 89 | 90 | label_image = cv2.imread(lbl_path, cv2.IMREAD_GRAYSCALE) 91 | oriHeight, oriWidth = label_image.shape 92 | label = np.zeros((oriHeight, oriWidth), dtype=np.uint8) 93 | # reserve the 'road' class 94 | label[label_image == 7] = 1 95 | 96 | _img = Image.open(img_path).convert('RGB') 97 | disp_image = cv2.imread(disp_path, cv2.IMREAD_ANYDEPTH) 98 | baseline, fx, fy, u0, v0 = read_calib_file(calib_path) 99 | depth = Disp2depth(fx, baseline, disp_image) 100 | _depth = Image.fromarray(depth) 101 | 102 | _target = Image.fromarray(label) 103 | 104 | sample = {'image': _img, 'depth': _depth, 'label': _target} 105 | 106 | if self.split == 'train': 107 | sample = self.transform_tr(sample) 108 | elif self.split == 'val': 109 | sample = self.transform_val(sample) 110 | elif self.split == 'test': 111 | sample = self.transform_ts(sample) 112 | else: 113 | sample = self.transform_ts(sample) 114 | 115 | depth_image = np.array(sample['depth']) 116 | 117 | calib = np.array([[fx, 0, u0], 118 | [0, fy, v0], 119 | [0, 0, 1]]) 120 | camParam = torch.tensor(calib, dtype=torch.float32) 121 | normal = self.sne_model(torch.tensor(depth_image.astype(np.float32)), camParam) 122 | normal = normal.cpu().numpy() 123 | normal = np.transpose(normal, [1, 2, 0]) 124 | normal = cv2.resize(normal, (self.args.crop_width, self.args.crop_height)) 125 | 126 | normal = transforms.ToTensor()(normal) 127 | 128 | sample['depth'] = normal 129 | 130 | sample['label'] = np.array(sample['label']) 131 | sample['label'] = torch.from_numpy(sample['label']).long() 132 | 133 | sample['oriHeight'] = oriHeight 134 | sample['oriWidth'] = oriWidth 135 | 136 | sample['img_path'] = img_path 137 | sample['depth_path'] = disp_path 138 | sample['calib_path'] = calib_path 139 | sample['lbl_path'] = lbl_path 140 | 141 | return sample 142 | 143 | def recursive_glob(self, rootdir='.', suffix=''): 144 | """Performs recursive glob with given suffix and rootdir 145 | :param rootdir is the root directory 146 | :param suffix is the suffix to be searched 147 | """ 148 | return [os.path.join(looproot, filename) 149 | for looproot, _, filenames in os.walk(rootdir) 150 | for filename in filenames if filename.endswith(suffix)] 151 | 152 | def transform_tr(self, sample): 153 | composed_transforms = transforms.Compose([ 154 | tr.RandomHorizontalFlip(), 155 | tr.RandomGaussianBlur(), 156 | tr.RandomGaussianNoise(), 157 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)), 158 | tr.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 159 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 160 | tr.ToTensor()]) 161 | return composed_transforms(sample) 162 | 163 | def transform_val(self, sample): 164 | composed_transforms = transforms.Compose([ 165 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)), 166 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 167 | tr.ToTensor()]) 168 | return composed_transforms(sample) 169 | 170 | def transform_ts(self, sample): 171 | composed_transforms = transforms.Compose([ 172 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)), 173 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 174 | tr.ToTensor()]) 175 | return composed_transforms(sample) 176 | -------------------------------------------------------------------------------- /dataset/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from torchvision import transforms 5 | from PIL import Image, ImageFilter, ImageEnhance 6 | 7 | 8 | class Normalize(object): 9 | """Normalize a tensor image with mean and standard deviation. 10 | Args: 11 | mean (tuple): means for each channel. 12 | std (tuple): standard deviations for each channel. 13 | """ 14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, sample): 19 | img = sample['image'] 20 | depth = sample['depth'] 21 | mask = sample['label'] 22 | img = np.array(img).astype(np.float32) 23 | depth = np.array(depth).astype(np.float32) 24 | mask = np.array(mask).astype(np.float32) 25 | img /= 255.0 26 | img -= self.mean 27 | img /= self.std 28 | 29 | return {'image': img, 30 | 'depth': depth, 31 | 'label': mask} 32 | 33 | 34 | class ToTensor(object): 35 | """Convert Image object in sample to Tensors.""" 36 | 37 | def __call__(self, sample): 38 | # swap color axis because 39 | # numpy image: H x W x C 40 | # torch image: C X H X W 41 | img = sample['image'] 42 | depth = sample['depth'] 43 | mask = sample['label'] 44 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 45 | depth = np.array(depth).astype(np.float32) 46 | mask = np.array(mask).astype(np.float32) 47 | 48 | img = torch.from_numpy(img).float() 49 | depth = torch.from_numpy(depth).float() 50 | mask = torch.from_numpy(mask).float() 51 | 52 | return {'image': img, 53 | 'depth': depth, 54 | 'label': mask} 55 | 56 | 57 | class ColorJitter(object): 58 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 59 | self.brightness = [max(1 - brightness, 0), 1 + brightness] 60 | self.contrast = [max(1 - contrast, 0), 1 + contrast] 61 | self.saturation = [max(1 - saturation, 0), 1 + saturation] 62 | 63 | def __call__(self, sample): 64 | img = sample['image'] 65 | depth = sample['depth'] 66 | mask = sample['label'] 67 | r_brightness = random.uniform(self.brightness[0], self.brightness[1]) 68 | r_contrast = random.uniform(self.contrast[0], self.contrast[1]) 69 | r_saturation = random.uniform(self.saturation[0], self.saturation[1]) 70 | img = ImageEnhance.Brightness(img).enhance(r_brightness) 71 | img = ImageEnhance.Contrast(img).enhance(r_contrast) 72 | img = ImageEnhance.Color(img).enhance(r_saturation) 73 | return {'image': img, 74 | 'depth': depth, 75 | 'label': mask} 76 | 77 | 78 | class RandomHorizontalFlip(object): 79 | def __call__(self, sample): 80 | img = sample['image'] 81 | depth = sample['depth'] 82 | mask = sample['label'] 83 | if random.random() < 0.5: 84 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 85 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 86 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 87 | 88 | return {'image': img, 89 | 'depth': depth, 90 | 'label': mask} 91 | 92 | 93 | class HorizontalFlip(object): 94 | def __call__(self, sample): 95 | img = sample['image'] 96 | depth = sample['depth'] 97 | mask = sample['label'] 98 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 99 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 100 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 101 | 102 | return {'image': img, 103 | 'depth': depth, 104 | 'label': mask} 105 | 106 | 107 | class RandomGaussianBlur(object): 108 | def __init__(self, radius=1): 109 | self.radius = radius 110 | 111 | def __call__(self, sample): 112 | img = sample['image'] 113 | depth = sample['depth'] 114 | mask = sample['label'] 115 | if random.random() < 0.5: 116 | img = img.filter(ImageFilter.GaussianBlur( 117 | radius=self.radius*random.random())) 118 | 119 | return {'image': img, 120 | 'depth': depth, 121 | 'label': mask} 122 | 123 | 124 | class RandomGaussianNoise(object): 125 | def __init__(self, mean=0, sigma=10): 126 | self.mean = mean 127 | self.sigma = sigma 128 | 129 | def gaussianNoisy(self, im, mean=0, sigma=10): 130 | noise = np.random.normal(mean, sigma, len(im)) 131 | im = im + noise 132 | 133 | im = np.clip(im, 0, 255) 134 | return im 135 | 136 | def __call__(self, sample): 137 | img = sample['image'] 138 | depth = sample['depth'] 139 | mask = sample['label'] 140 | if random.random() < 0.5: 141 | img = np.asarray(img) 142 | img = img.astype(np.int) 143 | width, height = img.shape[:2] 144 | img_r = self.gaussianNoisy(img[:, :, 0].flatten(), self.mean, self.sigma) 145 | img_g = self.gaussianNoisy(img[:, :, 1].flatten(), self.mean, self.sigma) 146 | img_b = self.gaussianNoisy(img[:, :, 2].flatten(), self.mean, self.sigma) 147 | img[:, :, 0] = img_r.reshape([width, height]) 148 | img[:, :, 1] = img_g.reshape([width, height]) 149 | img[:, :, 2] = img_b.reshape([width, height]) 150 | img = Image.fromarray(np.uint8(img)) 151 | return {'image': img, 152 | 'depth': depth, 153 | 'label': mask} 154 | 155 | 156 | class Resize(object): 157 | """Resize rgb and label images, while keep depth image unchanged. """ 158 | def __init__(self, size): 159 | self.size = size # size: (w, h) 160 | 161 | def __call__(self, sample): 162 | img = sample['image'] 163 | depth = sample['depth'] 164 | mask = sample['label'] 165 | 166 | assert img.size == depth.size == mask.size 167 | 168 | # resize rgb and label 169 | img = img.resize(self.size, Image.BILINEAR) 170 | # depth = depth.resize(self.size, Image.BILINEAR) 171 | mask = mask.resize(self.size, Image.NEAREST) 172 | 173 | return {'image': img, 174 | 'depth': depth, 175 | 'label': mask} 176 | -------------------------------------------------------------------------------- /dataset/kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | import cv2 8 | import glob 9 | from model.sne_model import SNE 10 | from dataset import custom_transforms as tr 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | class kittiCalibInfo(): 15 | """ 16 | Read calibration files in the kitti dataset, 17 | we need to use the intrinsic parameter of the cam2 18 | """ 19 | def __init__(self, filepath): 20 | """ 21 | Args: 22 | filepath ([str]): calibration file path (AAA.txt) 23 | """ 24 | self.data = self._load_calib(filepath) 25 | 26 | def get_cam_param(self): 27 | """ 28 | Returns: 29 | [numpy.array]: intrinsic parameter of the cam2 30 | """ 31 | return self.data['P2'] 32 | 33 | def _load_calib(self, filepath): 34 | rawdata = self._read_calib_file(filepath) 35 | data = {} 36 | P0 = np.reshape(rawdata['P0'], (3,4)) 37 | P1 = np.reshape(rawdata['P1'], (3,4)) 38 | P2 = np.reshape(rawdata['P2'], (3,4)) 39 | P3 = np.reshape(rawdata['P3'], (3,4)) 40 | R0_rect = np.reshape(rawdata['R0_rect'], (3,3)) 41 | Tr_velo_to_cam = np.reshape(rawdata['Tr_velo_to_cam'], (3,4)) 42 | 43 | data['P0'] = P0 44 | data['P1'] = P1 45 | data['P2'] = P2 46 | data['P3'] = P3 47 | data['R0_rect'] = R0_rect 48 | data['Tr_velo_to_cam'] = Tr_velo_to_cam 49 | 50 | return data 51 | 52 | def _read_calib_file(self, filepath): 53 | """Read in a calibration file and parse into a dictionary.""" 54 | data = {} 55 | 56 | with open(filepath, 'r') as f: 57 | for line in f.readlines(): 58 | key, value = line.split(':', 1) 59 | # The only non-float values in these files are dates, which 60 | # we don't care about anyway 61 | try: 62 | data[key] = np.array([float(x) for x in value.split()]) 63 | except ValueError: 64 | pass 65 | return data 66 | 67 | 68 | class Kitti_Dataset(data.Dataset): 69 | NUM_CLASSES = 2 70 | 71 | def __init__(self, args, root='./data/KITTI/', split='training'): 72 | 73 | self.root = root 74 | self.split = split 75 | self.args = args 76 | self.images = {} 77 | self.depths = {} 78 | self.labels = {} 79 | self.calibs = {} 80 | 81 | self.image_base = os.path.join(self.root, self.split, 'image_2') 82 | self.depth_base = os.path.join(self.root, self.split, 'depth_u16') 83 | self.label_base = os.path.join(self.root, self.split, 'gt_image_2') 84 | self.calib_base = os.path.join(self.root, self.split, 'calib') 85 | 86 | self.images[split] = [] 87 | self.images[split].extend(glob.glob(os.path.join(self.image_base, '*.png'))) 88 | self.images[split].sort() 89 | 90 | self.depths[split] = [] 91 | self.depths[split].extend(glob.glob(os.path.join(self.depth_base, '*.png'))) 92 | self.depths[split].sort() 93 | 94 | self.labels[split] = [] 95 | self.labels[split].extend(glob.glob(os.path.join(self.label_base, '*.png'))) 96 | self.labels[split].sort() 97 | 98 | self.calibs[split] = [] 99 | self.calibs[split].extend(glob.glob(os.path.join(self.calib_base, '*.txt'))) 100 | self.calibs[split].sort() 101 | 102 | self.sne_model = SNE(crop_top=True) 103 | 104 | if not self.images[split]: 105 | raise Exception("No RGB images for split=[%s] found in %s" % (split, self.image_base)) 106 | if not self.depths[split]: 107 | raise Exception("No depth images for split=[%s] found in %s" % (split, self.depth_base)) 108 | 109 | print("Found %d %s RGB images" % (len(self.images[split]), split)) 110 | print("Found %d %s depth images" % (len(self.depths[split]), split)) 111 | 112 | def __len__(self): 113 | return len(self.images[self.split]) 114 | 115 | def __getitem__(self, index): 116 | 117 | img_path = self.images[self.split][index].rstrip() 118 | depth_path = self.depths[self.split][index].rstrip() 119 | calib_path = self.calibs[self.split][index].rstrip() 120 | 121 | useDir = "/".join(img_path.split('/')[:-2]) 122 | name = img_path.split('/')[-1] 123 | 124 | _img = Image.open(img_path).convert('RGB') 125 | depth_image = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH) 126 | depth_image = depth_image.astype(np.float32) 127 | oriHeight, oriWidth = depth_image.shape 128 | _depth = Image.fromarray(depth_image) 129 | 130 | label = np.zeros((oriHeight, oriWidth), dtype=np.uint8) 131 | if not self.split == 'testing': 132 | lbl_path = os.path.join(useDir, 'gt_image_2', name[:-10] + 'road_' + name[-10:]) 133 | label_image = cv2.cvtColor(cv2.imread(lbl_path), cv2.COLOR_BGR2RGB) 134 | label[label_image[:, :, 2] > 0] = 1 135 | 136 | _target = Image.fromarray(label) 137 | 138 | sample = {'image': _img, 'depth': _depth, 'label': _target} 139 | 140 | if self.split == 'training': 141 | sample = self.transform_tr(sample) 142 | elif self.split == 'validating': 143 | sample = self.transform_val(sample) 144 | elif self.split == 'testing': 145 | sample = self.transform_ts(sample) 146 | else: 147 | sample = self.transform_ts(sample) 148 | 149 | depth_image = np.array(sample['depth']) 150 | 151 | calib = kittiCalibInfo(calib_path) 152 | camParam = torch.tensor(calib.get_cam_param(), dtype=torch.float32) 153 | normal = self.sne_model(torch.tensor(depth_image.astype(np.float32) / 1000), camParam) 154 | normal = normal.cpu().numpy() 155 | normal = np.transpose(normal, [1, 2, 0]) 156 | normal = cv2.resize(normal, (self.args.crop_width, self.args.crop_height)) 157 | 158 | normal = transforms.ToTensor()(normal) 159 | 160 | sample['depth'] = normal 161 | 162 | sample['label'] = np.array(sample['label']) 163 | sample['label'] = torch.from_numpy(sample['label']).long() 164 | 165 | sample['oriHeight'] = oriHeight 166 | sample['oriWidth'] = oriWidth 167 | 168 | sample['img_path'] = img_path 169 | sample['depth_path'] = depth_path 170 | sample['calib_path'] = calib_path 171 | 172 | return sample 173 | 174 | def transform_tr(self, sample): 175 | composed_transforms = transforms.Compose([ 176 | tr.RandomHorizontalFlip(), 177 | tr.RandomGaussianBlur(), 178 | tr.RandomGaussianNoise(), 179 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)), 180 | tr.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 181 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 182 | tr.ToTensor()]) 183 | return composed_transforms(sample) 184 | 185 | def transform_val(self, sample): 186 | composed_transforms = transforms.Compose([ 187 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)), 188 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 189 | tr.ToTensor()]) 190 | return composed_transforms(sample) 191 | 192 | def transform_ts(self, sample): 193 | composed_transforms = transforms.Compose([ 194 | tr.Resize(size=(self.args.crop_width, self.args.crop_height)), 195 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 196 | tr.ToTensor()]) 197 | return composed_transforms(sample) 198 | -------------------------------------------------------------------------------- /img/KITTI.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/img/KITTI.gif -------------------------------------------------------------------------------- /img/usnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/img/usnet.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | # loss function 6 | def KL(alpha, c): 7 | beta = torch.ones((1, c)).cuda() 8 | S_alpha = torch.sum(alpha, dim=1, keepdim=True) 9 | S_beta = torch.sum(beta, dim=1, keepdim=True) 10 | lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True) 11 | lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta) 12 | dg0 = torch.digamma(S_alpha) 13 | dg1 = torch.digamma(alpha) 14 | kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni 15 | return kl 16 | 17 | 18 | def ce_loss(p, alpha, c, global_step, annealing_step): 19 | S = torch.sum(alpha, dim=1, keepdim=True) 20 | E = alpha - 1 21 | label = F.one_hot(p, num_classes=c) 22 | A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True) 23 | 24 | annealing_coef = min(1, global_step / annealing_step) 25 | 26 | alp = E * (1 - label) + 1 27 | B = annealing_coef * KL(alp, c) 28 | 29 | return (A + B) 30 | 31 | 32 | def mse_loss(p, alpha, c, global_step, annealing_step=1): 33 | S = torch.sum(alpha, dim=1, keepdim=True) 34 | E = alpha - 1 35 | m = alpha / S 36 | label = F.one_hot(p, num_classes=c) 37 | A = torch.sum((label - m) ** 2, dim=1, keepdim=True) 38 | B = torch.sum(alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True) 39 | annealing_coef = min(1, global_step / annealing_step) 40 | alp = E * (1 - label) + 1 41 | C = annealing_coef * KL(alp, c) 42 | return (A + B) + C 43 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/morancyc/USNet/d761158ad42df7dcb62fa257dd02ce11c85f94a5/model/__init__.py -------------------------------------------------------------------------------- /model/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | 4 | 5 | class symmetric_backbone(torch.nn.Module): 6 | def __init__(self, name, pretrained=True): 7 | super().__init__() 8 | if name == 'resnet18': 9 | features = models.resnet18(pretrained=pretrained) 10 | features_d = models.resnet18(pretrained=pretrained) 11 | elif name == 'resnet101': 12 | features = models.resnet101(pretrained=pretrained) 13 | features_d = models.resnet101(pretrained=pretrained) 14 | else: 15 | print('Error: unspported backbone \n') 16 | 17 | self.conv1 = features.conv1 18 | self.bn1 = features.bn1 19 | self.relu = features.relu 20 | self.maxpool1 = features.maxpool 21 | self.layer1 = features.layer1 22 | self.layer2 = features.layer2 23 | self.layer3 = features.layer3 24 | self.layer4 = features.layer4 25 | 26 | self.conv1_d = features_d.conv1 27 | self.bn1_d = features_d.bn1 28 | self.relu_d = features_d.relu 29 | self.maxpool1_d = features_d.maxpool 30 | self.layer1_d = features_d.layer1 31 | self.layer2_d = features_d.layer2 32 | self.layer3_d = features_d.layer3 33 | self.layer4_d = features_d.layer4 34 | 35 | def forward(self, input_rgb, input_depth): 36 | # Symmetric Network 37 | x = self.conv1(input_rgb) 38 | x = self.relu(self.bn1(x)) 39 | feature0 = self.maxpool1(x) 40 | 41 | feature1 = self.layer1(feature0) # 1 / 4 42 | feature2 = self.layer2(feature1) # 1 / 8 43 | feature3 = self.layer3(feature2) # 1 / 16 44 | feature4 = self.layer4(feature3) # 1 / 32 45 | 46 | y = self.conv1_d(input_depth) 47 | y = self.relu_d(self.bn1_d(y)) 48 | feature0_d = self.maxpool1_d(y) 49 | 50 | feature1_d = self.layer1_d(feature0_d) # 1 / 4 51 | feature2_d = self.layer2_d(feature1_d) # 1 / 8 52 | feature3_d = self.layer3_d(feature2_d) # 1 / 16 53 | feature4_d = self.layer4_d(feature3_d) # 1 / 32 54 | 55 | return feature1, feature2, feature3, feature4, feature1_d, feature2_d, feature3_d, feature4_d 56 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class _ASPPModule(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): 8 | super(_ASPPModule, self).__init__() 9 | self.atrous_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 10 | stride=1, padding=padding, dilation=dilation, bias=False) 11 | self.bn = nn.BatchNorm2d(out_channels) 12 | self.relu = nn.ReLU() 13 | 14 | self._init_weight() 15 | 16 | def forward(self, x): 17 | x = self.atrous_conv(x) 18 | x = self.bn(x) 19 | 20 | return self.relu(x) 21 | 22 | def _init_weight(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | torch.nn.init.kaiming_normal_(m.weight) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | 30 | 31 | class ASPP(nn.Module): 32 | def __init__(self, in_channels, out_channels, dilations): 33 | super(ASPP, self).__init__() 34 | self.aspp1 = _ASPPModule(in_channels, out_channels, 1, padding=0, dilation=dilations[0]) 35 | self.aspp2 = _ASPPModule(in_channels, out_channels, 3, padding=dilations[1], dilation=dilations[1]) 36 | self.aspp3 = _ASPPModule(in_channels, out_channels, 3, padding=dilations[2], dilation=dilations[2]) 37 | self.aspp4 = _ASPPModule(in_channels, out_channels, 3, padding=dilations[3], dilation=dilations[3]) 38 | 39 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 40 | nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False), 41 | nn.BatchNorm2d(out_channels), 42 | nn.ReLU()) 43 | self.conv1 = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False) 44 | self.bn1 = nn.BatchNorm2d(out_channels) 45 | self.relu = nn.ReLU() 46 | self.dropout = nn.Dropout(0.5) 47 | self._init_weight() 48 | 49 | def forward(self, x): 50 | x1 = self.aspp1(x) 51 | x2 = self.aspp2(x) 52 | x3 = self.aspp3(x) 53 | x4 = self.aspp4(x) 54 | x5 = self.global_avg_pool(x) 55 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 56 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 57 | 58 | x = self.conv1(x) 59 | x = self.bn1(x) 60 | x = self.relu(x) 61 | 62 | return self.dropout(x) 63 | 64 | def _init_weight(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | torch.nn.init.kaiming_normal_(m.weight) 68 | elif isinstance(m, nn.BatchNorm2d): 69 | m.weight.data.fill_(1) 70 | m.bias.data.zero_() 71 | -------------------------------------------------------------------------------- /model/sne_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SNE(nn.Module): 7 | """Our SNE takes depth and camera intrinsic parameters as input, 8 | and outputs normal estimations. 9 | """ 10 | def __init__(self, crop_top=True): 11 | super(SNE, self).__init__() 12 | self.crop_top = crop_top 13 | 14 | def forward(self, depth, camParam): 15 | h,w = depth.size() 16 | v_map, u_map = torch.meshgrid(torch.arange(h), torch.arange(w)) 17 | v_map = v_map.type(torch.float32) 18 | u_map = u_map.type(torch.float32) 19 | 20 | Z = depth # h, w 21 | Y = Z.mul((v_map - camParam[1,2])) / camParam[0,0] # h, w 22 | X = Z.mul((u_map - camParam[0,2])) / camParam[0,0] # h, w 23 | 24 | if self.crop_top: 25 | Z[Y <= 0] = 0 26 | Y[Y <= 0] = 0 27 | 28 | Z[torch.isnan(Z)] = 0 29 | D = torch.div(torch.ones(h, w), Z) # h, w 30 | 31 | Gx = torch.tensor([[0,0,0],[-1,0,1],[0,0,0]], dtype=torch.float32) 32 | Gy = torch.tensor([[0,-1,0],[0,0,0],[0,1,0]], dtype=torch.float32) 33 | 34 | Gu = F.conv2d(D.view(1,1,h,w), Gx.view(1,1,3,3), padding=1) 35 | Gv = F.conv2d(D.view(1,1,h,w), Gy.view(1,1,3,3), padding=1) 36 | 37 | nx_t = Gu * camParam[0,0] # 1, 1, h, w 38 | ny_t = Gv * camParam[1,1] # 1, 1, h, w 39 | 40 | phi = torch.atan(torch.div(ny_t, nx_t)) + torch.ones([1,1,h,w])*3.141592657 41 | a = torch.cos(phi) 42 | b = torch.sin(phi) 43 | 44 | diffKernelArray = torch.tensor([[-1, 0, 0, 0, 1, 0, 0, 0, 0], 45 | [ 0,-1, 0, 0, 1, 0, 0, 0, 0], 46 | [ 0, 0,-1, 0, 1, 0, 0, 0, 0], 47 | [ 0, 0, 0,-1, 1, 0, 0, 0, 0], 48 | [ 0, 0, 0, 0, 1,-1, 0, 0, 0], 49 | [ 0, 0, 0, 0, 1, 0,-1, 0, 0], 50 | [ 0, 0, 0, 0, 1, 0, 0,-1, 0], 51 | [ 0, 0, 0, 0, 1, 0, 0, 0,-1]], dtype=torch.float32) 52 | 53 | sum_nx = torch.zeros((1,1,h,w), dtype=torch.float32) 54 | sum_ny = torch.zeros((1,1,h,w), dtype=torch.float32) 55 | sum_nz = torch.zeros((1,1,h,w), dtype=torch.float32) 56 | 57 | for i in range(8): 58 | diffKernel = diffKernelArray[i].view(1,1,3,3) 59 | X_d = F.conv2d(X.view(1,1,h,w), diffKernel, padding=1) 60 | Y_d = F.conv2d(Y.view(1,1,h,w), diffKernel, padding=1) 61 | Z_d = F.conv2d(Z.view(1,1,h,w), diffKernel, padding=1) 62 | 63 | nz_i = torch.div((torch.mul(nx_t, X_d) + torch.mul(ny_t, Y_d)), Z_d) 64 | norm = torch.sqrt(torch.mul(nx_t, nx_t) + torch.mul(ny_t, ny_t) + torch.mul(nz_i, nz_i)) 65 | nx_t_i = torch.div(nx_t, norm) 66 | ny_t_i = torch.div(ny_t, norm) 67 | nz_t_i = torch.div(nz_i, norm) 68 | 69 | nx_t_i[torch.isnan(nx_t_i)] = 0 70 | ny_t_i[torch.isnan(ny_t_i)] = 0 71 | nz_t_i[torch.isnan(nz_t_i)] = 0 72 | 73 | sum_nx = sum_nx + nx_t_i 74 | sum_ny = sum_ny + ny_t_i 75 | sum_nz = sum_nz + nz_t_i 76 | 77 | theta = -torch.atan(torch.div((torch.mul(sum_nx, a) + torch.mul(sum_ny, b)), sum_nz)) 78 | nx = torch.mul(torch.sin(theta), torch.cos(phi)) 79 | ny = torch.mul(torch.sin(theta), torch.sin(phi)) 80 | nz = torch.cos(theta) 81 | 82 | nx[torch.isnan(nz)] = 0 83 | ny[torch.isnan(nz)] = 0 84 | nz[torch.isnan(nz)] = -1 85 | 86 | sign = torch.ones((1,1,h,w), dtype=torch.float32) 87 | sign[ny > 0] = -1 88 | 89 | nx = torch.mul(nx, sign).squeeze(dim=0) 90 | ny = torch.mul(ny, sign).squeeze(dim=0) 91 | nz = torch.mul(nz, sign).squeeze(dim=0) 92 | 93 | return torch.cat([nx, ny, nz], dim=0) 94 | -------------------------------------------------------------------------------- /model/usnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from model.backbone import symmetric_backbone 5 | from model.module import ASPP 6 | from torchsummary import summary 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1): 11 | super().__init__() 12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 13 | bias=False) 14 | self.bn = nn.BatchNorm2d(out_channels) 15 | self.relu = nn.ReLU() 16 | 17 | def forward(self, input): 18 | x = self.conv1(input) 19 | return self.relu(self.bn(x)) 20 | 21 | 22 | # Feature Compression and Adaptation block 23 | class FCA(nn.Module): 24 | def __init__(self, in_channels, out_channels): 25 | super().__init__() 26 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 27 | self.bn = nn.BatchNorm2d(out_channels) 28 | self.sigmoid = nn.Sigmoid() 29 | self.in_channels = in_channels 30 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 31 | 32 | def forward(self, input): 33 | # global average pooling 34 | x = self.avgpool(input) 35 | assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1)) 36 | x = self.conv(x) 37 | x = self.sigmoid(x) 38 | x = torch.mul(input, x) 39 | return x 40 | 41 | 42 | class Up_layer(nn.Module): 43 | def __init__(self, in_channels, out_channels): 44 | super().__init__() 45 | self.conv = ConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) 46 | self.fca = FCA(in_channels=out_channels, out_channels=out_channels) 47 | 48 | def forward(self, x, x_aux): 49 | x = F.interpolate(x, size=x_aux.size()[-2:], mode='bilinear') 50 | x_aux = self.fca(self.conv(x_aux)) 51 | x = x + x_aux 52 | return x 53 | 54 | 55 | # Multi-scale Evidence Collection Module 56 | class MEC(nn.Module): 57 | def __init__(self, in_channels, out_channels): 58 | super().__init__() 59 | self.scale_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) 60 | self.scale_2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=3, dilation=3) 61 | self.scale_3 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=6, dilation=6) 62 | self.softplus = nn.Softplus() 63 | 64 | def forward(self, x, input): 65 | e_1 = self.scale_1(x) 66 | e_2 = self.scale_2(x) 67 | e_3 = self.scale_3(x) 68 | 69 | e_1 = F.interpolate(e_1, size=input.size()[-2:], mode='bilinear') 70 | e_2 = F.interpolate(e_2, size=input.size()[-2:], mode='bilinear') 71 | e_3 = F.interpolate(e_3, size=input.size()[-2:], mode='bilinear') 72 | 73 | e_1 = self.softplus(e_1) 74 | e_2 = self.softplus(e_2) 75 | e_3 = self.softplus(e_3) 76 | 77 | e = (e_1 + e_2 + e_3) / 3 78 | return e_1, e_2, e_3, e 79 | 80 | 81 | class USNet(nn.Module): 82 | def __init__(self, num_classes, backbone_name): 83 | super().__init__() 84 | # build backbone 85 | self.backbone = symmetric_backbone(name=backbone_name) 86 | 87 | if backbone_name == 'resnet101': 88 | # ASPP 89 | dilations = [1, 6, 12, 18] 90 | self.aspp_r = ASPP(2048, 256, dilations) 91 | self.aspp_d = ASPP(2048, 256, dilations) 92 | 93 | self.conv_r = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0) 94 | self.conv_d = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0) 95 | 96 | self.up_r1 = Up_layer(in_channels=1024, out_channels=64) 97 | self.up_r2 = Up_layer(in_channels=512, out_channels=64) 98 | self.up_r3 = Up_layer(in_channels=256, out_channels=64) 99 | 100 | self.up_d1 = Up_layer(in_channels=1024, out_channels=64) 101 | self.up_d2 = Up_layer(in_channels=512, out_channels=64) 102 | self.up_d3 = Up_layer(in_channels=256, out_channels=64) 103 | 104 | self.mec_r = MEC(in_channels=64, out_channels=num_classes) 105 | self.mec_d = MEC(in_channels=64, out_channels=num_classes) 106 | 107 | elif backbone_name == 'resnet18': 108 | # ASPP 109 | dilations = [1, 6, 12, 18] 110 | self.aspp_r = ASPP(512, 256, dilations) 111 | self.aspp_d = ASPP(512, 256, dilations) 112 | 113 | self.conv_r = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0) 114 | self.conv_d = ConvBlock(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0) 115 | 116 | self.up_r1 = Up_layer(in_channels=256, out_channels=64) 117 | self.up_r2 = Up_layer(in_channels=128, out_channels=64) 118 | self.up_r3 = Up_layer(in_channels=64, out_channels=64) 119 | 120 | self.up_d1 = Up_layer(in_channels=256, out_channels=64) 121 | self.up_d2 = Up_layer(in_channels=128, out_channels=64) 122 | self.up_d3 = Up_layer(in_channels=64, out_channels=64) 123 | 124 | self.mec_r = MEC(in_channels=64, out_channels=num_classes) 125 | self.mec_d = MEC(in_channels=64, out_channels=num_classes) 126 | else: 127 | print('Error: unspported backbone \n') 128 | 129 | self.num_classes = num_classes 130 | self.softplus = nn.Softplus() 131 | self.init_weight() 132 | 133 | def init_weight(self): 134 | for name, m in self.named_modules(): 135 | if 'backbone' not in name: 136 | if isinstance(m, nn.Conv2d): 137 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 138 | elif isinstance(m, nn.BatchNorm2d): 139 | m.eps = 1e-5 140 | m.momentum = 0.1 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | def DS_Combin(self, alpha): 145 | """ 146 | :param alpha: All Dirichlet distribution parameters. 147 | :return: Combined Dirichlet distribution parameters. 148 | """ 149 | 150 | def DS_Combin_two(alpha1, alpha2): 151 | """ 152 | :param alpha1: Dirichlet distribution parameters of view 1 153 | :param alpha2: Dirichlet distribution parameters of view 2 154 | :return: Combined Dirichlet distribution parameters 155 | """ 156 | alpha = dict() 157 | alpha[0], alpha[1] = alpha1, alpha2 158 | b, S, E, u = dict(), dict(), dict(), dict() 159 | for v in range(2): 160 | S[v] = torch.sum(alpha[v], dim=1, keepdim=True) 161 | E[v] = alpha[v] - 1 162 | b[v] = E[v] / (S[v].expand(E[v].shape)) 163 | u[v] = self.num_classes / S[v] 164 | 165 | # b^0 @ b^(0+1) 166 | bb = torch.bmm(b[0].view(-1, self.num_classes, 1), b[1].view(-1, 1, self.num_classes)) 167 | # b^0 * u^1 168 | uv1_expand = u[1].expand(b[0].shape) 169 | bu = torch.mul(b[0], uv1_expand) 170 | # b^1 * u^0 171 | uv_expand = u[0].expand(b[0].shape) 172 | ub = torch.mul(b[1], uv_expand) 173 | # calculate C 174 | bb_sum = torch.sum(bb, dim=(1, 2), out=None) 175 | bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1) 176 | # bb_diag1 = torch.diag(torch.mm(b[v], torch.transpose(b[v+1], 0, 1))) 177 | C = bb_sum - bb_diag 178 | 179 | # calculate b^a 180 | b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - C).view(-1, 1).expand(b[0].shape)) 181 | # calculate u^a 182 | u_a = torch.mul(u[0], u[1]) / ((1 - C).view(-1, 1).expand(u[0].shape)) 183 | 184 | # calculate new S 185 | S_a = self.num_classes / u_a 186 | # calculate new e_k 187 | e_a = torch.mul(b_a, S_a.expand(b_a.shape)) 188 | alpha_a = e_a + 1 189 | return alpha_a 190 | 191 | for v in range(len(alpha) - 1): 192 | if v == 0: 193 | alpha_a = DS_Combin_two(alpha[0], alpha[1]) 194 | else: 195 | alpha_a = DS_Combin_two(alpha_a, alpha[v + 1]) 196 | return alpha_a 197 | 198 | def forward(self, input_rgb, input_depth): 199 | # output of backbone 200 | x_r1, x_r2, x_r3, x_r4, x_d1, x_d2, x_d3, x_d4 = self.backbone(input_rgb, input_depth) 201 | 202 | # ASPP 203 | x_r4 = self.aspp_r(x_r4) 204 | x_d4 = self.aspp_d(x_d4) 205 | 206 | # compress channel to 64 207 | x_r4 = self.conv_r(x_r4) 208 | x_d4 = self.conv_d(x_d4) 209 | 210 | # decoder 211 | x_r3 = self.up_r1(x_r4, x_r3) 212 | x_r2 = self.up_r2(x_r3, x_r2) 213 | x_r1 = self.up_r3(x_r2, x_r1) 214 | 215 | x_d3 = self.up_d1(x_d4, x_d3) 216 | x_d2 = self.up_d2(x_d3, x_d2) 217 | x_d1 = self.up_d3(x_d2, x_d1) 218 | 219 | # MEC module 220 | e_r1, e_r2, e_r3, e_r = self.mec_r(x_r1, input_rgb) 221 | e_d1, e_d2, e_d3, e_d = self.mec_d(x_d1, input_depth) 222 | 223 | # compute evidence, alpha 224 | evidence = dict() 225 | evidence[0] = e_r.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 226 | evidence[1] = e_d.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 227 | 228 | alpha = dict() 229 | for v_num in range(len(evidence)): 230 | alpha[v_num] = evidence[v_num] + 1 231 | 232 | alpha_a = self.DS_Combin(alpha) 233 | evidence_a = alpha_a - 1 234 | 235 | if self.training == True: 236 | evidence_sup = dict() 237 | evidence_sup[0] = e_r1.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 238 | evidence_sup[1] = e_d1.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 239 | evidence_sup[2] = e_r2.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 240 | evidence_sup[3] = e_d2.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 241 | evidence_sup[4] = e_r3.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 242 | evidence_sup[5] = e_d3.permute(0, 2, 3, 1).reshape(-1, self.num_classes) # (B*H*W, 2) 243 | 244 | alpha_sup = dict() 245 | for v_num in range(len(evidence_sup)): 246 | alpha_sup[v_num] = evidence_sup[v_num] + 1 247 | 248 | return evidence_sup, alpha_sup, evidence, evidence_a, alpha, alpha_a 249 | 250 | return evidence, evidence_a, alpha, alpha_a 251 | 252 | 253 | if __name__ == '__main__': 254 | 255 | model = USNet(2, 'resnet18') 256 | model = model.cuda() 257 | summary(model, [(3, 1248, 384),(3, 1248, 384)]) 258 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision>=0.7.0 3 | torchsummary 4 | matplotlib 5 | scikit-image 6 | scipy 7 | tqdm 8 | timm 9 | opencv-python 10 | protobuf 11 | tensorboardX 12 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import os 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from dataset.kitti import Kitti_Dataset 8 | from model.usnet import USNet 9 | from utils import fast_hist, getScores 10 | 11 | 12 | def test(args, model, dataloader): 13 | print('start test!') 14 | with torch.no_grad(): 15 | model.eval() 16 | hist = np.zeros((args.num_classes, args.num_classes)) 17 | for i, sample in enumerate(dataloader): 18 | image, depth, label = sample['image'], sample['depth'], sample['label'] 19 | oriHeight, oriWidth = sample['oriHeight'], sample['oriWidth'] 20 | oriWidth = oriWidth.cpu().numpy()[0] 21 | oriHeight = oriHeight.cpu().numpy()[0] 22 | 23 | if torch.cuda.is_available() and args.use_gpu: 24 | image = image.cuda() 25 | depth = depth.cuda() 26 | label = label.cuda() 27 | 28 | # get predict image 29 | evidence, evidence_a, alpha, alpha_a = model(image, depth) 30 | 31 | s = torch.sum(alpha_a, dim=1, keepdim=True) 32 | p = alpha_a / (s.expand(alpha_a.shape)) 33 | 34 | pred = p[:,1] 35 | pred = pred.view(args.crop_height, args.crop_width) 36 | pred = pred.detach().cpu().numpy() 37 | pred = np.array(pred) 38 | 39 | # save predict image 40 | visualize = cv2.resize(pred, (oriWidth, oriHeight)) 41 | visualize = np.floor(255*(visualize - visualize.min()) / (visualize.max()-visualize.min())) 42 | img_path = sample['img_path'][0] 43 | img_name = img_path.split('/')[-1] 44 | save_name = img_name.split('_')[0]+'_road_'+img_name.split('_')[1] 45 | cv2.imwrite(os.path.join(args.save_path, save_name), np.uint8(visualize)) 46 | 47 | pred = np.uint8(pred > 0.5) 48 | pred = cv2.resize(pred, (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST) 49 | 50 | # get label image 51 | label = label.squeeze() 52 | label = label.cpu().numpy() 53 | label = np.array(label) 54 | label = cv2.resize(np.uint8(label), (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST) 55 | 56 | hist += fast_hist(label.flatten(), pred.flatten(), args.num_classes) 57 | F_score, pre, recall, fpr, fnr = getScores(hist) 58 | print('F_score: %.3f' % F_score) 59 | print('pre : %.3f' % pre) 60 | print('recall: %.3f' % recall) 61 | print('fpr: %.3f' % fpr) 62 | print('fnr: %.3f' % fpr) 63 | return F_score, pre, recall, fpr, fnr 64 | 65 | 66 | def main(params): 67 | # basic parameters 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--checkpoint_path', type=str, default=None, required=True, help='The path to the pretrained weights of model') 70 | parser.add_argument('--crop_height', type=int, default=384, help='Height of cropped/resized input image to network') 71 | parser.add_argument('--crop_width', type=int, default=1248, help='Width of cropped/resized input image to network') 72 | parser.add_argument('--data', type=str, default='', help='Path of testing data') 73 | parser.add_argument('--dataset', type=str, default="KITTI", help='Dataset you are using.') 74 | parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch') 75 | parser.add_argument('--backbone_name', type=str, default="resnet18", help='The backbone model you are using.') 76 | parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training') 77 | parser.add_argument('--use_gpu', type=bool, default=True, help='Whether to user gpu for training') 78 | parser.add_argument('--num_classes', type=int, default=2, help='num of object classes (with void)') 79 | parser.add_argument('--save_path', type=str, default=None, required=True, help='Path to save predict image') 80 | args = parser.parse_args(params) 81 | 82 | dataset = Kitti_Dataset(args, root=args.data, split='validating') 83 | dataloader = DataLoader( 84 | dataset, 85 | batch_size=args.batch_size, 86 | shuffle=False, 87 | num_workers=1, 88 | ) 89 | 90 | # build model 91 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda 92 | model = USNet(args.num_classes, args.backbone_name) 93 | if torch.cuda.is_available() and args.use_gpu: 94 | model = torch.nn.DataParallel(model).cuda() 95 | 96 | # load trained model 97 | print('load model from %s ...' % args.checkpoint_path) 98 | model.module.load_state_dict(torch.load(args.checkpoint_path)) 99 | print('Done!') 100 | 101 | # make save folder 102 | if not os.path.exists(args.save_path): 103 | os.makedirs(args.save_path) 104 | 105 | test(args, model, dataloader) 106 | 107 | 108 | if __name__ == '__main__': 109 | params = [ 110 | '--checkpoint_path', './log/KITTI_model/usnet_best.pth', 111 | '--data', './data/KITTI', 112 | '--batch_size', '1', 113 | '--backbone_name', 'resnet18', 114 | '--cuda', '0', 115 | '--num_classes', '2', 116 | '--save_path', './result/kitti_test', 117 | ] 118 | main(params) 119 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import cv2 4 | import torch 5 | import tqdm 6 | import argparse 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | from dataset.kitti import Kitti_Dataset 10 | from dataset.cityscapes import Cityscapes_Dataset 11 | from model.usnet import USNet 12 | from tensorboardX import SummaryWriter 13 | from utils import poly_lr_scheduler, fast_hist, getScores 14 | from loss import KL, ce_loss, mse_loss 15 | import time 16 | 17 | 18 | # validation 19 | def val(args, model, dataloader): 20 | print('start val!') 21 | with torch.no_grad(): 22 | model.eval() 23 | hist = np.zeros((args.num_classes, args.num_classes)) 24 | for i, sample in enumerate(dataloader): 25 | image, depth, label = sample['image'], sample['depth'], sample['label'] 26 | oriHeight, oriWidth = sample['oriHeight'], sample['oriWidth'] 27 | oriWidth = oriWidth.cpu().numpy()[0] 28 | oriHeight = oriHeight.cpu().numpy()[0] 29 | 30 | if torch.cuda.is_available() and args.use_gpu: 31 | image = image.cuda() 32 | depth = depth.cuda() 33 | label = label.cuda() 34 | 35 | # get predict 36 | evidence, evidence_a, alpha, alpha_a = model(image, depth) 37 | 38 | s = torch.sum(alpha_a, dim=1, keepdim=True) 39 | p = alpha_a / (s.expand(alpha_a.shape)) 40 | 41 | pred = p[:,1] 42 | pred = pred.view(args.crop_height, args.crop_width) 43 | pred = pred.detach().cpu().numpy() 44 | pred = np.array(pred) 45 | pred = np.uint8(pred > 0.5) 46 | pred = cv2.resize(pred, (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST) 47 | 48 | # get label 49 | label = label.squeeze() 50 | label = label.cpu().numpy() 51 | label = np.array(label) 52 | label = cv2.resize(np.uint8(label), (oriWidth, oriHeight), interpolation=cv2.INTER_NEAREST) 53 | 54 | hist += fast_hist(label.flatten(), pred.flatten(), args.num_classes) 55 | F_score, pre, recall, fpr, fnr = getScores(hist) 56 | print('F_score: %.3f' % F_score) 57 | print('pre : %.3f' % pre) 58 | print('recall: %.3f' % recall) 59 | print('fpr: %.3f' % fpr) 60 | print('fnr: %.3f' % fpr) 61 | return F_score, pre, recall, fpr, fnr 62 | 63 | 64 | def train(args, model, optimizer, dataloader_train, dataloader_val): 65 | writer = SummaryWriter(comment=''.format(args.backbone_name)) 66 | max_F_score = 0 67 | step = 0 68 | lambda_epochs = 50 69 | for epoch in range(args.num_epochs): 70 | lr = poly_lr_scheduler(optimizer, args.learning_rate, iter=epoch, max_iter=args.num_epochs) 71 | model.train() 72 | tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size) 73 | tq.set_description('epoch %d, lr %f' % (epoch, lr)) 74 | loss_record = [] 75 | for i, sample in enumerate(dataloader_train): 76 | image, depth, label = sample['image'], sample['depth'], sample['label'] 77 | if torch.cuda.is_available() and args.use_gpu: 78 | image = image.cuda() 79 | depth = depth.cuda() 80 | label = label.cuda() 81 | 82 | # network output 83 | evidence_sup, alpha_sup, evidence, evidence_a, alpha, alpha_a = model(image, depth) 84 | 85 | # compute loss 86 | label = label.flatten() 87 | loss = 0 88 | for v_num in range(len(alpha_sup)): 89 | loss += ce_loss(label, alpha_sup[v_num], args.num_classes, epoch, lambda_epochs) 90 | for v_num in range(len(alpha)): 91 | loss += ce_loss(label, alpha[v_num], args.num_classes, epoch, lambda_epochs) 92 | loss += 2 * ce_loss(label, alpha_a, args.num_classes, epoch, lambda_epochs) 93 | loss = torch.mean(loss) 94 | 95 | tq.update(args.batch_size) 96 | tq.set_postfix(loss='%.6f' % loss) 97 | optimizer.zero_grad() 98 | loss.backward() 99 | optimizer.step() 100 | step += 1 101 | writer.add_scalar('loss_step', loss, step) 102 | loss_record.append(loss.item()) 103 | tq.close() 104 | loss_train_mean = np.mean(loss_record) 105 | writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean), epoch) 106 | print('loss for train : %f' % (loss_train_mean)) 107 | 108 | # save checkpoints 109 | if epoch % args.checkpoint_step == 0: 110 | if not os.path.isdir(args.save_model_path): 111 | os.mkdir(args.save_model_path) 112 | torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'usnet_latest.pth')) 113 | 114 | if epoch % args.validation_step == 0: 115 | F_score, pre, recall, fpr, fnr = val(args, model, dataloader_val) 116 | file = open(os.path.join(args.save_model_path, 'F_score.txt'), mode='a+') 117 | file.write('epoch = %d, F_score = %f\n' % (epoch, F_score)) 118 | file.close() 119 | if F_score > max_F_score: 120 | max_F_score = F_score 121 | torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'usnet_best.pth')) 122 | writer.add_scalar('epoch/F_score', F_score, epoch) 123 | writer.add_scalar('epoch/pre', pre, epoch) 124 | writer.add_scalar('epoch/recall', recall, epoch) 125 | writer.add_scalar('epoch/fpr', fpr, epoch) 126 | writer.add_scalar('epoch/fnr', fnr, epoch) 127 | 128 | 129 | def main(params): 130 | # set initialization seed 131 | torch.manual_seed(0) 132 | torch.cuda.manual_seed(0) 133 | torch.cuda.manual_seed_all(0) 134 | np.random.seed(0) 135 | random.seed(0) 136 | torch.backends.cudnn.benchmark = True 137 | torch.backends.cudnn.deterministic = True 138 | 139 | # basic parameters 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('--num_epochs', type=int, default=500, help='Number of epochs to train for') 142 | parser.add_argument('--checkpoint_step', type=int, default=1, help='How often to save checkpoints (epochs)') 143 | parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)') 144 | parser.add_argument('--data', type=str, default='', help='path of training data') 145 | parser.add_argument('--dataset', type=str, default="Kitti", help='Dataset you are using.') 146 | parser.add_argument('--crop_height', type=int, default=384, help='Height of cropped/resized input image to network') 147 | parser.add_argument('--crop_width', type=int, default=1248, help='Width of cropped/resized input image to network') 148 | parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch') 149 | parser.add_argument('--backbone_name', type=str, default="resnet18", 150 | help='The backbone model you are using, resnet18, resnet101.') 151 | parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate used for train') 152 | parser.add_argument('--num_workers', type=int, default=4, help='num of workers') 153 | parser.add_argument('--num_classes', type=int, default=32, help='num of object classes') 154 | parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training') 155 | parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training') 156 | parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model') 157 | parser.add_argument('--save_model_path', type=str, default=None, help='path to save model') 158 | 159 | args = parser.parse_args(params) 160 | 161 | if args.dataset == 'Kitti': 162 | train_set = Kitti_Dataset(args, root=args.data, split='training') 163 | val_set = Kitti_Dataset(args, root=args.data, split='validating') 164 | elif args.dataset == 'Cityscapes': 165 | train_set = Cityscapes_Dataset(args, root=args.data, split='train') 166 | val_set = Cityscapes_Dataset(args, root=args.data, split='val') 167 | 168 | dataloader_train = DataLoader( 169 | train_set, 170 | batch_size=args.batch_size, 171 | shuffle=True, 172 | num_workers=args.num_workers, 173 | drop_last=True 174 | ) 175 | dataloader_val = DataLoader( 176 | val_set, 177 | batch_size=1, 178 | shuffle=False, 179 | num_workers=args.num_workers 180 | ) 181 | 182 | # build model 183 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda 184 | 185 | model = USNet(args.num_classes, args.backbone_name) 186 | if torch.cuda.is_available() and args.use_gpu: 187 | model = torch.nn.DataParallel(model).cuda() 188 | 189 | encoder_params = list(map(id, model.module.backbone.parameters())) 190 | base_params = filter(lambda p: id(p) not in encoder_params, model.parameters()) 191 | 192 | optimizer = torch.optim.AdamW([{'params': base_params}, 193 | {'params':model.module.backbone.parameters(), 'lr': args.learning_rate*0.1}], 194 | lr=args.learning_rate, betas=(0.9,0.999), weight_decay=0.01) 195 | 196 | # load pretrained model if exists 197 | if args.pretrained_model_path is not None: 198 | print('load model from %s ...' % args.pretrained_model_path) 199 | model.module.load_state_dict(torch.load(args.pretrained_model_path)) 200 | print('Done!') 201 | 202 | # train 203 | train(args, model, optimizer, dataloader_train, dataloader_val) 204 | 205 | if __name__ == '__main__': 206 | params = [ 207 | '--num_epochs', '500', 208 | '--learning_rate', '1e-3', 209 | '--data', './data/KITTI', 210 | '--dataset', 'Kitti', 211 | '--num_workers', '8', 212 | '--num_classes', '2', 213 | '--cuda', '0', 214 | '--batch_size', '2', 215 | '--save_model_path', './log/KITTI_model', 216 | '--backbone_name', 'resnet18', # only support resnet18 and resnet101 217 | '--checkpoint_step', '1', 218 | '--validation_step', '20', 219 | ] 220 | main(params) 221 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=500, power=0.9): 5 | """Polynomial decay of learning rate 6 | :param init_lr is base learning rate 7 | :param iter is a current iteration 8 | :param lr_decay_iter how frequently decay occurs, default is 1 9 | :param max_iter is number of maximum iterations 10 | :param power is a polymomial power 11 | 12 | """ 13 | # if iter % lr_decay_iter or iter > max_iter: 14 | # return optimizer 15 | 16 | lr = init_lr * (1 - iter / max_iter) ** power 17 | optimizer.param_groups[0]['lr'] = lr 18 | optimizer.param_groups[1]['lr'] = lr * 0.1 19 | return lr 20 | 21 | 22 | def fast_hist(a, b, n): 23 | ''' 24 | a and b are predict and mask respectively 25 | n is the number of classes 26 | ''' 27 | k = (a >= 0) & (a < n) 28 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 29 | 30 | 31 | def getScores(conf_matrix): 32 | if conf_matrix.sum() == 0: 33 | return 0, 0, 0, 0, 0 34 | with np.errstate(divide='ignore',invalid='ignore'): 35 | classpre = np.diag(conf_matrix) / conf_matrix.sum(0).astype(np.float) 36 | classrecall = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float) 37 | pre = classpre[1] 38 | recall = classrecall[1] 39 | F_score = 2 * (recall * pre) / (recall + pre) 40 | fpr = conf_matrix[0, 1] / np.float(conf_matrix[0, 0] + conf_matrix[0, 1]) 41 | fnr = conf_matrix[1, 0] / np.float(conf_matrix[1, 0] + conf_matrix[1, 1]) 42 | return F_score, pre, recall, fpr, fnr 43 | --------------------------------------------------------------------------------