├── config ├── __init__.py └── config_nuscenes.py ├── misc ├── __init__.py └── devkit │ ├── cpp │ ├── evaluate_depth │ ├── make.sh │ ├── log_colormap.h │ ├── utils.h │ ├── io_depth.h │ └── evaluate_depth.cpp │ ├── matlab │ └── depth_read.m │ ├── python │ └── read_depth.py │ └── readme.txt ├── model ├── __init__.py ├── utils.py ├── multistage_model.py └── models.py ├── dataset ├── __init__.py ├── dense_to_sparse.py ├── nuscenes_export.py ├── radar_preprocessing.py ├── transforms.py └── nuscenes_dataset_torch_new.py ├── evaluation ├── __init__.py ├── criteria.py ├── criteria_new.py └── metrics.py ├── pretrained └── .gitkeep ├── .gitignore ├── requirements.txt ├── train_model.sh ├── LICENSE ├── README.md ├── utils.py └── main.py /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/devkit/cpp/evaluate_depth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brade31919/radar_depth/HEAD/misc/devkit/cpp/evaluate_depth -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.ckpt 3 | *.index 4 | *.meta 5 | *.tar 6 | *.png 7 | *.jpg 8 | .idea/ 9 | *.code-workspace 10 | *.pth.tar 11 | __pycache__/ 12 | .vscode/ -------------------------------------------------------------------------------- /misc/devkit/matlab/depth_read.m: -------------------------------------------------------------------------------- 1 | function D = depth_read (filename) 2 | % loads depth map D from png file 3 | % for details see readme.txt 4 | 5 | I = imread(filename); 6 | D = double(I)/256; 7 | D(I==0) = -1; 8 | 9 | -------------------------------------------------------------------------------- /misc/devkit/cpp/make.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "===========================================================================" 3 | g++ -O3 -DNDEBUG -Wno-unused-result -o evaluate_depth evaluate_depth.cpp -lpng 4 | echo "Built evaluate_depth." -------------------------------------------------------------------------------- /misc/devkit/cpp/log_colormap.h: -------------------------------------------------------------------------------- 1 | #ifndef LOG_COLORMAP_H 2 | #define LOG_COLORMAP_H 3 | 4 | float LC[10][5] = 5 | {{0,0.0625,49,54,149}, 6 | {0.0625,0.125,69,117,180}, 7 | {0.125,0.25,116,173,209}, 8 | {0.25,0.5,171,217,233}, 9 | {0.5,1,224,243,248}, 10 | {1,2,254,224,144}, 11 | {2,4,253,174,97}, 12 | {4,8,244,109,67}, 13 | {8,16,215,48,39}, 14 | {16,1000000000.0,165,0,38}}; 15 | 16 | #endif // LOG_COLORMAP_H 17 | 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrdict==2.0.1 2 | attrs==19.1.0 3 | dominate==2.4.0 4 | easydict==1.9 5 | h5py==2.10.0 6 | numpy==1.16.3 7 | nuscenes-devkit==1.0.4 8 | opencv-contrib-python==4.0.1.24 9 | opencv-python==4.0.1.24 10 | pandas==0.24.2 11 | Pillow==6.2.1 12 | PyYAML==5.1 13 | scikit-image==0.16.2 14 | scipy==1.2.1 15 | tensorboardX==1.7 16 | torch==1.3.1 17 | torch-dct==0.1.5 18 | torchfile==0.1.0 19 | torchvision==0.4.2 20 | tqdm==4.31.1 -------------------------------------------------------------------------------- /misc/devkit/python/read_depth.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | def depth_read(filename): 8 | # loads depth map D from png file 9 | # and returns it as a numpy array, 10 | # for details see readme.txt 11 | 12 | depth_png = np.array(Image.open(filename), dtype=int) 13 | # make sure we have a proper 16bit depth map here.. not 8bit! 14 | assert(np.max(depth_png) > 255) 15 | 16 | depth = depth_png.astype(np.float32) / 256. 17 | depth[depth_png == 0] = -1. 18 | return depth 19 | -------------------------------------------------------------------------------- /train_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py \ 3 | --arch resnet18_multistage_uncertainty_fixs \ 4 | --data nuscenes \ 5 | --modality rgbd \ 6 | --decoder upproj \ 7 | -j 16 \ 8 | --epochs 20 \ 9 | -b 8 \ 10 | --num-samples 50 \ 11 | --max-depth 80 \ 12 | --sparsifier radar 13 | # --resume ./results/sparse_to_dense/nuscenes.sparsifier\=radar.samples\=50.modality\=rgbd.arch\=resnet18_latefusion.decoder\=upproj.criterion\=l1.lr\=0.01.bs\=16.pretrained\=True/checkpoint-0.pth.tar \ 14 | 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 沅 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /misc/devkit/cpp/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef UTILS_H 2 | #define UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | bool imageFormat(std::string file_name,png::color_type col,size_t depth,int32_t width,int32_t height) { 13 | std::ifstream file_stream; 14 | file_stream.open(file_name.c_str(),std::ios::binary); 15 | png::reader reader(file_stream); 16 | reader.read_info(); 17 | if (reader.get_color_type()!=col) return false; 18 | if (reader.get_bit_depth()!=depth) return false; 19 | if (reader.get_width()!=width) return false; 20 | if (reader.get_height()!=height) return false; 21 | return true; 22 | } 23 | 24 | float statMean(std::vector< std::vector > &errors,int32_t idx) { 25 | float err_mean = 0; 26 | for (int32_t i=0; i > &errors,int32_t idx,int32_t idx_num) { 32 | float err = 0; 33 | float num = 0; 34 | for (int32_t i=0; i > &errors,int32_t idx) { 42 | float err_min = 1; 43 | for (int32_t i=0; i > &errors,int32_t idx) { 49 | float err_max = 0; 50 | for (int32_t i=0; ierr_max) err_max = errors[i][idx]; 52 | return err_max; 53 | } 54 | 55 | #endif // UTILS_H 56 | 57 | -------------------------------------------------------------------------------- /evaluation/criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import ipdb 5 | 6 | 7 | class MaskedMSELoss(nn.Module): 8 | def __init__(self): 9 | super(MaskedMSELoss, self).__init__() 10 | 11 | def forward(self, pred, target): 12 | assert pred.dim() == target.dim(), "inconsistent dimensions" 13 | valid_mask = (target>0).detach() 14 | diff = target - pred 15 | diff = diff[valid_mask] 16 | self.loss = (diff ** 2).mean() 17 | return self.loss 18 | 19 | 20 | class MaskedL1Loss(nn.Module): 21 | def __init__(self): 22 | super(MaskedL1Loss, self).__init__() 23 | 24 | def forward(self, pred, target): 25 | assert pred.dim() == target.dim(), "inconsistent dimensions" 26 | valid_mask = (target>0).detach() 27 | diff = target - pred 28 | diff = diff[valid_mask] 29 | self.loss = diff.abs().mean() 30 | return self.loss 31 | 32 | 33 | class MaskedBerHuLoss(nn.Module): 34 | def __init__(self, thresh=0.2): 35 | super(MaskedBerHuLoss, self).__init__() 36 | self.thresh = thresh 37 | 38 | def forward(self, pred, target): 39 | assert pred.dim() == target.dim(), "inconsistent dimensions" 40 | valid_mask = (target > 0).detach() 41 | 42 | # Mask out the content 43 | pred = pred[valid_mask] 44 | target = target[valid_mask] 45 | 46 | # ipdb.set_trace() 47 | diff = torch.abs(target - pred) 48 | delta = self.thresh * torch.max(diff).item() 49 | 50 | part1 = - torch.nn.functional.threshold(-diff, -delta, 0.) 51 | part2 = torch.nn.functional.threshold(diff ** 2 - delta ** 2, 0., -delta**2.) + delta ** 2 52 | part2 = part2 / (2. * delta) 53 | 54 | loss = part1 + part2 55 | loss = torch.mean(loss) 56 | 57 | return loss 58 | -------------------------------------------------------------------------------- /config/config_nuscenes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains configurations of the NuScenes dataset 3 | """ 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | from attrdict import AttrDict 9 | import os 10 | 11 | 12 | # Define the configurations for kitti dataset 13 | class config_nuscenes(object): 14 | PROJECT_ROOT = "YOUR_PATH/radar_depth" 15 | dataset_mode = "full" 16 | 17 | # Data path configuration 18 | DATASET_ROOT = "DATASET_PATH" 19 | 20 | # Some parameters 21 | TRAIN_VAL_SEED = 100 22 | VAL_RATIO = 0.1 23 | 24 | # Define the orientation mode 25 | # ver1: only front and back 26 | # ver2: all directions 27 | version = "ver3" 28 | lidar_sweeps = 1 29 | radar_sweeps = 1 30 | 31 | scaling = True 32 | scale_factor = 0.5 33 | 34 | # [DORN] transform parameters 35 | DORN_transform_config = AttrDict({ 36 | "crop_size_train": [385, 513], 37 | "rotation_factor": 5., 38 | "scale_factor_train": [1., 1.5], 39 | "crop_size_val": [385, 513], 40 | "scale_factor_val": 1. 41 | }) 42 | # [sparse-to-dense] transform parameters 43 | sparse_transform_config = AttrDict({ 44 | "crop_size_train": [450, 800], 45 | "rotation_factor": 5., 46 | "scale_factor_train": [1., 1.5], 47 | "crop_size_val": [450, 800], 48 | "eval_size": [450, 800], 49 | "scale_factor_val": 1. 50 | }) 51 | 52 | # Dataset export path 53 | EXPORT_ROOT = os.path.join(DATASET_ROOT, "Nuscenes_depth") 54 | 55 | # Always do ver2 56 | if version in ["ver1", "ver2"]: 57 | export_name = "ver2_lidar%d_radar%d" % (lidar_sweeps, radar_sweeps) 58 | radar_export_name = None 59 | elif version == "ver3": 60 | export_name = "ver2_lidar%d_radar%d" % (lidar_sweeps, radar_sweeps) 61 | radar_export_name = "ver2_lidar1_radar3_radar_only" 62 | else: 63 | raise ValueError("Unknow dataset version. we currently only support ver1~ver3.") -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | 6 | class Result(object): 7 | def __init__(self): 8 | self.accuracy = 0 9 | self.precision = 0 10 | self.recall = 0 11 | self.loss = 0 12 | 13 | def set_to_worst(self): 14 | self.accuracy = 0 15 | self.precision = 0 16 | self.recall = 0 17 | self.loss = 0 18 | 19 | def update(self, accuracy, precision, recall, loss): 20 | self.accuracy = accuracy 21 | self.precision = precision 22 | self.recall = recall 23 | self.loss = loss 24 | 25 | def evaluate(self, predictions, labels, loss, mask=None): 26 | predictions = np.argmax(predictions.cpu().numpy(), axis=-1) 27 | labels = labels.cpu().numpy() 28 | 29 | if mask is not None: 30 | mask = torch.squeeze(mask, dim=-1).cpu().numpy().astype(np.bool) 31 | predictions = predictions[mask] 32 | labels = labels[mask] 33 | # pdb.set_trace() 34 | tp_count = np.sum((predictions == 1) & (labels == 1)) 35 | tn_count = np.sum((predictions == 0) & (labels == 0)) 36 | fn_count = np.sum((predictions == 0) & (labels == 1)) 37 | fp_count = np.sum((predictions == 1) & (labels == 0)) 38 | 39 | # ToDo: whether we should record precision 40 | self.accuracy = (tp_count + tn_count) / (tp_count + tn_count + fn_count + fp_count) 41 | self.precision = (tp_count) / (tp_count + fp_count) 42 | self.recall = tp_count / (tp_count + fn_count) 43 | 44 | self.loss = loss.cpu().numpy() 45 | 46 | 47 | class AverageMeter(object): 48 | def __init__(self): 49 | self.reset() 50 | 51 | def reset(self): 52 | self.count = 0.0 53 | 54 | self.sum_accuracy = 0 55 | self.sum_precision = 0 56 | self.sum_recall = 0 57 | self.sum_loss = 0 58 | 59 | def update(self, result, n=1): 60 | self.count += n 61 | 62 | self.sum_accuracy += n * result.accuracy 63 | self.sum_precision += n * result.precision 64 | self.sum_recall += n * result.recall 65 | self.sum_loss += n * result.loss 66 | 67 | def average(self): 68 | avg = Result() 69 | avg.update( 70 | self.sum_accuracy / self.count, 71 | self.sum_precision / self.count, 72 | self.sum_recall / self.count, 73 | self.sum_loss / self.count 74 | ) 75 | 76 | return avg 77 | -------------------------------------------------------------------------------- /evaluation/criteria_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | # Define the smoothness loss 8 | class SmoothnessLoss(nn.Module): 9 | def __init__(self): 10 | super(SmoothnessLoss, self).__init__() 11 | 12 | def forward(self, pred_depth, image): 13 | # Normalize the depth with mean 14 | depth_mean = pred_depth.mean(2, True).mean(3, True) 15 | pred_depth_normalized = pred_depth / (depth_mean + 1e-7) 16 | 17 | # Compute the gradient of depth 18 | grad_depth_x = torch.abs(pred_depth_normalized[:, :, :, :-1] - pred_depth_normalized[:, :, :, 1:]) 19 | grad_depth_y = torch.abs(pred_depth_normalized[:, :, :-1, :] - pred_depth_normalized[:, :, 1:, :]) 20 | 21 | # Compute the gradient of the image 22 | grad_image_x = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:]), 1, keepdim=True) 23 | grad_image_y = torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]), 1, keepdim=True) 24 | 25 | grad_depth_x *= torch.exp(-grad_image_x) 26 | grad_depth_y *= torch.exp(-grad_image_y) 27 | 28 | return grad_depth_x.mean() + grad_depth_y.mean() 29 | 30 | 31 | class MaskedMSELoss(nn.Module): 32 | def __init__(self): 33 | super(MaskedMSELoss, self).__init__() 34 | 35 | def forward(self, pred, target): 36 | assert pred.dim() == target.dim(), "inconsistent dimensions" 37 | valid_mask = (target>0).detach() 38 | diff = target - pred 39 | diff = diff[valid_mask] 40 | self.loss = (diff ** 2).mean() 41 | return self.loss 42 | 43 | 44 | class MaskedL1Loss(nn.Module): 45 | def __init__(self): 46 | super(MaskedL1Loss, self).__init__() 47 | 48 | def forward(self, pred, target): 49 | assert pred.dim() == target.dim(), "inconsistent dimensions" 50 | valid_mask = (target>0).detach() 51 | diff = target - pred 52 | diff = diff[valid_mask] 53 | self.loss = diff.abs().mean() 54 | return self.loss 55 | 56 | 57 | class MaskedBerHuLoss(nn.Module): 58 | def __init__(self, thresh=0.2): 59 | super(MaskedBerHuLoss, self).__init__() 60 | self.thresh = thresh 61 | 62 | def forward(self, pred, target): 63 | assert pred.dim() == target.dim(), "inconsistent dimensions" 64 | valid_mask = (target > 0).detach() 65 | 66 | # Mask out the content 67 | pred = pred[valid_mask] 68 | target = target[valid_mask] 69 | 70 | # ipdb.set_trace() 71 | diff = torch.abs(target - pred) 72 | delta = self.thresh * torch.max(diff).item() 73 | 74 | part1 = - torch.nn.functional.threshold(-diff, -delta, 0.) 75 | part2 = torch.nn.functional.threshold(diff ** 2 - delta ** 2, 0., -delta**2.) + delta ** 2 76 | part2 = part2 / (2. * delta) 77 | 78 | loss = part1 + part2 79 | loss = torch.mean(loss) 80 | 81 | return loss 82 | 83 | 84 | class MaskedCrossEntropyLoss(nn.Module): 85 | def __init__(self): 86 | super(MaskedCrossEntropyLoss, self).__init__() 87 | self.cross_entropy = nn.CrossEntropyLoss().cuda() 88 | 89 | def forward(self, pred, target, mask): 90 | # pdb.set_trace() 91 | mask = torch.squeeze(mask.to(torch.bool), dim=-1) 92 | masked_pred = pred[mask, :] 93 | masked_target = torch.squeeze(target[mask, :]).to(torch.long) 94 | 95 | return self.cross_entropy(masked_pred, masked_target) 96 | -------------------------------------------------------------------------------- /dataset/dense_to_sparse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DenseToSparse: 5 | def __init__(self): 6 | pass 7 | 8 | def dense_to_sparse(self, *args): 9 | pass 10 | 11 | def __repr__(self): 12 | pass 13 | 14 | 15 | class UniformSampling(DenseToSparse): 16 | name = "uar" 17 | def __init__(self, num_samples, max_depth=np.inf): 18 | DenseToSparse.__init__(self) 19 | self.num_samples = num_samples 20 | self.max_depth = max_depth 21 | 22 | def __repr__(self): 23 | return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth) 24 | 25 | def dense_to_sparse(self, depth): 26 | """ 27 | Samples pixels with `num_samples`/#pixels probability in `depth`. 28 | Only pixels with a maximum depth of `max_depth` are considered. 29 | If no `max_depth` is given, samples in all pixels 30 | """ 31 | mask_keep = depth > 0 32 | if self.max_depth is not np.inf: 33 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth) 34 | n_keep = np.count_nonzero(mask_keep) 35 | if n_keep == 0: 36 | return mask_keep 37 | else: 38 | prob = float(self.num_samples) / n_keep 39 | return np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob) 40 | 41 | 42 | class LidarRadarSampling(DenseToSparse): 43 | name = "lidar_radar" 44 | def __init__(self, num_samples, max_depth=np.inf): 45 | DenseToSparse.__init__(self) 46 | self.num_samples = num_samples 47 | self.max_depth = max_depth 48 | 49 | def __repr__(self): 50 | return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth) 51 | 52 | def dense_to_sparse(self, lidar_depth, radar_depth): 53 | """ 54 | Samples pixels with `num_samples`/#pixels probability in `depth`. 55 | Only pixels with a maximum depth of `max_depth` are considered. 56 | If no `max_depth` is given, samples in all pixels 57 | """ 58 | # Convert to numpy array first 59 | lidar_depth = np.squeeze(lidar_depth.cpu().numpy().transpose(1, 2, 0)) 60 | radar_depth = np.squeeze(radar_depth.cpu().numpy().transpose(1, 2, 0)) 61 | 62 | # h, w, _ = lidar_depth.shape 63 | # h_lin = np.linspace(0, h-1, h) 64 | # w_lin = np.linspace(0, w-1, w) 65 | # h_grid, w_grid = np.meshgrid(h_lin, w_lin, indexing="ij") 66 | # coord_map = np.concatenate((h_grid[..., None], w_grid[..., None]), axis=-1) 67 | 68 | # Find the locations radar > 0 69 | radar_coord_tmp = np.where(radar_depth > 0) 70 | lidar_coord_tmp = np.where(lidar_depth > 0) 71 | 72 | # Concatenate to coordinate map 73 | radar_coord = np.concatenate((radar_coord_tmp[0][..., None], radar_coord_tmp[1][..., None]), axis=-1) 74 | lidar_coord = np.concatenate((lidar_coord_tmp[0][..., None], lidar_coord_tmp[1][..., None]), axis=-1) 75 | 76 | radar_expand = np.expand_dims(radar_coord, axis=1) 77 | lidar_expand = np.expand_dims(lidar_coord, axis=0) 78 | 79 | # Compute the pair-wise distance => (100 v.s. 3000) 80 | dist = np.sqrt(np.sum((radar_expand - lidar_expand) ** 2, axis=-1)) 81 | 82 | # Get the top 2 nearest points and get unique 83 | lidar_candidates = np.unique(np.argsort(dist, axis=-1)[:, :2]) 84 | mask = lidar_coord[lidar_candidates, :] 85 | output_mask = np.zeros(lidar_depth.shape) 86 | output_mask[mask[:, 0], mask[:, 1]] = 1 87 | 88 | return output_mask -------------------------------------------------------------------------------- /dataset/nuscenes_export.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export the datapoint to disk 3 | """ 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | # Add system path for fast debugging 9 | import sys 10 | sys.path.append("../") 11 | 12 | import numpy as np 13 | from dataset.nuscenes_dataset import Nuscenes_dataset 14 | from config.config_nuscenes import config_nuscenes as cfg 15 | from dataset import transforms as transforms 16 | import h5py 17 | import os 18 | from tqdm import tqdm 19 | import ipdb 20 | 21 | 22 | def create_h5_file(output_filename, data, save_keys=None): 23 | # Compress depth maps 24 | if "lidar_depth" in save_keys: 25 | data["lidar_depth"] = (data["lidar_depth"] * 256).astype(np.int16) 26 | if "radar_depth" in save_keys: 27 | data["radar_depth"] = (data["radar_depth"] * 256).astype(np.int16) 28 | 29 | # Create file objects 30 | with h5py.File(output_filename, "w") as f: 31 | # Iterate through all the key value pairs in the object 32 | for key, value in data.items(): 33 | if key in save_keys: 34 | f.create_dataset(name=key, 35 | shape=value.shape, 36 | dtype=value.dtype, 37 | data=value) 38 | 39 | 40 | def parse_h5_file(file_path): 41 | # Check if file exists 42 | if not os.path.exists(file_path): 43 | raise ValueError("[Error] File does not exist.") 44 | 45 | # Read file 46 | output_dict = {} 47 | with h5py.File(file_path, "r") as f: 48 | for key_name in f: 49 | output_dict[key_name] = np.array(f[key_name]) 50 | 51 | # Decompress depth 52 | output_dict["lidar_depth"] = output_dict["lidar_depth"] / 256. 53 | output_dict["radar_depth"] = output_dict["radar_depth"] / 256. 54 | 55 | return output_dict 56 | 57 | 58 | # Export dataset according to mode 59 | def export_dataset(datapoints, dataset, mode="train"): 60 | # Define the export path 61 | # export_path = os.path.join(cfg.EXPORT_ROOT, cfg.export_name, mode) 62 | export_path = os.path.join(cfg.EXPORT_ROOT, cfg.export_name + "_radar_only", mode) 63 | 64 | ########################### 65 | ## Add radar only option ## 66 | ########################### 67 | # export_path = export_path + "_radar_only" 68 | save_keys = ['radar_depth', 'radar_depth_points', 'radar_points', 'radar_raw_points'] 69 | ########################### 70 | 71 | if not os.path.exists(export_path): 72 | os.makedirs(export_path) 73 | 74 | # Iterate through all the datapoints 75 | for i in tqdm(range(len(datapoints)), ascii=True): 76 | datapoint = datapoints[i] 77 | 78 | # Get the index and orientation 79 | index = datapoint[0] 80 | orientation = datapoint[1] 81 | 82 | # Get output filename 83 | output_filename = "%07d_%s.h5" % (index, orientation) 84 | output_filename = os.path.join(export_path, output_filename) 85 | 86 | # Get data 87 | data = dataset.get_data(datapoint, mode=mode) 88 | 89 | # ipdb.set_trace() 90 | 91 | # Export the data 92 | create_h5_file(output_filename, data, save_keys=save_keys) 93 | 94 | # print("[ %d / %d ]"%(i, len(datapoints))) 95 | 96 | 97 | if __name__ == "__main__": 98 | # Initialize dataset object 99 | dataset = Nuscenes_dataset(mode=cfg.dataset_mode) 100 | 101 | # Get all train / val samples 102 | train_data_points = dataset.get_datapoints("train") 103 | val_data_points = dataset.get_datapoints("val") 104 | 105 | # Export training set 106 | export_dataset(train_data_points, dataset, "train") 107 | 108 | # Export testing set 109 | export_dataset(val_data_points, dataset, "val") -------------------------------------------------------------------------------- /misc/devkit/readme.txt: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # THE KITTI VISION BENCHMARK: DEPTH PREDICTION/COMPLETION BENCHMARKS 2017 # 3 | # based on our publication Sparsity Invariant CNNs (3DV 2017) # 4 | # # 5 | # Jonas Uhrig Nick Schneider Lukas Schneider # 6 | # Uwe Franke Thomas Brox Andreas Geiger # 7 | # # 8 | # Daimler R&D Sindelfingen University of Freiburg # 9 | # KIT Karlsruhe ETH Zürich MPI Tübingen # 10 | # # 11 | ########################################################################### 12 | 13 | This file describes the 2017 KITTI depth completion and single image depth 14 | prediction benchmarks, consisting of 93k training and 1.5k test images. 15 | Ground truth has been acquired by accumulating 3D point clouds from a 16 | 360 degree Velodyne HDL-64 Laserscanner and a consistency check using 17 | stereo camera pairs. Please have a look at our publications for details. 18 | 19 | Dataset description: 20 | ==================== 21 | 22 | If you unzip all downloaded files from the KITTI vision benchmark website 23 | into the same base directory, your folder structure will look like this: 24 | 25 | |-- devkit 26 | |-- test_depth_completion_anonymous 27 | |-- image 28 | |-- 0000000000.png 29 | |-- ... 30 | |-- 0000000999.png 31 | |-- velodyne_raw 32 | |-- 0000000000.png 33 | |-- ... 34 | |-- 0000000999.png 35 | |-- test_depth_prediction_anonymous 36 | |-- image 37 | |-- 0000000000.png 38 | |-- ... 39 | |-- 0000000999.png 40 | |-- train 41 | |-- 2011_xx_xx_drive_xxxx_sync 42 | |-- proj_depth 43 | |-- groundtruth # "groundtruth" describes our annotated depth maps 44 | |-- image_02 # image_02 is the depth map for the left camera 45 | |-- 0000000005.png # image IDs start at 5 because we accumulate 11 frames 46 | |-- ... # .. which is +-5 around the current frame ;) 47 | |-- image_03 # image_02 is the depth map for the right camera 48 | |-- 0000000005.png 49 | |-- ... 50 | |-- velodyne_raw # this contains projected and temporally unrolled 51 | |-- image_02 # raw Velodyne laser scans 52 | |-- 0000000005.png 53 | |-- ... 54 | |-- image_03 55 | |-- 0000000005.png 56 | |-- ... 57 | |-- ... (all drives of all days in the raw KITTI dataset) 58 | |-- val 59 | |-- (same as in train) 60 | |-- val_selection_cropped # 1000 images of size 1216x352, cropped and manually 61 | |-- groundtruth_depth # selected frames from from the full validation split 62 | |-- 2011_xx_xx_drive_xxxx_sync_groundtruth_depth_xxxxxxxxxx_image_0x.png 63 | |-- ... 64 | |-- image 65 | |-- 2011_xx_xx_drive_xxxx_sync_groundtruth_depth_xxxxxxxxxx_image_0x.png 66 | |-- ... 67 | |-- velodyne_raw 68 | |-- 2011_xx_xx_drive_xxxx_sync_groundtruth_depth_xxxxxxxxxx_image_0x.png 69 | |-- ... 70 | 71 | For train and val splits, the mapping from the KITTI raw dataset to our 72 | generated depth maps and projected raw laser scans can be extracted. All 73 | files are uniquely identified by their recording date, the drive ID as well 74 | as the camera ID (02 for left, 03 for right camera). 75 | 76 | Submission instructions: 77 | ======================== 78 | 79 | NOTE: WHEN SUBMITTING RESULTS, PLEASE STORE THEM IN THE SAME DATA FORMAT IN 80 | WHICH THE GROUND TRUTH DATA IS PROVIDED (SEE BELOW), USING THE FILE NAMES 81 | 0000000000.png TO 0000000999.png (DEPTH COMPLETION) OR 0000000499.png (DEPTH 82 | PREDICTION). CREATE A ZIP ARCHIVE OF THEM AND STORE YOUR RESULTS IN YOUR 83 | ZIP'S ROOT FOLDER: 84 | 85 | |-- zip 86 | |-- 0000000000.png 87 | |-- ... 88 | |-- 0000000999.png 89 | 90 | Data format: 91 | ============ 92 | 93 | Depth maps (annotated and raw Velodyne scans) are saved as uint16 PNG images, 94 | which can be opened with either MATLAB, libpng++ or the latest version of 95 | Python's pillow (from PIL import Image). A 0 value indicates an invalid pixel 96 | (ie, no ground truth exists, or the estimation algorithm didn't produce an 97 | estimate for that pixel). Otherwise, the depth for a pixel can be computed 98 | in meters by converting the uint16 value to float and dividing it by 256.0: 99 | 100 | disp(u,v) = ((float)I(u,v))/256.0; 101 | valid(u,v) = I(u,v)>0; 102 | 103 | Evaluation Code: 104 | ================ 105 | 106 | For transparency we have included the benchmark evaluation code in the 107 | sub-folder 'cpp' of this development kit. It can be compiled by running 108 | the 'make.sh' script. Run it using two arguments: 109 | 110 | ./evaluate_depth gt_dir prediction_dir 111 | 112 | Note that gt_dir is most likely '../../val_selection_cropped/groundtruth_depth' 113 | if you unzipped all files in the same base directory. We also included a sample 114 | result of our proposed approach for the validation split ('predictions/sparseConv_val'). 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Depth Estimation from Monocular Images and Sparse Radar Data 2 | 3 | This is the official implementation of the paper [Depth Estimation from Monocular Images and Sparse Radar Data](https://arxiv.org/abs/2010.00058). In this repo, we provide code for dataset preprocessing, training, and evaluation. 4 | 5 | Some parts of the implementation are adapted from [sparse-to-dense](https://github.com/fangchangma/sparse-to-dense.pytorch). We thank the authors for sharing their implementation. 6 | 7 | ## Updates 8 | 9 | - [x] Training and evaluation code. 10 | 11 | - [x] Trained models. 12 | 13 | - [x] Download instructions for the processed dataset. 14 | 15 | - [ ] Detailed documentation for the processed dataset. 16 | 17 | - [ ] Code and instructions to process data from the official nuScenes dataset. 18 | 19 | ## Installation 20 | 21 | ```bash 22 | git clone https://github.com/brade31919/radar_depth.git 23 | cd radar_depth 24 | ``` 25 | 26 | ### Dataset preparation 27 | 28 | #### Use our processed files 29 | 30 | We provide our processed files specifically for the RGB + Radar depth estimation task. The download and setup instructions are: 31 | 32 | ```bash 33 | mkdir DATASET_PATH # Set the path you want to use on your own PC/cluster. 34 | cd DATASET_PATH 35 | wget https://data.vision.ee.ethz.ch/daid/NuscenesRadar/Nuscenes_depth.tar.gz 36 | tar -zxcf Nuscenes_depth.tar.gz 37 | ``` 38 | 39 | ⚠️ Since the processed dataset is an adapted material (non-commercial purpose) from the official [nuScenes dataset](https://www.nuscenes.org/), the contents in the processed dataset are also subject to the [official terms of use](https://www.nuscenes.org/terms-of-use) and the [licenses](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 40 | 41 | ### Package installation 42 | 43 | ```bash 44 | cd radar_depth # Go back to the project root 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | If you encounter error message like "ImportError: libSM.so.6: cannot open shared object file: No such file or directory" from cv2, you can try: 49 | 50 | ```bash 51 | sudo apt-get install libsm6 libxrender1 libfontconfig1 52 | ``` 53 | 54 | ### Project configuration setting 55 | 56 | we put important path setting in config/config_nuscenes.py. You need to modify them to the paths you use on your own PC/cluster. 57 | 58 | #### Project and dataset root setting 59 | 60 | In line 14 and 18, please specify your PROJECT_ROOT and DATASET_ROOT 61 | 62 | ```python 63 | PROJECT_ROOT = "YOUR_PATH/radar_depth" 64 | DATASET_ROOT = "DATASET_PATH" 65 | ``` 66 | 67 | #### Experiment path setting 68 | 69 | In line 53, please specify your EXPORT_PATH (the path you want to put our processed dataset). 70 | 71 | ```python 72 | EXPORT_ROOT = "YOUR_EXP_PATH" 73 | ``` 74 | 75 | ### Training 76 | 77 | #### Downlaod the pre-trained models 78 | 79 | We provide some pretrained models. They are not the original models used to produce the numbers on the paper but they have similar performances (I lost the original checkpoints due to some cluster issue...). 80 | 81 | Please download the pretrained models from [here](https://drive.google.com/drive/folders/1QDXIZmfEbwzoOjl8KoPZwGyN2JiD7-pg?usp=sharing), and put them to pretrained/ folder so that the directory structue looks like this: 82 | 83 | ```bash 84 | pretrained/ 85 | ├── resnet18_latefusion.pth.tar 86 | └── resnet18_multistage.pth.tar 87 | ``` 88 | 89 | #### Train the late fusion model yourself 90 | 91 | ```bash 92 | python main.py \ 93 | --arch resnet18_latefusion \ 94 | --data nuscenes \ 95 | --modality rgbd \ 96 | --decoder upproj \ 97 | -j 12 \ 98 | --epochs 20 \ 99 | -b 16 \ 100 | --max-depth 80 \ 101 | --sparsifier radar 102 | ``` 103 | 104 | #### Train the full multi-stage model 105 | 106 | To make sure that the training process is stable, we'll initialize each stage from the reset18_latefusion model. If you want to skip the trainig of resnet18_latefusion, you can use our pre-trained models. 107 | 108 | ```bash 109 | python main.py \ 110 | --arch resnet18_multistage_uncertainty_fixs \ 111 | --data nuscenes \ 112 | --modality rgbd \ 113 | --decoder upproj \ 114 | -j 12 \ 115 | --epochs 20 \ 116 | -b 8 \ 117 | --max-depth 80 \ 118 | --sparsifier radar 119 | ``` 120 | 121 | Here we use batch size 8 (instead of 16). This allows us to train the model on cheaper GPU models such as GTX1080Ti, GTX2080Ti, etc., and the training process is more stable. 122 | 123 | ### Evaluation 124 | 125 | After the training process finished, you can evaluate the model by (replace the PATH_TO_CHECKPOINT with the path to checkpoint file you want to evaluate): 126 | 127 | ```bash 128 | python main.py \ 129 | --evaluate PATH_TO_CHECKPOINT \ 130 | --data nuscenes 131 | ``` 132 | 133 | ## Code Borrowed From 134 | 135 | * [sparse-to-dense](https://github.com/fangchangma/sparse-to-dense.pytorch) 136 | 137 | * [nuscenes-devkit](https://github.com/nutonomy/nuscenes-devkit) 138 | 139 | * [KITTI-devkit](https://github.com/joseph-zhong/KITTI-devkit) 140 | 141 | ## Citation 142 | 143 | Please use the following citation format if you want to reference to our paper. 144 | 145 | ``` 146 | @InProceedings{radar:depth:20, 147 | author = {Lin, Juan-Ting and Dai, Dengxin and {Van Gool}, Luc}, 148 | title = {Depth Estimation from Monocular Images and Sparse Radar Data}, 149 | booktitle = {International Conference on Intelligent Robots and Systems (IROS)}, 150 | year = {2020} 151 | } 152 | ``` 153 | 154 | If you use the processed dataset, remember to cite the offical nuScenes dataset. 155 | 156 | ``` 157 | @article{nuscenes2019, 158 | title={nuScenes: A multimodal dataset for autonomous driving}, 159 | author={Holger Caesar and Varun Bankiti and Alex H. Lang and Sourabh Vora and 160 | Venice Erin Liong and Qiang Xu and Anush Krishnan and Yu Pan and 161 | Giancarlo Baldan and Oscar Beijbom}, 162 | journal={arXiv preprint arXiv:1903.11027}, 163 | year={2019} 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is adapted from https://github.com/fangchangma/sparse-to-dense.pytorch 3 | """ 4 | import os 5 | import torch 6 | import shutil 7 | import argparse 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | from model.models import Decoder 12 | 13 | cmap = plt.cm.viridis 14 | 15 | 16 | def parse_command(): 17 | # Define some constants 18 | model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet18_new', 'resnet18_latefusion', 'resnet18_multistage', 19 | 'resnet18_multistage_uncertainty', 'resnet18_multistage_uncertainty_fixs'] 20 | loss_names = ['l1', 'l2'] 21 | data_names = ['nuscenes'] 22 | sparsifier_names = ["uniform", "lidar_radar", "radar", "radar_filtered", "radar_filtered2"] 23 | decoder_names = Decoder.names 24 | modality_names = ['rgb', 'rgbd', 'd'] 25 | 26 | parser = argparse.ArgumentParser(description='Sparse-to-Dense') 27 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', choices=model_names, 28 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') 29 | parser.add_argument('--data', metavar='DATA', default='nyudepthv2', 30 | choices=data_names, 31 | help='dataset: ' + ' | '.join(data_names) + ' (default: nyudepthv2)') 32 | parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgb', choices=modality_names, 33 | help='modality: ' + ' | '.join(modality_names) + ' (default: rgb)') 34 | parser.add_argument('-s', '--num-samples', default=0, type=int, metavar='N', 35 | help='number of sparse depth samples (default: 0)') 36 | parser.add_argument('--max-depth', default=-1.0, type=float, metavar='D', 37 | help='cut-off depth of sparsifier, negative values means infinity (default: inf [m])') 38 | parser.add_argument('--sparsifier', metavar='SPARSIFIER', default=None, choices=sparsifier_names, 39 | help='sparsifier: ' + ' | '.join(sparsifier_names) + ' (default: None)') 40 | parser.add_argument('--decoder', '-d', metavar='DECODER', default='deconv2', choices=decoder_names, 41 | help='decoder: ' + ' | '.join(decoder_names) + ' (default: deconv2)') 42 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 43 | help='number of data loading workers (default: 10)') 44 | parser.add_argument('--epochs', default=15, type=int, metavar='N', 45 | help='number of total epochs to run (default: 15)') 46 | parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1', choices=loss_names, 47 | help='loss function: ' + ' | '.join(loss_names) + ' (default: l1)') 48 | parser.add_argument('-b', '--batch-size', default=8, type=int, help='mini-batch size (default: 8)') 49 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 50 | metavar='LR', help='initial learning rate (default 0.01)') 51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 52 | help='momentum') 53 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 54 | metavar='W', help='weight decay (default: 1e-4)') 55 | parser.add_argument('--print-freq', '-p', default=50, type=int, 56 | metavar='N', help='print frequency (default: 10)') 57 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 58 | help='path to latest checkpoint (default: none)') 59 | parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, default='', 60 | help='evaluate model on validation set') 61 | parser.add_argument('--no-pretrain', dest='pretrained', action='store_false', 62 | help='not to use ImageNet pre-trained weights') 63 | parser.add_argument('--no-validation', dest="validation", action='store_false', 64 | help="not do the evaluation during the training.") 65 | parser.set_defaults(pretrained=True) 66 | parser.set_defaults(validation=True) 67 | args = parser.parse_args() 68 | if args.modality == 'rgb' and args.num_samples != 0: 69 | print("number of samples is forced to be 0 when input modality is rgb") 70 | args.num_samples = 0 71 | if args.modality == 'rgb' and args.max_depth != 0.0: 72 | print("max depth is forced to be 0.0 when input modality is rgb/rgbd") 73 | args.max_depth = 0.0 74 | return args 75 | 76 | 77 | def save_checkpoint(state, is_best, epoch, output_directory): 78 | checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar') 79 | torch.save(state, checkpoint_filename) 80 | if is_best: 81 | best_filename = os.path.join(output_directory, 'model_best.pth.tar') 82 | shutil.copyfile(checkpoint_filename, best_filename) 83 | 84 | 85 | def adjust_learning_rate(optimizer, epoch, lr_init): 86 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs""" 87 | lr = lr_init * (0.1 ** (epoch // 5)) 88 | for param_group in optimizer.param_groups: 89 | param_group['lr'] = lr 90 | 91 | 92 | # Explicitly define the learning rate schedule. 93 | def adjust_learning_rate_new(optimizer, epoch, lr_init): 94 | if epoch <= 6: 95 | lr = lr_init 96 | elif (epoch >6) and (epoch <= 15): 97 | lr = 0.1 * lr_init 98 | else: 99 | lr = 0.01 * lr_init 100 | 101 | for param_group in optimizer.param_groups: 102 | param_group['lr'] = lr 103 | 104 | 105 | def get_output_directory(args): 106 | output_directory = os.path.join('results', 107 | '{}.sparsifier={}.samples={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}.pretrained={}'. 108 | format(args.data, args.sparsifier, args.num_samples, args.modality, \ 109 | args.arch, args.decoder, args.criterion, args.lr, args.batch_size, \ 110 | args.pretrained)) 111 | return output_directory 112 | 113 | 114 | def colored_depthmap(depth, d_min=None, d_max=None): 115 | if d_min is None: 116 | d_min = np.min(depth) 117 | if d_max is None: 118 | d_max = np.max(depth) 119 | depth_relative = (depth - d_min) / (d_max - d_min) 120 | return 255 * cmap(depth_relative)[:,:,:3] # H, W, C 121 | 122 | 123 | def merge_into_row(input, depth_target, depth_pred): 124 | # ipdb.set_trace() 125 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C 126 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 127 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 128 | 129 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) 130 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) 131 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 132 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 133 | img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) 134 | 135 | return img_merge 136 | 137 | 138 | def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred): 139 | # ipdb.set_trace() 140 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C 141 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) 142 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 143 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 144 | 145 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) 146 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu)) 147 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) 148 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 149 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 150 | 151 | img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col]) 152 | 153 | return img_merge 154 | 155 | 156 | def add_row(img_merge, row): 157 | return np.vstack([img_merge, row]) 158 | 159 | 160 | def save_image(img_merge, filename): 161 | img_merge = Image.fromarray(img_merge.astype('uint8')) 162 | img_merge.save(filename) -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | 6 | def log10(x): 7 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 8 | return torch.log(x) / math.log(10) 9 | 10 | 11 | # Object to record the evaluation results 12 | class Result(object): 13 | def __init__(self): 14 | self.irmse, self.imae = 0, 0 15 | self.mse, self.rmse, self.mae = 0, 0, 0 16 | self.absrel, self.lg10 = 0, 0 17 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 18 | self.data_time, self.gpu_time = 0, 0 19 | 20 | def set_to_worst(self): 21 | self.irmse, self.imae = np.inf, np.inf 22 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 23 | self.absrel, self.lg10 = np.inf, np.inf 24 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 25 | self.data_time, self.gpu_time = 0, 0 26 | 27 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 28 | self.irmse, self.imae = irmse, imae 29 | self.mse, self.rmse, self.mae = mse, rmse, mae 30 | self.absrel, self.lg10 = absrel, lg10 31 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 32 | self.data_time, self.gpu_time = data_time, gpu_time 33 | 34 | def evaluate(self, output, target): 35 | valid_mask = target>0 36 | output = output[valid_mask] 37 | target = target[valid_mask] 38 | 39 | abs_diff = (output - target).abs() 40 | 41 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 42 | self.rmse = math.sqrt(self.mse) 43 | self.mae = float(abs_diff.mean()) 44 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 45 | self.absrel = float((abs_diff / target).mean()) 46 | 47 | maxRatio = torch.max(output / target, target / output) 48 | self.delta1 = float((maxRatio < 1.25).float().mean()) 49 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 50 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 51 | self.data_time = 0 52 | self.gpu_time = 0 53 | 54 | inv_output = 1 / output 55 | inv_target = 1 / target 56 | abs_inv_diff = (inv_output - inv_target).abs() 57 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 58 | self.imae = float(abs_inv_diff.mean()) 59 | 60 | 61 | # Object to record evaluation results in different distance intervals 62 | class Result_multidist(object): 63 | def __init__(self): 64 | # initialize the end points of the dist intervals 65 | self.dist_interval = [10., 20., 30., 40., 50., 60., 70., 80., 90., 100.] 66 | # Initialize result object of each interval 67 | self.result_lst = [Result() for i in range(len(self.dist_interval))] 68 | # Initialize valid label for every distance interval 69 | self.valid_label = [1 for _ in range(len(self.dist_interval))] 70 | 71 | # Set each result to worst 72 | def set_to_worst(self): 73 | for res in self.result_lst: 74 | res.set_to_worst() 75 | 76 | # Update the results given another multi-distance result object 77 | def update(self, result): 78 | # Check if result is multidist object 79 | assert isinstance(result, Result_multidist) 80 | 81 | # Iterate through all dist interval to perform the update 82 | for idx, res in enumerate(result): 83 | self.result_lst[idx].update( 84 | res.irmse, res.imae, res.mse, 85 | res.rmse, res.mae, res.absrel, res.log10, 86 | res.delta1, res.delta2, res.delta3 87 | ) 88 | 89 | # Evaluate the results 90 | def evaluate(self, output, target): 91 | # Compute shared valid mask first 92 | valid_mask = target>0 93 | 94 | # Iterate through all the distance intervals and evaluate 95 | for idx, interval in enumerate(self.dist_interval): 96 | # First interval => min=0 97 | if idx == 0: 98 | dist_min = 0. 99 | dist_max = interval 100 | # Last interval => max=inf 101 | elif idx == len(self.dist_interval)-1: 102 | dist_min = self.dist_interval[idx - 1] 103 | dist_max = np.inf 104 | else: 105 | dist_min = self.dist_interval[idx - 1] 106 | dist_max = interval 107 | 108 | # Compute distance-aware valid mask 109 | # ToDo: Fix the corner case that no points lies in the interval. 110 | # ToDo: How to balance the point counts in different distance range. 111 | dist_valid_mask = (target >= dist_min) & (target <= dist_max) 112 | valid_mask_final = dist_valid_mask & valid_mask 113 | 114 | # change valid label to 0 if no point in the distance range 115 | if torch.sum(valid_mask_final) == 0: 116 | self.valid_label[idx] = 0 117 | 118 | output_masked = output[valid_mask_final] 119 | target_masked = target[valid_mask_final] 120 | 121 | abs_diff = (output_masked - target_masked).abs() 122 | 123 | self.result_lst[idx].mse = float((torch.pow(abs_diff, 2)).mean()) 124 | self.result_lst[idx].rmse = math.sqrt(self.result_lst[idx].mse) 125 | self.result_lst[idx].mae = float(abs_diff.mean()) 126 | self.result_lst[idx].lg10 = float((log10(output_masked) - log10(target_masked)).abs().mean()) 127 | self.result_lst[idx].absrel = float((abs_diff / target_masked).mean()) 128 | 129 | maxRatio = torch.max(output_masked / target_masked, target_masked / output_masked) 130 | self.result_lst[idx].delta1 = float((maxRatio < 1.25).float().mean()) 131 | self.result_lst[idx].delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 132 | self.result_lst[idx].delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 133 | 134 | inv_output = 1 / output_masked 135 | inv_target = 1 / target_masked 136 | abs_inv_diff = (inv_output - inv_target).abs() 137 | self.result_lst[idx].irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 138 | self.result_lst[idx].imae = float(abs_inv_diff.mean()) 139 | 140 | 141 | class AverageMeter_multidist(object): 142 | def __init__(self): 143 | self.dist_interval = [10., 20., 30., 40., 50., 60., 70., 80., 90., 100.] 144 | self.avg_lst = [AverageMeter() for i in range(len(self.dist_interval))] 145 | self.reset() 146 | 147 | def reset(self): 148 | # Reset every average meter 149 | for avg in self.avg_lst: 150 | avg.reset() 151 | 152 | def update(self, result, n=1): 153 | # Fetch the result list 154 | assert isinstance(result, Result_multidist) 155 | assert self.dist_interval == result.dist_interval 156 | result_lst = result.result_lst 157 | 158 | for idx, res in enumerate(result_lst): 159 | # Skip the invalid distanve intervals 160 | if result.valid_label[idx] == 0: 161 | pass 162 | else: 163 | self.avg_lst[idx].update(res, n) 164 | 165 | def average(self): 166 | avg = Result_multidist() 167 | # Iterate through all the avg_lst and perform average 168 | for idx, avg_obj in enumerate(self.avg_lst): 169 | res = avg_obj.average() 170 | avg.result_lst[idx].update( 171 | res.irmse, res.imae, res.mse, 172 | res.rmse, res.mae, res.absrel, res.lg10, 173 | res.delta1, res.delta2, res.delta3 174 | ) 175 | 176 | return avg 177 | 178 | 179 | class AverageMeter(object): 180 | def __init__(self): 181 | self.reset() 182 | 183 | def reset(self): 184 | self.count = 0.0 185 | 186 | self.sum_irmse, self.sum_imae = 0, 0 187 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 188 | self.sum_absrel, self.sum_lg10 = 0, 0 189 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 190 | self.sum_data_time, self.sum_gpu_time = 0, 0 191 | 192 | def update(self, result, gpu_time, data_time, n=1): 193 | self.count += n 194 | 195 | self.sum_irmse += n*result.irmse 196 | self.sum_imae += n*result.imae 197 | self.sum_mse += n*result.mse 198 | self.sum_rmse += n*result.rmse 199 | self.sum_mae += n*result.mae 200 | self.sum_absrel += n*result.absrel 201 | self.sum_lg10 += n*result.lg10 202 | self.sum_delta1 += n*result.delta1 203 | self.sum_delta2 += n*result.delta2 204 | self.sum_delta3 += n*result.delta3 205 | self.sum_data_time += n*data_time 206 | self.sum_gpu_time += n*gpu_time 207 | 208 | def average(self): 209 | avg = Result() 210 | avg.update( 211 | self.sum_irmse / self.count, self.sum_imae / self.count, 212 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 213 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 214 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, 215 | self.sum_gpu_time / self.count, self.sum_data_time / self.count) 216 | return avg -------------------------------------------------------------------------------- /misc/devkit/cpp/io_depth.h: -------------------------------------------------------------------------------- 1 | /* 2 | I/O interface class for loading, storing and manipulating 3 | depth maps in KITTI format. This file requires libpng 4 | and libpng++ to be installed for accessing png files. More 5 | detailed format specifications can be found in the readme.txt 6 | 7 | (c) Andreas Geiger 8 | */ 9 | 10 | #ifndef IO_DEPTH_H 11 | #define IO_DEPTH_H 12 | 13 | #include 14 | #include 15 | #include 16 | #include "log_colormap.h" 17 | 18 | class DepthImage { 19 | 20 | public: 21 | 22 | // default constructor 23 | DepthImage () { 24 | data_ = 0; 25 | width_ = 0; 26 | height_ = 0; 27 | } 28 | 29 | // construct depth image from png file 30 | DepthImage (const std::string file_name) { 31 | readDepthMap(file_name); 32 | } 33 | 34 | // copy constructor 35 | DepthImage (const DepthImage &D) { 36 | width_ = D.width_; 37 | height_ = D.height_; 38 | data_ = (float*)malloc(width_*height_*sizeof(float)); 39 | memcpy(data_,D.data_,width_*height_*sizeof(float)); 40 | } 41 | 42 | // construct depth image from data 43 | DepthImage (const float* data, const int32_t width, const int32_t height) : width_(width), height_(height) { 44 | data_ = (float*)malloc(width*height*sizeof(float)); 45 | memcpy(data_,data,width*height*sizeof(float)); 46 | } 47 | 48 | // construct empty (= all pixels invalid) depth map of given width / height 49 | DepthImage (const int32_t width, const int32_t height) : width_(width), height_(height) { 50 | data_ = (float*)malloc(width*height*sizeof(float)); 51 | for (int32_t i=0; i=0; 101 | } 102 | 103 | // set depth at given pixel 104 | inline void setDepth (const int32_t u,const int32_t v,const float val) { 105 | data_[v*width_+u] = val; 106 | } 107 | 108 | // is depth at given pixel to invalid 109 | inline bool setInvalid (const int32_t u,const int32_t v) { 110 | data_[v*width_+u] = -1; 111 | } 112 | 113 | // get maximal depth 114 | float maxDepth () { 115 | float max_depth = -1; 116 | for (int32_t i=0; imax_depth) 118 | max_depth = data_[i]; 119 | return max_depth; 120 | } 121 | 122 | // simple arithmetic operations 123 | DepthImage operator+ (const DepthImage &B) { 124 | DepthImage C(*this); 125 | for (int32_t i=0; i=1) { 159 | 160 | // first and last value for interpolation 161 | int32_t u1 = u-count; 162 | int32_t u2 = u-1; 163 | 164 | // set pixel to min depth 165 | if (u1>0 && u2=0; u--) { 192 | if (isValid(u,v)) { 193 | for (int32_t u2=u+1; u2<=width_-1; u2++) 194 | setDepth(u2,v,getDepth(u,v)); 195 | break; 196 | } 197 | } 198 | } 199 | 200 | // for each column do 201 | for (int32_t u=0; u=0; v--) { 214 | if (isValid(u,v)) { 215 | for (int32_t v2=v+1; v2<=height_-1; v2++) 216 | setDepth(u,v2,getDepth(u,v)); 217 | break; 218 | } 219 | } 220 | } 221 | } 222 | 223 | // compute error map of current image, given the 224 | // ground truth depth maps. stores result as color png image. 225 | png::image errorImage (DepthImage &D_gt,bool log_colors=false) { 226 | png::image image(width(),height()); 227 | for (int32_t v=1; v=LC[i][0] && n_err image(file_name); 271 | width_ = image.get_width(); 272 | height_ = image.get_height(); 273 | data_ = (float*)malloc(width_*height_*sizeof(float)); 274 | for (int32_t v=0; v image(width_,height_); 285 | for (int32_t v=0; v image(width_,height_); 313 | 314 | // for all pixels do 315 | for (int32_t v=0; v invalid, 1 => valid, 2 => unknown 61 | valid_labels = np.zeros([point_count, 1]) 62 | # Iterate through all the radar points 63 | for i in range(point_count): 64 | # ipdb.set_trace() 65 | depth_thresh_new = sid_depth_thresh(depth_value[i, :dist_valid_count[i, 0]]) 66 | depth_valid_count = np.sum((depth_dist[i, :dist_valid_count[i, 0]] < depth_thresh_new).astype(np.int16)) 67 | # ipdb.set_trace() 68 | if dist_valid_count[i, 0] == 0: 69 | valid_labels[i, 0] = 2 70 | elif depth_valid_count >= np.ceil(dist_valid_count[i, 0] / 2): 71 | valid_labels[i, 0] = 1 72 | 73 | return valid_labels 74 | 75 | 76 | # Filter radar points using the groundtruth LiDAR points 77 | def filter_radar_points_gt(radar_points, radar_depth_points, lidar_points, lidar_depth_points): 78 | # Find k nearest neighbors whithin distance threshold 79 | k = 3 80 | dist_thresh = 10 81 | 82 | # Fetch only the x, y coord 83 | radar_points = radar_points[:2, :].transpose(1, 0) 84 | lidar_points = lidar_points[:2, :].transpose(1, 0) 85 | 86 | # Mask out points > 80m 87 | # radar_mask = radar_depth_points < 80. 88 | # radar_depth_points = radar_depth_points[radar_mask] 89 | # radar_points = radar_points[radar_mask, :] 90 | 91 | radar_points_exp = np.expand_dims(radar_points, axis=1) 92 | lidar_points_exp = np.expand_dims(lidar_points, axis=0) 93 | 94 | dist = np.sqrt(np.sum((radar_points_exp - lidar_points_exp) ** 2, axis=-1)) 95 | 96 | # Fetch the topk index 97 | dist_topk_index = np.argsort(dist)[:, :k][..., None] 98 | # Get dist topk value 99 | dist_topk_val = np.sort(dist, axis=-1)[:, :k] 100 | # Get depth topk depth value 101 | depth_topk_val = np.squeeze(select_topk_depth(lidar_depth_points, dist_topk_index)) 102 | 103 | # Get depth-aware dist thresh 104 | dist_thresh_sid = sid_dist_thresh(depth_topk_val) 105 | dist_valid_count = np.sum((dist_topk_val <= dist_thresh_sid).astype(np.int16), axis=-1)[..., None] 106 | 107 | # print(sid_dist_thresh(depth_topk_val)) 108 | depth_dist = radar_depth_points[..., None] - depth_topk_val 109 | valid_labels = check_valid_depth(depth_dist, dist_valid_count, depth_topk_val) 110 | 111 | # ipdb.set_trace() 112 | # Perform the filtering 113 | valid_mask_final = np.squeeze(valid_labels > 0) 114 | radar_points_filtered = radar_points[valid_mask_final, :].transpose(1, 0) 115 | radar_depth_points_filtered = radar_depth_points[valid_mask_final] 116 | 117 | return { 118 | 'valid_labels': valid_labels, 119 | 'valid_mask': valid_mask_final, 120 | 'radar_points': radar_points_filtered, 121 | 'radar_depth': radar_depth_points_filtered 122 | } 123 | 124 | 125 | # Filter radar points using the groundtruth LiDAR points 126 | def filter_radar_points_analysis(radar_points, radar_depth_points, lidar_points, lidar_depth_points): 127 | # Find k nearest neighbors whithin distance threshold 128 | k = 3 129 | dist_thresh = 10 130 | 131 | # Fetch only the x, y coord 132 | radar_points = radar_points[:2, :].transpose(1, 0) 133 | lidar_points = lidar_points[:2, :].transpose(1, 0) 134 | 135 | radar_points_exp = np.expand_dims(radar_points, axis=1) 136 | lidar_points_exp = np.expand_dims(lidar_points, axis=0) 137 | 138 | dist = np.sqrt(np.sum((radar_points_exp - lidar_points_exp) ** 2, axis=-1)) 139 | 140 | # Fetch the topk index 141 | dist_topk_index = np.argsort(dist)[:, :k][..., None] 142 | # Get dist topk value 143 | dist_topk_val = np.sort(dist, axis=-1)[:, :k] 144 | # Get depth topk depth value 145 | depth_topk_val = np.squeeze(select_topk_depth(lidar_depth_points, dist_topk_index)) 146 | 147 | # ipdb.set_trace() 148 | 149 | # Get depth-aware dist thresh 150 | dist_thresh_sid = sid_dist_thresh(depth_topk_val) 151 | dist_valid_count = np.sum((dist_topk_val <= dist_thresh_sid).astype(np.int16), axis=-1)[..., None] 152 | 153 | # print(sid_dist_thresh(depth_topk_val)) 154 | depth_dist = radar_depth_points[..., None] - depth_topk_val 155 | valid_labels = check_valid_depth(depth_dist, dist_valid_count, depth_topk_val) 156 | 157 | # ipdb.set_trace() 158 | # Perform the filtering 159 | valid_mask_final = np.squeeze(valid_labels > 0) 160 | radar_points_filtered = radar_points[valid_mask_final, :].transpose(1, 0) 161 | radar_depth_points_filtered = radar_depth_points[valid_mask_final] 162 | 163 | ########################### 164 | ## Compute some analysis ## 165 | ########################### 166 | # Compute inconsistencies using top-1 nearest neighbor 167 | # ipdb.set_trace() 168 | depth_top1_val = depth_topk_val[:, 0][..., None] 169 | depth_inconsist_raw = radar_depth_points[..., None] - depth_top1_val 170 | 171 | # depth_inconsist = depth_inconsist_raw[np.squeeze(valid_labels < 2), :] 172 | # depth_inconsist_filtered = depth_inconsist_raw[np.squeeze(valid_labels == 1), :] 173 | # ipdb.set_trace() 174 | 175 | return { 176 | 'valid_labels': valid_labels, 177 | 'valid_mask': valid_mask_final, 178 | 'radar_points': radar_points_filtered, 179 | 'radar_depth': radar_depth_points_filtered, 180 | 'depth_top1_val': depth_top1_val, 181 | 'depth_inconsist_raw': depth_inconsist_raw 182 | } 183 | 184 | 185 | # Filter radar points using the groundtruth LiDAR points 186 | def filter_radar_points_classify(radar_points, radar_depth_points, radar_raw_points, classifyer=None): 187 | # Find k nearest neighbors whithin distance threshold 188 | assert classifyer is not None 189 | 190 | # Fetch only the x, y coord 191 | radar_points = radar_points[:2, :].transpose(1, 0) 192 | lidar_points = lidar_points[:2, :].transpose(1, 0) 193 | 194 | # Mask out points > 80m 195 | # radar_mask = radar_depth_points < 80. 196 | # radar_depth_points = radar_depth_points[radar_mask] 197 | # radar_points = radar_points[radar_mask, :] 198 | 199 | radar_points_exp = np.expand_dims(radar_points, axis=1) 200 | lidar_points_exp = np.expand_dims(lidar_points, axis=0) 201 | 202 | dist = np.sqrt(np.sum((radar_points_exp - lidar_points_exp) ** 2, axis=-1)) 203 | 204 | # Fetch the topk index 205 | dist_topk_index = np.argsort(dist)[:, :k][..., None] 206 | # Get dist topk value 207 | dist_topk_val = np.sort(dist, axis=-1)[:, :k] 208 | # Get depth topk depth value 209 | depth_topk_val = np.squeeze(select_topk_depth(lidar_depth_points, dist_topk_index)) 210 | 211 | # Get depth-aware dist thresh 212 | dist_thresh_sid = sid_dist_thresh(depth_topk_val) 213 | dist_valid_count = np.sum((dist_topk_val <= dist_thresh_sid).astype(np.int16), axis=-1)[..., None] 214 | 215 | # print(sid_dist_thresh(depth_topk_val)) 216 | depth_dist = radar_depth_points[..., None] - depth_topk_val 217 | valid_labels = check_valid_depth(depth_dist, dist_valid_count, depth_topk_val) 218 | 219 | # ipdb.set_trace() 220 | # Perform the filtering 221 | valid_mask_final = np.squeeze(valid_labels > 0) 222 | radar_points_filtered = radar_points[valid_mask_final, :].transpose(1, 0) 223 | radar_depth_points_filtered = radar_depth_points[valid_mask_final] 224 | 225 | return { 226 | 'valid_labels': valid_labels, 227 | 'valid_mask': valid_mask_final, 228 | 'radar_points': radar_points_filtered, 229 | 'radar_depth': radar_depth_points_filtered 230 | } 231 | 232 | 233 | # Plot LiDAR depth 234 | def plot_lidar_depth(image, points, depth_points, vmin=0., vmax=80.): 235 | # Plot depth 236 | fig = plt.figure(figsize=(10,6)) 237 | ax = plt.gca() 238 | plt.imshow(image) 239 | im = plt.scatter(points[0, :], points[1, :], c=depth_points, s=3, vmin=vmin, vmax=vmax, cmap="jet") 240 | fig.colorbar(im, ax=ax, cmap="jet", pad=0.01) 241 | plt.show() 242 | 243 | 244 | # Plot Radar depth 245 | def plot_radar_depth(image, points, depth_points, vmin=0., vmax=80.): 246 | # Plot depth 247 | fig = plt.figure(figsize=(10,6)) 248 | ax = plt.gca() 249 | plt.imshow(image) 250 | im = plt.scatter(points[0, :], points[1, :], c=depth_points, s=30, vmin=vmin, vmax=vmax, cmap="jet") 251 | fig.colorbar(im, ax=ax, cmap="jet", pad=0.01) 252 | plt.show() 253 | 254 | 255 | # Plot valid labels 256 | def plot_valid_labels(image, points, valid_labels, vmin=0., vmax=2.): 257 | # Plot depth 258 | fig = plt.figure(figsize=(10,6)) 259 | ax = plt.gca() 260 | plt.imshow(image) 261 | im = plt.scatter(points[0, :], points[1, :], c=valid_labels, s=30, vmin=vmin, vmax=vmax, cmap="jet") 262 | fig.colorbar(im, ax=ax, cmap="jet", pad=0.01) 263 | plt.show() 264 | 265 | 266 | if __name__ == "__main__": 267 | # Construct the dataset object 268 | nuscene_dataset = Nuscenes_dataset() 269 | 270 | # Get samples 271 | sample_obj = nuscene_dataset.samples[90] 272 | 273 | orientation = "front" 274 | num_sweeps = 1 275 | 276 | # Get LiDAR and RADAR points separately 277 | lidar_data = nuscene_dataset.get_lidar_depth_map_multi_bidirectional(sample_obj, orientation, 1) 278 | radar_data = nuscene_dataset.get_radar_depth_map_multi_bidirectional(sample_obj, orientation, 3) 279 | 280 | filtered_data = filter_radar_points_gt(radar_data['points'], 281 | radar_data['depth_points'], 282 | lidar_data['points'], 283 | lidar_data['depth_points']) 284 | 285 | plot_lidar_depth(lidar_data['image'], 286 | lidar_data['points'], 287 | lidar_data['depth_points'], 0, 100) 288 | plot_radar_depth(radar_data['image'], 289 | filtered_data['radar_points'], 290 | filtered_data['radar_depth'], 0, 100) -------------------------------------------------------------------------------- /model/multistage_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | sys.path.append("../results/") 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.models 10 | from torchvision.models.resnet import Bottleneck, conv1x1, conv3x3 11 | from model.models import Unpool, weights_init, weights_init_kaiming, weights_init_kaiming_leaky 12 | from model.models import BasicBlock, Decoder, DeConv, UpConv, UpProj, choose_decoder 13 | from model.models import ResNet_latefusion 14 | from config.config_nuscenes import config_nuscenes as cfg 15 | import collections 16 | import math 17 | 18 | ################################ 19 | ## Define the full model here ## 20 | ################################ 21 | # The multistage network 22 | class ResNet_multistage(nn.Module): 23 | def __init__(self, layers, decoder, output_size, pretrained=True): 24 | if layers not in [18, 34, 50, 101, 152]: 25 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 26 | super(ResNet_multistage, self).__init__() 27 | 28 | # Define the model here 29 | self.stage1 = ResNet_latefusion2(layers, decoder, output_size, in_channels=4, pretrained=True) 30 | self.stage2 = ResNet_latefusion2(layers, decoder, output_size, in_channels=5, pretrained=True) 31 | self.filter_layer = Filter_layer() 32 | 33 | # ToDo: add method to load pretrained latefusion model 34 | if pretrained is True: 35 | # Get pretrained weights 36 | pretrained_path = os.path.join(cfg.PROJECT_ROOT, "pretrained/resnet18_latefusion.pth.tar") 37 | if not os.path.exists(pretrained_path): 38 | raise ValueError("[Error] Can't find pretrained latefusion model. "\ 39 | "Please follow the instructions in README.md to download the weights!") 40 | checkpoint = torch.load(pretrained_path) 41 | pretrain_weight = checkpoint["model_state_dict"] 42 | 43 | # Load state dict 44 | # Stage1 is the same so no problem 45 | self.stage1.load_state_dict(pretrain_weight) 46 | 47 | # Stage2 has some inconsistencies 48 | pretrain_weight_filtered = self.filter_state_dict(pretrain_weight, self.stage2.state_dict()) 49 | self.stage2.load_state_dict(pretrain_weight_filtered, strict=False) 50 | 51 | def filter_state_dict(self, pretrain_dict, target_dict): 52 | # iterate throught all the pretrain element 53 | del_keys = [] 54 | for key, value in pretrain_dict.items(): 55 | if target_dict[key].shape != value.shape: 56 | del_keys.append(key) 57 | 58 | for key in del_keys: 59 | pretrain_dict.pop(key) 60 | 61 | return pretrain_dict 62 | 63 | def forward(self, x): 64 | # Fetch inputs from different dimensions 65 | x_img = x[:, :3, :, :] 66 | x_d = x[:, 3:, :, :] 67 | 68 | # Stage 1 inference 69 | depth_stage1 = self.stage1(x) 70 | 71 | # Perform filtering 72 | x_d_filtered, mask = self.filter_layer(x_d, depth_stage1) 73 | 74 | # Stage 2 inference 75 | x_stage2 = torch.cat((x_img, x_d_filtered, depth_stage1), dim=1) 76 | depth_stage2 = self.stage2(x_stage2) 77 | 78 | return { 79 | "stage1": depth_stage1, 80 | "stage2": depth_stage2, 81 | "mask": mask, 82 | "radar_filtered": x_d_filtered 83 | } 84 | 85 | 86 | # Filter intermediate outputs 87 | class Filter_layer(nn.Module): 88 | def __init__(self): 89 | super(Filter_layer, self).__init__() 90 | # Define some filter parameters 91 | self.alpha = torch.tensor(5.) 92 | self.beta = torch.tensor(18.) 93 | self.K = torch.tensor(100.) 94 | 95 | # Convert to SID depth threshold 96 | def sid_depth_thresh(self, input_depth): 97 | # Compute depth threshold 98 | depth_thresh = torch.exp(((input_depth * torch.log(self.beta / self.alpha)) / self.K) + torch.log(self.alpha)) 99 | 100 | return depth_thresh 101 | 102 | # Compute valid mask 103 | def compute_valid_mask(self, sparse_depth, dense_depth): 104 | # Compute depth distance 105 | diff = torch.abs(dense_depth - sparse_depth) 106 | 107 | # Compute depth threshold 108 | depth_thresh = self.sid_depth_thresh(dense_depth) 109 | 110 | valid_mask = diff <= depth_thresh 111 | 112 | return valid_mask 113 | 114 | # Forward pass of the filtering 115 | def forward(self, sparse_depth, dense_depth): 116 | # Get valid mask 117 | mask = self.compute_valid_mask(sparse_depth, dense_depth).to(torch.float32) 118 | # pdb.set_trace() 119 | return sparse_depth * mask, mask 120 | 121 | 122 | # The original latefusion model 123 | class ResNet_latefusion2(nn.Module): 124 | def __init__(self, layers, decoder, output_size, in_channels=4, pretrained=True): 125 | if layers not in [18, 34, 50, 101, 152]: 126 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 127 | 128 | super(ResNet_latefusion2, self).__init__() 129 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 130 | 131 | # Configurations required by resnet 132 | self.in_channels = in_channels 133 | self._norm_layer = nn.BatchNorm2d 134 | self.dilation = 1 135 | self.inplanes = 16 136 | self.groups = 1 137 | self.base_width = 16 138 | 139 | assert in_channels > 3 140 | ################ 141 | ## RGB Branch ## 142 | ################ 143 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 144 | self.bn1 = nn.BatchNorm2d(64) 145 | weights_init(self.conv1) 146 | weights_init(self.bn1) 147 | 148 | self.output_size = output_size 149 | 150 | self.relu = pretrained_model._modules['relu'] 151 | self.maxpool = pretrained_model._modules['maxpool'] 152 | self.layer1 = pretrained_model._modules['layer1'] 153 | self.layer2 = pretrained_model._modules['layer2'] 154 | self.layer3 = pretrained_model._modules['layer3'] 155 | self.layer4 = pretrained_model._modules['layer4'] 156 | 157 | # clear memory 158 | del pretrained_model 159 | 160 | ################## 161 | ## Depth Branch ## 162 | ################## 163 | depth_input_dim = self.in_channels - 3 164 | self.conv1_depth = nn.Conv2d(depth_input_dim, 16, kernel_size=7, stride=2, padding=3, bias=False) 165 | self.bn1_depth = nn.BatchNorm2d(16) 166 | weights_init_kaiming_leaky(self.conv1) 167 | weights_init_kaiming(self.bn1) 168 | 169 | self.relu_depth = nn.LeakyReLU(0.2, inplace=True) 170 | self.maxpool_depth = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 171 | self.layer1_depth = self._make_layer(BasicBlock, 16, 2, stride=1, dilate=False) 172 | self.layer2_depth = self._make_layer(BasicBlock, 32, 2, stride=2) 173 | self.layer3_depth = self._make_layer(BasicBlock, 64, 2, stride=2) 174 | self.layer4_depth = self._make_layer(BasicBlock, 128, 2, stride=2) 175 | 176 | # ToDo: If we need one more convolution to do the fusion 177 | # Define the fusion operator 178 | self.conv_fusion = nn.Conv2d(512 + 128, 512, kernel_size=1, bias=False) 179 | self.bn_fusion = nn.BatchNorm2d(512) 180 | 181 | # define number of intermediate channels 182 | if layers <= 34: 183 | num_channels = 512 184 | elif layers >= 50: 185 | num_channels = 2048 186 | 187 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False) 188 | self.bn2 = nn.BatchNorm2d(num_channels//2) 189 | self.decoder = choose_decoder(decoder, num_channels//2) 190 | 191 | # setting bias=true doesn't improve accuracy 192 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False) 193 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 194 | 195 | # weight init 196 | self.conv2.apply(weights_init) 197 | self.bn2.apply(weights_init) 198 | self.decoder.apply(weights_init) 199 | self.conv3.apply(weights_init) 200 | 201 | # Make layer function adapted from resnet 202 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 203 | norm_layer = self._norm_layer 204 | downsample = None 205 | previous_dilation = self.dilation 206 | if dilate: 207 | self.dilation *= stride 208 | stride = 1 209 | if stride != 1 or self.inplanes != planes * block.expansion: 210 | downsample = nn.Sequential( 211 | conv1x1(self.inplanes, planes * block.expansion, stride), 212 | norm_layer(planes * block.expansion), 213 | ) 214 | 215 | layers = [] 216 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 217 | self.base_width, previous_dilation, norm_layer)) 218 | self.inplanes = planes * block.expansion 219 | for _ in range(1, blocks): 220 | layers.append(block(self.inplanes, planes, groups=self.groups, 221 | base_width=self.base_width, dilation=self.dilation, 222 | norm_layer=norm_layer)) 223 | 224 | layers = nn.Sequential(*layers) 225 | 226 | # Explicitly initialize layers after construction 227 | for m in layers.modules(): 228 | weights_init_kaiming(m) 229 | 230 | return layers 231 | 232 | def forward(self, x): 233 | assert x.shape[1] >= 4 234 | x_img = x[:, :3, :, :] 235 | 236 | if self.in_channels == 4: 237 | x_d = x[:, 3:, :, :] 238 | else: 239 | x_d_sparse = x[:, 3:4, :, :] 240 | x_d_dense = x[:, 4:5, :, :] 241 | x_d = torch.cat((x_d_sparse, x_d_dense), dim=1) 242 | 243 | # ipdb.set_trace() 244 | # RGB 245 | x_img = self.conv1(x_img) 246 | x_img = self.bn1(x_img) 247 | x_img = self.relu(x_img) 248 | x_img = self.maxpool(x_img) # 113 x 200 x 64 249 | x_img = self.layer1(x_img) # 113 x 200 x 64 250 | x_img = self.layer2(x_img) # 57 x 100 x 128 251 | x_img = self.layer3(x_img) # 29 x 50 x 256 252 | x_img = self.layer4(x_img) # 15 x 25 x 512 253 | 254 | # Depth 255 | x_d = self.conv1_depth(x_d) 256 | x_d = self.bn1_depth(x_d) 257 | x_d = self.relu_depth(x_d) 258 | x_d = self.maxpool_depth(x_d) # 113 x 200 x 16 259 | x_d = self.layer1_depth(x_d) # 113 x 200 x 16 260 | x_d = self.layer2_depth(x_d) # 57 x 100 x 32 261 | x_d = self.layer3_depth(x_d) # 29 x 50 x 64 262 | x_d = self.layer4_depth(x_d) # 15 x 25 x 128 263 | 264 | x_fused = torch.cat((x_img, x_d), dim=1) 265 | x_fused = self.conv_fusion(x_fused) 266 | x_fused = self.bn_fusion(x_fused) 267 | 268 | x_fused = self.conv2(x_fused) 269 | x_fused = self.bn2(x_fused) 270 | 271 | # decoder 272 | x_fused = self.decoder(x_fused) 273 | x_fused = self.conv3(x_fused) 274 | x_fused = self.bilinear(x_fused) 275 | 276 | return x_fused 277 | 278 | 279 | if __name__ == "__main__": 280 | # Create fake inputs 281 | inputs = torch.rand([16, 4, 450, 800]).to(torch.float32).cuda() 282 | 283 | # Create model 284 | model = ResNet_multistage(18, "upproj", [450, 800], True).cuda() 285 | 286 | # Run the inference 287 | outputs = model(inputs) -------------------------------------------------------------------------------- /misc/devkit/cpp/evaluate_depth.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "io_depth.h" 6 | #include "utils.h" 7 | //iterate over files in directory 8 | #include "dirent.h" 9 | #include 10 | 11 | using namespace std; 12 | 13 | /** \brief Method to calculate depth average error. 14 | * \param D_gt the ground truth depth image 15 | * \param D_ipol the interpolated depth image to be benchmarked 16 | * \return mae between original and ground truth depth (2 entries: occluded and non-occluded) 17 | * 18 | */ 19 | std::vector depthError (DepthImage &D_gt, DepthImage &D_ipol) { 20 | 21 | // check file size 22 | if (D_gt.width() != D_ipol.width() || D_gt.height() != D_ipol.height()) { 23 | cout << "ERROR: Wrong file size!" << endl; 24 | throw 1; 25 | } 26 | 27 | // extract width and height 28 | uint32_t width = D_gt.width(); 29 | uint32_t height = D_gt.height(); 30 | 31 | //init errors 32 | // 1. mae 33 | // 2. rmse 34 | // 3. inverse mae 35 | // 4. inverse rmse 36 | // 5. log mae 37 | // 6. log rmse 38 | // 7. scale invariant log 39 | // 8. abs relative 40 | // 9. squared relative 41 | 42 | std::vector errors(9, 0.f); 43 | // 44 | uint32_t num_pixels = 0; 45 | uint32_t num_pixels_result = 0; 46 | 47 | //log sum for scale invariant metric 48 | float logSum = 0.0; 49 | 50 | // for all pixels do 51 | for (uint32_t u = 0; u < width; u++) { 52 | for (uint32_t v = 0; v < height; v++) { 53 | if (D_gt.isValid(u, v)) { 54 | const float depth_ipol_m = D_ipol.getDepth(u, v); 55 | const float depth_gt_m = D_gt.getDepth(u, v); 56 | 57 | //error if gt is valid 58 | const float d_err = fabs(depth_gt_m - depth_ipol_m); 59 | 60 | const float d_err_squared = d_err * d_err; 61 | const float d_err_inv = fabs( 1.0 / depth_gt_m - 1.0 / depth_ipol_m); 62 | const float d_err_inv_squared = d_err_inv * d_err_inv; 63 | const float d_err_log = fabs(log(depth_gt_m) - log(depth_ipol_m)); 64 | const float d_err_log_squared = d_err_log * d_err_log; 65 | 66 | //mae 67 | errors[0] += d_err; 68 | //rmse 69 | errors[1] += d_err_squared; 70 | //inv_mae 71 | errors[2] += d_err_inv; 72 | //inv_rmse 73 | errors[3] += d_err_inv_squared; 74 | //log 75 | errors[4] += d_err_log; 76 | //log rmse 77 | errors[5] += d_err_log_squared; 78 | //log diff for scale invariant metric 79 | logSum += (log(depth_gt_m) - log(depth_ipol_m)); 80 | //abs relative 81 | errors[7] += d_err/depth_gt_m; 82 | //squared relative 83 | errors[8] += d_err_squared/(depth_gt_m*depth_gt_m); 84 | 85 | //increase valid gt pixels 86 | num_pixels++; 87 | } 88 | } //end for v 89 | } //end for u 90 | 91 | // check number of pixels 92 | if (num_pixels == 0) { 93 | cout << "ERROR: Ground truth defect => Please write me an email!" << endl; 94 | throw 1; 95 | } 96 | 97 | //normalize mae 98 | errors[0] /= (float)num_pixels; 99 | //normalize and take sqrt for rmse 100 | errors[1] /= (float)num_pixels; 101 | errors[1] = sqrt(errors[1]); 102 | //normalize inverse absoulte error 103 | errors[2] /= (float)num_pixels; 104 | //normalize and take sqrt for inverse rmse 105 | errors[3] /= (float)num_pixels; 106 | errors[3] = sqrt(errors[3]); 107 | //normalize log mae 108 | errors[4] /= (float)num_pixels; 109 | //first normalize log rmse -> we need this result later 110 | const float normalizedSquaredLog = errors[5] / (float)num_pixels; 111 | errors[5] = sqrt(normalizedSquaredLog); 112 | //calculate scale invariant metric 113 | errors[6] = sqrt(normalizedSquaredLog - (logSum*logSum / ((float)num_pixels*(float)num_pixels))); 114 | //normalize abs relative 115 | errors[7] /= (float)num_pixels; 116 | //normalize squared relative 117 | errors[8] /= (float)num_pixels; 118 | // return errors 119 | return errors; 120 | } 121 | 122 | /** \brief Helper function for png file selection. 123 | * \param entry direct struct to be compared 124 | * 125 | */ 126 | int png_select(const dirent *entry) 127 | { 128 | const char* fileName = entry->d_name; 129 | 130 | //check that this is not a directory 131 | if ((strcmp(fileName, ".")== 0) || (strcmp(fileName, "..") == 0)) 132 | return false; 133 | 134 | /* Check for png filename extensions */ 135 | const char* ptr = strrchr(fileName, '.'); 136 | if ((ptr != NULL) && (strcmp(ptr, ".png") == 0)) 137 | return true; 138 | else 139 | return false; 140 | } 141 | 142 | /** \brief Method to evaluate depth maps. 143 | * \param prediction_dir The directory containing predicted depth maps. 144 | * \return success If true, writes a txt file containing all depth error metrics. 145 | */ 146 | bool eval (string gt_img_dir, string prediction_dir) { 147 | // make sure all directories have ending slashes 148 | gt_img_dir += "/"; 149 | prediction_dir += "/"; 150 | 151 | // for all evaluation files do 152 | struct dirent **namelist_gt; 153 | struct dirent **namelist_prediction; 154 | int num_files_gt = scandir(gt_img_dir.c_str(), &namelist_gt, png_select, alphasort); 155 | int num_files_prediction = scandir(prediction_dir.c_str(), &namelist_prediction, png_select, alphasort); 156 | 157 | if( num_files_gt != num_files_prediction ){ 158 | std::cout << "Number of groundtruth (" << num_files_gt << ") and prediction files (" << num_files_prediction << ") mismatch!" << std::endl; 159 | free(namelist_gt); 160 | free(namelist_prediction); 161 | return false; 162 | } 163 | std::cout << "Found " << num_files_gt << " groundtruth and " << num_files_prediction << " prediction files!" << std::endl; 164 | 165 | if( num_files_prediction < 0 ){ 166 | perror("scandir"); 167 | } 168 | 169 | // std::vector for storing the errors 170 | std::vector< std::vector > errors_out; 171 | // create output directories 172 | system(("mkdir " + prediction_dir + "/errors_out/").c_str()); 173 | system(("mkdir " + prediction_dir + "/errors_img/").c_str()); 174 | system(("mkdir " + prediction_dir + "/depth_orig/").c_str()); 175 | system(("mkdir " + prediction_dir + "/depth_ipol/").c_str()); 176 | system(("mkdir " + prediction_dir + "/image_0/").c_str()); 177 | 178 | for( int i = 0; i < num_files_prediction; ++i ){ 179 | //Be aware: we use the same index here, the files have to be correctly sorted!!!! 180 | std::string fileName_gt = gt_img_dir + namelist_gt[i]->d_name; 181 | std::string fileName_prediction = prediction_dir + namelist_prediction[i]->d_name; 182 | 183 | if( strcmp(fileName_gt.c_str(), ".") == 0 || strcmp(fileName_gt.c_str(), "..") == 0 ) continue; 184 | std::string filePath = gt_img_dir + fileName_gt; 185 | //std::string fileName_gt = path.back(); 186 | 187 | size_t lastindex = std::string(namelist_gt[i]->d_name).find_last_of("."); 188 | std::string prefix = std::string(namelist_gt[i]->d_name).substr(0, lastindex); 189 | 190 | // output 191 | std::cout << "Processing: " << prefix.c_str() << std::endl; 192 | 193 | // catch errors, when loading fails 194 | try { 195 | // load ground truth depth maps 196 | DepthImage D_gt(fileName_gt); 197 | 198 | // check file format 199 | if (!imageFormat(fileName_gt, png::color_type_gray, 16, D_gt.width(), D_gt.height())) { 200 | std::cout << "ERROR: Input must be png, 1 channel, 16 bit, " << D_gt.width() << " x " << D_gt.height() << "px" << std::endl; 201 | free(namelist_gt); 202 | free(namelist_prediction); 203 | return false; 204 | } 205 | // load prediction and interpolate missing values 206 | DepthImage D_orig(fileName_prediction); 207 | DepthImage D_ipol(D_orig); 208 | D_ipol.interpolateBackground(); 209 | 210 | // add depth errors 211 | std::vector errors_out_curr = depthError(D_gt, D_ipol); 212 | errors_out.push_back(errors_out_curr); 213 | 214 | // save detailed infos for first 20 images 215 | if (i < 20) { 216 | // save errors of error images to text file 217 | FILE *errors_out_file = fopen((prediction_dir + "/errors_out/" + prefix + ".txt").c_str(), "w"); 218 | if (errors_out_file == NULL) { 219 | std::cout << "ERROR: Couldn't generate/store output statistics!" << std::endl; 220 | return false; 221 | } 222 | for (int32_t j = 0; j < errors_out_curr.size(); j++) { 223 | fprintf(errors_out_file, "%f ", errors_out_curr[j]); 224 | } 225 | fclose(errors_out_file); 226 | 227 | // save error image 228 | png::image D_err = D_ipol.errorImage(D_gt, true); 229 | D_err.write(prediction_dir + "/errors_img/" + prefix + ".png"); 230 | 231 | // compute maximum depth 232 | float max_depth = D_gt.maxDepth(); 233 | 234 | // save original depth image false color coded 235 | D_orig.writeColor(prediction_dir + "/depth_orig/" + prefix + ".png", max_depth); 236 | 237 | // save interpolated depth image false color coded 238 | D_ipol.writeColor(prediction_dir + "/depth_ipol/" + prefix + ".png", max_depth); 239 | 240 | // copy left camera image 241 | string img_src = gt_img_dir + "/" + prefix + ".png"; 242 | string img_dst = prediction_dir + "/image_0/" + prefix + ".png"; 243 | system(("cp " + img_src + " " + img_dst).c_str()); 244 | } 245 | 246 | // on error, exit 247 | } catch (...) { 248 | std::cout << "ERROR: Couldn't read: " << prefix.c_str() << ".png" << std::endl; 249 | free(namelist_gt); 250 | free(namelist_prediction); 251 | return false; 252 | } 253 | } 254 | // open stats file for writing 255 | string stats_out_file_name = prediction_dir + "/stats_depth.txt"; 256 | FILE *stats_out_file = fopen(stats_out_file_name.c_str(), "w"); 257 | 258 | if (stats_out_file == NULL || errors_out.size() == 0) { 259 | std::cout << "ERROR: Couldn't generate/store output statistics!" << std::endl; 260 | free(namelist_gt); 261 | free(namelist_prediction); 262 | return false; 263 | } 264 | 265 | const char *metrics[] = {"mae", 266 | "rmse", 267 | "inverse mae", 268 | "inverse rmse", 269 | "log mae", 270 | "log rmse", 271 | "scale invariant log", 272 | "abs relative", 273 | "squared relative"}; 274 | // write mean, min and max 275 | std::cout << "Done. Your evaluation results are:" << std::endl; 276 | for (int32_t i = 0; i < errors_out[0].size(); i++) { 277 | std::cout << "mean " << metrics[i] << ": " << statMean(errors_out, i) << std::endl; 278 | fprintf(stats_out_file, "mean %s: %f \n", metrics[i], statMean(errors_out, i)); 279 | fprintf(stats_out_file, "min %s: %f \n", metrics[i], statMin(errors_out, i)); 280 | fprintf(stats_out_file, "max %s: %f \n", metrics[i], statMax(errors_out, i)); 281 | } 282 | 283 | // close file 284 | fclose(stats_out_file); 285 | //free memory of scandir calls 286 | free(namelist_gt); 287 | free(namelist_prediction); 288 | 289 | // success 290 | return true; 291 | } 292 | 293 | int32_t main (int32_t argc, char *argv[]) { 294 | 295 | // we need 3 arguments! 296 | for (int32_t i = 0; i < argc; ++i){ 297 | std::cout << argv[i] << " "; 298 | } 299 | std::cout << std::endl; 300 | if (argc != 3) { 301 | cout << "Usage: ./evaluate_depth gt_dir prediction_dir" << endl; 302 | return 1; 303 | } 304 | 305 | // read arguments 306 | string gt_img_dir = argv[1]; 307 | string prediction_dir = argv[2]; 308 | std::cout << "Starting depth evaluation.." << std::endl; 309 | // run evaluation 310 | bool success = eval(gt_img_dir, prediction_dir); 311 | 312 | if (success) { 313 | std::cout << "Your evaluation results are available at:" << std::endl; 314 | std::cout << prediction_dir + "/stats_depth.txt" << std::endl; 315 | } else { 316 | std::cout << "An error occured while processing your results." << std::endl; 317 | std::cout << "Please make sure that the data in your result directory has the right format (compare to prediction/sparseConv_val)" << std::endl; 318 | } 319 | 320 | // exit 321 | return 0; 322 | } 323 | 324 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapt from fangchangma/sparse_to_dense.pytorch on github 3 | """ 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | import torch 9 | import torchvision 10 | import math 11 | import random 12 | 13 | from PIL import Image, ImageOps, ImageEnhance 14 | try: 15 | import accimage 16 | except ImportError: 17 | accimage = None 18 | 19 | import numpy as np 20 | import numbers 21 | import types 22 | import collections 23 | 24 | import scipy.ndimage.interpolation as itpl 25 | import scipy.misc as misc 26 | 27 | 28 | # Check whether the input is a numpy array 29 | def _is_numpy_image(img): 30 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 31 | 32 | 33 | # Check whether the input is a PIL image 34 | def _is_pil_image(img): 35 | if accimage is not None: 36 | return isinstance(img, (Image.Image, accimage.Image)) 37 | else: 38 | return isinstance(img, Image.Image) 39 | 40 | 41 | # Check wheter the input is a tensor 42 | def _is_tensor_image(img): 43 | return torch.is_tensor(img) and img.ndimension() == 3 44 | 45 | 46 | def adjust_brightness(img, brightness_factor): 47 | """Adjust brightness of an Image. 48 | Args: 49 | img (PIL Image): PIL Image to be adjusted. 50 | brightness_factor (float): How much to adjust the brightness. Can be 51 | any non negative number. 0 gives a black image, 1 gives the 52 | original image while 2 increases the brightness by a factor of 2. 53 | Returns: 54 | PIL Image: Brightness adjusted image. 55 | """ 56 | if not _is_pil_image(img): 57 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 58 | 59 | enhancer = ImageEnhance.Brightness(img) 60 | img = enhancer.enhance(brightness_factor) 61 | return img 62 | 63 | 64 | def adjust_contrast(img, contrast_factor): 65 | """Adjust contrast of an Image. 66 | Args: 67 | img (PIL Image): PIL Image to be adjusted. 68 | contrast_factor (float): How much to adjust the contrast. Can be any 69 | non negative number. 0 gives a solid gray image, 1 gives the 70 | original image while 2 increases the contrast by a factor of 2. 71 | Returns: 72 | PIL Image: Contrast adjusted image. 73 | """ 74 | if not _is_pil_image(img): 75 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 76 | 77 | enhancer = ImageEnhance.Contrast(img) 78 | img = enhancer.enhance(contrast_factor) 79 | return img 80 | 81 | 82 | def adjust_saturation(img, saturation_factor): 83 | """Adjust color saturation of an image. 84 | Args: 85 | img (PIL Image): PIL Image to be adjusted. 86 | saturation_factor (float): How much to adjust the saturation. 0 will 87 | give a black and white image, 1 will give the original image while 88 | 2 will enhance the saturation by a factor of 2. 89 | Returns: 90 | PIL Image: Saturation adjusted image. 91 | """ 92 | if not _is_pil_image(img): 93 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 94 | 95 | enhancer = ImageEnhance.Color(img) 96 | img = enhancer.enhance(saturation_factor) 97 | return img 98 | 99 | 100 | def adjust_hue(img, hue_factor): 101 | """Adjust hue of an image. 102 | The image hue is adjusted by converting the image to HSV and 103 | cyclically shifting the intensities in the hue channel (H). 104 | The image is then converted back to original image mode. 105 | `hue_factor` is the amount of shift in H channel and must be in the 106 | interval `[-0.5, 0.5]`. 107 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 108 | Args: 109 | img (PIL Image): PIL Image to be adjusted. 110 | hue_factor (float): How much to shift the hue channel. Should be in 111 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 112 | HSV space in positive and negative direction respectively. 113 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 114 | with complementary colors while 0 gives the original image. 115 | Returns: 116 | PIL Image: Hue adjusted image. 117 | """ 118 | if not(-0.5 <= hue_factor <= 0.5): 119 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 120 | 121 | if not _is_pil_image(img): 122 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 123 | 124 | input_mode = img.mode 125 | if input_mode in {'L', '1', 'I', 'F'}: 126 | return img 127 | 128 | h, s, v = img.convert('HSV').split() 129 | 130 | np_h = np.array(h, dtype=np.uint8) 131 | # uint8 addition take cares of rotation across boundaries 132 | with np.errstate(over='ignore'): 133 | np_h += np.uint8(hue_factor * 255) 134 | h = Image.fromarray(np_h, 'L') 135 | 136 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 137 | return img 138 | 139 | 140 | def adjust_gamma(img, gamma, gain=1): 141 | """Perform gamma correction on an image. 142 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 143 | based on the following equation: 144 | I_out = 255 * gain * ((I_in / 255) ** gamma) 145 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 146 | Args: 147 | img (PIL Image): PIL Image to be adjusted. 148 | gamma (float): Non negative real number. gamma larger than 1 make the 149 | shadows darker, while gamma smaller than 1 make dark regions 150 | lighter. 151 | gain (float): The constant multiplier. 152 | """ 153 | if not _is_pil_image(img): 154 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 155 | 156 | if gamma < 0: 157 | raise ValueError('Gamma should be a non-negative real number') 158 | 159 | input_mode = img.mode 160 | img = img.convert('RGB') 161 | 162 | np_img = np.array(img, dtype=np.float32) 163 | np_img = 255 * gain * ((np_img / 255) ** gamma) 164 | np_img = np.uint8(np.clip(np_img, 0, 255)) 165 | 166 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 167 | return img 168 | 169 | 170 | class Compose(object): 171 | """Composes several transforms together. 172 | Args: 173 | transforms (list of ``Transform`` objects): list of transforms to compose. 174 | Example: 175 | >>> transforms.Compose([ 176 | >>> transforms.CenterCrop(10), 177 | >>> transforms.ToTensor(), 178 | >>> ]) 179 | """ 180 | 181 | def __init__(self, transforms): 182 | self.transforms = transforms 183 | 184 | def __call__(self, img): 185 | for t in self.transforms: 186 | img = t(img) 187 | return img 188 | 189 | 190 | class ToTensor(object): 191 | """Convert a ``numpy.ndarray`` to tensor. 192 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 193 | """ 194 | 195 | def __call__(self, img): 196 | """Convert a ``numpy.ndarray`` to tensor. 197 | Args: 198 | img (numpy.ndarray): Image to be converted to tensor. 199 | Returns: 200 | Tensor: Converted image. 201 | """ 202 | if not(_is_numpy_image(img)): 203 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 204 | 205 | if isinstance(img, np.ndarray): 206 | # handle numpy array 207 | if img.ndim == 3: 208 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) 209 | elif img.ndim == 2: 210 | img = torch.from_numpy(img.copy()) 211 | else: 212 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 213 | 214 | # backward compatibility 215 | # return img.float().div(255) 216 | return img.float() 217 | 218 | 219 | class NormalizeNumpyArray(object): 220 | """Normalize a ``numpy.ndarray`` with mean and standard deviation. 221 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 222 | will normalize each channel of the input ``numpy.ndarray`` i.e. 223 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 224 | Args: 225 | mean (sequence): Sequence of means for each channel. 226 | std (sequence): Sequence of standard deviations for each channel. 227 | """ 228 | 229 | def __init__(self, mean, std): 230 | self.mean = mean 231 | self.std = std 232 | 233 | def __call__(self, img): 234 | """ 235 | Args: 236 | img (numpy.ndarray): Image of size (H, W, C) to be normalized. 237 | Returns: 238 | Tensor: Normalized image. 239 | """ 240 | if not(_is_numpy_image(img)): 241 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 242 | # TODO: make efficient 243 | print(img.shape) 244 | for i in range(3): 245 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i] 246 | return img 247 | 248 | 249 | class NormalizeTensor(object): 250 | """Normalize an tensor image with mean and standard deviation. 251 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 252 | will normalize each channel of the input ``torch.*Tensor`` i.e. 253 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 254 | Args: 255 | mean (sequence): Sequence of means for each channel. 256 | std (sequence): Sequence of standard deviations for each channel. 257 | """ 258 | 259 | def __init__(self, mean, std): 260 | self.mean = mean 261 | self.std = std 262 | 263 | def __call__(self, tensor): 264 | """ 265 | Args: 266 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 267 | Returns: 268 | Tensor: Normalized Tensor image. 269 | """ 270 | if not _is_tensor_image(tensor): 271 | raise TypeError('tensor is not a torch image.') 272 | # TODO: make efficient 273 | for t, m, s in zip(tensor, self.mean, self.std): 274 | t.sub_(m).div_(s) 275 | return tensor 276 | 277 | 278 | class Rotate(object): 279 | """Rotates the given ``numpy.ndarray``. 280 | Args: 281 | angle (float): The rotation angle in degrees. 282 | """ 283 | 284 | def __init__(self, angle): 285 | self.angle = angle 286 | 287 | def __call__(self, img): 288 | """ 289 | Args: 290 | img (numpy.ndarray (C x H x W)): Image to be rotated. 291 | Returns: 292 | img (numpy.ndarray (C x H x W)): Rotated image. 293 | """ 294 | 295 | # order=0 means nearest-neighbor type interpolation 296 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0) 297 | 298 | 299 | class Resize(object): 300 | """Resize the the given ``numpy.ndarray`` to the given size. 301 | Args: 302 | size (sequence or int): Desired output size. If size is a sequence like 303 | (h, w), output size will be matched to this. If size is an int, 304 | smaller edge of the image will be matched to this number. 305 | i.e, if height > width, then image will be rescaled to 306 | (size * height / width, size) 307 | interpolation (int, optional): Desired interpolation. Default is 308 | ``PIL.Image.BILINEAR`` 309 | """ 310 | 311 | def __init__(self, size, interpolation='nearest'): 312 | assert isinstance(size, int) or isinstance(size, float) or \ 313 | (isinstance(size, collections.Iterable) and len(size) == 2) 314 | self.size = size 315 | self.interpolation = interpolation 316 | 317 | def __call__(self, img): 318 | """ 319 | Args: 320 | img (PIL Image): Image to be scaled. 321 | Returns: 322 | PIL Image: Rescaled image. 323 | """ 324 | if img.ndim == 3: 325 | return misc.imresize(img, self.size, self.interpolation) 326 | elif img.ndim == 2: 327 | return misc.imresize(img, self.size, self.interpolation, 'F') 328 | else: 329 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 330 | 331 | 332 | class CenterCrop(object): 333 | """Crops the given ``numpy.ndarray`` at the center. 334 | Args: 335 | size (sequence or int): Desired output size of the crop. If size is an 336 | int instead of sequence like (h, w), a square crop (size, size) is 337 | made. 338 | """ 339 | 340 | def __init__(self, size): 341 | if isinstance(size, numbers.Number): 342 | self.size = (int(size), int(size)) 343 | else: 344 | self.size = size 345 | 346 | @staticmethod 347 | def get_params(img, output_size): 348 | """Get parameters for ``crop`` for center crop. 349 | Args: 350 | img (numpy.ndarray (C x H x W)): Image to be cropped. 351 | output_size (tuple): Expected output size of the crop. 352 | Returns: 353 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 354 | """ 355 | h = img.shape[0] 356 | w = img.shape[1] 357 | th, tw = output_size 358 | i = int(round((h - th) / 2.)) 359 | j = int(round((w - tw) / 2.)) 360 | 361 | # # randomized cropping 362 | # i = np.random.randint(i-3, i+4) 363 | # j = np.random.randint(j-3, j+4) 364 | 365 | return i, j, th, tw 366 | 367 | def __call__(self, img): 368 | """ 369 | Args: 370 | img (numpy.ndarray (C x H x W)): Image to be cropped. 371 | Returns: 372 | img (numpy.ndarray (C x H x W)): Cropped image. 373 | """ 374 | i, j, h, w = self.get_params(img, self.size) 375 | 376 | """ 377 | i: Upper pixel coordinate. 378 | j: Left pixel coordinate. 379 | h: Height of the cropped image. 380 | w: Width of the cropped image. 381 | """ 382 | if not(_is_numpy_image(img)): 383 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 384 | if img.ndim == 3: 385 | return img[i:i+h, j:j+w, :] 386 | elif img.ndim == 2: 387 | return img[i:i + h, j:j + w] 388 | else: 389 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 390 | 391 | 392 | class Lambda(object): 393 | """Apply a user-defined lambda as a transform. 394 | Args: 395 | lambd (function): Lambda/function to be used for transform. 396 | """ 397 | 398 | def __init__(self, lambd): 399 | assert isinstance(lambd, types.LambdaType) 400 | self.lambd = lambd 401 | 402 | def __call__(self, img): 403 | return self.lambd(img) 404 | 405 | 406 | class HorizontalFlip(object): 407 | """Horizontally flip the given ``numpy.ndarray``. 408 | Args: 409 | do_flip (boolean): whether or not do horizontal flip. 410 | """ 411 | 412 | def __init__(self, do_flip): 413 | self.do_flip = do_flip 414 | 415 | def __call__(self, img): 416 | """ 417 | Args: 418 | img (numpy.ndarray (C x H x W)): Image to be flipped. 419 | Returns: 420 | img (numpy.ndarray (C x H x W)): flipped image. 421 | """ 422 | if not(_is_numpy_image(img)): 423 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 424 | 425 | if self.do_flip: 426 | return np.fliplr(img) 427 | else: 428 | return img 429 | 430 | 431 | class ColorJitter(object): 432 | """Randomly change the brightness, contrast and saturation of an image. 433 | Args: 434 | brightness (float): How much to jitter brightness. brightness_factor 435 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 436 | contrast (float): How much to jitter contrast. contrast_factor 437 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 438 | saturation (float): How much to jitter saturation. saturation_factor 439 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 440 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 441 | [-hue, hue]. Should be >=0 and <= 0.5. 442 | """ 443 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 444 | self.brightness = brightness 445 | self.contrast = contrast 446 | self.saturation = saturation 447 | self.hue = hue 448 | 449 | @staticmethod 450 | def get_params(brightness, contrast, saturation, hue): 451 | """Get a randomized transform to be applied on image. 452 | Arguments are same as that of __init__. 453 | Returns: 454 | Transform which randomly adjusts brightness, contrast and 455 | saturation in a random order. 456 | """ 457 | transforms = [] 458 | if brightness > 0: 459 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 460 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 461 | 462 | if contrast > 0: 463 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 464 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 465 | 466 | if saturation > 0: 467 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 468 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 469 | 470 | if hue > 0: 471 | hue_factor = np.random.uniform(-hue, hue) 472 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 473 | 474 | np.random.shuffle(transforms) 475 | transform = Compose(transforms) 476 | 477 | return transform 478 | 479 | def __call__(self, img): 480 | """ 481 | Args: 482 | img (numpy.ndarray (C x H x W)): Input image. 483 | Returns: 484 | img (numpy.ndarray (C x H x W)): Color jittered image. 485 | """ 486 | if not(_is_numpy_image(img)): 487 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 488 | 489 | pil = Image.fromarray(img) 490 | transform = self.get_params(self.brightness, self.contrast, 491 | self.saturation, self.hue) 492 | return np.array(transform(pil)) 493 | 494 | 495 | # Easier version of color jitter 496 | def Colorjitter2(brightness=0, contrast=0, saturation=0): 497 | return torchvision.transforms.ColorJitter( 498 | brightness=brightness, 499 | contrast=contrast, 500 | saturation=saturation 501 | ) 502 | 503 | 504 | # Normalization using imagenet mean and variance 505 | def normalization_imagenet(inputs): 506 | # Construct the normalization 507 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 508 | 509 | return normalize(inputs) 510 | 511 | # Denormalization using imagenet mean and variance 512 | def denormalization_imagenet(inputs): 513 | # Construct the denormalization 514 | mean_r = 0.485 515 | mean_g = 0.456 516 | mean_b = 0.406 517 | std_r = 0.229 518 | std_g = 0.224 519 | std_b = 0.225 520 | denormalize = torchvision.transforms.Normalize(mean=[-mean_r/std_r, -mean_g/std_g, -mean_b/std_b], 521 | std=[1/std_r, 1/std_g, 1/std_b]) 522 | return denormalize(inputs) 523 | 524 | 525 | # Denormalize batch of tensors 526 | def denormalization_batch(inputs): 527 | # Get the batch size 528 | batch_size = inputs.shape[0] 529 | tensor_list = [] 530 | for i in range(batch_size): 531 | tensor_list.append(torch.unsqueeze(denormalization_imagenet(inputs[i, :, :, :]), dim=0)) 532 | 533 | return torch.cat(tuple(tensor_list), dim=0) 534 | 535 | 536 | class Crop(object): 537 | """Crops the given PIL Image to a rectangular region based on a given 538 | 4-tuple defining the left, upper pixel coordinated, hight and width size. 539 | Args: 540 | a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple 541 | """ 542 | 543 | def __init__(self, i, j, h, w): 544 | """ 545 | i: Upper pixel coordinate. 546 | j: Left pixel coordinate. 547 | h: Height of the cropped image. 548 | w: Width of the cropped image. 549 | """ 550 | self.i = i 551 | self.j = j 552 | self.h = h 553 | self.w = w 554 | 555 | def __call__(self, img): 556 | """ 557 | Args: 558 | img (numpy.ndarray (C x H x W)): Image to be cropped. 559 | Returns: 560 | img (numpy.ndarray (C x H x W)): Cropped image. 561 | """ 562 | 563 | i, j, h, w = self.i, self.j, self.h, self.w 564 | 565 | if not(_is_numpy_image(img)): 566 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 567 | if img.ndim == 3: 568 | return img[i:i + h, j:j + w, :] 569 | elif img.ndim == 2: 570 | return img[i:i + h, j:j + w] 571 | else: 572 | raise RuntimeError( 573 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 574 | 575 | def __repr__(self): 576 | return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format( 577 | self.i, self.j, self.h, self.w) -------------------------------------------------------------------------------- /dataset/nuscenes_dataset_torch_new.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file implement the dataset object for pytorch 3 | """ 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | # Add system path for fast debugging 9 | import sys 10 | sys.path.insert(0, "/cluster/home/julin/workspace/Semester_project_release") 11 | 12 | import os 13 | import torch 14 | import numpy as np 15 | from torch.utils.data import Dataset 16 | from dataset.nuscenes_dataset import Nuscenes_dataset 17 | from config.config_nuscenes import config_nuscenes as cfg 18 | from dataset.dense_to_sparse import UniformSampling, LidarRadarSampling 19 | from dataset import transforms as transforms 20 | from dataset.radar_preprocessing import filter_radar_points_gt 21 | import math 22 | import h5py 23 | import pickle 24 | import matplotlib.pyplot as plt 25 | to_tensor = transforms.ToTensor() 26 | 27 | 28 | #################################### 29 | ## Sparsifier Documentations: 30 | ## 1. uniform: Uniformly sampled LiDAR points. 31 | ## 2. lidar_radar: Sampled LiDAR points using the radar pattern. 32 | ## 3. radar: raw radar points (accumulated from three time steps. 33 | ## 4. radar_filtered: Filtered radar points using the heuristic algorithm. 34 | ## 5. radar_filtered2: Filtered radar points using the trained point classifier. 35 | #################################### 36 | 37 | # Define the dataset object for torch 38 | class nuscenes_dataset_torch(Dataset): 39 | def __init__(self, 40 | mode="train", transform_mode="DORN", modality="rgb", 41 | sparsifier=None, num_samples=0, max_depth=100., 42 | ): 43 | super(nuscenes_dataset_torch, self).__init__() 44 | print("[Info] Initializing exported nuscenes dataset") 45 | self.mode = mode 46 | self.filename_dataset = self.get_filename_dataset() 47 | print("\t Mode: ", self.mode) 48 | print("\t Version: ", cfg.version) 49 | print("\t Data counts: ", self.filename_dataset["length"]) 50 | 51 | # Check modalities 52 | self.avail_modality = ["rgb", "rgbd"] 53 | if not modality in self.avail_modality: 54 | raise ValueError("[Error] Unsupported modality. Consider ", self.avail_modality) 55 | 56 | self.modality = modality 57 | print("\t Modality: ", self.modality) 58 | 59 | # Check sparsifier and modality 60 | if (self.modality == "rgb"): 61 | self.sparsifier = "radar" 62 | self.num_samples = num_samples 63 | self.max_depth = max_depth 64 | 65 | elif (self.modality == "rgbd") and (sparsifier is None): 66 | # If rgbd and not sparsifier, then use radar 67 | self.sparsifier = "radar" 68 | self.num_samples = num_samples 69 | self.max_depth = max_depth 70 | 71 | elif (self.modality == "rgbd") and (sparsifier is not None): 72 | # If sparsifier is provided then check if it's valid 73 | if not sparsifier in ["uniform", "lidar_radar", "radar", 74 | "radar_filtered", "radar_filtered2"]: 75 | raise ValueError("[Error] Invalid sparsifier.") 76 | 77 | assert num_samples is not None 78 | self.sparsifier = sparsifier 79 | self.num_samples = num_samples 80 | self.max_depth = max_depth 81 | # Initialize uniform sampler 82 | if self.sparsifier == "uniform": 83 | self.sparsifier_func = UniformSampling(num_samples, max_depth) 84 | # Initialize lidar_radar sampler 85 | elif self.sparsifier == "lidar_radar": 86 | self.sparsifier_func = LidarRadarSampling(num_samples, max_depth) 87 | # Radar will be handled in the end of transform, no sparsifier_func is required 88 | elif (self.sparsifier == "radar") or \ 89 | (self.sparsifier == "radar_filtered") or \ 90 | (self.sparsifier == "radar_filtered2"): 91 | pass 92 | else: 93 | raise NotImplementedError 94 | 95 | print("\t Sparsifier: ", self.sparsifier) 96 | 97 | # Further get the day-night split table. 98 | day_night_info = self.get_day_night_info() 99 | self.train_daynight_table = day_night_info["train"] 100 | self.train_daynight_count = day_night_info["train_count"] 101 | self.test_daynight_table = day_night_info["test"] 102 | self.test_daynight_count = day_night_info["test_count"] 103 | print("\t Day-Night info:") 104 | print("\t\t Train:") 105 | print("\t\t\t Day: %d" %(self.train_daynight_count["day"])) 106 | print("\t\t\t Day + Rain: %d" %(self.train_daynight_count["day_rain"])) 107 | print("\t\t\t Night: %d" %(self.train_daynight_count["night"])) 108 | print("\t\t\t Night + Rain: %d" %(self.train_daynight_count["night_rain"])) 109 | print("\t\t Test:") 110 | print("\t\t\t Day: %d" %(self.test_daynight_count["day"])) 111 | print("\t\t\t Day + Rain: %d" %(self.test_daynight_count["day_rain"])) 112 | print("\t\t\t Night: %d" %(self.test_daynight_count["night"])) 113 | print("\t\t\t Night + Rain: %d" %(self.test_daynight_count["night_rain"])) 114 | print("-----------------------------------------") 115 | 116 | # Check transform mode 117 | assert transform_mode in ["DORN", "sparse-to-dense"] 118 | self.transform_mode = transform_mode 119 | if self.transform_mode == "DORN": 120 | self.t_cfg = cfg.DORN_transform_config 121 | elif self.transform_mode == "sparse-to-dense": 122 | self.t_cfg = cfg.sparse_transform_config 123 | 124 | # Define outut size 125 | if self.mode == "train": 126 | self.output_size = self.t_cfg.crop_size_train 127 | else: 128 | self.output_size = self.t_cfg.crop_size_val 129 | 130 | # Create filename dataset from the exported nuscenes dataset 131 | def get_filename_dataset(self): 132 | dataset_root = os.path.join(cfg.EXPORT_ROOT, cfg.export_name) 133 | 134 | # Use different root for different mode 135 | if self.mode == "train": 136 | dataset_root = os.path.join(dataset_root, "train") 137 | elif self.mode == "val": 138 | dataset_root = os.path.join(dataset_root, "val") 139 | else: 140 | raise ValueError("[Error] Unknown dataset mode") 141 | 142 | # Add different radar root for version 3 143 | if cfg.version == "ver3": 144 | assert cfg.radar_export_name is not None 145 | dataset_root_radar = os.path.join(cfg.EXPORT_ROOT, cfg.radar_export_name) 146 | if self.mode == "train": 147 | dataset_root_radar = os.path.join(dataset_root_radar, "train") 148 | elif self.mode == "val": 149 | dataset_root_radar = os.path.join(dataset_root_radar, "val") 150 | 151 | # Get all filenames in the dataroot 152 | filenames = os.listdir(dataset_root) 153 | filenames = [_ for _ in filenames if _.endswith(".h5")] 154 | 155 | # Get subset of filenames given dataset version 156 | if cfg.version in ["ver1", "ver3"]: 157 | ver1_ori = ["front", "back"] 158 | filenames = [_ for _ in filenames if os.path.splitext(_)[0].split("_")[-1] in ver1_ori] 159 | 160 | assert len(filenames) > 0 161 | # Add to full data path 162 | filenames_original = [os.path.join(dataset_root, _) for _ in filenames] 163 | 164 | # Add special case for version3 165 | if cfg.version == "ver3": 166 | filenames_radar = [os.path.join(dataset_root_radar, _) for _ in filenames] 167 | return { 168 | "datapoints": filenames_original, 169 | "datapoints_radar": filenames_radar, 170 | "length": len(filenames) 171 | } 172 | 173 | else: 174 | return { 175 | "datapoints": filenames_original, 176 | "length": len(filenames) 177 | } 178 | 179 | # Get h5 data given data path 180 | def get_data(self, datapoint): 181 | # Check if file exists 182 | if not os.path.exists(datapoint): 183 | raise ValueError("[Error] File does not exist.") 184 | 185 | # Read file 186 | output_dict = {} 187 | with h5py.File(datapoint, "r") as f: 188 | for key_name in f: 189 | output_dict[key_name] = np.array(f[key_name]) 190 | 191 | # Decompress depth 192 | if "lidar_depth" in output_dict.keys(): 193 | output_dict["lidar_depth"] = output_dict["lidar_depth"] / 256. 194 | if "radar_depth" in output_dict.keys(): 195 | output_dict["radar_depth"] = output_dict["radar_depth"] / 256. 196 | 197 | return output_dict 198 | 199 | # Sampling sparse depth points from lidar 200 | def get_sparse_depth(self, input_depth, radar_depth=None): 201 | # Check if the sparsifier is valid 202 | if not self.sparsifier in ["uniform", "lidar_radar"]: 203 | raise ValueError("[Error] Invalid lidar sparsifier.") 204 | 205 | if self.sparsifier == "uniform": 206 | mask_keep = self.sparsifier_func.dense_to_sparse(input_depth) 207 | 208 | elif self.sparsifier == "lidar_radar": 209 | assert radar_depth is not None 210 | mask_keep = self.sparsifier_func.dense_to_sparse(input_depth, radar_depth) 211 | mask_keep = torch.tensor(mask_keep[..., None].transpose(2, 0, 1)).to(torch.bool) 212 | 213 | # ipdb.set_trace() 214 | sparse_depth = torch.zeros(input_depth.shape) 215 | sparse_depth[mask_keep] = input_depth[mask_keep] 216 | return sparse_depth 217 | 218 | # Get the exported day night info 219 | def get_day_night_info(self): 220 | # Check the default 221 | file_path = os.path.join(cfg.EXPORT_ROOT, "nuscenes_day_night_info.pkl") 222 | 223 | if not os.path.exists(file_path): 224 | raise ValueError("[Error] Can't find the day-night info pickle file in %s" % (file_path)) 225 | 226 | # Load the file 227 | with open(file_path, "rb") as f: 228 | data = pickle.load(f) 229 | 230 | return data 231 | 232 | # Return the length of a dataset 233 | def __len__(self): 234 | return self.filename_dataset["length"] 235 | 236 | # Define the transform for train 237 | def transform_train(self, input_data): 238 | # import ipdb; ipdb.set_trace() 239 | # Fetch the data 240 | rgb = np.array(input_data["image"]).astype(np.float32) 241 | lidar_depth = np.array(input_data["lidar_depth"]).astype(np.float32) 242 | radar_depth = np.array(input_data["radar_depth"]).astype(np.float32) 243 | if 'index_map' in input_data.keys(): 244 | index_map = np.array(input_data["index_map"]).astype(np.int) 245 | 246 | # Define augmentation factor 247 | scale_factor = np.random.uniform(self.t_cfg.scale_factor_train[0], self.t_cfg.scale_factor_train[1]) # random scaling 248 | angle_factor = np.random.uniform(-self.t_cfg.rotation_factor, self.t_cfg.rotation_factor) # random rotation degrees 249 | flip_factor = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 250 | 251 | # Compose customized transform for RGB and Depth separately 252 | color_jitter = transforms.ColorJitter(0.2, 0.2, 0.2) 253 | resize_image = transforms.Resize(scale_factor, interpolation="bilinear") 254 | resize_depth = transforms.Resize(scale_factor, interpolation="nearest") 255 | 256 | # # First, we uniformly downsample all the images by half 257 | # resize_image_initial = transforms.Resize(0.5, interpolation="bilinear") 258 | # resize_depth_initial = transforms.Resize(0.5, interpolation="nearest") 259 | 260 | # Then, we add model-aware resizing 261 | if self.transform_mode == "DORN": 262 | if cfg.scaling is True: 263 | h, w, _ = tuple((np.array(rgb.shape)).astype(np.int32)) 264 | else: 265 | h, w, _ = tuple((np.array(rgb.shape)* 0.5).astype(np.int32)) 266 | 267 | # ipdb.set_trace() 268 | h_new = self.t_cfg.crop_size_train[0] 269 | w_new = w 270 | resize_image_method = transforms.Resize([h_new, w_new], interpolation="bilinear") 271 | resize_depth_method = transforms.Resize([h_new, w_new], interpolation="nearest") 272 | elif self.transform_mode == "sparse-to-dense": 273 | h_new = self.t_cfg.crop_size_train[0] 274 | w_new = self.t_cfg.crop_size_train[1] 275 | resize_image_method = transforms.Resize([h_new, w_new], interpolation="bilinear") 276 | resize_depth_method = transforms.Resize([h_new, w_new], interpolation="nearest") 277 | 278 | # Get the border of random crop 279 | h_scaled, w_scaled = math.floor(h_new * scale_factor), math.floor((w_new * scale_factor)) 280 | h_bound, w_bound = h_scaled - self.t_cfg.crop_size_train[0], w_scaled - self.t_cfg.crop_size_train[1] 281 | h_startpoint = round(np.random.uniform(0, h_bound)) 282 | w_startpoint = round(np.random.uniform(0, w_bound)) 283 | 284 | # Compose the transforms for RGB 285 | transform_rgb = transforms.Compose([ 286 | transforms.Rotate(angle_factor), 287 | resize_image, 288 | transforms.Crop(h_startpoint, w_startpoint, self.t_cfg.crop_size_train[0], self.t_cfg.crop_size_train[1]), 289 | transforms.HorizontalFlip(flip_factor) 290 | ]) 291 | 292 | # Compose the transforms for Depth 293 | transform_depth = transforms.Compose([ 294 | transforms.Rotate(angle_factor), 295 | resize_depth, 296 | transforms.Crop(h_startpoint, w_startpoint, self.t_cfg.crop_size_train[0], self.t_cfg.crop_size_train[1]), 297 | transforms.HorizontalFlip(flip_factor) 298 | ]) 299 | 300 | # Perform transform on rgb data 301 | # ToDo: whether we need to - imagenet mean here 302 | rgb = transform_rgb(rgb) 303 | rgb = color_jitter(rgb) 304 | rgb = rgb / 255. 305 | 306 | # Perform transform on lidar depth data 307 | lidar_depth /= float(scale_factor) 308 | lidar_depth = transform_depth(lidar_depth) 309 | 310 | rgb = np.array(rgb).astype(np.float32) 311 | lidar_depth = np.array(lidar_depth).astype(np.float32) 312 | 313 | rgb = to_tensor(rgb) 314 | lidar_depth = to_tensor(lidar_depth) 315 | 316 | # Perform transform on radar depth data 317 | radar_depth /= float(scale_factor) 318 | radar_depth = transform_depth(radar_depth) 319 | 320 | radar_depth = np.array(radar_depth).astype(np.float32) 321 | radar_depth = to_tensor(radar_depth) 322 | 323 | # Perform transform on index map 324 | if 'index_map' in input_data.keys(): 325 | index_map = transform_depth(index_map) 326 | index_map = np.array(index_map).astype(np.int) 327 | index_map = to_tensor(index_map) 328 | index_map = index_map.unsqueeze(0) 329 | 330 | # Normalize rgb using imagenet mean and std 331 | # ToDo: only do imagenet normalization on DORN 332 | if self.transform_mode == "DORN": 333 | rgb = transforms.normalization_imagenet(rgb) 334 | 335 | if self.sparsifier == "radar_filtered": 336 | #################### 337 | ## Filtering part ## 338 | #################### 339 | # Indicating the invalid entries 340 | invalid_mask = ~ input_data['valid_mask'] 341 | invalid_index = np.where(invalid_mask)[0] 342 | invalid_index_mask = invalid_index[None, None, ...].transpose(2, 0, 1) 343 | 344 | # Constructing mask for dense depth 345 | dense_mask = torch.ByteTensor(np.sum(index_map.numpy() == invalid_index_mask, axis=0)) 346 | radar_depth_filtered = radar_depth.clone() 347 | radar_depth_filtered[dense_mask.to(torch.bool)] = 0. 348 | radar_depth_filtered = radar_depth_filtered.unsqueeze(0) 349 | 350 | if self.sparsifier == "radar_filtered2": 351 | ###################################### 352 | ## Filtering using predicted labels ## 353 | ###################################### 354 | invalid_mask = ~ input_data['pred_labels'] 355 | invalid_index = np.where(invalid_mask)[0] 356 | invalid_index_mask = invalid_index[None, None, ...].transpose(2, 0, 1) 357 | 358 | dense_mask = torch.ByteTensor(np.sum(index_map.numpy() == invalid_index_mask, axis=0)) 359 | radar_depth_filtered2 = radar_depth.clone() 360 | radar_depth_filtered2[dense_mask.to(torch.bool)] = 0. 361 | radar_depth_filtered2 = radar_depth_filtered2.unsqueeze(0) 362 | ###################################### 363 | 364 | lidar_depth = lidar_depth.unsqueeze(0) 365 | radar_depth = radar_depth.unsqueeze(0) 366 | 367 | # Return different data for different modality 368 | if self.modality == "rgb": 369 | inputs = rgb 370 | elif self.modality == "rgbd": 371 | if self.sparsifier == "radar": 372 | # Filter out the the points exceeding max_depth 373 | mask = (radar_depth > self.max_depth) 374 | radar_depth[mask] = 0 375 | inputs = torch.cat((rgb, radar_depth), dim=0) 376 | # Using the generated groundtruth 377 | elif self.sparsifier == "radar_filtered": 378 | # Filter out the points exceeding max_depth 379 | mask = (radar_depth_filtered > self.max_depth) 380 | radar_depth_filtered[mask] = 0 381 | inputs = torch.cat((rgb, radar_depth_filtered), dim=0) 382 | # Using the learned classifyer 383 | elif self.sparsifier == "radar_filtered2": 384 | # Filter out the points exceeding max_depth 385 | mask = (radar_depth_filtered2 > self.max_depth) 386 | radar_depth_filtered2[mask] = 0 387 | inputs = torch.cat((rgb, radar_depth_filtered2), dim=0) 388 | else: 389 | s_depth = self.get_sparse_depth(lidar_depth, radar_depth) 390 | inputs = torch.cat((rgb, s_depth), dim=0) 391 | else: 392 | raise ValueError("[Error] Unsupported modality. Consider ", self.avail_modality) 393 | labels = lidar_depth 394 | 395 | # Gathering output results 396 | output_dict = { 397 | "rgb": rgb, 398 | "lidar_depth": lidar_depth, 399 | "radar_depth": radar_depth, 400 | "inputs": inputs, 401 | "labels": labels 402 | } 403 | if self.sparsifier == "radar_filtered": 404 | output_dict["radar_depth_filtered"] = radar_depth_filtered 405 | 406 | if self.sparsifier == "radar_filtered2": 407 | output_dict["radar_depth_filtered2"] = radar_depth_filtered2 408 | 409 | if 'index_map' in input_data.keys(): 410 | output_dict["index_map"] = index_map 411 | 412 | return output_dict 413 | 414 | # Define the transform for val 415 | def transform_val(self, input_data): 416 | rgb = np.array(input_data["image"]).astype(np.float32) 417 | lidar_depth = np.array(input_data["lidar_depth"]).astype(np.float32) 418 | radar_depth = np.array(input_data["radar_depth"]).astype(np.float32) 419 | if 'index_map' in input_data.keys(): 420 | index_map = np.array(input_data["index_map"]).astype(np.int) 421 | 422 | # Then, we add model-aware resizing 423 | if self.transform_mode == "DORN": 424 | if cfg.scaling is True: 425 | h, w, _ = tuple((np.array(rgb.shape)).astype(np.int32)) 426 | else: 427 | h, w, _ = tuple((np.array(rgb.shape) * 0.5).astype(np.int32)) 428 | 429 | h_new = self.t_cfg.crop_size_train[0] 430 | w_new = w 431 | resize_image_method = transforms.Resize([h_new, w_new], interpolation="bilinear") 432 | resize_depth_method = transforms.Resize([h_new, w_new], interpolation="nearest") 433 | elif self.transform_mode == "sparse-to-dense": 434 | h_new = self.t_cfg.crop_size_train[0] 435 | w_new = self.t_cfg.crop_size_train[1] 436 | resize_image_method = transforms.Resize([h_new, w_new], interpolation="bilinear") 437 | resize_depth_method = transforms.Resize([h_new, w_new], interpolation="nearest") 438 | 439 | transform_rgb = transforms.Compose([ 440 | # resize_image_method, 441 | transforms.CenterCrop(self.t_cfg.crop_size_val) 442 | ]) 443 | transform_depth = transforms.Compose([ 444 | # resize_depth_method, 445 | transforms.CenterCrop(self.t_cfg.crop_size_val) 446 | ]) 447 | 448 | rgb = transform_rgb(rgb) 449 | rgb = rgb / 255. 450 | lidar_depth = transform_depth(lidar_depth) 451 | 452 | rgb = np.array(rgb).astype(np.float32) 453 | lidar_depth = np.array(lidar_depth).astype(np.float32) 454 | 455 | rgb = to_tensor(rgb) 456 | lidar_depth = to_tensor(lidar_depth) 457 | 458 | radar_depth = transform_depth(radar_depth) 459 | radar_depth = np.array(radar_depth).astype(np.float32) 460 | radar_depth = to_tensor(radar_depth) 461 | 462 | # Perform transform on index map 463 | if 'index_map' in input_data.keys(): 464 | index_map = transform_depth(index_map) 465 | index_map = np.array(index_map).astype(np.int) 466 | index_map = to_tensor(index_map) 467 | index_map = index_map.unsqueeze(0) 468 | 469 | # Normalize to imagenet mean and std 470 | if self.transform_mode == "DORN": 471 | rgb = transforms.normalization_imagenet(rgb) 472 | 473 | #################### 474 | ## Filtering part ## 475 | #################### 476 | if self.sparsifier == "radar_filtered": 477 | # Indicating the invalid entries 478 | invalid_mask = ~ input_data['valid_mask'] 479 | invalid_index = np.where(invalid_mask)[0] 480 | invalid_index_mask = invalid_index[None, None, ...].transpose(2, 0, 1) 481 | 482 | # Constructing mask for dense depth 483 | dense_mask = torch.ByteTensor(np.sum(index_map.numpy() == invalid_index_mask, axis=0)) 484 | radar_depth_filtered = radar_depth.clone() 485 | radar_depth_filtered[dense_mask.to(torch.bool)] = 0. 486 | radar_depth_filtered = radar_depth_filtered.unsqueeze(0) 487 | # ipdb.set_trace() 488 | #################### 489 | 490 | ###################################### 491 | ## Filtering using predicted labels ## 492 | ###################################### 493 | if self.sparsifier == "radar_filtered2": 494 | # ipdb.set_trace() 495 | invalid_mask = ~ input_data['pred_labels'] 496 | invalid_index = np.where(invalid_mask)[0] 497 | invalid_index_mask = invalid_index[None, None, ...].transpose(2, 0, 1) 498 | 499 | dense_mask = torch.ByteTensor(np.sum(index_map.numpy() == invalid_index_mask, axis=0)) 500 | radar_depth_filtered2 = radar_depth.clone() 501 | radar_depth_filtered2[dense_mask.to(torch.bool)] = 0. 502 | radar_depth_filtered2 = radar_depth_filtered2.unsqueeze(0) 503 | ###################################### 504 | 505 | lidar_depth = lidar_depth.unsqueeze(0) 506 | radar_depth = radar_depth.unsqueeze(0) 507 | 508 | # Return different data for different modality 509 | ################ Input sparsifier ######### 510 | if self.modality == "rgb": 511 | inputs = rgb 512 | elif self.modality == "rgbd": 513 | if self.sparsifier == "radar": 514 | # Filter out the the points exceeding max_depth 515 | mask = (radar_depth > self.max_depth) 516 | radar_depth[mask] = 0 517 | inputs = torch.cat((rgb, radar_depth), dim=0) 518 | elif self.sparsifier == "radar_filtered": 519 | # Filter out the points exceeding max_depth 520 | mask = (radar_depth_filtered > self.max_depth) 521 | radar_depth_filtered[mask] = 0 522 | inputs = torch.cat((rgb, radar_depth_filtered), dim=0) 523 | # Using the learned classifyer 524 | elif self.sparsifier == "radar_filtered2": 525 | # Filter out the points exceeding max_depth 526 | mask = (radar_depth_filtered2 > self.max_depth) 527 | radar_depth_filtered2[mask] = 0 528 | inputs = torch.cat((rgb, radar_depth_filtered2), dim=0) 529 | else: 530 | s_depth = self.get_sparse_depth(lidar_depth, radar_depth) 531 | inputs = torch.cat((rgb, s_depth), dim=0) 532 | else: 533 | raise ValueError("[Error] Unsupported modality. Consider ", self.avail_modality) 534 | labels = lidar_depth 535 | 536 | output_dict = { 537 | "rgb": rgb, 538 | "lidar_depth": lidar_depth, 539 | "radar_depth": radar_depth, 540 | "inputs": inputs, 541 | "labels": labels 542 | } 543 | 544 | if self.sparsifier == "radar_filtered": 545 | output_dict["radar_depth_filtered"] = radar_depth_filtered 546 | 547 | if self.sparsifier == "radar_filtered2": 548 | output_dict["radar_depth_filtered2"] = radar_depth_filtered2 549 | 550 | # For 'index_map' compatibility 551 | if 'index_map' in input_data.keys(): 552 | output_dict["index_map"] = index_map 553 | 554 | return output_dict 555 | 556 | # Add index map and perform valid check 557 | def filter_radar_points(self, input_data): 558 | # Fetch data 559 | radar_points = input_data['radar_points'] 560 | radar_depth_points = input_data['radar_depth_points'] 561 | 562 | # Construct index map 563 | depth_loc = (radar_points[:2, :].T).astype(np.int32) 564 | point_index = np.arange(0, radar_points.shape[1], 1) 565 | index_map = - np.ones(input_data['image'].shape[:2]) 566 | index_map[depth_loc[:, 1], depth_loc[:, 0]] = point_index 567 | 568 | input_data['index_map'] = index_map 569 | 570 | # Filter the radar points 571 | filtered_data = filter_radar_points_gt(input_data['radar_points'], 572 | input_data['radar_depth_points'], 573 | input_data['lidar_points'], 574 | input_data['lidar_depth_points']) 575 | 576 | # Record the masked depth points 577 | input_data['radar_points_filtered'] = filtered_data['radar_points'] 578 | input_data['radar_depth_points_filtered'] = filtered_data['radar_depth'] 579 | input_data['valid_mask'] = filtered_data['valid_mask'] 580 | 581 | if self.sparsifier == "radar_filtered2": 582 | raise NotImplementedError("[Error] The filtering method using point classifier is not supported in the released code.") 583 | 584 | return input_data 585 | 586 | # Perform transform on radar pointclouds 587 | def transform_point(self, point_data): 588 | points = point_data["radar_points_raw"] 589 | labels = point_data["radar_points_label"][..., None] 590 | 591 | # shuffle the points 592 | if self.mode == "train": 593 | tmp = np.concatenate((points, labels), axis=-1) 594 | np.random.shuffle(tmp) 595 | points = tmp[:, :-1] 596 | labels = tmp[:, -1][..., None] 597 | 598 | # Pad the points to 512 599 | num_points = points.shape[0] 600 | target_size = 512 601 | output_points = np.repeat(points[-1, :][None, ...], target_size, axis=0) 602 | output_points[:num_points, :] = points 603 | output_labels = np.repeat(labels[-1, :][None, ...], target_size, axis=0) 604 | output_labels[:num_points, :] = labels 605 | 606 | # Create valid mask 607 | mask = np.zeros([target_size, 1]) 608 | mask[:num_points] = 1 609 | 610 | return { 611 | "radar_points_raw": to_tensor(output_points), 612 | "radar_points_label": to_tensor(output_labels), 613 | "radar_points_mask": to_tensor(mask) 614 | } 615 | 616 | # Define the getitem method 617 | def __getitem__(self, index): 618 | # Get data from given index 619 | datapoint = self.filename_dataset["datapoints"][index] 620 | data = self.get_data(datapoint) 621 | 622 | # Get the daynight info 623 | daynight_key = os.path.basename(datapoint) 624 | if self.mode == "train": 625 | daynight_info = self.train_daynight_table[daynight_key] 626 | else: 627 | daynight_info = self.test_daynight_table[daynight_key] 628 | 629 | # Further get radar datapoints 630 | if cfg.version == "ver3": 631 | datapoint_radar = self.filename_dataset["datapoints_radar"][index] 632 | data_radar = self.get_data(datapoint_radar) 633 | 634 | for key in data_radar.keys(): 635 | data[key] = data_radar[key] 636 | 637 | # Filter radar points here 638 | data = self.filter_radar_points(data) 639 | 640 | # Apply transforms given mode 641 | if self.mode == "train": 642 | outputs = self.transform_train(data) 643 | 644 | else: 645 | outputs = self.transform_val(data) 646 | 647 | # Add daynight info 648 | outputs["daynight_info"] = daynight_info 649 | 650 | return outputs 651 | 652 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import csv 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torchvision.utils import make_grid 11 | cudnn.benchmark = True 12 | 13 | from model.utils import Result as Result_point 14 | from model.utils import AverageMeter as AverageMeter_point 15 | # Models and modified models from sparse to dense 16 | from model.models import ( 17 | ResNet, 18 | ResNet2, 19 | ResNet_latefusion 20 | ) 21 | # The multi-stage model variants 22 | from model.multistage_model import ResNet_multistage 23 | from evaluation.metrics import AverageMeter, Result 24 | from tensorboardX import SummaryWriter 25 | import utils 26 | from evaluation.criteria_new import ( 27 | MaskedCrossEntropyLoss, 28 | SmoothnessLoss, 29 | MaskedMSELoss, 30 | MaskedL1Loss 31 | ) 32 | from dataset.nuscenes_dataset_torch_new import nuscenes_dataset_torch 33 | import torch.utils.data.dataloader as torch_loader 34 | 35 | args = utils.parse_command() 36 | 37 | fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae', 38 | 'delta1', 'delta2', 'delta3', 39 | 'data_time', 'gpu_time'] 40 | 41 | best_result = Result() 42 | best_result.set_to_worst() 43 | 44 | multistage_group = ['resnet18_multistage', 'resnet18_multistage_uncertainty', 'resnet18_multistage_uncertainty_fixs'] 45 | uncertainty_group = ['resnet18_multistage_uncertainty', 'resnet18_multistage_uncertainty_fixs'] 46 | torch.cuda.empty_cache() 47 | torch.backends.cudnn.benchmark = True 48 | 49 | # Define the customized collate_fn 50 | def customized_collate(batch): 51 | list_keys = ["daynight_info"] 52 | batch_keys = ['rgb', 'lidar_depth', 'radar_depth', 'inputs', \ 53 | 'labels', 'index_map'] 54 | outputs = {} 55 | for key in batch_keys: 56 | outputs[key] = torch_loader.default_collate([b[key] for b in batch]) 57 | for key in list_keys: 58 | outputs[key] = [b[key] for b in batch] 59 | 60 | return outputs 61 | 62 | 63 | # Create dataloader given input arguments 64 | def create_data_loaders(args): 65 | # Data loading code 66 | print("[Info] Creating data loaders ...") 67 | train_loader = None 68 | val_loader = None 69 | 70 | # sparsifier is a class for generating random sparse depth input from the ground truth 71 | max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf 72 | 73 | if args.data == "nuscenes": 74 | if not args.evaluate: 75 | train_dataset = nuscenes_dataset_torch( 76 | "train", 77 | transform_mode="sparse-to-dense", 78 | modality=args.modality, 79 | sparsifier=args.sparsifier, 80 | num_samples=args.num_samples, 81 | max_depth=max_depth 82 | ) 83 | if args.validation: 84 | val_dataset = nuscenes_dataset_torch( 85 | "val", 86 | transform_mode="sparse-to-dense", 87 | modality=args.modality, 88 | sparsifier=args.sparsifier, 89 | num_samples=args.num_samples, 90 | max_depth=max_depth 91 | ) 92 | else: 93 | raise RuntimeError('[Error] Dataset not found. The dataset must be nuscenes') 94 | 95 | if args.validation: 96 | # Always use batch_size=1 in validation 97 | val_loader = torch.utils.data.DataLoader(val_dataset, 98 | batch_size=1, num_workers=4, shuffle=False, pin_memory=True, 99 | collate_fn=customized_collate 100 | ) 101 | 102 | # put construction of train loader here, for those who are interested in testing only 103 | if not args.evaluate: 104 | train_loader = torch.utils.data.DataLoader( 105 | train_dataset, batch_size=args.batch_size, shuffle=True, 106 | num_workers=args.workers, pin_memory=True, sampler=None, 107 | worker_init_fn=lambda work_id:np.random.seed(work_id) 108 | ) # worker_init_fn ensures different sampling patterns for each data loading thread 109 | 110 | print("=> data loaders created.") 111 | if args.validation: 112 | return train_loader, val_loader 113 | else: 114 | return train_loader 115 | 116 | 117 | # Create model given input arguments and output size 118 | def create_model(args, output_size): 119 | print(f"[Info] Creating Model ({args.arch}-{args.decoder}) ...") 120 | in_channels = len(args.modality) 121 | if args.arch == 'resnet50': 122 | model = ResNet(layers=50, decoder=args.decoder, output_size=output_size, 123 | in_channels=in_channels, pretrained=args.pretrained) 124 | elif args.arch == 'resnet18': 125 | model = ResNet(layers=18, decoder=args.decoder, output_size=output_size, 126 | in_channels=in_channels, pretrained=args.pretrained) 127 | elif args.arch == "resnet34": 128 | model = ResNet(layers=34, decoder=args.decoder, output_size=output_size, 129 | in_channels=in_channels, pretrained=args.pretrained) 130 | elif args.arch == 'resnet18_new': 131 | model = ResNet2(layers=18, decoder=args.decoder, output_size=output_size, 132 | in_channels=in_channels, pretrained=args.pretrained) 133 | elif args.arch == "resnet18_latefusion": 134 | model = ResNet_latefusion(layers=18, decoder=args.decoder, output_size=output_size, 135 | in_channels=in_channels, pretrained=args.pretrained) 136 | elif args.arch == "resnet18_multistage": 137 | model = ResNet_multistage(layers=18, decoder=args.decoder, output_size=output_size, 138 | pretrained=args.pretrained) 139 | # If uncertainty model, we need to add weighting parameters to the model 140 | elif args.arch == "resnet18_multistage_uncertainty": 141 | model = ResNet_multistage(layers=18, decoder=args.decoder, output_size=output_size, 142 | pretrained=args.pretrained) 143 | # Get loss weights 144 | w_stage1 = nn.Parameter(torch.tensor(1., dtype=torch.float32), requires_grad=True) 145 | w_stage2 = nn.Parameter(torch.tensor(1., dtype=torch.float32), requires_grad=True) 146 | w_smooth = nn.Parameter(torch.tensor(0.1, dtype=torch.float32), requires_grad=True) 147 | 148 | # Register the parameters to the model 149 | model.register_parameter("w_stage1", w_stage1) 150 | model.register_parameter("w_stage2", w_stage2) 151 | model.register_parameter("w_smooth", w_smooth) 152 | 153 | loss_weights = { 154 | "w_stage1": w_stage1, 155 | "w_stage2": w_stage2, 156 | "w_smooth": w_smooth 157 | } 158 | 159 | return model, loss_weights 160 | 161 | # If the fixs model, we have deterministic weights for smoothness loss 162 | elif args.arch == "resnet18_multistage_uncertainty_fixs": 163 | model = ResNet_multistage(layers=18, decoder=args.decoder, output_size=output_size, 164 | pretrained=args.pretrained) 165 | # Get loss weights 166 | w_stage1 = nn.Parameter(torch.tensor(1., dtype=torch.float32), requires_grad=True) 167 | w_stage2 = nn.Parameter(torch.tensor(1., dtype=torch.float32), requires_grad=True) 168 | w_smooth = 0.1 169 | 170 | # Register the parameters to the model 171 | model.register_parameter("w_stage1", w_stage1) 172 | model.register_parameter("w_stage2", w_stage2) 173 | 174 | loss_weights = { 175 | "w_stage1": w_stage1, 176 | "w_stage2": w_stage2, 177 | "w_smooth": w_smooth 178 | } 179 | 180 | return model, loss_weights 181 | 182 | else: 183 | raise ValueError("[Error] Unknown model!!") 184 | print("[Info] model created.") 185 | 186 | return model 187 | 188 | 189 | def main(): 190 | global args, best_result, output_directory, train_csv, test_csv 191 | 192 | # evaluation mode 193 | start_epoch = 0 194 | if args.evaluate: 195 | assert os.path.isfile(args.evaluate), \ 196 | f"[Error] Can't find the specified checkpoint at '{args.evaluate}'" 197 | print(f"[Info] loading the model '{args.evaluate}'") 198 | checkpoint = torch.load(args.evaluate) 199 | output_directory = os.path.dirname(args.evaluate) 200 | args = checkpoint['args'] 201 | print(args) 202 | train_loader, val_loader = create_data_loaders(args) 203 | model_weights = checkpoint['model_state_dict'] 204 | # Create model 205 | if args.arch == "resnet18_multistage_uncertainty" or \ 206 | args.arch == "resnet18_multistage_uncertainty_fixs": 207 | model, loss_weights = create_model(args, output_size=train_loader.dataset.output_size) 208 | else: 209 | model = create_model(args, output_size=train_loader.dataset.output_size) 210 | loss_weights = None 211 | model.load_state_dict(model_weights, strict=False) 212 | model = model.cuda() 213 | print(f"[Info] Loaded best model (epoch {checkpoint['epoch']})") 214 | args.evaluate = True 215 | validate(val_loader, model, checkpoint['epoch'], write_to_file=False) 216 | return 217 | 218 | # optionally resume from a checkpoint 219 | elif args.resume: 220 | chkpt_path = args.resume 221 | assert os.path.isfile(chkpt_path), \ 222 | f"[Info] No checkpoint found at '{chkpt_path}'" 223 | print(f"=> loading checkpoint '{chkpt_path}'") 224 | checkpoint = torch.load(chkpt_path) 225 | args = checkpoint['args'] 226 | print(args) 227 | start_epoch = checkpoint['epoch'] + 1 228 | try: 229 | best_result = checkpoint['best_result'] 230 | except: 231 | best_result.set_to_worst() 232 | 233 | # Create dataloader first 234 | args.validation = True 235 | args.workers = 8 236 | 237 | if (args.data == "nuscenes") and (args.modality == "rgbd") and (args.sparsifier == "uar"): 238 | args.sparsifier = None 239 | # Create dataloader 240 | if args.validation: 241 | train_loader, val_loader = create_data_loaders(args) 242 | else: 243 | train_loader = create_data_loaders(args) 244 | # Load from model's state dict instead 245 | model_weights = checkpoint['model_state_dict'] 246 | # Create model 247 | if args.arch == "resnet18_multistage_uncertainty" or \ 248 | args.arch == "resnet18_multistage_uncertainty_fixs": 249 | model, loss_weights = create_model(args, output_size=train_loader.dataset.output_size) 250 | else: 251 | model = create_model(args, output_size=train_loader.dataset.output_size) 252 | loss_weights = None 253 | model.load_state_dict(model_weights, strict=False) 254 | model = model.cuda() 255 | 256 | # Create optimizer 257 | optimizer = torch.optim.SGD( 258 | model.parameters(), 259 | args.lr, 260 | momentum=args.momentum, 261 | weight_decay=args.weight_decay 262 | ) 263 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 264 | output_directory = os.path.dirname(os.path.abspath(chkpt_path)) 265 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 266 | args.resume = True 267 | # Create new model 268 | else: 269 | print(args) 270 | # Create dataloader 271 | if args.validation: 272 | train_loader, val_loader = create_data_loaders(args) 273 | else: 274 | train_loader = create_data_loaders(args) 275 | 276 | # Create model 277 | if args.arch == "resnet18_multistage_uncertainty" or \ 278 | args.arch == "resnet18_multistage_uncertainty_fixs": 279 | model, loss_weights = create_model(args, output_size=train_loader.dataset.output_size) 280 | else: 281 | model = create_model(args, output_size=train_loader.dataset.output_size) 282 | loss_weights = None 283 | 284 | # Create optimizer 285 | optimizer = torch.optim.SGD( 286 | model.parameters(), 287 | args.lr, 288 | momentum=args.momentum, 289 | weight_decay=args.weight_decay 290 | ) 291 | model = model.cuda() 292 | 293 | # Define loss function (criterion) and optimizer 294 | criterion = {} 295 | if args.criterion == 'l2': 296 | criterion["depth"] = MaskedMSELoss().cuda() 297 | elif args.criterion == 'l1': 298 | criterion["depth"] = MaskedL1Loss().cuda() 299 | else: 300 | raise ValueError("[Error] Unknown criterion...") 301 | 302 | # Add smoothness loss to the criterion 303 | if args.arch == "resnet18_multistage_uncertainty" or \ 304 | args.arch == "resnet18_multistage_uncertainty_fixs": 305 | criterion["smooth"] = SmoothnessLoss().cuda() 306 | 307 | # Create results folder, if not already exists 308 | output_directory = utils.get_output_directory(args) 309 | if not os.path.exists(output_directory): 310 | os.makedirs(output_directory) 311 | train_csv = os.path.join(output_directory, 'train.csv') 312 | test_csv = os.path.join(output_directory, 'test.csv') 313 | best_txt = os.path.join(output_directory, 'best.txt') 314 | 315 | # Create new csv files with only header 316 | if not args.resume: 317 | with open(train_csv, 'w') as csvfile: 318 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 319 | writer.writeheader() 320 | with open(test_csv, 'w') as csvfile: 321 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 322 | writer.writeheader() 323 | 324 | # Create summary writer 325 | log_path = os.path.join(output_directory, "logs") 326 | if not os.path.exists(log_path): 327 | os.makedirs(log_path) 328 | logger = SummaryWriter(log_path) 329 | 330 | # Main training loop 331 | for epoch in range(start_epoch, args.epochs): 332 | # Adjust the learning rate 333 | utils.adjust_learning_rate(optimizer, epoch, args.lr) 334 | 335 | # Record the learning rate summary 336 | for i, param_group in enumerate(optimizer.param_groups): 337 | old_lr = float(param_group['lr']) 338 | logger.add_scalar('Lr/lr_' + str(i), old_lr, epoch) 339 | 340 | # Perform training (train for one epoch) 341 | train(train_loader, model, criterion, optimizer, epoch, loss_weights, logger=logger) 342 | 343 | # Perform evaluation 344 | if args.validation: 345 | result, img_merge = validate(val_loader, model, epoch, logger=logger) 346 | 347 | is_best = result.rmse < best_result.rmse 348 | if is_best: 349 | best_result = result 350 | with open(best_txt, 'w') as txtfile: 351 | txtfile.write("epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n". 352 | format(epoch, result.mse, result.rmse, result.absrel, result.lg10, result.mae, result.delta1, result.gpu_time)) 353 | if img_merge is not None: 354 | img_filename = output_directory + '/comparison_best.png' 355 | utils.save_image(img_merge, img_filename) 356 | 357 | # Save different things in different mode 358 | if args.validation: 359 | utils.save_checkpoint({ 360 | 'args': args, 361 | 'epoch': epoch, 362 | 'arch': args.arch, 363 | 'model_state_dict': model.state_dict(), 364 | 'best_result': best_result, 365 | 'optimizer_state_dict' : optimizer.state_dict(), 366 | }, is_best, epoch, output_directory) 367 | else: 368 | utils.save_checkpoint({ 369 | 'args': args, 370 | 'epoch': epoch, 371 | 'arch': args.arch, 372 | 'model_state_dict': model.state_dict(), 373 | 'optimizer_state_dict': optimizer.state_dict(), 374 | }, False, epoch, output_directory) 375 | 376 | 377 | def train(train_loader, model, criterion, optimizer, epoch, loss_weights=None, logger=None): 378 | # pdb.set_trace() 379 | average_meter = AverageMeter() 380 | if args.arch in multistage_group: 381 | average_meter_stage1 = AverageMeter() 382 | 383 | model.train() # switch to train mode 384 | end = time.time() 385 | 386 | # Record number of batches 387 | batch_num = len(train_loader) 388 | for i, data in enumerate(train_loader): 389 | ############ Fetch input data ################ 390 | # Add compatibility for nuscenes 391 | if args.data != "nuscenes": 392 | inputs, target = data[0].cuda(), data[1].cuda() 393 | else: 394 | inputs, target = data["inputs"].cuda(), data["labels"].cuda() 395 | 396 | torch.cuda.synchronize() 397 | data_time = time.time() - end 398 | 399 | # Training step 400 | end = time.time() 401 | if args.arch == "resnet18_multistage_uncertainty": 402 | pred_ = model(inputs) 403 | pred1 = pred_["stage1"] 404 | pred = pred_["stage2"] 405 | depth_loss1 = criterion["depth"](pred1, target) 406 | depth_loss2 = criterion["depth"](pred, target) 407 | smooth_loss = criterion["smooth"](pred1, input) 408 | weight_loss = loss_weights["w_stage1"] + loss_weights["w_stage2"] + loss_weights["w_smooth"] 409 | 410 | # Weighted sum to total loss 411 | loss = torch.exp(-loss_weights["w_stage1"]) * depth_loss1 + \ 412 | torch.exp(-loss_weights["w_stage2"]) * depth_loss2 + \ 413 | torch.exp(-loss_weights["w_smooth"]) * smooth_loss + \ 414 | weight_loss 415 | 416 | elif args.arch == "resnet18_multistage_uncertainty_fixs": 417 | pred_ = model(inputs) 418 | pred1 = pred_["stage1"] 419 | pred = pred_["stage2"] 420 | depth_loss1 = criterion["depth"](pred1, target) 421 | depth_loss2 = criterion["depth"](pred, target) 422 | smooth_loss = criterion["smooth"](pred1, inputs) 423 | weight_loss = loss_weights["w_stage1"] + loss_weights["w_stage2"] 424 | 425 | # Weighted sum to total loss 426 | stage1_weighted_loss = torch.exp(-loss_weights["w_stage1"]) * (depth_loss1 + (loss_weights["w_smooth"] * smooth_loss)) 427 | stage2_weighted_loss = torch.exp(-loss_weights["w_stage2"]) * depth_loss2 428 | loss = stage1_weighted_loss + stage2_weighted_loss + \ 429 | weight_loss 430 | 431 | elif args.arch in multistage_group: 432 | pred_ = model(inputs) 433 | pred1 = pred_["stage1"] 434 | pred = pred_["stage2"] 435 | depth_loss1 = criterion["depth"](pred1, target) 436 | depth_loss2 = criterion["depth"](pred, target) 437 | loss = depth_loss1 + depth_loss2 438 | 439 | else: 440 | pred = model(inputs) 441 | loss = criterion["depth"](pred, target) 442 | 443 | optimizer.zero_grad() 444 | loss.backward() # compute gradient and do SGD step 445 | optimizer.step() 446 | torch.cuda.synchronize() 447 | gpu_time = time.time() - end 448 | 449 | # [Depth] Measure error and record loss 450 | result = Result() 451 | result.evaluate(pred.data, target.data) 452 | average_meter.update(result, gpu_time, data_time, inputs.size(0)) 453 | 454 | # [Depth] Measure stage1 error 455 | if args.arch in multistage_group: 456 | result_stage1 = Result() 457 | result_stage1.evaluate(pred1.data, target.data) 458 | average_meter_stage1.update(result_stage1, gpu_time, data_time, inputs.size(0)) 459 | 460 | if (i + 1) % args.print_freq == 0: 461 | print('=> output: {}'.format(output_directory)) 462 | print('Train Epoch: {0} [{1}/{2}]\t' 463 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 464 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 465 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 466 | 'MAE={result.mae:.2f}({average.mae:.2f}) ' 467 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 468 | 'REL={result.absrel:.3f}({average.absrel:.3f}) ' 469 | 'Lg10={result.lg10:.3f}({average.lg10:.3f})'.format( 470 | epoch, i+1, len(train_loader), data_time=data_time, 471 | gpu_time=gpu_time, result=result, average=average_meter.average())) 472 | 473 | if ((i + 1) % 100 == 0) and (logger is not None): 474 | current_step = epoch * batch_num + i 475 | # Add scalar summaries 476 | logger.add_scalar('Train_loss/Loss', loss.item(), current_step) 477 | record_scalar_summary(result, average_meter, current_step, logger, "Train") 478 | 479 | # Further add some scalar summaries 480 | if args.arch == "resnet18_multistage_uncertainty": 481 | # Add weight summaries 482 | logger.add_scalar("Train_weights/w_stage1", torch.exp(-loss_weights["w_stage1"]).item(), current_step) 483 | logger.add_scalar("Train_weights/w_stage2", torch.exp(-loss_weights["w_stage2"]).item(), current_step) 484 | logger.add_scalar("Train_weights/w_smooth", torch.exp(-loss_weights["w_smooth"]).item(), current_step) 485 | 486 | # Add loss summary 487 | logger.add_scalar("Train_loss/Smoothness_loss", smooth_loss.item(), current_step) 488 | logger.add_scalar("Train_loss/Weight_loss", weight_loss.item(), current_step) 489 | 490 | # Some scalar summaries for the uncertainty fixs model 491 | if args.arch == "resnet18_multistage_uncertainty_fixs": 492 | # Add weight summaries 493 | logger.add_scalar("Train_weights/w_stage1", torch.exp(-loss_weights["w_stage1"]).item(), current_step) 494 | logger.add_scalar("Train_weights/w_stage2", torch.exp(-loss_weights["w_stage2"]).item(), current_step) 495 | 496 | # Add loss summary 497 | logger.add_scalar("Train_loss/Smoothness_loss", smooth_loss.item(), current_step) 498 | logger.add_scalar("Train_loss/Weight_loss", weight_loss.item(), current_step) 499 | 500 | # Add weighted loss 501 | logger.add_scalar("Train_loss_weighted/stage1", stage1_weighted_loss.item(), current_step) 502 | logger.add_scalar("Train_loss_weighted/stage2", stage2_weighted_loss.item(), current_step) 503 | 504 | if args.arch in multistage_group: 505 | logger.add_scalar('Train_loss/Depth_loss1', depth_loss1.item(), current_step) 506 | logger.add_scalar('Train_loss/Depth_loss2', depth_loss2.item(), current_step) 507 | # Record error summaries for stage1 508 | record_scalar_summary(result_stage1, average_meter_stage1, current_step, logger, "Train_stage1") 509 | 510 | # Add system info 511 | logger.add_scalar('System/gpu_time', average_meter.average().gpu_time, current_step) 512 | logger.add_scalar('System/data_time', average_meter.average().data_time, current_step) 513 | 514 | # Add some image summary 515 | if args.modality == "rgb": 516 | input_images = inputs.cpu() 517 | else: 518 | input_images = inputs[:, 0:3, :, :].cpu() 519 | input_depth = torch.unsqueeze(inputs[:, 3, :, :], dim=1).cpu() 520 | rgb_grid = make_grid(input_images[0:6, :, :, :], nrow=3, normalize=False), 521 | target_grid = make_grid(target.cpu()[0:6, :, :, :], nrow=3, normalize=True, range=(0, 80)) 522 | pred_grid = make_grid(pred.cpu()[0:6, :, :, :], nrow=3, normalize=True, range=(0,80)) 523 | logger.add_image('Train/RGB', rgb_grid[0].data.numpy()) 524 | logger.add_image('Train/Depth_gt', target_grid.data.numpy()) 525 | logger.add_image('Train/Depth_pred', pred_grid.data.numpy()) 526 | 527 | # Also record depth predictions from stage1 528 | if args.arch in multistage_group: 529 | pred_grid1 = make_grid(pred1.cpu()[0:6, :, :, :], nrow=3, normalize=True, range=(0,80)) 530 | logger.add_image('Train/Depth_pred1', pred_grid1.data.numpy()) 531 | if args.modality == "rgbd": 532 | depth_grid = make_grid(input_depth[0:6, :, :, :], nrow=3, normalize=True, range=(0, 80)) 533 | logger.add_image('Train/Depth_input', depth_grid.data.numpy()) 534 | 535 | end = time.time() 536 | 537 | avg = average_meter.average() 538 | with open(train_csv, 'a') as csvfile: 539 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 540 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10, 541 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3, 542 | 'gpu_time': avg.gpu_time, 'data_time': avg.data_time}) 543 | 544 | 545 | def validate(val_loader, model, epoch, write_to_file=True, logger=None): 546 | average_meter = AverageMeter() 547 | if args.arch in multistage_group: 548 | average_meter_stage1 = AverageMeter() 549 | 550 | # Include daynight info and rain condition 551 | avg_meter_day = AverageMeter() 552 | avg_meter_night = AverageMeter() 553 | 554 | # day, night, sun, rain combinations 555 | avg_meter_day_sun = AverageMeter() 556 | avg_meter_day_rain = AverageMeter() 557 | avg_meter_night_sun = AverageMeter() 558 | avg_meter_night_rain = AverageMeter() 559 | 560 | # sun and rain 561 | avg_meter_sun = AverageMeter() 562 | avg_meter_rain = AverageMeter() 563 | 564 | model.eval() # switch to evaluate mode 565 | end = time.time() 566 | 567 | # Save something to draw?? 568 | if logger is None: 569 | import h5py 570 | output_path = os.path.join(output_directory, "results.h5") 571 | h5_writer = h5py.File(output_path, "w", libver="latest", swmr=True) 572 | 573 | for i, data in enumerate(val_loader): 574 | # Add compatibility for nuscenes 575 | if args.data != "nuscenes": 576 | inputs, target = data[0].cuda(), data[1].cuda() 577 | else: 578 | inputs, target = data["inputs"].cuda(), data["labels"].cuda() 579 | 580 | torch.cuda.synchronize() 581 | data_time = time.time() - end 582 | 583 | # Compute output 584 | end = time.time() 585 | with torch.no_grad(): 586 | if args.arch in multistage_group: 587 | pred_ = model(inputs) 588 | pred1 = pred_["stage1"] 589 | pred = pred_["stage2"] 590 | else: 591 | pred = model(inputs) 592 | pred_ = None 593 | 594 | torch.cuda.synchronize() 595 | gpu_time = time.time() - end 596 | 597 | # Record for qualitative results 598 | if (logger is None) and (i % 5 == 0): 599 | pred_np = {} 600 | if pred_ is None: 601 | pred_np = pred.cpu().numpy() 602 | else: 603 | for key in pred_.keys(): 604 | pred_np[key] = pred_[key][0, ...].cpu().numpy() 605 | res = { 606 | "inputs": data["inputs"][0, ...].cpu().numpy(), 607 | "lidar_depth": data["lidar_depth"][0, ...].cpu().numpy(), 608 | "radar_depth": data["radar_depth"][0, ...].cpu().numpy(), 609 | "pred": pred_np 610 | } 611 | file_key = "%05d"%(i) 612 | f_group = h5_writer.create_group(file_key) 613 | # Store data 614 | for key, output_data in res.items(): 615 | if isinstance(output_data, dict): 616 | for key, data_ in output_data.items(): 617 | if key in res.keys(): 618 | key = key + "*" 619 | f_group.create_dataset(key, data=data_, compression="gzip") 620 | elif output_data is None: 621 | pass 622 | else: 623 | f_group.create_dataset(key, data=output_data, compression="gzip") 624 | 625 | # Measure accuracy and record loss 626 | result = Result() 627 | result.evaluate(pred.data, target.data) 628 | average_meter.update(result, gpu_time, data_time, inputs.size(0)) 629 | if args.arch in multistage_group: 630 | result_stage1 = Result() 631 | result_stage1.evaluate(pred1.data, target.data) 632 | average_meter_stage1.update(result_stage1, gpu_time, data_time, inputs.size(0)) 633 | end = time.time() 634 | 635 | # Record the day, night, rain info 636 | assert inputs.size(0) == 1 637 | daynight_info = data["daynight_info"][0] 638 | if ("day" in daynight_info) and ("rain" in daynight_info): 639 | avg_meter_day_rain.update(result, gpu_time, data_time, inputs.size(0)) 640 | avg_meter_day.update(result, gpu_time, data_time, inputs.size(0)) 641 | avg_meter_rain.update(result, gpu_time, data_time, inputs.size(0)) 642 | elif "day" in daynight_info: 643 | avg_meter_day_sun.update(result, gpu_time, data_time, inputs.size(0)) 644 | avg_meter_day.update(result, gpu_time, data_time, inputs.size(0)) 645 | avg_meter_sun.update(result, gpu_time, data_time, inputs.size(0)) 646 | 647 | if ("night" in daynight_info) and ("rain" in daynight_info): 648 | avg_meter_night_rain.update(result, gpu_time, data_time, inputs.size(0)) 649 | avg_meter_night.update(result, gpu_time, data_time, inputs.size(0)) 650 | avg_meter_rain.update(result, gpu_time, data_time, inputs.size(0)) 651 | elif "night" in daynight_info: 652 | avg_meter_night_sun.update(result, gpu_time, data_time, inputs.size(0)) 653 | avg_meter_night.update(result, gpu_time, data_time, inputs.size(0)) 654 | avg_meter_sun.update(result, gpu_time, data_time, inputs.size(0)) 655 | 656 | 657 | # save 8 images for visualization 658 | skip = 50 659 | if args.modality == 'd': 660 | img_merge = None 661 | else: 662 | if args.modality == 'rgb': 663 | rgb = inputs 664 | elif args.modality == 'rgbd': 665 | rgb = inputs[:,:3,:,:] 666 | depth = inputs[:,3:,:,:] 667 | 668 | if i == 0: 669 | if args.modality == 'rgbd': 670 | img_merge = utils.merge_into_row_with_gt(rgb, depth, target, pred) 671 | else: 672 | img_merge = utils.merge_into_row(rgb, target, pred) 673 | elif (i < 8*skip) and (i % skip == 0): 674 | if args.modality == 'rgbd': 675 | row = utils.merge_into_row_with_gt(rgb, depth, target, pred) 676 | else: 677 | row = utils.merge_into_row(rgb, target, pred) 678 | img_merge = utils.add_row(img_merge, row) 679 | elif i == 8*skip: 680 | filename = output_directory + '/comparison_' + str(epoch) + '.png' 681 | utils.save_image(img_merge, filename) 682 | 683 | if (i+1) % args.print_freq == 0: 684 | print('Test: [{0}/{1}]\t' 685 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 686 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 687 | 'MAE={result.mae:.2f}({average.mae:.2f}) ' 688 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 689 | 'REL={result.absrel:.3f}({average.absrel:.3f}) ' 690 | 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format( 691 | i+1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average())) 692 | 693 | # Save the result to pkl file 694 | if logger is None: 695 | h5_writer.close() 696 | avg = average_meter.average() 697 | if args.arch in multistage_group: 698 | avg_stage1 = average_meter_stage1.average() 699 | if logger is not None: 700 | record_test_scalar_summary(avg_stage1, epoch, logger, "Test_stage1") 701 | 702 | print('\n*\n' 703 | 'RMSE={average.rmse:.3f}\n' 704 | 'Rel={average.absrel:.3f}\n' 705 | 'Log10={average.lg10:.3f}\n' 706 | 'Delta1={average.delta1:.3f}\n' 707 | 'Delta2={average.delta2:.3f}\n' 708 | 'Delta3={average.delta3:.3f}\n' 709 | 't_GPU={time:.3f}\n'.format( 710 | average=avg, time=avg.gpu_time)) 711 | 712 | if logger is not None: 713 | # Record summaries 714 | record_test_scalar_summary(avg, epoch, logger, "Test") 715 | 716 | print('\n*\n' 717 | 'RMSE={average.rmse:.3f}\n' 718 | 'MAE={average.mae:.3f}\n' 719 | 'Delta1={average.delta1:.3f}\n' 720 | 'REL={average.absrel:.3f}\n' 721 | 'Lg10={average.lg10:.3f}\n' 722 | 't_GPU={time:.3f}\n'.format( 723 | average=avg, time=avg.gpu_time)) 724 | 725 | if write_to_file: 726 | with open(test_csv, 'a') as csvfile: 727 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 728 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10, 729 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3, 730 | 'data_time': avg.data_time, 'gpu_time': avg.gpu_time}) 731 | 732 | return avg, img_merge 733 | 734 | 735 | def record_scalar_summary(result, avg, current_step, logger, prefix="Train"): 736 | # Add scalar summaries 737 | logger.add_scalar(prefix + '_error/RMSE', result.rmse, current_step) 738 | logger.add_scalar(prefix + '_error/rel', result.absrel, current_step) 739 | logger.add_scalar(prefix + '_error/mae', result.mae, current_step) 740 | logger.add_scalar(prefix + '_delta/Delta1', result.delta1, current_step) 741 | logger.add_scalar(prefix + '_delta/Delta2', result.delta2, current_step) 742 | logger.add_scalar(prefix + '_delta/Delta3', result.delta3, current_step) 743 | 744 | # Add smoothed summaries 745 | average = avg.average() 746 | logger.add_scalar(prefix + '_error_smoothed/RMSE', average.rmse, current_step) 747 | logger.add_scalar(prefix + '_error_smoothed/rml', average.absrel, current_step) 748 | logger.add_scalar(prefix + '_error_smoothed/mae', average.mae, current_step) 749 | logger.add_scalar(prefix + '_delta_smoothed/Delta1', average.delta1, current_step) 750 | logger.add_scalar(prefix + '_delta_smoothed/Delta2', average.delta2, current_step) 751 | logger.add_scalar(prefix + '_delta_smoothed/Delta3', average.delta3, current_step) 752 | 753 | 754 | def record_test_scalar_summary(avg, epoch, logger, prefix="Test"): 755 | logger.add_scalar(prefix + '/rmse', avg.rmse, epoch) 756 | logger.add_scalar(prefix + '/Rel', avg.absrel, epoch) 757 | logger.add_scalar(prefix + '/log10', avg.lg10, epoch) 758 | logger.add_scalar(prefix + '/Delta1', avg.delta1, epoch) 759 | logger.add_scalar(prefix + '/Delta2', avg.delta2, epoch) 760 | logger.add_scalar(prefix + '/Delta3', avg.delta3, epoch) 761 | 762 | def display_results(avg_meter): 763 | avg = avg_meter.average() 764 | print("RMSE:", avg.rmse) 765 | print("MAE:", avg.mae) 766 | print("REL", avg.absrel) 767 | print("log10", avg.lg10) 768 | print("delta1", avg.delta1) 769 | print("delta2", avg.delta2) 770 | print("delta3", avg.delta3) 771 | 772 | 773 | if __name__ == '__main__': 774 | main() -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is adapted from https://github.com/fangchangma/sparse-to-dense.pytorch 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.models 8 | from torchvision.models.resnet import Bottleneck, conv1x1, conv3x3 9 | import collections 10 | import math 11 | 12 | 13 | class Unpool(nn.Module): 14 | # Unpool: 2*2 unpooling with zero padding 15 | def __init__(self, num_channels, stride=2): 16 | super(Unpool, self).__init__() 17 | 18 | self.num_channels = num_channels 19 | self.stride = stride 20 | 21 | # create kernel [1, 0; 0, 0] 22 | 23 | self.weights = torch.autograd.Variable(torch.zeros(num_channels, 1, stride, stride).cuda()) # currently not compatible with running on CPU 24 | self.weights[:,:,0,0] = 1 25 | 26 | def forward(self, x): 27 | return F.conv_transpose2d(x, self.weights, stride=self.stride, groups=self.num_channels) 28 | 29 | 30 | def weights_init(m): 31 | # Initialize filters with Gaussian random weights 32 | if isinstance(m, nn.Conv2d): 33 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 34 | m.weight.data.normal_(0, math.sqrt(2. / n)) 35 | if m.bias is not None: 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.ConvTranspose2d): 38 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 39 | m.weight.data.normal_(0, math.sqrt(2. / n)) 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | 46 | 47 | def weights_init_kaiming(m): 48 | if isinstance(m, nn.Conv2d): 49 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 50 | if m.bias is not None: 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.ConvTranspose2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | 60 | 61 | def weights_init_kaiming_leaky(m): 62 | if isinstance(m, nn.Conv2d): 63 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.ConvTranspose2d): 67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 71 | nn.init.constant_(m.weight, 1) 72 | nn.init.constant_(m.bias, 0) 73 | 74 | 75 | class BasicBlock(nn.Module): 76 | expansion = 1 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(BasicBlock, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | # if groups != 1 or base_width != 64: 84 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64') 85 | if dilation > 1: 86 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 87 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 88 | self.conv1 = conv3x3(inplanes, planes, stride) 89 | self.bn1 = norm_layer(planes) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.conv2 = conv3x3(planes, planes) 92 | self.bn2 = norm_layer(planes) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | identity = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class Decoder(nn.Module): 116 | # Decoder is the base class for all decoders 117 | 118 | names = ['deconv2', 'deconv3', 'upconv', 'upproj'] 119 | 120 | def __init__(self): 121 | super(Decoder, self).__init__() 122 | 123 | self.layer1 = None 124 | self.layer2 = None 125 | self.layer3 = None 126 | self.layer4 = None 127 | 128 | def forward(self, x): 129 | x = self.layer1(x) 130 | x = self.layer2(x) 131 | x = self.layer3(x) 132 | x = self.layer4(x) 133 | return x 134 | 135 | class DeConv(Decoder): 136 | def __init__(self, in_channels, kernel_size): 137 | assert kernel_size>=2, "kernel_size out of range: {}".format(kernel_size) 138 | super(DeConv, self).__init__() 139 | 140 | def convt(in_channels): 141 | stride = 2 142 | padding = (kernel_size - 1) // 2 143 | output_padding = kernel_size % 2 144 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect" 145 | 146 | module_name = "deconv{}".format(kernel_size) 147 | return nn.Sequential(collections.OrderedDict([ 148 | (module_name, nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size, 149 | stride,padding,output_padding,bias=False)), 150 | ('batchnorm', nn.BatchNorm2d(in_channels//2)), 151 | ('relu', nn.ReLU(inplace=True)), 152 | ])) 153 | 154 | self.layer1 = convt(in_channels) 155 | self.layer2 = convt(in_channels // 2) 156 | self.layer3 = convt(in_channels // (2 ** 2)) 157 | self.layer4 = convt(in_channels // (2 ** 3)) 158 | 159 | class UpConv(Decoder): 160 | # UpConv decoder consists of 4 upconv modules with decreasing number of channels and increasing feature map size 161 | def upconv_module(self, in_channels): 162 | # UpConv module: unpool -> 5*5 conv -> batchnorm -> ReLU 163 | upconv = nn.Sequential(collections.OrderedDict([ 164 | ('unpool', Unpool(in_channels)), 165 | ('conv', nn.Conv2d(in_channels,in_channels//2,kernel_size=5,stride=1,padding=2,bias=False)), 166 | ('batchnorm', nn.BatchNorm2d(in_channels//2)), 167 | ('relu', nn.ReLU()), 168 | ])) 169 | return upconv 170 | 171 | def __init__(self, in_channels): 172 | super(UpConv, self).__init__() 173 | self.layer1 = self.upconv_module(in_channels) 174 | self.layer2 = self.upconv_module(in_channels//2) 175 | self.layer3 = self.upconv_module(in_channels//4) 176 | self.layer4 = self.upconv_module(in_channels//8) 177 | 178 | class UpProj(Decoder): 179 | # UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size 180 | 181 | class UpProjModule(nn.Module): 182 | # UpProj module has two branches, with a Unpool at the start and a ReLu at the end 183 | # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm 184 | # bottom branch: 5*5 conv -> batchnorm 185 | 186 | def __init__(self, in_channels): 187 | super(UpProj.UpProjModule, self).__init__() 188 | out_channels = in_channels//2 189 | self.unpool = Unpool(in_channels) 190 | self.upper_branch = nn.Sequential(collections.OrderedDict([ 191 | ('conv1', nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False)), 192 | ('batchnorm1', nn.BatchNorm2d(out_channels)), 193 | ('relu', nn.ReLU()), 194 | ('conv2', nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False)), 195 | ('batchnorm2', nn.BatchNorm2d(out_channels)), 196 | ])) 197 | self.bottom_branch = nn.Sequential(collections.OrderedDict([ 198 | ('conv', nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False)), 199 | ('batchnorm', nn.BatchNorm2d(out_channels)), 200 | ])) 201 | self.relu = nn.ReLU() 202 | 203 | def forward(self, x): 204 | x = self.unpool(x) 205 | x1 = self.upper_branch(x) 206 | x2 = self.bottom_branch(x) 207 | x = x1 + x2 208 | x = self.relu(x) 209 | return x 210 | 211 | def __init__(self, in_channels): 212 | super(UpProj, self).__init__() 213 | self.layer1 = self.UpProjModule(in_channels) 214 | self.layer2 = self.UpProjModule(in_channels//2) 215 | self.layer3 = self.UpProjModule(in_channels//4) 216 | self.layer4 = self.UpProjModule(in_channels//8) 217 | 218 | 219 | def choose_decoder(decoder, in_channels): 220 | # iheight, iwidth = 10, 8 221 | if decoder[:6] == 'deconv': 222 | assert len(decoder)==7 223 | kernel_size = int(decoder[6]) 224 | return DeConv(in_channels, kernel_size) 225 | elif decoder == "upproj": 226 | return UpProj(in_channels) 227 | elif decoder == "upconv": 228 | return UpConv(in_channels) 229 | else: 230 | assert False, "invalid option for decoder: {}".format(decoder) 231 | 232 | 233 | class ResNet(nn.Module): 234 | def __init__(self, layers, decoder, output_size, in_channels=3, pretrained=True): 235 | 236 | if layers not in [18, 34, 50, 101, 152]: 237 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 238 | 239 | super(ResNet, self).__init__() 240 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 241 | 242 | if in_channels == 3: 243 | self.conv1 = pretrained_model._modules['conv1'] 244 | self.bn1 = pretrained_model._modules['bn1'] 245 | else: 246 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 247 | self.bn1 = nn.BatchNorm2d(64) 248 | weights_init(self.conv1) 249 | weights_init(self.bn1) 250 | 251 | self.output_size = output_size 252 | 253 | self.relu = pretrained_model._modules['relu'] 254 | self.maxpool = pretrained_model._modules['maxpool'] 255 | self.layer1 = pretrained_model._modules['layer1'] 256 | self.layer2 = pretrained_model._modules['layer2'] 257 | self.layer3 = pretrained_model._modules['layer3'] 258 | self.layer4 = pretrained_model._modules['layer4'] 259 | 260 | # clear memory 261 | del pretrained_model 262 | 263 | # define number of intermediate channels 264 | if layers <= 34: 265 | num_channels = 512 266 | elif layers >= 50: 267 | num_channels = 2048 268 | 269 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False) 270 | self.bn2 = nn.BatchNorm2d(num_channels//2) 271 | self.decoder = choose_decoder(decoder, num_channels//2) 272 | 273 | # setting bias=true doesn't improve accuracy 274 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False) 275 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 276 | 277 | # weight init 278 | self.conv2.apply(weights_init) 279 | self.bn2.apply(weights_init) 280 | self.decoder.apply(weights_init) 281 | self.conv3.apply(weights_init) 282 | 283 | def forward(self, x): 284 | # resnet 285 | x = self.conv1(x) 286 | x = self.bn1(x) 287 | x = self.relu(x) 288 | x = self.maxpool(x) 289 | # ipdb.set_trace() 290 | x = self.layer1(x) 291 | x = self.layer2(x) 292 | x = self.layer3(x) 293 | x = self.layer4(x) 294 | 295 | x = self.conv2(x) 296 | x = self.bn2(x) 297 | 298 | # decoder 299 | x = self.decoder(x) 300 | x = self.conv3(x) 301 | x = self.bilinear(x) 302 | 303 | return x 304 | 305 | 306 | class ResNet_pnp(nn.Module): 307 | def __init__(self, layers, decoder, output_size, in_channels=3, pretrained=True): 308 | 309 | if layers not in [18, 34, 50, 101, 152]: 310 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 311 | 312 | super(ResNet_pnp, self).__init__() 313 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 314 | 315 | if in_channels == 3: 316 | self.conv1 = pretrained_model._modules['conv1'] 317 | self.bn1 = pretrained_model._modules['bn1'] 318 | else: 319 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 320 | self.bn1 = nn.BatchNorm2d(64) 321 | weights_init(self.conv1) 322 | weights_init(self.bn1) 323 | 324 | self.output_size = output_size 325 | 326 | self.relu = pretrained_model._modules['relu'] 327 | self.maxpool = pretrained_model._modules['maxpool'] 328 | self.layer1 = pretrained_model._modules['layer1'] 329 | self.layer2 = pretrained_model._modules['layer2'] 330 | self.layer3 = pretrained_model._modules['layer3'] 331 | self.layer4 = pretrained_model._modules['layer4'] 332 | 333 | # clear memory 334 | del pretrained_model 335 | 336 | # define number of intermediate channels 337 | if layers <= 34: 338 | num_channels = 512 339 | elif layers >= 50: 340 | num_channels = 2048 341 | 342 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False) 343 | self.bn2 = nn.BatchNorm2d(num_channels//2) 344 | self.decoder = choose_decoder(decoder, num_channels//2) 345 | 346 | # setting bias=true doesn't improve accuracy 347 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False) 348 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 349 | 350 | # weight init 351 | self.conv2.apply(weights_init) 352 | self.bn2.apply(weights_init) 353 | self.decoder.apply(weights_init) 354 | self.conv3.apply(weights_init) 355 | 356 | def forward(self, x): 357 | # resnet 358 | x = self.conv1(x) 359 | x = self.bn1(x) 360 | x = self.relu(x) 361 | x = self.maxpool(x) 362 | x = self.layer1(x) 363 | x = self.layer2(x) 364 | x = self.layer3(x) 365 | x = self.layer4(x) 366 | 367 | x = self.conv2(x) 368 | x = self.bn2(x) 369 | 370 | # decoder 371 | x = self.decoder(x) 372 | x = self.conv3(x) 373 | x = self.bilinear(x) 374 | 375 | return x 376 | 377 | ####################### 378 | ## PnP-Depth forward ## 379 | ####################### 380 | def pnp_forward_front(self, x): 381 | x = self.conv1(x) 382 | x = self.bn1(x) 383 | x = self.relu(x) 384 | x = self.maxpool(x) 385 | x = self.layer1(x) 386 | x = self.layer2(x) 387 | x = self.layer3(x) 388 | x = self.layer4(x) 389 | 390 | x = self.conv2(x) 391 | x = self.bn2(x) 392 | 393 | return x 394 | 395 | def pnp_forward_rear(self, x): 396 | x = self.decoder(x) 397 | x = self.conv3(x) 398 | x = self.bilinear(x) 399 | 400 | return x 401 | 402 | 403 | class ResNet2(nn.Module): 404 | def __init__(self, layers, decoder, output_size, in_channels=3, pretrained=True): 405 | 406 | if layers not in [18, 34, 50, 101, 152]: 407 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 408 | 409 | super(ResNet2, self).__init__() 410 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 411 | 412 | if in_channels == 3: 413 | self.conv1 = pretrained_model._modules['conv1'] 414 | self.bn1 = pretrained_model._modules['bn1'] 415 | else: 416 | self.conv1_d = conv_bn_relu(1, 64 // 4, kernel_size=3, stride=2, padding=1) 417 | self.conv1_img = conv_bn_relu(3, 64 * 3 // 4, kernel_size=3, stride=2, padding=1) 418 | 419 | self.output_size = output_size 420 | self.in_channels = in_channels 421 | 422 | self.relu = pretrained_model._modules['relu'] 423 | self.maxpool = pretrained_model._modules['maxpool'] 424 | self.layer1 = pretrained_model._modules['layer1'] 425 | self.layer2 = pretrained_model._modules['layer2'] 426 | self.layer3 = pretrained_model._modules['layer3'] 427 | self.layer4 = pretrained_model._modules['layer4'] 428 | 429 | # clear memory 430 | del pretrained_model 431 | 432 | # define number of intermediate channels 433 | if layers <= 34: 434 | num_channels = 512 435 | elif layers >= 50: 436 | num_channels = 2048 437 | 438 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False) 439 | self.bn2 = nn.BatchNorm2d(num_channels//2) 440 | self.decoder = choose_decoder(decoder, num_channels//2) 441 | 442 | # setting bias=true doesn't improve accuracy 443 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False) 444 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 445 | 446 | # weight init 447 | self.conv2.apply(weights_init) 448 | self.bn2.apply(weights_init) 449 | self.decoder.apply(weights_init) 450 | self.conv3.apply(weights_init) 451 | 452 | def forward(self, x): 453 | # resnet 454 | if self.in_channels == 3: 455 | x = self.conv1(x) 456 | x = self.bn1(x) 457 | x = self.relu(x) 458 | x = self.maxpool(x) 459 | else: 460 | x_d = self.conv1_d(x[:, 3:, :, :]) 461 | x_img = self.conv1_img(x[:, :3, :, :]) 462 | x = torch.cat((x_img, x_d), 1) 463 | # x = self.relu(x) 464 | # x = self.maxpool(x) 465 | 466 | # ipdb.set_trace() 467 | x = self.layer1(x) 468 | x = self.layer2(x) 469 | x = self.layer3(x) 470 | x = self.layer4(x) 471 | 472 | x = self.conv2(x) 473 | x = self.bn2(x) 474 | 475 | # decoder 476 | x = self.decoder(x) 477 | # ipdb.set_trace() 478 | x = self.conv3(x) 479 | x = self.bilinear(x) 480 | 481 | return x 482 | 483 | ####################### 484 | ## PnP-Depth forward ## 485 | ####################### 486 | def pnp_forward_front(self, x): 487 | # resnet 488 | if self.in_channels == 3: 489 | x = self.conv1(x) 490 | x = self.bn1(x) 491 | x = self.relu(x) 492 | x = self.maxpool(x) 493 | else: 494 | x_d = self.conv1_d(x[:, 3:, :, :]) 495 | x_img = self.conv1_img(x[:, :3, :, :]) 496 | x = torch.cat((x_img, x_d), 1) 497 | # x = self.relu(x) 498 | # x = self.maxpool(x) 499 | 500 | # ipdb.set_trace() 501 | x = self.layer1(x) 502 | x = self.layer2(x) 503 | x = self.layer3(x) 504 | x = self.layer4(x) 505 | 506 | x = self.conv2(x) 507 | x = self.bn2(x) 508 | 509 | return x 510 | 511 | def pnp_forward_rear(self, x): 512 | x = self.decoder(x) 513 | x = self.conv3(x) 514 | x = self.bilinear(x) 515 | 516 | return x 517 | 518 | 519 | class ResNet_latefusion(nn.Module): 520 | def __init__(self, layers, decoder, output_size, in_channels=4, pretrained=True): 521 | 522 | if layers not in [18, 34, 50, 101, 152]: 523 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 524 | 525 | super(ResNet_latefusion, self).__init__() 526 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 527 | 528 | # Configurations required by resnet 529 | self._norm_layer = nn.BatchNorm2d 530 | self.dilation = 1 531 | self.inplanes = 16 532 | self.groups = 1 533 | self.base_width = 16 534 | 535 | assert in_channels > 3 536 | ################ 537 | ## RGB Branch ## 538 | ################ 539 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 540 | self.bn1 = nn.BatchNorm2d(64) 541 | weights_init(self.conv1) 542 | weights_init(self.bn1) 543 | 544 | self.output_size = output_size 545 | 546 | self.relu = pretrained_model._modules['relu'] 547 | self.maxpool = pretrained_model._modules['maxpool'] 548 | self.layer1 = pretrained_model._modules['layer1'] 549 | self.layer2 = pretrained_model._modules['layer2'] 550 | self.layer3 = pretrained_model._modules['layer3'] 551 | self.layer4 = pretrained_model._modules['layer4'] 552 | 553 | # clear memory 554 | del pretrained_model 555 | 556 | ################## 557 | ## Depth Branch ## 558 | ################## 559 | self.conv1_depth = nn.Conv2d(1, 16, kernel_size=7, stride=2, padding=3, bias=False) 560 | self.bn1_depth = nn.BatchNorm2d(16) 561 | weights_init_kaiming_leaky(self.conv1) 562 | weights_init_kaiming(self.bn1) 563 | 564 | self.relu_depth = nn.LeakyReLU(0.2, inplace=True) 565 | self.maxpool_depth = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 566 | self.layer1_depth = self._make_layer(BasicBlock, 16, 2, stride=1, dilate=False) 567 | self.layer2_depth = self._make_layer(BasicBlock, 32, 2, stride=2) 568 | self.layer3_depth = self._make_layer(BasicBlock, 64, 2, stride=2) 569 | self.layer4_depth = self._make_layer(BasicBlock, 128, 2, stride=2) 570 | 571 | # ToDo: If we need one more convolution to do the fusion 572 | # Define the fusion operator 573 | self.conv_fusion = nn.Conv2d(512 + 128, 512, kernel_size=1, bias=False) 574 | self.bn_fusion = nn.BatchNorm2d(512) 575 | 576 | # define number of intermediate channels 577 | if layers <= 34: 578 | num_channels = 512 579 | elif layers >= 50: 580 | num_channels = 2048 581 | 582 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False) 583 | self.bn2 = nn.BatchNorm2d(num_channels//2) 584 | self.decoder = choose_decoder(decoder, num_channels//2) 585 | 586 | # setting bias=true doesn't improve accuracy 587 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False) 588 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 589 | 590 | # weight init 591 | self.conv2.apply(weights_init) 592 | self.bn2.apply(weights_init) 593 | self.decoder.apply(weights_init) 594 | self.conv3.apply(weights_init) 595 | 596 | # Make layer function adapted from resnet 597 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 598 | norm_layer = self._norm_layer 599 | downsample = None 600 | previous_dilation = self.dilation 601 | if dilate: 602 | self.dilation *= stride 603 | stride = 1 604 | if stride != 1 or self.inplanes != planes * block.expansion: 605 | downsample = nn.Sequential( 606 | conv1x1(self.inplanes, planes * block.expansion, stride), 607 | norm_layer(planes * block.expansion), 608 | ) 609 | 610 | layers = [] 611 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 612 | self.base_width, previous_dilation, norm_layer)) 613 | self.inplanes = planes * block.expansion 614 | for _ in range(1, blocks): 615 | layers.append(block(self.inplanes, planes, groups=self.groups, 616 | base_width=self.base_width, dilation=self.dilation, 617 | norm_layer=norm_layer)) 618 | 619 | layers = nn.Sequential(*layers) 620 | 621 | # Explicitly initialize layers after construction 622 | for m in layers.modules(): 623 | weights_init_kaiming(m) 624 | 625 | return layers 626 | 627 | def forward(self, x): 628 | x_img = x[:, :3, :, :] 629 | x_d = x[:, 3:, :, :] 630 | 631 | # ipdb.set_trace() 632 | # RGB 633 | x_img = self.conv1(x_img) 634 | x_img = self.bn1(x_img) 635 | x_img = self.relu(x_img) 636 | x_img = self.maxpool(x_img) # 113 x 200 x 64 637 | x_img = self.layer1(x_img) # 113 x 200 x 64 638 | x_img = self.layer2(x_img) # 57 x 100 x 128 639 | x_img = self.layer3(x_img) # 29 x 50 x 256 640 | x_img = self.layer4(x_img) # 15 x 25 x 512 641 | 642 | # Depth 643 | x_d = self.conv1_depth(x_d) 644 | x_d = self.bn1_depth(x_d) 645 | x_d = self.relu_depth(x_d) 646 | x_d = self.maxpool_depth(x_d) # 113 x 200 x 16 647 | x_d = self.layer1_depth(x_d) # 113 x 200 x 16 648 | x_d = self.layer2_depth(x_d) # 57 x 100 x 32 649 | x_d = self.layer3_depth(x_d) # 29 x 50 x 64 650 | x_d = self.layer4_depth(x_d) # 15 x 25 x 128 651 | 652 | x_fused = torch.cat((x_img, x_d), dim=1) 653 | x_fused = self.conv_fusion(x_fused) 654 | x_fused = self.bn_fusion(x_fused) 655 | 656 | x_fused = self.conv2(x_fused) 657 | x_fused = self.bn2(x_fused) 658 | 659 | # decoder 660 | x_fused = self.decoder(x_fused) 661 | x_fused = self.conv3(x_fused) 662 | x_fused = self.bilinear(x_fused) 663 | 664 | return x_fused 665 | 666 | ####################### 667 | ## PnP-Depth forward ## 668 | ####################### 669 | def pnp_forward_front(self, x): 670 | x_img = x[:, :3, :, :] 671 | x_d = x[:, 3:, :, :] 672 | 673 | # RGB 674 | x_img = self.conv1(x_img) 675 | x_img = self.bn1(x_img) 676 | x_img = self.relu(x_img) 677 | x_img = self.maxpool(x_img) 678 | x_img = self.layer1(x_img) 679 | x_img = self.layer2(x_img) 680 | x_img = self.layer3(x_img) 681 | x_img = self.layer4(x_img) 682 | 683 | # Depth 684 | x_d = self.conv1_depth(x_d) 685 | x_d = self.bn1_depth(x_d) 686 | x_d = self.relu_depth(x_d) 687 | x_d = self.maxpool_depth(x_d) 688 | x_d = self.layer1_depth(x_d) 689 | x_d = self.layer2_depth(x_d) 690 | x_d = self.layer3_depth(x_d) 691 | x_d = self.layer4_depth(x_d) 692 | 693 | x_fused = torch.cat((x_img, x_d), dim=1) 694 | x_fused = self.conv_fusion(x_fused) 695 | x_fused = self.bn_fusion(x_fused) 696 | 697 | x_fused = self.conv2(x_fused) 698 | x_fused = self.bn2(x_fused) 699 | 700 | return x_fused 701 | 702 | def pnp_forward_rear(self, x): 703 | x = self.decoder(x) 704 | x = self.conv3(x) 705 | x = self.bilinear(x) 706 | 707 | return x 708 | 709 | 710 | class ResNet_multifusion(nn.Module): 711 | def __init__(self, layers, decoder, output_size, in_channels=4, pretrained=True): 712 | 713 | if layers not in [18, 34, 50, 101, 152]: 714 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 715 | 716 | super(ResNet_multifusion, self).__init__() 717 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 718 | 719 | # Configurations required by resnet 720 | self._norm_layer = nn.BatchNorm2d 721 | self.dilation = 1 722 | self.inplanes = 16 723 | self.groups = 1 724 | self.base_width = 16 725 | 726 | assert in_channels > 3 727 | ################ 728 | ## RGB Branch ## 729 | ################ 730 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 731 | self.bn1 = nn.BatchNorm2d(64) 732 | weights_init(self.conv1) 733 | weights_init(self.bn1) 734 | 735 | self.output_size = output_size 736 | 737 | self.relu = pretrained_model._modules['relu'] 738 | self.maxpool = pretrained_model._modules['maxpool'] 739 | self.layer1 = pretrained_model._modules['layer1'] 740 | self.layer2 = pretrained_model._modules['layer2'] 741 | self.layer3 = pretrained_model._modules['layer3'] 742 | self.layer4 = pretrained_model._modules['layer4'] 743 | 744 | # clear memory 745 | del pretrained_model 746 | 747 | ################## 748 | ## Depth Branch ## 749 | ################## 750 | self.conv1_depth = nn.Conv2d(1, 16, kernel_size=7, stride=2, padding=3, bias=False) 751 | self.bn1_depth = nn.BatchNorm2d(16) 752 | weights_init_kaiming_leaky(self.conv1) 753 | weights_init_kaiming(self.bn1) 754 | 755 | self.relu_depth = nn.LeakyReLU(0.2, inplace=True) 756 | self.maxpool_depth = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 757 | self.layer1_depth = self._make_layer(BasicBlock, 16, 2, stride=1, dilate=False) 758 | self.layer2_depth = self._make_layer(BasicBlock, 32, 2, stride=2) 759 | self.layer3_depth = self._make_layer(BasicBlock, 64, 2, stride=2) 760 | self.layer4_depth = self._make_layer(BasicBlock, 128, 2, stride=2) 761 | 762 | # ToDo: If we need one more convolution to do the fusion 763 | # Define the fusion operator 764 | self.conv_fusion1 = nn.Conv2d(64 + 16, 64, kernel_size=1, bias=False) 765 | self.bn_fusion1 = nn.BatchNorm2d(64) 766 | 767 | self.conv_fusion2 = nn.Conv2d(128 + 32, 128, kernel_size=1, bias=False) 768 | self.bn_fusion2 = nn.BatchNorm2d(128) 769 | 770 | self.conv_fusion3 = nn.Conv2d(256 + 64, 256, kernel_size=1, bias=False) 771 | self.bn_fusion3 = nn.BatchNorm2d(256) 772 | 773 | self.conv_fusion4 = nn.Conv2d(512 + 128, 512, kernel_size=1, bias=False) 774 | self.bn_fusion4 = nn.BatchNorm2d(512) 775 | 776 | # define number of intermediate channels 777 | if layers <= 34: 778 | num_channels = 512 779 | elif layers >= 50: 780 | num_channels = 2048 781 | 782 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False) 783 | self.bn2 = nn.BatchNorm2d(num_channels//2) 784 | self.decoder = choose_decoder(decoder, num_channels//2) 785 | 786 | # setting bias=true doesn't improve accuracy 787 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False) 788 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 789 | 790 | # weight init 791 | self.conv_fusion1.apply(weights_init_kaiming) 792 | self.conv_fusion2.apply(weights_init_kaiming) 793 | self.conv_fusion3.apply(weights_init_kaiming) 794 | self.conv_fusion4.apply(weights_init_kaiming) 795 | 796 | self.bn_fusion1.apply(weights_init_kaiming) 797 | self.bn_fusion2.apply(weights_init_kaiming) 798 | self.bn_fusion3.apply(weights_init_kaiming) 799 | self.bn_fusion4.apply(weights_init_kaiming) 800 | 801 | self.conv2.apply(weights_init) 802 | self.bn2.apply(weights_init) 803 | self.decoder.apply(weights_init) 804 | self.conv3.apply(weights_init) 805 | 806 | # Make layer function adapted from resnet 807 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 808 | norm_layer = self._norm_layer 809 | downsample = None 810 | previous_dilation = self.dilation 811 | if dilate: 812 | self.dilation *= stride 813 | stride = 1 814 | if stride != 1 or self.inplanes != planes * block.expansion: 815 | downsample = nn.Sequential( 816 | conv1x1(self.inplanes, planes * block.expansion, stride), 817 | norm_layer(planes * block.expansion), 818 | ) 819 | 820 | layers = [] 821 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 822 | self.base_width, previous_dilation, norm_layer)) 823 | self.inplanes = planes * block.expansion 824 | for _ in range(1, blocks): 825 | layers.append(block(self.inplanes, planes, groups=self.groups, 826 | base_width=self.base_width, dilation=self.dilation, 827 | norm_layer=norm_layer)) 828 | 829 | layers = nn.Sequential(*layers) 830 | 831 | # Explicitly initialize layers after construction 832 | for m in layers.modules(): 833 | weights_init_kaiming(m) 834 | 835 | return layers 836 | 837 | def forward(self, x): 838 | x_img = x[:, :3, :, :] 839 | x_d = x[:, 3:, :, :] 840 | 841 | # ipdb.set_trace() 842 | 843 | # RGB layer1 844 | x_img = self.conv1(x_img) 845 | x_img = self.bn1(x_img) 846 | x_img = self.relu(x_img) 847 | x_img = self.maxpool(x_img) 848 | x_img = self.layer1(x_img) 849 | 850 | # Depth layer1 851 | x_d = self.conv1_depth(x_d) 852 | x_d = self.bn1_depth(x_d) 853 | x_d = self.relu_depth(x_d) 854 | x_d = self.maxpool_depth(x_d) 855 | x_d = self.layer1_depth(x_d) 856 | 857 | # Fusion layer1 858 | x_fused1 = torch.cat((x_img, x_d), dim=1) 859 | x_fused1 = self.conv_fusion1(x_fused1) 860 | x_fused1 = self.bn_fusion1(x_fused1) 861 | 862 | # RGB layer2 863 | x_img = self.layer2(x_fused1) 864 | # Depth layer2 865 | x_d = self.layer2_depth(x_d) 866 | # Fusion layer2 867 | x_fused2 = torch.cat((x_img, x_d), dim=1) 868 | x_fused2 = self.conv_fusion2(x_fused2) 869 | x_fused2 = self.bn_fusion2(x_fused2) 870 | 871 | # RGB layer3 872 | x_img = self.layer3(x_fused2) 873 | # Depth layer3 874 | x_d = self.layer3_depth(x_d) 875 | # Fusion layer3 876 | x_fused3 = torch.cat((x_img, x_d), dim=1) 877 | x_fused3 = self.conv_fusion3(x_fused3) 878 | x_fused3 = self.bn_fusion3(x_fused3) 879 | 880 | # ipdb.set_trace() 881 | # RGB layer4 882 | x_img = self.layer4(x_fused3) 883 | # Depth layer4 884 | x_d = self.layer4_depth(x_d) 885 | # Fusion layer4 886 | x_fused4 = torch.cat((x_img, x_d), dim=1) 887 | x_fused4 = self.conv_fusion4(x_fused4) 888 | x_fused4 = self.bn_fusion4(x_fused4) 889 | 890 | x_fused = self.conv2(x_fused4) 891 | x_fused = self.bn2(x_fused) 892 | 893 | # decoder 894 | x_fused = self.decoder(x_fused) 895 | x_fused = self.conv3(x_fused) 896 | x_fused = self.bilinear(x_fused) 897 | 898 | return x_fused 899 | 900 | ####################### 901 | ## PnP-Depth forward ## 902 | ####################### 903 | def pnp_forward_front(self, x): 904 | x_img = x[:, :3, :, :] 905 | x_d = x[:, 3:, :, :] 906 | 907 | # ipdb.set_trace() 908 | 909 | # RGB layer1 910 | x_img = self.conv1(x_img) 911 | x_img = self.bn1(x_img) 912 | x_img = self.relu(x_img) 913 | x_img = self.maxpool(x_img) 914 | x_img = self.layer1(x_img) 915 | 916 | # Depth layer1 917 | x_d = self.conv1_depth(x_d) 918 | x_d = self.bn1_depth(x_d) 919 | x_d = self.relu_depth(x_d) 920 | x_d = self.maxpool_depth(x_d) 921 | x_d = self.layer1_depth(x_d) 922 | 923 | # Fusion layer1 924 | x_fused1 = torch.cat((x_img, x_d), dim=1) 925 | x_fused1 = self.conv_fusion1(x_fused1) 926 | x_fused1 = self.bn_fusion1(x_fused1) 927 | 928 | # RGB layer2 929 | x_img = self.layer2(x_fused1) 930 | # Depth layer2 931 | x_d = self.layer2_depth(x_d) 932 | # Fusion layer2 933 | x_fused2 = torch.cat((x_img, x_d), dim=1) 934 | x_fused2 = self.conv_fusion2(x_fused2) 935 | x_fused2 = self.bn_fusion2(x_fused2) 936 | 937 | # RGB layer3 938 | x_img = self.layer3(x_fused2) 939 | # Depth layer3 940 | x_d = self.layer3_depth(x_d) 941 | # Fusion layer3 942 | x_fused3 = torch.cat((x_img, x_d), dim=1) 943 | x_fused3 = self.conv_fusion3(x_fused3) 944 | x_fused3 = self.bn_fusion3(x_fused3) 945 | 946 | # ipdb.set_trace() 947 | # RGB layer4 948 | x_img = self.layer4(x_fused3) 949 | # Depth layer4 950 | x_d = self.layer4_depth(x_d) 951 | # Fusion layer4 952 | x_fused4 = torch.cat((x_img, x_d), dim=1) 953 | x_fused4 = self.conv_fusion4(x_fused4) 954 | x_fused4 = self.bn_fusion4(x_fused4) 955 | 956 | x_fused = self.conv2(x_fused4) 957 | x_fused = self.bn2(x_fused) 958 | 959 | return x_fused 960 | 961 | def pnp_forward_rear(self, x): 962 | x = self.decoder(x) 963 | x = self.conv3(x) 964 | x = self.bilinear(x) 965 | 966 | return x 967 | 968 | 969 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride=1, padding=0, bn=True, relu=True): 970 | bias = not bn 971 | layers = [] 972 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, 973 | padding, bias=bias)) 974 | if bn: 975 | layers.append(nn.BatchNorm2d(out_channels)) 976 | if relu: 977 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 978 | layers = nn.Sequential(*layers) 979 | 980 | # initialize the weights 981 | for m in layers.modules(): 982 | weights_init(m) 983 | 984 | return layers 985 | --------------------------------------------------------------------------------