├── README.md ├── dataset ├── DatasetLidarCam.py ├── __init__.py └── data_utils.py ├── environment ├── __init__.py ├── buffer.py ├── environment.py └── transformations.py ├── generate_depth_gt.py ├── images └── Graphic abstract.jpg ├── ip_basic ├── depth_map_utils.py └── vis_utils.py ├── ipcv_utils └── utils │ └── __init__.py ├── models ├── base │ ├── R_MSFM.py │ ├── resnet_encoder.py │ └── update.py └── model.py ├── requirements.txt ├── save_fig ├── 20_dg.png ├── 20_gt.png ├── 20_iter0.png ├── 20_iter1.png ├── 20_iter2.png ├── 20_iter3.png ├── 20_pred_d.png └── 20_rgb.png ├── test.py ├── test.txt ├── test ├── test.png ├── test_dataset.py ├── test_depth.png ├── test_model.py ├── test_pose_rotate.py └── test_utils.py ├── train.py ├── train.txt ├── utility ├── __init__.py ├── logger.py ├── metrics.py ├── quaternion_distances.py └── utils.py └── visual_test.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # CalibDepth 4 | 5 | — LiDAR-Camera online calibration is of great significance for building a stable autonomous driving perception 6 | system. For online calibration, a key challenge lies in constructing a unified and robust representation between multimodal sensor data. Most methods extract features manually 7 | or implicitly with an end-to-end deep learning method. The 8 | former suffers poor robustness, while the latter has poor 9 | interpretability. In this paper, we propose CalibDepth, which 10 | uses depth maps as the unified representation for image and 11 | LiDAR point cloud. CalibDepth introduces a sub-network for 12 | monocular depth estimation to assist online calibration tasks. To 13 | further improve the performance, we regard online calibration 14 | as a sequence prediction problem, and introduce global and 15 | local losses to optimize the calibration results. CalibDepth shows 16 | excellent performance in different experimental setups. 17 | 18 | 19 | 20 | 26 | 27 | 28 | 50 | 51 | 52 | 论文链接:https://ieeexplore.ieee.org/document/10161575 53 | 54 | ## 目录 55 | 56 | - [环境配置](#环境配置) 57 | - [文件目录说明](#文件目录说明) 58 | - [数据准备](#数据准备) 59 | - [运行](#运行) 60 | - [鸣谢](#鸣谢) 61 | 62 | ### 环境配置 63 | 64 | 65 | 1. 创建虚拟环境(python 3.6.13) 66 | 2. Clone the repo 67 | ```sh 68 | git clone https://github.com/Brickzhuantou/CalibDepth 69 | ``` 70 | 3. 安装依赖 71 | ```sh 72 | pip install requirement.txt 73 | ``` 74 | 75 | ### 文件目录说明 76 | 77 | ``` 78 | filetree 79 | 80 | ├── /dataset/ ----数据读取 81 | ├── /environment/ ----标定执行相关函数 82 | ├── /ip_basic/ ----深度补全相关 83 | ├── /ipcv_utils/ ----可视化函数 84 | ├── /models/ ----模型搭建 85 | ├── /test/ ----函数功能测试 86 | ├── /utility/ ----通用函数 87 | ├── train.py ----训练脚本 88 | ├── test.py ----评测脚本 89 | ├── visual_test.py ----可视化脚本 90 | ├── generate_depth_gt.py ----深度图标签生成脚本 91 | └── README.md 92 | ``` 93 | 94 | ### 数据准备 95 | 1. Kitti官网下载数据 https://www.cvlibs.net/datasets/kitti/raw_data.php?type=road 96 | (也可以参考[CalibNet](https://github.com/epiception/CalibNet/tree/main)下载) 97 | 98 | 2. 数据集组织和命名如下; 99 | ``` 100 | ├── /dataset/ 101 | |── /kitti_raw 102 | |── /2011_09_26/ 103 | |── /2011_09_26_drive_0001_sync/ 104 | |── /depth_gt/ 105 | |── /image_00/ 106 | |── /image_01/ 107 | |── /image_02/ 108 | |── /image_03/ 109 | |── /oxts/ 110 | |── /velodyne_points/ 111 | |── /2011_09_26_drive_0002_sync/ 112 | |── /2011_09_26_drive_0005_sync/ 113 | |── /2011_09_26_drive_0009_sync/ 114 | |── ... 115 | |── calib_cam_to_cam.txt 116 | |── calib_imu_to_velo.txt 117 | |── calib_velo_to_cam.txt 118 | 119 | |── /2011_09_28/ 120 | |── ... 121 | |── /2011_09_29/ 122 | |── ... 123 | |── /2011_09_30/ 124 | |── ... 125 | |── /2011_10_03/ 126 | |── ... 127 | |── train.txt ----训练数据路径 128 | |── test.txt ----测试数据路径 129 | 130 | ``` 131 | 说明: 132 | /depth_gt/存储用于单目深度估计的深度图标签,原始KITTI数据没有提供,可以参考generate_depth_gt.py脚本生成; 133 | train.txt和test.txt为随机采样生成的训练数据与测试数据路径; 134 | 135 | 136 | 137 | 138 | 139 | ### 运行 140 | 修改train.py中的参数,执行 python train.py 即可; 141 | test.py同理; 142 | 143 | 144 | ### 鸣谢 145 | 146 | 147 | - [CalibNet](https://github.com/epiception/CalibNet/tree/main) 148 | - [LCCNet](https://github.com/IIPCVLAB/LCCNet) 149 | - [reagent](https://github.com/dornik/reagent) 150 | - [ip_basic](https://github.com/kujason/ip_basic) 151 | - [Best_README_template](https://github.com/shaojintian/Best_README_template) 152 | - [R-MSFM](https://github.com/jsczzzk/R-MSFM) 153 | 154 | 155 | [your-project-path]:https://github.com/Brickzhuantou/CalibDepth 156 | -------------------------------------------------------------------------------- /dataset/DatasetLidarCam.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import mathutils 4 | import csv 5 | import math 6 | from math import radians 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import Dataset 11 | import torchvision 12 | from torchvision import transforms 13 | import torchvision.transforms.functional as TTF 14 | from PIL import Image 15 | 16 | 17 | import pykitti 18 | from pykitti import odometry 19 | 20 | from utility.utils import ( invert_pose, quaternion_from_matrix, rotate_back, rotate_forward, 21 | quaternion_from_matrix, rotate_back ) 22 | 23 | def get_2D_lidar_projection(pcl, cam_intrinsic): 24 | """投影点云到图像平面 25 | 26 | Args: 27 | pcl (_type_): 相机坐标系的点云(3, :) 28 | cam_intrinsic (_type_): 相机内参 29 | Returns: 30 | pcl_uv: 点云的像素坐标 31 | pcl_z: 点云在每个像素上对应的深度 32 | 33 | """ 34 | pcl_xyz = cam_intrinsic @ pcl 35 | pcl_xyz = pcl_xyz.T 36 | pcl_z = pcl_xyz[:, 2] 37 | pcl_xyz = pcl_xyz / (pcl_xyz[:, 2, None] + 1e-10) 38 | pcl_uv = pcl_xyz[:, :2] 39 | return pcl_uv, pcl_z 40 | 41 | def lidar_project_depth(pc, cam_calib, img_shape): 42 | """获取点云的深度图 43 | 44 | Args: 45 | pc (_type_): 已经转到相机坐标系的点云 46 | cam_calib (_type_): 相机内参 47 | img_shape (_type_): 图像尺寸 48 | Returns: 49 | depth_img: 点云的深度图(1, H, W) 50 | pcl_uv: 点云的像素坐标(N, 2) 51 | """ 52 | pc = pc[:3, :].detach().cpu().numpy() 53 | cam_intrinsic = cam_calib.detach().cpu().numpy() 54 | pcl_uv, pcl_z = get_2D_lidar_projection(pc, cam_intrinsic) 55 | mask = (pcl_uv[:, 0]>0) & (pcl_uv[:, 0]0) & (pcl_uv[:, 1]0) # 筛选出图像内且深度大于0的点 58 | pcl_uv = pcl_uv[mask] 59 | pcl_z = pcl_z[mask] 60 | pcl_uv = pcl_uv.astype(np.uint32) 61 | pcl_z = pcl_z.reshape(-1, 1) 62 | depth_img = np.zeros((img_shape[0], img_shape[1], 1), dtype=np.float32) 63 | depth_img[pcl_uv[:, 1], pcl_uv[:, 0]] = pcl_z 64 | depth_img = torch.from_numpy(depth_img.astype(np.float32)) 65 | depth_img = depth_img.permute(2, 0, 1) 66 | 67 | return depth_img, pcl_uv 68 | 69 | class Resampler: 70 | def __init__(self, num: int): 71 | """Resamples a point cloud containing N points to one containing M 72 | 73 | Guaranteed to have no repeated points if M <= N. 74 | Otherwise, it is guaranteed that all points appear at least once. 75 | 76 | Args: 77 | num (int): Number of points to resample to, i.e. M 78 | 79 | """ 80 | self.num = num 81 | 82 | def __call__(self, sample): 83 | 84 | if 'deterministic' in sample and sample['deterministic']: 85 | np.random.seed(sample['idx']) 86 | 87 | if 'points' in sample: 88 | sample['points'] = self._resample(sample['points'], self.num) 89 | else: 90 | if 'crop_proportion' not in sample: 91 | src_size, ref_size = self.num, self.num 92 | elif len(sample['crop_proportion']) == 1: 93 | src_size = math.ceil(sample['crop_proportion'][0] * self.num) 94 | ref_size = self.num 95 | elif len(sample['crop_proportion']) == 2: 96 | src_size = math.ceil(sample['crop_proportion'][0] * self.num) 97 | ref_size = math.ceil(sample['crop_proportion'][1] * self.num) 98 | else: 99 | raise ValueError('Crop proportion must have 1 or 2 elements') 100 | 101 | sample['points_src'] = self._resample(sample['points_src'], src_size) 102 | sample['points_ref'] = self._resample(sample['points_ref'], ref_size) 103 | 104 | return sample 105 | 106 | @staticmethod 107 | def _resample(points, k): 108 | """Resamples the points such that there is exactly k points. 109 | 110 | If the input point cloud has <= k points, it is guaranteed the 111 | resampled point cloud contains every point in the input. 112 | If the input point cloud has > k points, it is guaranteed the 113 | resampled point cloud does not contain repeated point. 114 | """ 115 | 116 | if k <= points.shape[0]: 117 | rand_idxs = np.random.choice(points.shape[0], k, replace=False) 118 | return points[rand_idxs, :] 119 | elif points.shape[0] == k: 120 | return points 121 | else: 122 | rand_idxs = np.concatenate([np.random.choice(points.shape[0], points.shape[0], replace=False), 123 | np.random.choice(points.shape[0], k - points.shape[0], replace=True)]) 124 | return points[rand_idxs, :] 125 | 126 | 127 | 128 | class DatasetKittiRawCalibNet(Dataset): 129 | def __init__(self, dataset_dir, transform = None, augmentation = False, 130 | use_reflectance = False, max_t = 1.5, max_r = 15.0, 131 | split = 'val', device = 'cpu', 132 | val_sequence = ['2011_09_26_drive_0005_sync', '2011_09_26_drive_0070_sync']): 133 | super(DatasetKittiRawCalibNet, self).__init__() 134 | self.use_reflectance = use_reflectance 135 | self.maps_folder = '' 136 | self.device = device 137 | self.max_r = max_r 138 | self.max_t = max_t 139 | self.augmentation = augmentation 140 | self.root_dir = dataset_dir 141 | self.transform = transform 142 | self.split = split 143 | self.GTs_R = {} 144 | self.GTs_T = {} 145 | self.GTs_T_cam02_velo = {} 146 | self.max_depth = 80 147 | self.K_list = {} 148 | 149 | self.all_files = [] 150 | date_list = ['2011_09_26', '2011_09_28', '2011_09_29', '2011_09_30', '2011_10_03'] 151 | data_drive_list = ['0001', '0002', '0004', '0016', '0027'] 152 | self.calib_date = {} 153 | 154 | # 获取不同日期对应的calib文件 155 | for i in range(len(date_list)): # 这个循环应该是为了获得不同日期的数据的calib文件 156 | date = date_list[i] 157 | data_drive = data_drive_list[i] 158 | data = pykitti.raw(self.root_dir, date, data_drive) 159 | calib = {'K2': data.calib.K_cam2, 'K3': data.calib.K_cam3, 160 | 'RT2': data.calib.T_cam2_velo, 'RT3': data.calib.T_cam3_velo} 161 | self.calib_date[date] = calib 162 | 163 | date = val_sequence[0][:10] 164 | test_list = ['2011_09_26_drive_0005_sync', '2011_09_26_drive_0070_sync', '2011_10_03_drive_0027_sync'] 165 | seq_list = os.listdir(os.path.join(self.root_dir, date)) 166 | 167 | # 读取预先存储的训练集和测试集文件名 168 | train_path = os.path.join(self.root_dir, 'train.txt') 169 | test_path = os.path.join(self.root_dir, 'test.txt') 170 | 171 | self.train_array = np.loadtxt(train_path, dtype=str) 172 | self.test_array = np.loadtxt(test_path, dtype=str) 173 | 174 | 175 | def custom_transform(self, rgb, img_rotation, h_mirror, flip = False): 176 | to_tensor = transforms.ToTensor() 177 | normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], 178 | std=[0.229, 0.224, 0.225]) 179 | if self.split == 'train': 180 | color_transform = transforms.ColorJitter(0.1, 0.1, 0.1) 181 | rgb = color_transform(rgb) 182 | if flip: 183 | rgb = TTF.hflip(rgb) 184 | rgb = TTF.rotate(rgb, img_rotation) 185 | rgb = to_tensor(rgb) 186 | rgb = normalization(rgb) 187 | return rgb 188 | 189 | def __len__(self): 190 | if self.split == 'train': 191 | return len(self.train_array) 192 | else: 193 | return len(self.test_array) 194 | 195 | def __getitem__(self, idx): 196 | if self.split == 'train': 197 | item = self.train_array[idx] 198 | else: 199 | item = self.test_array[idx] 200 | 201 | # 路径获取 202 | date = str(item.split('/')[0]) 203 | seq = str(item.split('/')[1]) 204 | rgb_name = str(item.split('/')[4]) 205 | img_path = os.path.join(self.root_dir, date, seq, 'image_02/data', rgb_name+'.png') 206 | lidar_path = os.path.join(self.root_dir, date, seq, 'velodyne_points/data', rgb_name+'.bin') 207 | 208 | # 数据获取 209 | lidar_scan = np.fromfile(lidar_path, dtype = np.float32) 210 | pc = lidar_scan.reshape(-1, 4) 211 | valid_indices = pc[:,0] < -3. 212 | valid_indices = valid_indices | (pc[:,0] > 3.) 213 | valid_indices = valid_indices | (pc[:,1] < -3.) 214 | valid_indices = valid_indices | (pc[:,1] > 3.) 215 | pc = pc[valid_indices].copy() # 滤除自车 216 | 217 | pc_org = torch.from_numpy(pc.astype(np.float32)) 218 | 219 | if self.use_reflectance: 220 | reflectence = pc[:,3].copy() 221 | reflectence = torch.from_numpy(reflectence).float() 222 | 223 | # 读取标定参数; 224 | calib = self.calib_date[date] 225 | RT_cam02 = calib['RT2'].astype(np.float32) 226 | calib_cam02 = calib['K2'] 227 | 228 | # 校验点云数据,保证输出为4xN,且最后一行为1 229 | if pc_org.shape[1] == 4 or pc_org.shape[1] == 3: 230 | pc_org = pc_org.t() 231 | if pc_org.shape[0] == 3: 232 | homogeneous = torch.ones(pc_org.shape[1]).unsqueeze(0) 233 | pc_org = torch.cat((pc_org, homogeneous), 0) 234 | elif pc_org.shape[0] == 4: 235 | if not torch.all(pc_org[3, :] == 1.): 236 | pc_org[3, :] = 1. 237 | else: 238 | raise TypeError("Wrong PointCloud shape") 239 | 240 | # 转到相机坐标系下 241 | pc_rot = np.matmul(RT_cam02, pc_org.numpy()) 242 | pc_rot = pc_rot.astype(np.float32).copy() 243 | pc_in = torch.from_numpy(pc_rot) 244 | 245 | # 图像数据获取 246 | img = Image.open(img_path) 247 | img_rotation = 0. 248 | h_mirror = False 249 | try: 250 | img = self.custom_transform(img, img_rotation, h_mirror) 251 | except OSError: 252 | new_idx = np.random.randint(0, self.__len__()) 253 | return self.__getitem__(new_idx) 254 | 255 | # 添加扰动 256 | max_angle = self.max_r 257 | rotz = np.random.uniform(-max_angle, max_angle) * (np.pi / 180.0) 258 | roty = np.random.uniform(-max_angle, max_angle) * (np.pi / 180.0) 259 | rotx = np.random.uniform(-max_angle, max_angle) * (np.pi / 180.0) 260 | transl_x = np.random.uniform(-self.max_t, self.max_t) 261 | transl_y = np.random.uniform(-self.max_t, self.max_t) 262 | transl_z = np.random.uniform(-self.max_t, self.max_t) 263 | initial_RT = 0.0 264 | 265 | R = mathutils.Euler((rotx, roty, rotz), 'XYZ') 266 | T = mathutils.Vector((transl_x, transl_y, transl_z)) 267 | 268 | R, T = invert_pose(R,T) # 计算求逆后的四元数和平移向量 269 | R, T = torch.tensor(R), torch.tensor(T) 270 | 271 | calib = calib_cam02 272 | if h_mirror: 273 | calib[2] = img.shape[2] - calib[2] 274 | calib = torch.tensor(calib, dtype=torch.float32) 275 | 276 | # 获取点云深度图 277 | max_depth = 80. 278 | real_shape = [img.shape[1], img.shape[2], img.shape[0]] 279 | 280 | # 点云下采样 281 | transformes = torchvision.transforms.Compose( 282 | [Resampler(100000)]) # 对点云下采样为100000个点 283 | pc_temp = {'points': pc_in} 284 | pc_temp['points'] = pc_temp['points'].transpose(0, 1) 285 | ds_pc = transformes(pc_temp)['points'].transpose(0, 1) 286 | 287 | # 获取深度图标签 288 | depth_path = os.path.join(self.root_dir, date, seq, 'depth_gt/data', rgb_name+'.jpg') 289 | depth_gt = Image.open(depth_path) 290 | to_tensor = torchvision.transforms.ToTensor() 291 | depth_gt = to_tensor(depth_gt)*255 292 | 293 | # 获取扰动的点云深度图 294 | R_m = mathutils.Quaternion(R).to_matrix() 295 | R_m.resize_4x4() 296 | T_m = mathutils.Matrix.Translation(T) 297 | RT_m = T_m * R_m 298 | 299 | pc_rotated = rotate_back(pc_in, RT_m) # Pc’ = RT * Pc 300 | ds_pc_rotated = rotate_back(ds_pc, RT_m) # 下采样后旋转扰动的点云 301 | 302 | depth_img, uv = lidar_project_depth(ds_pc_rotated, calib, real_shape) 303 | depth_img /= max_depth # 深度值归一化 304 | 305 | # 点云数据 306 | pc_target = pc_in # 原始的相机坐标系的点云为target 307 | pc_source = pc_rotated # 扰动后的点云为source 308 | 309 | # 位姿数据(扰动的逆作为点云位姿) 310 | i_pose_target = np.array(RT_m, dtype=np.float32) 311 | pose_target = i_pose_target.copy() 312 | pose_target[:3, :3] = pose_target[:3, :3].T 313 | pose_target[:3, 3] = -np.matmul(pose_target[:3, :3], pose_target[:3, 3]) 314 | pose_target = torch.from_numpy(pose_target) 315 | 316 | pose_source = torch.eye(4) 317 | 318 | # 数据统一字典格式输出 319 | if self.split == 'test': 320 | sample = {'rgb': img, 321 | 'calib': calib, 322 | 'rgb_name': rgb_name + '.png', 323 | 'item': item, 'extrin': RT_cam02, 324 | 'tr_error': T, 'rot_error': R, 325 | 'img_path': img_path, 326 | 'initial_RT': initial_RT, 327 | 'pc_target': pc_target, 328 | 'pc_source': pc_source, 329 | 'pose_target': pose_target, 330 | 'pose_source': pose_source, 331 | 'ds_pc_target': ds_pc, 332 | 'ds_pc_source': ds_pc_rotated, 333 | 'depth_gt': depth_gt, 334 | 'depth_img': depth_img, 335 | } 336 | else: 337 | sample = {'rgb': img, 338 | 'calib': calib, 339 | 'rgb_name': rgb_name, # TODO:少了个后缀验一下 340 | 'item': item, 341 | 'img_path': img_path, 342 | 'tr_error': T, 'rot_error': R, 343 | 'pc_target': pc_target, 344 | 'pc_source': pc_source, 345 | 'pose_target': pose_target, 346 | 'pose_source': pose_source, 347 | 'ds_pc_target': ds_pc, 348 | 'ds_pc_source': ds_pc_rotated, 349 | 'depth_gt': depth_gt, 350 | 'depth_img': depth_img, 351 | } 352 | 353 | return sample 354 | 355 | 356 | 357 | 358 | 359 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------- 2 | # Copyright (C) 2020 Università degli studi di Milano-Bicocca, iralab 3 | # Author: Daniele Cattaneo (d.cattaneo10@campus.unimib.it) 4 | # Released under Creative Commons 5 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | # ------------------------------------------------------------------- 8 | 9 | # Modified Author: Xudong Lv 10 | # based on github.com/cattaneod/CMRNet/blob/master/utils.py 11 | 12 | import math 13 | 14 | import mathutils 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from matplotlib import cm 19 | from torch.utils.data.dataloader import default_collate 20 | 21 | 22 | def rotate_points(PC, R, T=None, inverse=True): 23 | if T is not None: 24 | R = R.to_matrix() 25 | R.resize_4x4() 26 | T = mathutils.Matrix.Translation(T) 27 | RT = T*R 28 | else: 29 | RT=R.copy() 30 | if inverse: 31 | RT.invert_safe() 32 | RT = torch.tensor(RT, device=PC.device, dtype=torch.float) 33 | 34 | if PC.shape[0] == 4: 35 | PC = torch.mm(RT, PC) 36 | elif PC.shape[1] == 4: 37 | PC = torch.mm(RT, PC.t()) 38 | PC = PC.t() 39 | else: 40 | raise TypeError("Point cloud must have shape [Nx4] or [4xN] (homogeneous coordinates)") 41 | return PC 42 | 43 | 44 | def rotate_points_torch(PC, R, T=None, inverse=True): 45 | if T is not None: 46 | R = quat2mat(R) 47 | T = tvector2mat(T) 48 | RT = torch.mm(T, R) 49 | else: 50 | RT = R.clone() 51 | if inverse: 52 | RT = RT.inverse() 53 | 54 | if PC.shape[0] == 4: 55 | PC = torch.mm(RT, PC) 56 | elif PC.shape[1] == 4: 57 | PC = torch.mm(RT, PC.t()) 58 | PC = PC.t() 59 | else: 60 | raise TypeError("Point cloud must have shape [Nx4] or [4xN] (homogeneous coordinates)") 61 | return PC 62 | 63 | 64 | def rotate_forward(PC, R, T=None): 65 | """ 66 | Transform the point cloud PC, so to have the points 'as seen from' the new 67 | pose T*R 68 | Args: 69 | PC (torch.Tensor): Point Cloud to be transformed, shape [4xN] or [Nx4] 70 | R (torch.Tensor/mathutils.Euler): can be either: 71 | * (mathutils.Euler) euler angles of the rotation part, in this case T cannot be None 72 | * (torch.Tensor shape [4]) quaternion representation of the rotation part, in this case T cannot be None 73 | * (mathutils.Matrix shape [4x4]) Rotation matrix, 74 | in this case it should contains the translation part, and T should be None 75 | * (torch.Tensor shape [4x4]) Rotation matrix, 76 | in this case it should contains the translation part, and T should be None 77 | T (torch.Tensor/mathutils.Vector): Translation of the new pose, shape [3], or None (depending on R) 78 | 79 | Returns: 80 | torch.Tensor: Transformed Point Cloud 'as seen from' pose T*R 81 | """ 82 | if isinstance(R, torch.Tensor): 83 | return rotate_points_torch(PC, R, T, inverse=True) 84 | else: 85 | return rotate_points(PC, R, T, inverse=True) 86 | 87 | 88 | def rotate_back(PC_ROTATED, R, T=None): 89 | """ 90 | Inverse of :func:`~utils.rotate_forward`. 91 | """ 92 | if isinstance(R, torch.Tensor): 93 | return rotate_points_torch(PC_ROTATED, R, T, inverse=False) 94 | else: 95 | return rotate_points(PC_ROTATED, R, T, inverse=False) 96 | 97 | 98 | def invert_pose(R, T): 99 | """ 100 | Given the 'sampled pose' (aka H_init), we want CMRNet to predict inv(H_init). 101 | inv(T*R) will be used as ground truth for the network. 102 | Args: 103 | R (mathutils.Euler): Rotation of 'sampled pose' 104 | T (mathutils.Vector): Translation of 'sampled pose' 105 | 106 | Returns: 107 | (R_GT, T_GT) = (mathutils.Quaternion, mathutils.Vector) 108 | """ 109 | R = R.to_matrix() 110 | R.resize_4x4() 111 | T = mathutils.Matrix.Translation(T) 112 | RT = T * R 113 | RT.invert_safe() 114 | T_GT, R_GT, _ = RT.decompose() 115 | return R_GT.normalized(), T_GT 116 | 117 | 118 | def merge_inputs(queries): 119 | pc_target = [] 120 | pc_source = [] 121 | depth_gt= [] 122 | depth_img = [] 123 | imgs = [] 124 | # reflectances = [] 125 | returns = {key: default_collate([d[key] for d in queries]) for key in queries[0] 126 | if key != 'pc_target' and key != 'rgb' and key != 'pc_source' and key != 'depth_gt' and key != 'depth_img'} 127 | for input in queries: 128 | pc_target.append(input['pc_target']) 129 | pc_source.append(input['pc_source']) 130 | imgs.append(input['rgb']) 131 | depth_gt.append(input['depth_gt']) 132 | depth_img.append(input['depth_img']) 133 | # if 'reflectance' in input: 134 | # reflectances.append(input['reflectance']) 135 | returns['pc_target'] = pc_target 136 | returns['pc_source'] = pc_source 137 | returns['rgb'] = imgs 138 | returns['depth_gt'] = depth_gt 139 | returns['depth_img'] = depth_img 140 | # if len(reflectances) > 0: 141 | # returns['reflectance'] = reflectances 142 | return returns 143 | 144 | 145 | def quaternion_from_matrix(matrix): 146 | """ 147 | Convert a rotation matrix to quaternion. 148 | Args: 149 | matrix (torch.Tensor): [4x4] transformation matrix or [3,3] rotation matrix. 150 | 151 | Returns: 152 | torch.Tensor: shape [4], normalized quaternion 153 | """ 154 | if matrix.shape == (4, 4): 155 | R = matrix[:-1, :-1] 156 | elif matrix.shape == (3, 3): 157 | R = matrix 158 | else: 159 | raise TypeError("Not a valid rotation matrix") 160 | tr = R[0, 0] + R[1, 1] + R[2, 2] 161 | q = torch.zeros(4, device=matrix.device) 162 | if tr > 0.: 163 | S = (tr+1.0).sqrt() * 2 164 | q[0] = 0.25 * S 165 | q[1] = (R[2, 1] - R[1, 2]) / S 166 | q[2] = (R[0, 2] - R[2, 0]) / S 167 | q[3] = (R[1, 0] - R[0, 1]) / S 168 | elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: 169 | S = (1.0 + R[0, 0] - R[1, 1] - R[2, 2]).sqrt() * 2 170 | q[0] = (R[2, 1] - R[1, 2]) / S 171 | q[1] = 0.25 * S 172 | q[2] = (R[0, 1] + R[1, 0]) / S 173 | q[3] = (R[0, 2] + R[2, 0]) / S 174 | elif R[1, 1] > R[2, 2]: 175 | S = (1.0 + R[1, 1] - R[0, 0] - R[2, 2]).sqrt() * 2 176 | q[0] = (R[0, 2] - R[2, 0]) / S 177 | q[1] = (R[0, 1] + R[1, 0]) / S 178 | q[2] = 0.25 * S 179 | q[3] = (R[1, 2] + R[2, 1]) / S 180 | else: 181 | S = (1.0 + R[2, 2] - R[0, 0] - R[1, 1]).sqrt() * 2 182 | q[0] = (R[1, 0] - R[0, 1]) / S 183 | q[1] = (R[0, 2] + R[2, 0]) / S 184 | q[2] = (R[1, 2] + R[2, 1]) / S 185 | q[3] = 0.25 * S 186 | return q / q.norm() 187 | 188 | 189 | def quatmultiply(q, r): 190 | """ 191 | Multiply two quaternions 192 | Args: 193 | q (torch.Tensor/nd.ndarray): shape=[4], first quaternion 194 | r (torch.Tensor/nd.ndarray): shape=[4], second quaternion 195 | 196 | Returns: 197 | torch.Tensor: shape=[4], normalized quaternion q*r 198 | """ 199 | t = torch.zeros(4, device=q.device) 200 | t[0] = r[0] * q[0] - r[1] * q[1] - r[2] * q[2] - r[3] * q[3] 201 | t[1] = r[0] * q[1] + r[1] * q[0] - r[2] * q[3] + r[3] * q[2] 202 | t[2] = r[0] * q[2] + r[1] * q[3] + r[2] * q[0] - r[3] * q[1] 203 | t[3] = r[0] * q[3] - r[1] * q[2] + r[2] * q[1] + r[3] * q[0] 204 | return t / t.norm() 205 | 206 | 207 | def quat2mat(q): 208 | """ 209 | Convert a quaternion to a rotation matrix 210 | Args: 211 | q (torch.Tensor): shape [4], input quaternion 212 | 213 | Returns: 214 | torch.Tensor: [4x4] homogeneous rotation matrix 215 | """ 216 | assert q.shape == torch.Size([4]), "Not a valid quaternion" 217 | if q.norm() != 1.: 218 | q = q / q.norm() 219 | mat = torch.zeros((4, 4), device=q.device) 220 | mat[0, 0] = 1 - 2*q[2]**2 - 2*q[3]**2 221 | mat[0, 1] = 2*q[1]*q[2] - 2*q[3]*q[0] 222 | mat[0, 2] = 2*q[1]*q[3] + 2*q[2]*q[0] 223 | mat[1, 0] = 2*q[1]*q[2] + 2*q[3]*q[0] 224 | mat[1, 1] = 1 - 2*q[1]**2 - 2*q[3]**2 225 | mat[1, 2] = 2*q[2]*q[3] - 2*q[1]*q[0] 226 | mat[2, 0] = 2*q[1]*q[3] - 2*q[2]*q[0] 227 | mat[2, 1] = 2*q[2]*q[3] + 2*q[1]*q[0] 228 | mat[2, 2] = 1 - 2*q[1]**2 - 2*q[2]**2 229 | mat[3, 3] = 1. 230 | return mat 231 | 232 | 233 | def tvector2mat(t): 234 | """ 235 | Translation vector to homogeneous transformation matrix with identity rotation 236 | Args: 237 | t (torch.Tensor): shape=[3], translation vector 238 | 239 | Returns: 240 | torch.Tensor: [4x4] homogeneous transformation matrix 241 | 242 | """ 243 | assert t.shape == torch.Size([3]), "Not a valid translation" 244 | mat = torch.eye(4, device=t.device) 245 | mat[0, 3] = t[0] 246 | mat[1, 3] = t[1] 247 | mat[2, 3] = t[2] 248 | return mat 249 | 250 | 251 | def mat2xyzrpy(rotmatrix): 252 | """ 253 | Decompose transformation matrix into components 254 | Args: 255 | rotmatrix (torch.Tensor/np.ndarray): [4x4] transformation matrix 256 | 257 | Returns: 258 | torch.Tensor: shape=[6], contains xyzrpy 259 | """ 260 | roll = math.atan2(-rotmatrix[1, 2], rotmatrix[2, 2]) 261 | pitch = math.asin ( rotmatrix[0, 2]) 262 | yaw = math.atan2(-rotmatrix[0, 1], rotmatrix[0, 0]) 263 | x = rotmatrix[:3, 3][0] 264 | y = rotmatrix[:3, 3][1] 265 | z = rotmatrix[:3, 3][2] 266 | 267 | return torch.tensor([x, y, z, roll, pitch, yaw], device=rotmatrix.device, dtype=rotmatrix.dtype) 268 | 269 | 270 | def to_rotation_matrix(R, T): 271 | R = quat2mat(R) 272 | T = tvector2mat(T) 273 | RT = torch.mm(T, R) 274 | return RT 275 | 276 | 277 | def overlay_imgs(rgb, lidar, idx=0): 278 | std = [0.229, 0.224, 0.225] 279 | mean = [0.485, 0.456, 0.406] 280 | 281 | rgb = rgb.clone().cpu().permute(1,2,0).numpy() 282 | rgb = rgb*std+mean 283 | lidar = lidar.clone() 284 | 285 | lidar[lidar == 0] = 1000. 286 | lidar = -lidar 287 | #lidar = F.max_pool2d(lidar, 3, 1, 1) 288 | lidar = F.max_pool2d(lidar, 3, 1, 1) 289 | lidar = -lidar 290 | lidar[lidar == 1000.] = 0. 291 | 292 | #lidar = lidar.squeeze() 293 | lidar = lidar[0][0] 294 | lidar = (lidar*255).int().cpu().numpy() 295 | lidar_color = cm.jet(lidar) 296 | lidar_color[:, :, 3] = 0.5 297 | lidar_color[lidar == 0] = [0, 0, 0, 0] 298 | blended_img = lidar_color[:, :, :3] * (np.expand_dims(lidar_color[:, :, 3], 2)) + \ 299 | rgb * (1. - np.expand_dims(lidar_color[:, :, 3], 2)) 300 | blended_img = blended_img.clip(min=0., max=1.) 301 | #io.imshow(blended_img) 302 | #io.show() 303 | #plt.figure() 304 | #plt.imshow(blended_img) 305 | #io.imsave(f'./IMGS/{idx:06d}.png', blended_img) 306 | return blended_img 307 | -------------------------------------------------------------------------------- /environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/environment/__init__.py -------------------------------------------------------------------------------- /environment/buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import functools 4 | 5 | def cat(list_of_tensors, dim=0): 6 | """ 7 | Concatenate a list of tensors. 8 | """ 9 | return functools.reduce(lambda x, y: torch.cat([x, y], dim=dim), list_of_tensors) 10 | 11 | 12 | def catcat(list_of_lists_of_tensors, dim_outer=0, dim_inner=0): 13 | """ 14 | Recursively concatenate a list of tensors. 15 | """ 16 | return cat([cat(inner_list, dim_inner) for inner_list in list_of_lists_of_tensors], dim_outer) 17 | 18 | class Buffer: 19 | """replay buffer, to generate trajectories 20 | """ 21 | 22 | def __init__(self): 23 | self.count = 0 24 | self.sources = [] 25 | self.targets = [] 26 | self.target_depth = [] 27 | self.expert_actions = [] 28 | 29 | def __len__(self): 30 | return self.count 31 | 32 | def start_trajectory(self): 33 | self.count += 1 34 | self.sources += [[]] 35 | self.targets += [[]] 36 | self.target_depth += [[]] 37 | self.expert_actions += [[]] 38 | 39 | def log_step(self, observation, expert_action): 40 | self.sources[-1].append(observation[0].detach()) 41 | self.targets[-1].append(observation[1].detach()) 42 | self.target_depth[-1].append(observation[2].detach()) 43 | self.expert_actions[-1].append(expert_action.detach()) 44 | 45 | def get_samples(self): 46 | samples = [self.sources, self.targets, self.target_depth, self.expert_actions] 47 | return [catcat(sample) for sample in samples] 48 | 49 | def clear(self): 50 | self.count = 0 51 | self.source.clear() 52 | self.target.clear() 53 | self.target_depth.clear() 54 | self.expert_actions.clear() 55 | 56 | 57 | -------------------------------------------------------------------------------- /environment/environment.py: -------------------------------------------------------------------------------- 1 | import environment.transformations as tra 2 | import torch.nn.functional as F 3 | import torch 4 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | 6 | 7 | def init(data): 8 | """数据进行预处理格式化 9 | 10 | Args: 11 | data (_type_): _description_ 12 | """ 13 | img_shape = (384, 1280) # 目标图像尺寸 14 | 15 | pose_target = data['pose_target'].to(DEVICE) 16 | pose_source = data['pose_source'].to(DEVICE) 17 | 18 | ds_pc_target_input = [] 19 | ds_pc_source_input = [] 20 | calib_input = [] 21 | depth_img_input = [] 22 | depth_gt_input = [] 23 | rgb_input = [] 24 | 25 | for idx in range(len(data['rgb'])): 26 | rgb = data['rgb'][idx].cuda() 27 | depth_gt = data['depth_gt'][idx].cuda() 28 | depth_img = data['depth_img'][idx].cuda() 29 | ds_pc_source = data['ds_pc_source'][idx].cuda() 30 | ds_pc_target = data['ds_pc_target'][idx].cuda() 31 | calib = data['calib'][idx].cuda() 32 | 33 | shape_pad = [0, 0, 0, 0] 34 | shape_pad[3] = (img_shape[0] - rgb.shape[1]) # // 2 35 | shape_pad[1] = (img_shape[1] - rgb.shape[2]) # // 2 + 1 36 | 37 | rgb = F.pad(rgb, shape_pad) 38 | depth_img = F.pad(depth_img, shape_pad) # 填充为目标尺寸 39 | 40 | rgb_input.append(rgb) 41 | depth_img_input.append(depth_img) 42 | depth_gt_input.append(depth_gt) 43 | 44 | ds_pc_target_input.append(ds_pc_target.transpose(0, 1)[:, :3]) 45 | ds_pc_source_input.append(ds_pc_source.transpose(0, 1)[:, :3]) # (4xN->Nx3) 46 | calib_input.append(calib) 47 | 48 | depth_img_input = torch.stack(depth_img_input) 49 | depth_gt_input = torch.stack(depth_gt_input) 50 | rgb_input = torch.stack(rgb_input) 51 | ds_pc_source_input = torch.stack(ds_pc_source_input) 52 | ds_pc_target_input = torch.stack(ds_pc_target_input) 53 | calib_input = torch.stack(calib_input) 54 | 55 | rgb_input = F.interpolate(rgb_input, size=[256, 512], mode = 'bilinear', align_corners=False) 56 | depth_img_input = F.interpolate(depth_img_input, size=[256, 512], mode = 'bilinear', align_corners=False) 57 | 58 | return rgb_input, depth_img_input, depth_gt_input, pose_target, pose_source, ds_pc_target_input, ds_pc_source_input, calib_input 59 | 60 | 61 | def step_continous(source, actions, pose_source): 62 | """ 63 | Update the state (source and accumulator) using the given actions. 64 | """ 65 | steps_t, steps_r = actions[:, 0], actions[:, 1] 66 | pose_update = torch.eye(4, device=DEVICE).repeat(pose_source.shape[0], 1, 1) 67 | pose_update[:, :3, :3] = tra.axis_angle_to_matrix(steps_r) 68 | pose_update[:, :3, 3] = steps_t 69 | pose_source = pose_update @ pose_source.to(DEVICE) 70 | new_source = tra.apply_trafo(source.to(DEVICE), pose_source, False) 71 | 72 | return new_source, pose_source 73 | 74 | 75 | def expert_step_real(pose_source, targets, mode='steady'): 76 | """ 77 | Get the expert action in the current state. 直接输出当前的source和target的偏差作为专家动作 78 | """ 79 | delta_t = targets[:, :3, 3] - pose_source[:, :3, 3] 80 | delta_R = targets[:, :3, :3] @ pose_source[:, :3, :3].transpose(2, 1) # global accumulator 81 | 82 | delta_r = tra.matrix_to_axis_angle(delta_R) # 旋转矩阵到旋转向量 83 | 84 | steps_t = delta_t.unsqueeze(1) 85 | steps_r = delta_r.unsqueeze(1) 86 | action = torch.cat([steps_t, steps_r], dim=1) 87 | 88 | return action 89 | 90 | 91 | -------------------------------------------------------------------------------- /environment/transformations.py: -------------------------------------------------------------------------------- 1 | from cmath import pi 2 | import torch 3 | import functools 4 | import torch.nn.functional as F 5 | """ 6 | Implements the disentangled transformations. Additionally, provides rotation conversions from pytorch3d. 7 | """ 8 | 9 | 10 | def apply_trafo(pcd, trafo, disentangled=True): 11 | """ 12 | Applies transformation to clone of pcd tensor - see eq. 9 in paper. 13 | """ 14 | ret = pcd.clone() 15 | 16 | if disentangled: # to origin 17 | ret_mean = ret[..., :3].mean(dim=1)[:, None, :] 18 | ret[..., :3] -= ret_mean 19 | ret[..., :3] = (trafo[:, :3, :3] @ ret[..., :3].transpose(-1, -2)).transpose(-1, -2) # rotate 20 | if disentangled: # from origin 21 | ret[..., :3] += ret_mean 22 | ret[..., :3] += trafo[:, :3, 3][..., None, :] # translate 23 | 24 | return ret 25 | 26 | 27 | def to_disentangled(poses, pcd): 28 | """ 29 | Add rotation-induced translation to translation vector - see eq. 11 in paper. 30 | """ 31 | poses[:, :3, 3] = poses[:, :3, 3] - pcd[..., :3].mean(dim=1) + (poses[:, :3, :3] @ pcd[..., :3].mean(dim=1)[:, :, None]).view(-1, 3) 32 | return poses 33 | 34 | 35 | def to_global(poses, pcd): 36 | """ 37 | Remove rotation-induced translation from translation vector - see eq. 11 in paper. 38 | """ 39 | poses[:, :3, 3] = poses[:, :3, 3] + pcd[..., :3].mean(dim=1) \ 40 | - (poses[:, :3, :3] @ pcd[..., :3].mean(dim=1)[:, :, None]).view(-1, 3) 41 | return poses 42 | 43 | 44 | def square_distance(pcd1, pcd2): 45 | """ 46 | Squared distance between any two points in the two point clouds. 47 | """ 48 | return torch.sum((pcd1[:, :, None, :].contiguous() - pcd2[:, None, :, :].contiguous()) ** 2, dim=-1) 49 | 50 | 51 | # --- euler-matrix conversions from pytorch3d 52 | # via https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html 53 | 54 | def _axis_angle_rotation(axis: str, angle): 55 | """ 56 | Return the rotation matrices for one of the rotations about an axis 57 | of which Euler angles describe, for each value of the angle given. 58 | 59 | Args: 60 | axis: Axis label "X" or "Y or "Z". 61 | angle: any shape tensor of Euler angles in radians 62 | 63 | Returns: 64 | Rotation matrices as tensor of shape (..., 3, 3). 65 | """ 66 | 67 | cos = torch.cos(angle) 68 | sin = torch.sin(angle) 69 | one = torch.ones_like(angle) 70 | zero = torch.zeros_like(angle) 71 | 72 | if axis == "X": 73 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 74 | if axis == "Y": 75 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 76 | if axis == "Z": 77 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 78 | 79 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 80 | 81 | 82 | def euler_angles_to_matrix(euler_angles, convention: str): 83 | """ 84 | Convert rotations given as Euler angles in radians to rotation matrices. 85 | 86 | Args: 87 | euler_angles: Euler angles in radians as tensor of shape (..., 3). 88 | convention: Convention string of three uppercase letters from 89 | {"X", "Y", and "Z"}. 90 | 91 | Returns: 92 | Rotation matrices as tensor of shape (..., 3, 3). 93 | """ 94 | if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: 95 | raise ValueError("Invalid input euler angles.") 96 | if len(convention) != 3: 97 | raise ValueError("Convention must have 3 letters.") 98 | if convention[1] in (convention[0], convention[2]): 99 | raise ValueError(f"Invalid convention {convention}.") 100 | for letter in convention: 101 | if letter not in ("X", "Y", "Z"): 102 | raise ValueError(f"Invalid letter {letter} in convention string.") 103 | matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) 104 | return functools.reduce(torch.matmul, matrices) 105 | 106 | 107 | def _angle_from_tan( 108 | axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool 109 | ): 110 | """ 111 | Extract the first or third Euler angle from the two members of 112 | the matrix which are positive constant times its sine and cosine. 113 | 114 | Args: 115 | axis: Axis label "X" or "Y or "Z" for the angle we are finding. 116 | other_axis: Axis label "X" or "Y or "Z" for the middle axis in the 117 | convention. 118 | data: Rotation matrices as tensor of shape (..., 3, 3). 119 | horizontal: Whether we are looking for the angle for the third axis, 120 | which means the relevant entries are in the same row of the 121 | rotation matrix. If not, they are in the same column. 122 | tait_bryan: Whether the first and third axes in the convention differ. 123 | 124 | Returns: 125 | Euler Angles in radians for each matrix in data as a tensor 126 | of shape (...). 127 | """ 128 | 129 | i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] 130 | if horizontal: 131 | i2, i1 = i1, i2 132 | even = (axis + other_axis) in ["XY", "YZ", "ZX"] 133 | if horizontal == even: 134 | return torch.atan2(data[..., i1], data[..., i2]) 135 | if tait_bryan: 136 | return torch.atan2(-data[..., i2], data[..., i1]) 137 | return torch.atan2(data[..., i2], -data[..., i1]) 138 | 139 | 140 | def _index_from_letter(letter: str): 141 | if letter == "X": 142 | return 0 143 | if letter == "Y": 144 | return 1 145 | if letter == "Z": 146 | return 2 147 | 148 | 149 | def matrix_to_euler_angles(matrix, convention: str): 150 | """ 151 | Convert rotations given as rotation matrices to Euler angles in radians. 152 | 153 | Args: 154 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 155 | convention: Convention string of three uppercase letters. 156 | 157 | Returns: 158 | Euler angles in radians as tensor of shape (..., 3). 159 | """ 160 | if len(convention) != 3: 161 | raise ValueError("Convention must have 3 letters.") 162 | if convention[1] in (convention[0], convention[2]): 163 | raise ValueError(f"Invalid convention {convention}.") 164 | for letter in convention: 165 | if letter not in ("X", "Y", "Z"): 166 | raise ValueError(f"Invalid letter {letter} in convention string.") 167 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 168 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 169 | i0 = _index_from_letter(convention[0]) 170 | i2 = _index_from_letter(convention[2]) 171 | tait_bryan = i0 != i2 172 | if tait_bryan: 173 | central_angle = torch.asin( 174 | matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) 175 | ) 176 | else: 177 | central_angle = torch.acos(matrix[..., i0, i0]) 178 | 179 | o = ( 180 | _angle_from_tan( 181 | convention[0], convention[1], matrix[..., i2], False, tait_bryan 182 | ), 183 | central_angle, 184 | _angle_from_tan( 185 | convention[2], convention[1], matrix[..., i0, :], True, tait_bryan 186 | ), 187 | ) 188 | return torch.stack(o, -1) 189 | 190 | 191 | def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: 192 | """ 193 | Convert rotations given as axis/angle to rotation matrices. 194 | 195 | Args: 196 | axis_angle: Rotations given as a vector in axis angle form, 197 | as a tensor of shape (..., 3), where the magnitude is 198 | the angle turned anticlockwise in radians around the 199 | vector's direction. 200 | 201 | Returns: 202 | Rotation matrices as tensor of shape (..., 3, 3). 203 | """ 204 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) 205 | 206 | 207 | 208 | def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: 209 | """ 210 | Convert rotations given as rotation matrices to axis/angle. 211 | 212 | Args: 213 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 214 | 215 | Returns: 216 | Rotations given as a vector in axis angle form, as a tensor 217 | of shape (..., 3), where the magnitude is the angle 218 | turned anticlockwise in radians around the vector's 219 | direction. 220 | """ 221 | return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) 222 | 223 | 224 | 225 | def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: 226 | """ 227 | Convert rotations given as axis/angle to quaternions. 228 | 229 | Args: 230 | axis_angle: Rotations given as a vector in axis angle form, 231 | as a tensor of shape (..., 3), where the magnitude is 232 | the angle turned anticlockwise in radians around the 233 | vector's direction. 234 | 235 | Returns: 236 | quaternions with real part first, as tensor of shape (..., 4). 237 | """ 238 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 239 | half_angles = angles * 0.5 240 | eps = 1e-6 241 | small_angles = angles.abs() < eps 242 | sin_half_angles_over_angles = torch.empty_like(angles) 243 | sin_half_angles_over_angles[~small_angles] = ( 244 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 245 | ) 246 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 247 | # so sin(x/2)/x is about 1/2 - (x*x)/48 248 | sin_half_angles_over_angles[small_angles] = ( 249 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 250 | ) 251 | quaternions = torch.cat( 252 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 253 | ) 254 | return quaternions 255 | 256 | 257 | 258 | def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: 259 | """ 260 | Convert rotations given as quaternions to axis/angle. 261 | 262 | Args: 263 | quaternions: quaternions with real part first, 264 | as tensor of shape (..., 4). 265 | 266 | Returns: 267 | Rotations given as a vector in axis angle form, as a tensor 268 | of shape (..., 3), where the magnitude is the angle 269 | turned anticlockwise in radians around the vector's 270 | direction. 271 | """ 272 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) 273 | half_angles = torch.atan2(norms, quaternions[..., :1]) 274 | angles = 2 * half_angles 275 | eps = 1e-6 276 | small_angles = angles.abs() < eps 277 | sin_half_angles_over_angles = torch.empty_like(angles) 278 | sin_half_angles_over_angles[~small_angles] = ( 279 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 280 | ) 281 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 282 | # so sin(x/2)/x is about 1/2 - (x*x)/48 283 | sin_half_angles_over_angles[small_angles] = ( 284 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 285 | ) 286 | return quaternions[..., 1:] / sin_half_angles_over_angles 287 | 288 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 289 | """ 290 | Convert rotations given as rotation matrices to quaternions. 291 | 292 | Args: 293 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 294 | 295 | Returns: 296 | quaternions with real part first, as tensor of shape (..., 4). 297 | """ 298 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 299 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 300 | 301 | batch_dim = matrix.shape[:-2] 302 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 303 | matrix.reshape(batch_dim + (9,)), dim=-1 304 | ) 305 | 306 | q_abs = _sqrt_positive_part( 307 | torch.stack( 308 | [ 309 | 1.0 + m00 + m11 + m22, 310 | 1.0 + m00 - m11 - m22, 311 | 1.0 - m00 + m11 - m22, 312 | 1.0 - m00 - m11 + m22, 313 | ], 314 | dim=-1, 315 | ) 316 | ) 317 | 318 | # we produce the desired quaternion multiplied by each of r, i, j, k 319 | quat_by_rijk = torch.stack( 320 | [ 321 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 322 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 323 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 324 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 325 | ], 326 | dim=-2, 327 | ) 328 | 329 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 330 | # the candidate won't be picked. 331 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 332 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 333 | 334 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 335 | # forall i; we pick the best-conditioned one (with the largest denominator) 336 | 337 | return quat_candidates[ 338 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] 339 | ].reshape(batch_dim + (4,)) 340 | 341 | 342 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 343 | """ 344 | Returns torch.sqrt(torch.max(0, x)) 345 | but with a zero subgradient where x is 0. 346 | """ 347 | ret = torch.zeros_like(x) 348 | positive_mask = x > 0 349 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 350 | return ret 351 | 352 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 353 | """ 354 | Convert rotations given as quaternions to rotation matrices. 355 | 356 | Args: 357 | quaternions: quaternions with real part first, 358 | as tensor of shape (..., 4). 359 | 360 | Returns: 361 | Rotation matrices as tensor of shape (..., 3, 3). 362 | """ 363 | r, i, j, k = torch.unbind(quaternions, -1) 364 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 365 | 366 | o = torch.stack( 367 | ( 368 | 1 - two_s * (j * j + k * k), 369 | two_s * (i * j - k * r), 370 | two_s * (i * k + j * r), 371 | two_s * (i * j + k * r), 372 | 1 - two_s * (i * i + k * k), 373 | two_s * (j * k - i * r), 374 | two_s * (i * k - j * r), 375 | two_s * (j * k + i * r), 376 | 1 - two_s * (i * i + j * j), 377 | ), 378 | -1, 379 | ) 380 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 381 | 382 | 383 | 384 | # test = torch.tensor((20/180*pi,20/180*pi,20/180*pi)) 385 | # test_axis = matrix_to_axis_angle(euler_angles_to_matrix(test,'XYZ')) 386 | # print(test_axis) 387 | 388 | # test = torch.tensor((0.0,0.0,0.0)) 389 | # test_axis = matrix_to_quaternion(euler_angles_to_matrix(test,'XYZ')) 390 | # print(test_axis) 391 | 392 | # test = torch.tensor((0/180*pi,100/180*pi,0/180*pi)) 393 | # test_axis = matrix_to_quaternion(euler_angles_to_matrix(test,'XYZ')) 394 | # print(test_axis) 395 | # print(test_axis[0]**2+test_axis[1]**2+test_axis[2]**2+test_axis[3]**2) 396 | -------------------------------------------------------------------------------- /generate_depth_gt.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import random 4 | from ip_basic import depth_map_utils 5 | from ip_basic import vis_utils 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import torch 9 | import pykitti 10 | from dataset.DatasetLidarCam import Resampler 11 | import torchvision 12 | import cv2 13 | 14 | #%% 点云投影+深度补全 15 | def get_2D_lidar_projection(pcl, cam_intrinsic): 16 | pcl_xyz = cam_intrinsic @ pcl.T 17 | pcl_xyz = pcl_xyz.T 18 | pcl_z = pcl_xyz[:, 2] 19 | pcl_xyz = pcl_xyz / (pcl_xyz[:, 2, None] + 1e-10) 20 | pcl_uv = pcl_xyz[:, :2] 21 | 22 | return pcl_uv, pcl_z 23 | 24 | def lidar_project_depth(pc_rotated, cam_calib, img_shape): 25 | pc_rotated = pc_rotated[:3, :].detach().cpu().numpy() 26 | cam_intrinsic = cam_calib.detach().cpu().numpy() 27 | pcl_uv, pcl_z = get_2D_lidar_projection(pc_rotated.T, cam_intrinsic) 28 | mask = (pcl_uv[:, 0] > 0) & (pcl_uv[:, 0] < img_shape[1]) & (pcl_uv[:, 1] > 0) & ( 29 | pcl_uv[:, 1] < img_shape[0]) & (pcl_z > 0) 30 | pcl_uv = pcl_uv[mask] 31 | pcl_z = pcl_z[mask] 32 | pcl_uv = pcl_uv.astype(np.uint32) 33 | # pcl_z = pcl_z.reshape(-1, 1) 34 | depth_img = np.zeros((img_shape[0], img_shape[1])) 35 | depth_img[pcl_uv[:, 1], pcl_uv[:, 0]] = pcl_z 36 | 37 | depth_img, process_dict = depth_map_utils.fill_in_multiscale( 38 | depth_img, extrapolate=False, blur_type='bilateral', 39 | show_process=False) 40 | # 添加了深度补全看看效果; 41 | 42 | depth_img = torch.from_numpy(depth_img.astype(np.float32)) 43 | depth_img = depth_img.cuda() 44 | # depth_img = depth_img.permute(2, 0, 1) 45 | depth_img = depth_img.unsqueeze(0) 46 | 47 | 48 | return depth_img, pcl_uv 49 | 50 | def lidar_project_depth_batch(pc, calib, img_shape): 51 | depth_img_out = [] 52 | for idx in range(pc.shape[0]): 53 | depth_img, _ = lidar_project_depth(pc[idx].transpose(0, 1), calib[idx], img_shape) 54 | depth_img_out.append(depth_img) 55 | 56 | depth_img_out = torch.stack(depth_img_out) 57 | depth_img_out = F.interpolate(depth_img_out, size=[256, 512], mode = 'bilinear', align_corners=False) 58 | return depth_img_out 59 | 60 | 61 | #%% 数据路径读取 62 | root_dir = '/home/zhujt/dataset_zjt/kitti_raw/' # 数据路径 63 | date = '2011_09_26' # 以9-26数据为例 64 | dataset_dir = root_dir 65 | seq_list = os.listdir(os.path.join(root_dir, date)) 66 | 67 | all_files = [] # 用于遍历存储文件路径 68 | 69 | for seq in seq_list: 70 | if not os.path.isdir(os.path.join(dataset_dir, date, seq)): 71 | continue 72 | image_list = os.listdir(os.path.join(dataset_dir, date, seq, 'image_02/data')) 73 | image_list.sort() 74 | for image_name in image_list: 75 | if not os.path.exists(os.path.join(dataset_dir, date, seq, 'velodyne_points/data', 76 | str(image_name.split('.')[0])+'.bin')): 77 | continue 78 | if not os.path.exists(os.path.join(dataset_dir, date, seq, 'image_02/data', 79 | str(image_name.split('.')[0])+'.png')): # png 80 | continue 81 | all_files.append(os.path.join(date, seq, 'image_02/data', image_name.split('.')[0])) 82 | # random.shuffle(all_files) 83 | 84 | 85 | #%% 遍历投影生成深度图 86 | 87 | # 读取标定参数 88 | data = pykitti.raw(root_dir, date, '0001') 89 | calib = {'K2': data.calib.K_cam2, 'K3': data.calib.K_cam3, 90 | 'RT2': data.calib.T_cam2_velo, 'RT3': data.calib.T_cam3_velo} 91 | 92 | for item in all_files: 93 | 94 | date = str(item.split('/')[0]) 95 | seq = str(item.split('/')[1]) 96 | rgb_name = str(item.split('/')[4]) 97 | # 读取点云数据 98 | lidar_path = os.path.join(root_dir, date, seq, 'velodyne_points/data', rgb_name+'.bin') 99 | lidar_scan = np.fromfile(lidar_path, dtype=np.float32) 100 | pc = lidar_scan.reshape((-1, 4)) 101 | valid_indices = pc[:, 0] < -3. 102 | valid_indices = valid_indices | (pc[:, 0] > 3.) 103 | valid_indices = valid_indices | (pc[:, 1] < -3.) 104 | valid_indices = valid_indices | (pc[:, 1] > 3.) 105 | pc = pc[valid_indices].copy() 106 | pc_org = torch.from_numpy(pc.astype(np.float32)) 107 | 108 | # 读取标定参数 109 | RT_cam02 = calib['RT2'].astype(np.float32) 110 | # camera intrinsic parameter 111 | calib_cam02 = calib['K2'] # 3x3 112 | E_RT = RT_cam02 113 | calib_cal = torch.tensor(calib_cam02, dtype = torch.float) 114 | 115 | if pc_org.shape[1] == 4 or pc_org.shape[1] == 3: 116 | pc_org = pc_org.t() 117 | if pc_org.shape[0] == 3: 118 | homogeneous = torch.ones(pc_org.shape[1]).unsqueeze(0) 119 | pc_org = torch.cat((pc_org, homogeneous), 0) 120 | elif pc_org.shape[0] == 4: 121 | if not torch.all(pc_org[3, :] == 1.): 122 | pc_org[3, :] = 1. 123 | else: 124 | raise TypeError("Wrong PointCloud shape") 125 | 126 | pc_rot = np.matmul(E_RT, pc_org.numpy()) 127 | pc_rot = pc_rot.astype(np.float32).copy() 128 | pc_in = torch.from_numpy(pc_rot) 129 | 130 | # 对原始点云下采样用于强化学习流程 131 | transforms = torchvision.transforms.Compose([Resampler(100000)]) 132 | # 为了匹配数据增强的格式引入字典格式 133 | pc_temp = {'points': pc_in} 134 | pc_temp['points'] = pc_temp['points'].transpose(0, 1) 135 | ds_pc = transforms(pc_temp)['points'] 136 | 137 | depth_gt = lidar_project_depth_batch(ds_pc.unsqueeze(0), calib_cal.unsqueeze(0), (384, 1280)) 138 | if not os.path.exists(os.path.join(root_dir, date, seq, 'depth_gt/data')): 139 | os.makedirs(os.path.join(root_dir, date, seq, 'depth_gt/data')) 140 | save_path = os.path.join(root_dir, date, seq, 'depth_gt/data', rgb_name+'.jpg') 141 | cv2.imwrite(save_path, depth_gt[0].permute(1,2,0).cpu().numpy()) # 存储深度图标签 142 | 143 | 144 | #%% 145 | -------------------------------------------------------------------------------- /images/Graphic abstract.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/images/Graphic abstract.jpg -------------------------------------------------------------------------------- /ip_basic/depth_map_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | # Full kernels 7 | FULL_KERNEL_3 = np.ones((3, 3), np.uint8) 8 | FULL_KERNEL_5 = np.ones((5, 5), np.uint8) 9 | FULL_KERNEL_7 = np.ones((7, 7), np.uint8) 10 | FULL_KERNEL_9 = np.ones((9, 9), np.uint8) 11 | FULL_KERNEL_31 = np.ones((31, 31), np.uint8) 12 | 13 | # 3x3 cross kernel 14 | CROSS_KERNEL_3 = np.asarray( 15 | [ 16 | [0, 1, 0], 17 | [1, 1, 1], 18 | [0, 1, 0], 19 | ], dtype=np.uint8) 20 | 21 | # 5x5 cross kernel 22 | CROSS_KERNEL_5 = np.asarray( 23 | [ 24 | [0, 0, 1, 0, 0], 25 | [0, 0, 1, 0, 0], 26 | [1, 1, 1, 1, 1], 27 | [0, 0, 1, 0, 0], 28 | [0, 0, 1, 0, 0], 29 | ], dtype=np.uint8) 30 | 31 | # 5x5 diamond kernel 32 | DIAMOND_KERNEL_5 = np.array( 33 | [ 34 | [0, 0, 1, 0, 0], 35 | [0, 1, 1, 1, 0], 36 | [1, 1, 1, 1, 1], 37 | [0, 1, 1, 1, 0], 38 | [0, 0, 1, 0, 0], 39 | ], dtype=np.uint8) 40 | 41 | # 7x7 cross kernel 42 | CROSS_KERNEL_7 = np.asarray( 43 | [ 44 | [0, 0, 0, 1, 0, 0, 0], 45 | [0, 0, 0, 1, 0, 0, 0], 46 | [0, 0, 0, 1, 0, 0, 0], 47 | [1, 1, 1, 1, 1, 1, 1], 48 | [0, 0, 0, 1, 0, 0, 0], 49 | [0, 0, 0, 1, 0, 0, 0], 50 | [0, 0, 0, 1, 0, 0, 0], 51 | ], dtype=np.uint8) 52 | 53 | # 7x7 diamond kernel 54 | DIAMOND_KERNEL_7 = np.asarray( 55 | [ 56 | [0, 0, 0, 1, 0, 0, 0], 57 | [0, 0, 1, 1, 1, 0, 0], 58 | [0, 1, 1, 1, 1, 1, 0], 59 | [1, 1, 1, 1, 1, 1, 1], 60 | [0, 1, 1, 1, 1, 1, 0], 61 | [0, 0, 1, 1, 1, 0, 0], 62 | [0, 0, 0, 1, 0, 0, 0], 63 | ], dtype=np.uint8) 64 | 65 | 66 | def fill_in_fast(depth_map, max_depth=100.0, custom_kernel=DIAMOND_KERNEL_5, 67 | extrapolate=False, blur_type='bilateral'): 68 | """Fast, in-place depth completion. 69 | 70 | Args: 71 | depth_map: projected depths 72 | max_depth: max depth value for inversion 73 | custom_kernel: kernel to apply initial dilation 74 | extrapolate: whether to extrapolate by extending depths to top of 75 | the frame, and applying a 31x31 full kernel dilation 76 | blur_type: 77 | 'bilateral' - preserves local structure (recommended) 78 | 'gaussian' - provides lower RMSE 79 | 80 | Returns: 81 | depth_map: dense depth map 82 | """ 83 | 84 | # Invert 85 | valid_pixels = (depth_map > 0.1) 86 | depth_map[valid_pixels] = max_depth - depth_map[valid_pixels] 87 | 88 | # Dilate 89 | depth_map = cv2.dilate(depth_map, custom_kernel) 90 | 91 | # Hole closing 92 | depth_map = cv2.morphologyEx(depth_map, cv2.MORPH_CLOSE, FULL_KERNEL_5) 93 | 94 | # Fill empty spaces with dilated values 95 | empty_pixels = (depth_map < 0.1) 96 | dilated = cv2.dilate(depth_map, FULL_KERNEL_7) 97 | depth_map[empty_pixels] = dilated[empty_pixels] 98 | 99 | # Extend highest pixel to top of image 100 | if extrapolate: 101 | top_row_pixels = np.argmax(depth_map > 0.1, axis=0) 102 | top_pixel_values = depth_map[top_row_pixels, range(depth_map.shape[1])] 103 | 104 | for pixel_col_idx in range(depth_map.shape[1]): 105 | depth_map[0:top_row_pixels[pixel_col_idx], pixel_col_idx] = \ 106 | top_pixel_values[pixel_col_idx] 107 | 108 | # Large Fill 109 | empty_pixels = depth_map < 0.1 110 | dilated = cv2.dilate(depth_map, FULL_KERNEL_31) 111 | depth_map[empty_pixels] = dilated[empty_pixels] 112 | 113 | # Median blur 114 | depth_map = cv2.medianBlur(depth_map, 5) 115 | 116 | # Bilateral or Gaussian blur 117 | if blur_type == 'bilateral': 118 | # Bilateral blur 119 | depth_map = cv2.bilateralFilter(depth_map, 5, 1.5, 2.0) 120 | elif blur_type == 'gaussian': 121 | # Gaussian blur 122 | valid_pixels = (depth_map > 0.1) 123 | blurred = cv2.GaussianBlur(depth_map, (5, 5), 0) 124 | depth_map[valid_pixels] = blurred[valid_pixels] 125 | 126 | # Invert 127 | valid_pixels = (depth_map > 0.1) 128 | depth_map[valid_pixels] = max_depth - depth_map[valid_pixels] 129 | 130 | return depth_map 131 | 132 | 133 | def fill_in_multiscale(depth_map, max_depth=100.0, 134 | dilation_kernel_far=CROSS_KERNEL_3, 135 | dilation_kernel_med=CROSS_KERNEL_5, 136 | dilation_kernel_near=CROSS_KERNEL_7, 137 | extrapolate=False, 138 | blur_type='bilateral', 139 | show_process=False): 140 | """Slower, multi-scale dilation version with additional noise removal that 141 | provides better qualitative results. 142 | 143 | Args: 144 | depth_map: projected depths 145 | max_depth: max depth value for inversion 146 | dilation_kernel_far: dilation kernel to use for 30.0 < depths < 80.0 m 147 | dilation_kernel_med: dilation kernel to use for 15.0 < depths < 30.0 m 148 | dilation_kernel_near: dilation kernel to use for 0.1 < depths < 15.0 m 149 | extrapolate:whether to extrapolate by extending depths to top of 150 | the frame, and applying a 31x31 full kernel dilation 151 | blur_type: 152 | 'gaussian' - provides lower RMSE 153 | 'bilateral' - preserves local structure (recommended) 154 | show_process: saves process images into an OrderedDict 155 | 156 | Returns: 157 | depth_map: dense depth map 158 | process_dict: OrderedDict of process images 159 | """ 160 | 161 | # Convert to float32 162 | depths_in = np.float32(depth_map) 163 | 164 | # Calculate bin masks before inversion 165 | valid_pixels_near = (depths_in > 0.1) & (depths_in <= 15.0) 166 | valid_pixels_med = (depths_in > 15.0) & (depths_in <= 30.0) 167 | valid_pixels_far = (depths_in > 30.0) 168 | 169 | # Invert (and offset) 170 | s1_inverted_depths = np.copy(depths_in) 171 | valid_pixels = (s1_inverted_depths > 0.1) 172 | s1_inverted_depths[valid_pixels] = \ 173 | max_depth - s1_inverted_depths[valid_pixels] 174 | 175 | # Multi-scale dilation 176 | dilated_far = cv2.dilate( 177 | np.multiply(s1_inverted_depths, valid_pixels_far), 178 | dilation_kernel_far) 179 | dilated_med = cv2.dilate( 180 | np.multiply(s1_inverted_depths, valid_pixels_med), 181 | dilation_kernel_med) 182 | dilated_near = cv2.dilate( 183 | np.multiply(s1_inverted_depths, valid_pixels_near), 184 | dilation_kernel_near) 185 | 186 | # Find valid pixels for each binned dilation 187 | valid_pixels_near = (dilated_near > 0.1) 188 | valid_pixels_med = (dilated_med > 0.1) 189 | valid_pixels_far = (dilated_far > 0.1) 190 | 191 | # Combine dilated versions, starting farthest to nearest 192 | s2_dilated_depths = np.copy(s1_inverted_depths) 193 | s2_dilated_depths[valid_pixels_far] = dilated_far[valid_pixels_far] 194 | s2_dilated_depths[valid_pixels_med] = dilated_med[valid_pixels_med] 195 | s2_dilated_depths[valid_pixels_near] = dilated_near[valid_pixels_near] 196 | 197 | # Small hole closure 198 | s3_closed_depths = cv2.morphologyEx( 199 | s2_dilated_depths, cv2.MORPH_CLOSE, FULL_KERNEL_5) 200 | 201 | # Median blur to remove outliers 202 | s4_blurred_depths = np.copy(s3_closed_depths) 203 | blurred = cv2.medianBlur(s3_closed_depths, 5) 204 | valid_pixels = (s3_closed_depths > 0.1) 205 | s4_blurred_depths[valid_pixels] = blurred[valid_pixels] 206 | 207 | # Calculate a top mask 208 | top_mask = np.ones(depths_in.shape, dtype=np.bool) 209 | for pixel_col_idx in range(s4_blurred_depths.shape[1]): 210 | pixel_col = s4_blurred_depths[:, pixel_col_idx] 211 | top_pixel_row = np.argmax(pixel_col > 0.1) 212 | top_mask[0:top_pixel_row, pixel_col_idx] = False 213 | 214 | # Get empty mask 215 | valid_pixels = (s4_blurred_depths > 0.1) 216 | empty_pixels = ~valid_pixels & top_mask 217 | 218 | # Hole fill 219 | dilated = cv2.dilate(s4_blurred_depths, FULL_KERNEL_9) 220 | s5_dilated_depths = np.copy(s4_blurred_depths) 221 | s5_dilated_depths[empty_pixels] = dilated[empty_pixels] 222 | 223 | # Extend highest pixel to top of image or create top mask 224 | s6_extended_depths = np.copy(s5_dilated_depths) 225 | top_mask = np.ones(s5_dilated_depths.shape, dtype=np.bool) 226 | 227 | top_row_pixels = np.argmax(s5_dilated_depths > 0.1, axis=0) 228 | top_pixel_values = s5_dilated_depths[top_row_pixels, 229 | range(s5_dilated_depths.shape[1])] 230 | 231 | for pixel_col_idx in range(s5_dilated_depths.shape[1]): 232 | if extrapolate: 233 | s6_extended_depths[0:top_row_pixels[pixel_col_idx], 234 | pixel_col_idx] = top_pixel_values[pixel_col_idx] 235 | else: 236 | # Create top mask 237 | top_mask[0:top_row_pixels[pixel_col_idx], pixel_col_idx] = False 238 | 239 | # Fill large holes with masked dilations 240 | s7_blurred_depths = np.copy(s6_extended_depths) 241 | for i in range(6): 242 | empty_pixels = (s7_blurred_depths < 0.1) & top_mask 243 | dilated = cv2.dilate(s7_blurred_depths, FULL_KERNEL_5) 244 | s7_blurred_depths[empty_pixels] = dilated[empty_pixels] 245 | 246 | # Median blur 247 | blurred = cv2.medianBlur(s7_blurred_depths, 5) 248 | valid_pixels = (s7_blurred_depths > 0.1) & top_mask 249 | s7_blurred_depths[valid_pixels] = blurred[valid_pixels] 250 | 251 | if blur_type == 'gaussian': 252 | # Gaussian blur 253 | blurred = cv2.GaussianBlur(s7_blurred_depths, (5, 5), 0) 254 | valid_pixels = (s7_blurred_depths > 0.1) & top_mask 255 | s7_blurred_depths[valid_pixels] = blurred[valid_pixels] 256 | elif blur_type == 'bilateral': 257 | # Bilateral blur 258 | blurred = cv2.bilateralFilter(s7_blurred_depths, 5, 0.5, 2.0) 259 | s7_blurred_depths[valid_pixels] = blurred[valid_pixels] 260 | 261 | # Invert (and offset) 262 | s8_inverted_depths = np.copy(s7_blurred_depths) 263 | valid_pixels = np.where(s8_inverted_depths > 0.1) 264 | s8_inverted_depths[valid_pixels] = \ 265 | max_depth - s8_inverted_depths[valid_pixels] 266 | 267 | depths_out = s8_inverted_depths 268 | 269 | process_dict = None 270 | if show_process: 271 | process_dict = collections.OrderedDict() 272 | 273 | process_dict['s0_depths_in'] = depths_in 274 | 275 | process_dict['s1_inverted_depths'] = s1_inverted_depths 276 | process_dict['s2_dilated_depths'] = s2_dilated_depths 277 | process_dict['s3_closed_depths'] = s3_closed_depths 278 | process_dict['s4_blurred_depths'] = s4_blurred_depths 279 | process_dict['s5_combined_depths'] = s5_dilated_depths 280 | process_dict['s6_extended_depths'] = s6_extended_depths 281 | process_dict['s7_blurred_depths'] = s7_blurred_depths 282 | process_dict['s8_inverted_depths'] = s8_inverted_depths 283 | 284 | process_dict['s9_depths_out'] = depths_out 285 | 286 | return depths_out, process_dict 287 | -------------------------------------------------------------------------------- /ip_basic/vis_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def cv2_show_image(window_name, image, 5 | size_wh=None, location_xy=None): 6 | """Helper function for specifying window size and location when 7 | displaying images with cv2. 8 | 9 | Args: 10 | window_name: str window name 11 | image: ndarray image to display 12 | size_wh: window size (w, h) 13 | location_xy: window location (x, y) 14 | """ 15 | 16 | if size_wh is not None: 17 | cv2.namedWindow(window_name, 18 | cv2.WINDOW_KEEPRATIO | cv2.WINDOW_GUI_NORMAL) 19 | cv2.resizeWindow(window_name, *size_wh) 20 | else: 21 | cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE) 22 | 23 | if location_xy is not None: 24 | cv2.moveWindow(window_name, *location_xy) 25 | 26 | cv2.imshow(window_name, image) 27 | -------------------------------------------------------------------------------- /ipcv_utils/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------- 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | #------------------------------------------------------------------- 5 | 6 | def imshow(img, cmap="gray", vmin=0, vmax=1, frameon=False, zoom=1.0): 7 | 8 | dpi = float(matplotlib.rcParams['figure.dpi'])/zoom 9 | 10 | fig = plt.figure(figsize=[img.shape[1]/dpi, img.shape[0]/dpi], 11 | frameon=frameon) 12 | ax = fig.add_axes([0, 0, 1, 1]) 13 | ax.axis('off') 14 | ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax) 15 | 16 | # plt.savefig('contrast.png', dpi=300) 17 | plt.show() 18 | #------------------------------------------------------------------- 19 | 20 | def show(close=None, block=None): 21 | 22 | plt.show(close, block) 23 | #------------------------------------------------------------------- 24 | 25 | def imwrite(img, name, cmap="gray", vmin=0, vmax=1, frameon=False, zoom=1.0): 26 | 27 | # dpi = float(matplotlib.rcParams['figure.dpi'])/zoom 28 | 29 | # fig = plt.figure(figsize=[img.shape[1]/dpi, img.shape[0]/dpi], 30 | # frameon=frameon) 31 | fig = plt.figure() 32 | ax = fig.add_axes([0, 0, 1, 1]) 33 | ax.axis('off') 34 | ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax) 35 | 36 | plt.savefig(name+'.png', dpi=300) 37 | # plt.show() -------------------------------------------------------------------------------- /models/base/R_MSFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.base.update import BasicUpdateBlock 6 | 7 | 8 | try: 9 | autocast = torch.cuda.amp.autocast 10 | except: 11 | # dummy autocast for PyTorch < 1.6 12 | class autocast: 13 | def __init__(self, enabled): 14 | pass 15 | 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, *args): 20 | pass 21 | 22 | 23 | 24 | 25 | 26 | class SepConvGRU(nn.Module): 27 | def __init__(self): 28 | super(SepConvGRU, self).__init__() 29 | hidden_dim = 128 30 | catt = 256 31 | 32 | self.convz1 = nn.Conv2d(catt, hidden_dim, (1, 3), padding=(0, 1)) 33 | self.convr1 = nn.Conv2d(catt, hidden_dim, (1, 3), padding=(0, 1)) 34 | self.convq1 = nn.Conv2d(catt, hidden_dim, (1, 3), padding=(0, 1)) 35 | 36 | self.convz2 = nn.Conv2d(catt, hidden_dim, (3, 1), padding=(1, 0)) 37 | self.convr2 = nn.Conv2d(catt, hidden_dim, (3, 1), padding=(1, 0)) 38 | self.convq2 = nn.Conv2d(catt, hidden_dim, (3, 1), padding=(1, 0)) 39 | 40 | def forward(self, h, x): 41 | # horizontal 42 | hx = torch.cat([h, x], dim=1) 43 | z = torch.sigmoid(self.convz1(hx)) 44 | r = torch.sigmoid(self.convr1(hx)) 45 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 46 | h = (1 - z) * h + z * q 47 | 48 | # vertical 49 | hx = torch.cat([h, x], dim=1) 50 | z = torch.sigmoid(self.convz2(hx)) 51 | r = torch.sigmoid(self.convr2(hx)) 52 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 53 | h = (1 - z) * h + z * q 54 | 55 | return h 56 | 57 | 58 | class R_MSFM3(nn.Module): 59 | def __init__(self, x): 60 | super(R_MSFM3, self).__init__() 61 | 62 | self.convX11 = torch.nn.Sequential( 63 | nn.ReflectionPad2d(1), 64 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=0, bias=True), 65 | torch.nn.LeakyReLU(inplace=True), 66 | nn.ReflectionPad2d(1), 67 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 68 | torch.nn.Tanh()) 69 | if x: 70 | self.convX21 = torch.nn.Sequential( 71 | nn.ReflectionPad2d(1), 72 | torch.nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 73 | torch.nn.Tanh()) 74 | self.convX31 = torch.nn.Sequential( 75 | nn.ReflectionPad2d(1), 76 | torch.nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), 77 | torch.nn.Tanh()) 78 | else: 79 | self.convX21 = torch.nn.Sequential( 80 | nn.ReflectionPad2d(1), 81 | torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 82 | torch.nn.Tanh()) 83 | self.convX31 = torch.nn.Sequential( 84 | nn.ReflectionPad2d(1), 85 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), 86 | torch.nn.Tanh()) 87 | 88 | self.sigmoid = nn.Sigmoid() 89 | 90 | self.update_block = BasicUpdateBlock() 91 | self.gruc = SepConvGRU() 92 | def upsample_depth(self, flow, mask): 93 | """ Upsample depth field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 94 | N, _, H, W = flow.shape 95 | mask = mask.view(N, 1, 9, 8, 8, H, W) 96 | mask = torch.softmax(mask, dim=2) 97 | 98 | up_flow = F.unfold(flow, [3, 3], padding=1) 99 | up_flow = up_flow.view(N, 1, 9, 1, 1, H, W) 100 | 101 | up_flow = torch.sum(mask * up_flow, dim=2) 102 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 103 | return up_flow.reshape(N, 1, 8 * H, 8 * W) 104 | 105 | def forward(self, features, iters=3): 106 | """ Estimate depth for a single image """ 107 | 108 | x1, x2, x3 = features 109 | 110 | disp_predictions = {} 111 | b, c, h, w = x3.shape 112 | dispFea = torch.zeros([b, 1, h, w], requires_grad=True).to(x1.device) 113 | net = torch.zeros([b, 256, h, w], requires_grad=True).to(x1.device) 114 | 115 | for itr in range(iters): 116 | if itr in [0]: 117 | corr = self.convX31(x3) 118 | elif itr in [1]: 119 | corrh = corr 120 | corr = self.convX21(x2) 121 | corr = self.gruc(corrh, corr) 122 | elif itr in [2]: 123 | corrh = corr 124 | corr = self.convX11(x1) 125 | corr = self.gruc(corrh, corr) 126 | 127 | net, up_mask, delta_disp = self.update_block(net, corr, dispFea) 128 | dispFea = dispFea + delta_disp 129 | 130 | disp = self.sigmoid(dispFea) 131 | # upsample predictions 132 | if self.training: 133 | disp_up = self.upsample_depth(disp, up_mask) 134 | disp_predictions[("disp_up", itr)] = disp_up 135 | else: 136 | if (iters-1)==itr: 137 | disp_up = self.upsample_depth(disp, up_mask) 138 | disp_predictions[("disp_up", itr)] = disp_up 139 | 140 | 141 | return disp_predictions 142 | 143 | 144 | class R_MSFM6(nn.Module): 145 | def __init__(self,x): 146 | super(R_MSFM6, self).__init__() 147 | 148 | self.convX11 = torch.nn.Sequential( 149 | nn.ReflectionPad2d(1), 150 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=0, bias=True), 151 | torch.nn.LeakyReLU(inplace=True), 152 | nn.ReflectionPad2d(1), 153 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 154 | torch.nn.Tanh()) 155 | 156 | self.convX12 = torch.nn.Sequential( 157 | nn.Conv2d(128, 128, (1, 3), padding=(0, 1)), 158 | torch.nn.Tanh(), 159 | nn.Conv2d(128, 128, (3, 1), padding=(1, 0)), 160 | torch.nn.Tanh()) 161 | 162 | 163 | if x: 164 | self.convX21 = torch.nn.Sequential( 165 | nn.ReflectionPad2d(1), 166 | torch.nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 167 | torch.nn.Tanh()) 168 | self.convX31 = torch.nn.Sequential( 169 | nn.ReflectionPad2d(1), 170 | torch.nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), 171 | torch.nn.Tanh()) 172 | else: 173 | self.convX21 = torch.nn.Sequential( 174 | nn.ReflectionPad2d(1), 175 | torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 176 | torch.nn.Tanh()) 177 | self.convX31 = torch.nn.Sequential( 178 | nn.ReflectionPad2d(1), 179 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0, dilation=1, 180 | bias=True), 181 | torch.nn.Tanh()) 182 | 183 | 184 | 185 | self.convX22 = torch.nn.Sequential( 186 | nn.Conv2d(128, 128, (1, 3), padding=(0, 1)), 187 | torch.nn.Tanh(), 188 | nn.Conv2d(128, 128, (3, 1), padding=(1, 0)), 189 | torch.nn.Tanh()) 190 | 191 | self.convX32 = torch.nn.Sequential( 192 | nn.Conv2d(128, 128, (1, 3), padding=(0, 1)), 193 | torch.nn.Tanh(), 194 | nn.Conv2d(128, 128, (3, 1), padding=(1, 0)), 195 | torch.nn.Tanh()) 196 | 197 | self.sigmoid = nn.Sigmoid() 198 | self.gruc = SepConvGRU() 199 | self.update_block = BasicUpdateBlock() 200 | 201 | def upsample_depth(self, flow, mask): 202 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 203 | N, _, H, W = flow.shape 204 | mask = mask.view(N, 1, 9, 8, 8, H, W) 205 | mask = torch.softmax(mask, dim=2) 206 | 207 | up_flow = F.unfold(flow, [3, 3], padding=1) 208 | up_flow = up_flow.view(N, 1, 9, 1, 1, H, W) 209 | 210 | up_flow = torch.sum(mask * up_flow, dim=2) 211 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 212 | return up_flow.reshape(N, 1, 8 * H, 8 * W) 213 | 214 | def forward(self, features, iters=6): 215 | """ Estimate depth for a single image """ 216 | 217 | x1, x2, x3 = features 218 | 219 | disp_predictions = {} 220 | b, c, h, w = x3.shape 221 | dispFea = torch.zeros([b, 1, h, w], requires_grad=True).to(x1.device) 222 | net = torch.zeros([b, 256, h, w], requires_grad=True).to(x1.device) 223 | 224 | for itr in range(iters): 225 | if itr in [0]: 226 | corr = self.convX31(x3) 227 | elif itr in [1]: 228 | corrh = corr 229 | corr = self.convX32(corr) 230 | corr = self.gruc(corrh, corr) 231 | elif itr in [2]: 232 | corrh = corr 233 | corr = self.convX21(x2) 234 | corr = self.gruc(corrh, corr) 235 | elif itr in [3]: 236 | corrh = corr 237 | corr = self.convX22(corr) 238 | corr = self.gruc(corrh, corr) 239 | elif itr in [4]: 240 | corrh = corr 241 | corr = self.convX11(x1) 242 | corr = self.gruc(corrh, corr) 243 | elif itr in [5]: 244 | corrh = corr 245 | corr = self.convX12(corr) 246 | corr = self.gruc(corrh, corr) 247 | 248 | net, up_mask, delta_disp = self.update_block(net, corr, dispFea) 249 | dispFea = dispFea + delta_disp 250 | 251 | disp = self.sigmoid(dispFea) 252 | # upsample predictions 253 | 254 | if self.training: 255 | disp_up = self.upsample_depth(disp, up_mask) 256 | disp_predictions[("disp_up", itr)] = disp_up 257 | else: 258 | if (iters-1)==itr: 259 | disp_up = self.upsample_depth(disp, up_mask) 260 | disp_predictions[("disp_up", itr)] = disp_up 261 | 262 | 263 | return disp_predictions 264 | -------------------------------------------------------------------------------- /models/base/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.models as models 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | 17 | class ResNetMultiImageInput(models.ResNet): 18 | """Constructs a resnet model with varying number of input images. 19 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 20 | """ 21 | def __init__(self, block, layers, num_classes=1000, num_input_images=1): 22 | super(ResNetMultiImageInput, self).__init__(block, layers) 23 | self.inplanes = 64 24 | self.conv1 = nn.Conv2d( 25 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 26 | self.bn1 = nn.BatchNorm2d(64) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 29 | self.layer1 = self._make_layer(block, 64, layers[0]) 30 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 32 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | elif isinstance(m, nn.BatchNorm2d): 38 | nn.init.constant_(m.weight, 1) 39 | nn.init.constant_(m.bias, 0) 40 | 41 | 42 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): 43 | """Constructs a ResNet model. 44 | Args: 45 | num_layers (int): Number of resnet layers. Must be 18 or 50 46 | pretrained (bool): If True, returns a model pre-trained on ImageNet 47 | num_input_images (int): Number of frames stacked as input 48 | """ 49 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 50 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 51 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 52 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 53 | 54 | if pretrained: 55 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 56 | loaded['conv1.weight'] = torch.cat( 57 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images 58 | model.load_state_dict(loaded) 59 | return model 60 | 61 | 62 | class ResnetEncoder(nn.Module): 63 | """Pytorch module for a resnet encoder 64 | """ 65 | def __init__(self, num_layers, pretrained, num_input_images=1): 66 | super(ResnetEncoder, self).__init__() 67 | 68 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 69 | 70 | resnets = {18: models.resnet18, 71 | 34: models.resnet34, 72 | 50: models.resnet50, 73 | 101: models.resnet101, 74 | 152: models.resnet152} 75 | 76 | if num_layers not in resnets: 77 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 78 | 79 | if num_input_images > 1: 80 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 81 | else: 82 | self.encoder = resnets[num_layers](pretrained) 83 | 84 | if num_layers > 34: 85 | self.num_ch_enc[1:] *= 4 86 | 87 | def forward(self, input_image): 88 | self.features = [] 89 | x = (input_image - 0.45) / 0.225 90 | x = self.encoder.conv1(x) 91 | x = self.encoder.bn1(x) 92 | self.features.append(self.encoder.relu(x)) 93 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 94 | self.features.append(self.encoder.layer2(self.features[-1])) 95 | # self.features.append(self.encoder.layer3(self.features[-1])) 96 | # self.features.append(self.encoder.layer4(self.features[-1])) 97 | 98 | return self.features 99 | 100 | 101 | class ResnetEncoder2(nn.Module): 102 | """Pytorch module for a resnet encoder 103 | """ 104 | def __init__(self, num_layers, pretrained, num_input_images=1): 105 | super(ResnetEncoder2, self).__init__() 106 | 107 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 108 | 109 | resnets = {18: models.resnet18, 110 | 34: models.resnet34, 111 | 50: models.resnet50, 112 | 101: models.resnet101, 113 | 152: models.resnet152} 114 | 115 | if num_layers not in resnets: 116 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 117 | 118 | if num_input_images > 1: 119 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 120 | else: 121 | self.encoder = resnets[num_layers](pretrained) 122 | 123 | if num_layers > 34: 124 | self.num_ch_enc[1:] *= 4 125 | 126 | def forward(self, input_image): 127 | self.features = [] 128 | x = (input_image - 0.45) / 0.225 129 | x = self.encoder.conv1(x) 130 | x = self.encoder.bn1(x) 131 | self.features.append(self.encoder.relu(x)) 132 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 133 | self.features.append(self.encoder.layer2(self.features[-1])) 134 | self.features.append(self.encoder.layer3(self.features[-1])) 135 | self.features.append(self.encoder.layer4(self.features[-1])) 136 | 137 | return self.features 138 | -------------------------------------------------------------------------------- /models/base/update.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('..') 4 | import torch.nn as nn 5 | import torch 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | """Layer to perform a convolution followed by LeakyReLU 10 | """ 11 | 12 | def __init__(self, in_channels, out_channels): 13 | super(ConvBlock, self).__init__() 14 | 15 | self.conv = Conv3x3(in_channels, out_channels) 16 | self.nonlin = nn.LeakyReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | out = self.conv(x) 20 | out = self.nonlin(out) 21 | return out 22 | 23 | 24 | class Conv3x3(nn.Module): 25 | """Layer to pad and convolve input 26 | """ 27 | 28 | def __init__(self, in_channels, out_channels, use_refl=True): 29 | super(Conv3x3, self).__init__() 30 | 31 | if use_refl: 32 | self.pad = nn.ReflectionPad2d(1) 33 | else: 34 | self.pad = nn.ZeroPad2d(1) 35 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 36 | 37 | def forward(self, x): 38 | out = self.pad(x) 39 | out = self.conv(out) 40 | return out 41 | 42 | 43 | class dispHead(nn.Module): 44 | def __init__(self): 45 | super(dispHead, self).__init__() 46 | outD = 1 47 | 48 | self.covd1 = torch.nn.Sequential(nn.ReflectionPad2d(1), 49 | torch.nn.Conv2d(in_channels=192, out_channels=256, kernel_size=3, stride=1, 50 | padding=0, bias=True), 51 | torch.nn.LeakyReLU(inplace=True)) 52 | 53 | self.covd2 = torch.nn.Sequential(nn.ReflectionPad2d(1), 54 | torch.nn.Conv2d(in_channels=256, out_channels=outD, kernel_size=3, stride=1, 55 | padding=0, bias=True)) 56 | 57 | def forward(self, x): 58 | return self.covd2(self.covd1(x)) 59 | 60 | 61 | class BasicMotionEncoder(nn.Module): 62 | def __init__(self): 63 | super(BasicMotionEncoder, self).__init__() 64 | # inD = 1 65 | 66 | self.convc1 = ConvBlock(128, 160) 67 | self.convc2 = ConvBlock(160, 128) 68 | 69 | self.convf1 = torch.nn.Sequential( 70 | nn.ReflectionPad2d(3), 71 | torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, padding=0, bias=True), 72 | torch.nn.LeakyReLU(inplace=True)) 73 | self.convf2 = torch.nn.Sequential( 74 | nn.ReflectionPad2d(1), 75 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=0, bias=True), 76 | torch.nn.LeakyReLU(inplace=True)) 77 | 78 | self.conv = ConvBlock(128 + 32, 192 - 1) 79 | 80 | def forward(self, depth, corr): 81 | cor = self.convc1(corr) 82 | cor = self.convc2(cor) 83 | 84 | dep = self.convf1(depth) 85 | dep = self.convf2(dep) 86 | 87 | cor_depth = torch.cat([cor, dep], dim=1) 88 | out = self.conv(cor_depth) 89 | return torch.cat([out, depth], dim=1) 90 | 91 | 92 | class BasicUpdateBlock(nn.Module): 93 | def __init__(self): 94 | super(BasicUpdateBlock, self).__init__() 95 | self.encoder = BasicMotionEncoder() 96 | 97 | self.flow_head = dispHead() 98 | 99 | self.mask = nn.Sequential( 100 | nn.ReflectionPad2d(1), 101 | nn.Conv2d(192, 324, 3), 102 | nn.LeakyReLU(inplace=True), 103 | nn.Conv2d(324, 64 * 9, 1, padding=0)) 104 | 105 | def forward(self, net, corr, depth): 106 | net = self.encoder(depth, corr) 107 | delta_depth = self.flow_head(net) 108 | 109 | # scale mask to balence gradients 110 | mask = .25 * self.mask(net) 111 | 112 | return net, mask, delta_depth 113 | 114 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from re import M 2 | from turtle import forward 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | import torchvision.models as models 7 | from torch.distributions.normal import Normal 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | # depth 12 | import models.base.resnet_encoder as resnet_encoder 13 | from models.base.R_MSFM import R_MSFM3 14 | 15 | def disp_to_depth(disp, min_depth, max_depth): 16 | """Convert network's sigmoid output into depth prediction 17 | """ 18 | min_disp = 1 / max_depth 19 | max_disp = 1 / min_depth 20 | scaled_disp = min_disp + (max_disp - min_disp) * disp 21 | depth = 1 / scaled_disp 22 | return scaled_disp, depth 23 | 24 | def MyConv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 25 | return nn.Sequential( 26 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 27 | groups=1, bias=True), 28 | nn.ReLU()) 29 | 30 | 31 | def nin_block(in_channels, out_channels, kernel_size, strides, padding): 32 | return nn.Sequential( 33 | nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding), 34 | # nn.LeakyReLU(0.1), 35 | nn.ReLU(), 36 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0), nn.ReLU(), 37 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0), nn.ReLU()) 38 | 39 | 40 | 41 | # 特征提取部分 42 | class StateEmbed(nn.Module): 43 | def __init__(self): 44 | super(StateEmbed, self).__init__() 45 | 46 | # 图像分支 47 | self.resnet_encoder = resnet_encoder.ResnetEncoder(18, True) 48 | self.rgb_depth_decoder = R_MSFM3(False) 49 | self.rgb_dep_encoder = nn.Sequential( # 对图像分支深度图编码 50 | MyConv(1,16,8,4,2), 51 | MyConv(16,32,4,2,1), 52 | ) 53 | self.rgb_res_encoder = MyConv(128, 32, # 用于残差特征提 54 | kernel_size=1, padding=0, stride=1) 55 | 56 | # 点云深度图分支 57 | self.depth_backbone = resnet_encoder.ResnetEncoder(18, True) 58 | self.depth_encoder = MyConv(128,32,kernel_size=1,padding=0,stride=1) # 对点云分支深度图编码 59 | 60 | # 特征融合 61 | self.match_layer = nin_block(64, 32, 8, 4, 2) 62 | 63 | self.match_block = nn.Sequential( 64 | MyConv(32, 64, kernel_size=1, padding=0, stride=1), 65 | MyConv(64, 64, kernel_size=3, padding=1, stride=1), 66 | MyConv(64, 32, kernel_size=1, padding=0, stride=1) 67 | ) 68 | self.leakyRELU = nn.LeakyReLU(0.1) 69 | 70 | def forward(self, rgb_img, depth_img): 71 | # 图像分支 72 | res_emb = self.resnet_encoder(rgb_img) 73 | rgb_dep = self.rgb_depth_decoder(res_emb)[("disp_up", 2)] 74 | _, scale_dep = disp_to_depth(rgb_dep, 0.1, 80) # 获取中间深度图 75 | scale_dep_emb = self.rgb_dep_encoder(scale_dep/80) # 图像分支深度图编码 76 | _, _, x3 = res_emb 77 | rgb_ori_emb = self.rgb_res_encoder(x3) # 用于残差特征提取 78 | rgb_emb = torch.add(rgb_ori_emb, scale_dep_emb) # 残差特征提取与深度图编码融合 79 | 80 | 81 | # 点云分支 82 | _, _, depth_emb = self.depth_backbone(depth_img.expand(-1,3,-1,-1)) 83 | depth_emb = self.depth_encoder(depth_emb) 84 | 85 | # 特征融合 86 | match_emb = torch.cat((rgb_emb, depth_emb), dim=1) 87 | match_emb = self.match_layer(match_emb) 88 | match_emb = match_emb + self.match_block(match_emb) 89 | match_emb = self.leakyRELU(match_emb) 90 | 91 | return match_emb, scale_dep # 返回融合特征和中间深度图 92 | 93 | 94 | # 标定动作预测部分 95 | class CalibActionHead(nn.Module): 96 | def __init__(self): 97 | super(CalibActionHead, self).__init__() 98 | self.activation = nn.ReLU() 99 | self.input_dim = 32*8*16 100 | self.head_dim = 128 101 | 102 | self.lstm = nn.LSTM(input_size = self.input_dim, hidden_size = 2*self.head_dim, 103 | num_layers = 2, batch_first = True, dropout = 0.5) 104 | 105 | self.emb_r = nn.Sequential( 106 | nn.Linear(self.head_dim*2, self.head_dim), 107 | self.activation 108 | ) 109 | 110 | self.emb_t = nn.Sequential( 111 | nn.Linear(self.head_dim*2, self.head_dim), 112 | self.activation 113 | ) 114 | self.action_t = nn.Linear(self.head_dim, 3) 115 | self.action_r = nn.Linear(self.head_dim, 3) 116 | 117 | def forward(self, state, h_n, c_n): 118 | state = state.view(state.shape[0], -1) 119 | 120 | output, (h_n, c_n) = self.lstm(state.unsqueeze(1), (h_n, c_n)) 121 | emb_t = self.emb_t(output) 122 | emb_r = self.emb_r(output) 123 | 124 | action_mean_t = self.action_t(emb_t).squeeze(1) 125 | action_mean_r = self.action_r(emb_r).squeeze(1) 126 | action_mean = [action_mean_t, action_mean_r] 127 | 128 | return action_mean, (h_n, c_n) 129 | 130 | 131 | class Agent(nn.Module): 132 | def __init__(self): 133 | super(Agent, self).__init__() 134 | self.state_emb = StateEmbed() 135 | self.calib_action = CalibActionHead() 136 | def forward(self, rgb_img, depth_img, h_last, c_last): 137 | state_emb, predict_depth = self.state_emb(rgb_img, depth_img) 138 | action_mean, hc = self.calib_action(state_emb, h_last, c_last) 139 | 140 | return action_mean, predict_depth, hc 141 | 142 | class berHuLoss(nn.Module): 143 | def __init__(self): 144 | super(berHuLoss, self).__init__() 145 | 146 | def forward(self, pred, target): 147 | assert pred.dim() == target.dim(), "inconsistent dimensions" 148 | 149 | huber_c = torch.max(torch.abs(pred - target)) 150 | huber_c = 0.2 * huber_c 151 | 152 | valid_mask = (target > 0).detach() 153 | diff = target - pred 154 | diff = diff[valid_mask] 155 | diff = diff.abs() 156 | 157 | huber_mask = (diff > huber_c).detach() 158 | 159 | diff2 = diff[huber_mask] 160 | diff2 = diff2 ** 2 161 | 162 | self.loss = torch.cat((diff, diff2)).mean() 163 | 164 | return self.loss 165 | 166 | 167 | # --- model helpers 168 | def load(model, path): 169 | infos = torch.load(path) 170 | model.load_state_dict(infos['model_state_dict']) 171 | return infos 172 | 173 | 174 | def save(model, path, infos={}): 175 | infos['model_state_dict'] = model.state_dict() 176 | torch.save(infos, path) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | apptools==5.2.0 5 | argon2-cffi==21.3.0 6 | argon2-cffi-bindings==21.2.0 7 | async-generator==1.10 8 | async-timeout==4.0.2 9 | asynctest==0.13.0 10 | attrs==21.4.0 11 | backcall==0.2.0 12 | bleach==4.1.0 13 | cached-property==1.5.2 14 | cachetools==4.2.4 15 | certifi==2021.5.30 16 | cffi==1.15.0 17 | charset-normalizer==2.0.12 18 | configobj==5.0.6 19 | cycler==0.11.0 20 | dataclasses @ file:///tmp/build/80754af9/dataclasses_1614363715916/work 21 | decorator==4.4.2 22 | defusedxml==0.7.1 23 | descartes==1.1.0 24 | docopt==0.6.2 25 | drawnow==0.72.5 26 | entrypoints==0.4 27 | envisage==6.1.0 28 | fire==0.4.0 29 | frozenlist==1.2.0 30 | future==0.18.2 31 | google-auth==1.35.0 32 | google-auth-oauthlib==0.4.6 33 | grpcio==1.44.0 34 | h5py==3.1.0 35 | idna==3.3 36 | idna-ssl==1.1.0 37 | imageio==2.15.0 38 | imageio-ffmpeg==0.4.7 39 | importlib-metadata==4.8.3 40 | importlib-resources==5.4.0 41 | iniconfig==1.1.1 42 | ipykernel==5.5.6 43 | ipython==7.16.3 44 | ipython-genutils==0.2.0 45 | ipywidgets==7.7.0 46 | jedi==0.17.2 47 | Jinja2==3.0.3 48 | joblib==1.1.0 49 | jsonpickle==0.9.6 50 | jsonschema==3.2.0 51 | jupyter==1.0.0 52 | jupyter-client==7.1.2 53 | jupyter-console==6.4.3 54 | jupyter-core==4.9.2 55 | jupyterlab-pygments==0.1.2 56 | jupyterlab-widgets==1.1.0 57 | kiwisolver==1.3.1 58 | Markdown==3.3.6 59 | MarkupSafe==2.0.1 60 | mathutils @ git+https://gitlab.com/m1lhaus/blender-mathutils.git@74ca2f141213227e41f8101fb8455ed759a1d024 61 | matplotlib==3.3.4 62 | mayavi==4.8.0 63 | mistune==0.8.4 64 | mkl-fft==1.3.0 65 | mkl-random==1.1.1 66 | mkl-service==2.3.0 67 | moviepy==1.0.3 68 | multidict==5.2.0 69 | munch==2.5.0 70 | nbclient==0.5.9 71 | nbconvert==6.0.7 72 | nbformat==5.1.3 73 | nest-asyncio==1.5.5 74 | networkx==2.5.1 75 | notebook==6.4.10 76 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603487797006/work 77 | nuscenes-devkit==1.1.9 78 | oauthlib==3.2.0 79 | olefile==0.46 80 | open3d-python==0.7.0.0 81 | opencv-python==4.5.1.48 82 | packaging==21.3 83 | pandas==1.1.5 84 | pandocfilters==1.5.0 85 | parso==0.7.1 86 | pexpect==4.8.0 87 | pickleshare==0.7.5 88 | Pillow==8.4.0 89 | pluggy==1.0.0 90 | prefetch-generator==1.0.1 91 | proglog==0.1.10 92 | prometheus-client==0.14.1 93 | prompt-toolkit==3.0.29 94 | protobuf==3.19.4 95 | ptyprocess==0.7.0 96 | py==1.11.0 97 | py-cpuinfo==8.0.0 98 | pyasn1==0.4.8 99 | pyasn1-modules==0.2.8 100 | pycocotools==2.0.4 101 | pycparser==2.21 102 | pyface==7.4.2 103 | Pygments==2.12.0 104 | pykitti==0.3.1 105 | pyparsing==3.0.8 106 | pypng==0.20220715.0 107 | pyquaternion==0.9.9 108 | pyrsistent==0.18.0 109 | pytest==7.0.1 110 | python-dateutil==2.8.2 111 | pytz==2022.1 112 | PyWavelets==1.1.1 113 | pyzmq==22.3.0 114 | qtconsole==5.2.2 115 | QtPy==2.0.1 116 | requests==2.27.1 117 | requests-oauthlib==1.3.1 118 | rsa==4.8 119 | sacred==0.7.4 120 | scikit-image==0.17.2 121 | scikit-learn==0.24.2 122 | scipy==1.5.2 123 | Send2Trash==1.8.0 124 | Shapely==1.8.2 125 | six @ file:///tmp/build/80754af9/six_1644875935023/work 126 | tensorboard==2.4.1 127 | tensorboard-plugin-wit==1.8.1 128 | tensorboardX==2.5 129 | termcolor==1.1.0 130 | terminado==0.12.1 131 | testpath==0.6.0 132 | threadpoolctl==3.1.0 133 | tifffile==2020.9.3 134 | tomli==1.2.3 135 | torch==1.7.0 136 | torch-tb-profiler==0.4.0 137 | torchaudio==0.7.0a0+ac17b64 138 | torchstat==0.0.7 139 | torchvision==0.8.0 140 | tornado==6.1 141 | tqdm==4.48.2 142 | traitlets==4.3.3 143 | traits==6.4.1 144 | traitsui==7.4.0 145 | transforms3d==0.3.1 146 | trimesh==3.9.9 147 | typing_extensions @ file:///opt/conda/conda-bld/typing_extensions_1647553014482/work 148 | urllib3==1.26.9 149 | vtk==9.1.0 150 | wcwidth==0.2.5 151 | webencodings==0.5.1 152 | Werkzeug==2.0.3 153 | widgetsnbextension==3.6.0 154 | wrapt==1.14.0 155 | wslink==1.8.2 156 | yarl==1.7.2 157 | zipp==3.6.0 158 | -------------------------------------------------------------------------------- /save_fig/20_dg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_dg.png -------------------------------------------------------------------------------- /save_fig/20_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_gt.png -------------------------------------------------------------------------------- /save_fig/20_iter0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_iter0.png -------------------------------------------------------------------------------- /save_fig/20_iter1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_iter1.png -------------------------------------------------------------------------------- /save_fig/20_iter2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_iter2.png -------------------------------------------------------------------------------- /save_fig/20_iter3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_iter3.png -------------------------------------------------------------------------------- /save_fig/20_pred_d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_pred_d.png -------------------------------------------------------------------------------- /save_fig/20_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/save_fig/20_rgb.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | from tqdm import tqdm 9 | from prefetch_generator import BackgroundGenerator 10 | 11 | from utility.logger import Logger 12 | import utility.metrics as metrics 13 | from utility.quaternion_distances import quaternion_distance 14 | from models.model import Agent 15 | import models.model as util_model 16 | from dataset.DatasetLidarCam import DatasetKittiRawCalibNet 17 | from dataset.DatasetLidarCam import lidar_project_depth 18 | from dataset.data_utils import (merge_inputs, quaternion_from_matrix) 19 | 20 | from environment import environment as env 21 | from environment import transformations as tra 22 | from environment.buffer import Buffer 23 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | np.random.seed(42) 25 | 26 | 27 | def lidar_project_depth_batch(pc, calib, img_shape): 28 | depth_img_out = [] 29 | for idx in range(pc.shape[0]): 30 | depth_img, _ = lidar_project_depth(pc[idx].transpose(0, 1), calib[idx], img_shape) 31 | depth_img = depth_img.to(DEVICE) 32 | depth_img_out.append(depth_img) 33 | 34 | depth_img_out = torch.stack(depth_img_out) 35 | depth_img_out = F.interpolate(depth_img_out, size=[256, 512], mode = 'bilinear', align_corners=False) 36 | return depth_img_out 37 | 38 | def parse_args(): 39 | # fmt: off 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--data_folder", type=str, default='/home/zhujt/dataset_zjt/kitti_raw/', 42 | help="the data path") 43 | parser.add_argument("--dataset", type=str, default='raw', choices = ['raw', 'odometry', 'raw_calibNet'], 44 | help="the data type") 45 | parser.add_argument("--test_type", type=str, default='generalization', choices = ['precision', 'generalization'], 46 | help="the test type") 47 | parser.add_argument("--load_model", type=str, default='weights_lstm/raw_1.zip', 48 | help="the model to load") 49 | parser.add_argument("--val_sequence", type=list, default=['2011_09_26_drive_0005_sync', '2011_09_26_drive_0070_sync'], 50 | help="the data for valuation") 51 | # parser.add_argument("--val_sequence_generalization", type=list, 52 | # default=['2011_09_30_drive_0016_sync', '2011_09_30_drive_0018_sync', '2011_09_30_drive_0020_sync', 53 | # '2011_09_30_drive_0028_sync', '2011_09_30_drive_0033_sync', '2011_09_30_drive_0034_sync', 54 | # '2011_09_30_drive_0027_sync'], 55 | # help="the data for evaluate the generalization") 56 | 57 | # parser.add_argument("--val_sequence_generalization", type=list, 58 | # default=['2011_09_30_drive_0028_sync'], 59 | # help="the data for evaluate the generalization") 60 | 61 | parser.add_argument("--max_t", type=float, default=0.25, 62 | help="the translation decalibration range") 63 | parser.add_argument("--max_r", type=float, default=10., 64 | help="the rotation decalibration range") 65 | parser.add_argument("--ITER_EVAL", type=int, default=5, 66 | help="value iterations") 67 | parser.add_argument("--batch_size", type=int, default=1, 68 | help="the batch size for data collection") 69 | parser.add_argument("--num_worker", type=int, default=5, 70 | help="the worker nums for training") 71 | args = parser.parse_args() 72 | return args 73 | 74 | def evaluate(agent, data_path, max_t, max_r, val_sequence): 75 | args = parse_args() 76 | 77 | dataset_class = DatasetKittiRawCalibNet 78 | dataset_val = dataset_class(data_path, max_r=max_r, max_t=max_t, split='val', 79 | use_reflectance=False, val_sequence=val_sequence) 80 | ValImgLoader = torch.utils.data.DataLoader(dataset=dataset_val, 81 | shuffle=False, 82 | batch_size=args.batch_size, 83 | num_workers=args.num_worker, 84 | # worker_init_fn=init_fn, 85 | collate_fn=merge_inputs, 86 | drop_last=False, 87 | pin_memory=True) 88 | print(len(ValImgLoader)) 89 | 90 | agent.eval() 91 | progress = tqdm(BackgroundGenerator(ValImgLoader), total=len(ValImgLoader)) 92 | 93 | predictions = [] 94 | with torch.no_grad(): 95 | for data in progress: 96 | 97 | rgb_input, depth_input, depth_target, pose_target, pose_source, ds_pc_target, ds_pc_source, calib = env.init(data) 98 | 99 | current_source = ds_pc_source 100 | current_depth = depth_input 101 | 102 | for step in range(args.ITER_EVAL): 103 | # expert prediction 104 | if(step == 0): 105 | 106 | actions, _, hc = agent(rgb_input, current_depth, torch.zeros(2, depth_input.shape[0], 256).to(DEVICE), torch.zeros(2, depth_input.shape[0], 256).to(DEVICE)) 107 | else: 108 | actions, _, hc = agent(rgb_input, current_depth, h_last, c_last) 109 | 110 | h_last, c_last = hc[0], hc[1] 111 | 112 | action_t, action_r = actions[0].unsqueeze(1), actions[1].unsqueeze(1) 113 | 114 | action_tr = torch.cat([action_t, action_r], dim = 1) 115 | new_source, pose_source = env.step_continous(ds_pc_source, action_tr, pose_source) 116 | 117 | current_source = new_source 118 | current_depth = lidar_project_depth_batch(current_source, calib, (384, 1280)) 119 | current_depth /= 80 120 | 121 | predictions.append(pose_source) 122 | 123 | predictions = torch.cat(predictions) 124 | eval_metrics, summary_metrics = metrics.compute_stats(predictions, data_loader=ValImgLoader) 125 | 126 | # log test metrics 127 | print(f"MAE R: {summary_metrics['r_mae']:0.4f}") 128 | print(f"MAE rr: {summary_metrics['rr_mae']:0.4f}") 129 | print(f"MAE pp: {summary_metrics['yy_mae']:0.4f}") 130 | print(f"MAE yy: {summary_metrics['pp_mae']:0.4f}") 131 | 132 | print(f"qdMAE R: {summary_metrics['qd_error']:0.4f}") 133 | print(f"MAE t: {summary_metrics['t_mae']:0.6f}") 134 | print(f"MAE x: {summary_metrics['x_mae']:0.6f}") 135 | print(f"MAE y: {summary_metrics['y_mae']:0.6f}") 136 | print(f"MAE z: {summary_metrics['z_mae']:0.6f}") 137 | 138 | print(f"ISO R: {summary_metrics['r_iso']:0.4f}") 139 | print(f"ISO t: {summary_metrics['t_iso']:0.6f}") 140 | 141 | 142 | 143 | if __name__ == '__main__': 144 | args = parse_args() 145 | dataset = args.data_folder 146 | code_path = os.path.dirname(os.path.abspath(__file__)) 147 | 148 | pretrain = os.path.join(code_path, args.load_model) 149 | print(" loading weights...") 150 | agent = Agent().to(DEVICE) 151 | if os.path.exists(pretrain): 152 | util_model.load(agent, pretrain) 153 | else: 154 | raise FileNotFoundError(f"No weights found at {pretrain}. Download pretrained weights or run training first.") 155 | 156 | evaluate(agent, dataset, args.max_t, args.max_r, args.val_sequence) 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /test/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/test/test.png -------------------------------------------------------------------------------- /test/test_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from distutils.command.build_scripts import first_line_re 3 | import sys 4 | sys.path.append('/home/zhujt/code_calib/CalibDepth') 5 | import pytest 6 | import torch 7 | import tqdm 8 | from prefetch_generator import BackgroundGenerator 9 | from dataset.DatasetLidarCam import DatasetKittiRawCalibNet as DatasetKittiRawCalibNet 10 | from dataset.DatasetLidarCam import lidar_project_depth, Resampler 11 | from dataset.data_utils import merge_inputs 12 | from environment import environment as env 13 | import pykitti 14 | import os 15 | import numpy as np 16 | import torchvision 17 | from torchvision import transforms as transf 18 | from PIL import Image 19 | import cv2 20 | 21 | 22 | def test_data_init(): 23 | """验证数据处理后的维度是否符合模型的输入要求""" 24 | # 路径获取 25 | root_dir = '/home/zhujt/dataset_zjt/kitti_raw/' 26 | date = '2011_09_26' 27 | dataset_dir = root_dir 28 | seq_list = os.listdir(os.path.join(root_dir, date)) 29 | seq = seq_list[0] 30 | image_list = os.listdir(os.path.join(dataset_dir, date, seq, 'image_02/data')) 31 | image_name = image_list[0] 32 | item = os.path.join(date, seq, 'image_02/data', image_name.split('.')[0]) 33 | data = pykitti.raw(root_dir, date, '0001') 34 | calib = {'K2': data.calib.K_cam2, 'K3': data.calib.K_cam3, 35 | 'RT2': data.calib.T_cam2_velo, 'RT3': data.calib.T_cam3_velo} 36 | date = str(item.split('/')[0]) 37 | seq = str(item.split('/')[1]) 38 | rgb_name = str(item.split('/')[4]) 39 | 40 | # 读取图像数据 41 | img_path = os.path.join(root_dir, date, seq, 'image_02/data', rgb_name+'.png') # png 42 | img = Image.open(img_path) 43 | to_tensor = transf.ToTensor() 44 | img = to_tensor(img) 45 | real_shape = [img.shape[1], img.shape[2], img.shape[0]] 46 | 47 | # 读取点云数据 48 | lidar_path = os.path.join(root_dir, date, seq, 'velodyne_points/data', rgb_name+'.bin') 49 | lidar_scan = np.fromfile(lidar_path, dtype=np.float32) 50 | pc = lidar_scan.reshape((-1, 4)) 51 | valid_indices = pc[:, 0] < -3. 52 | valid_indices = valid_indices | (pc[:, 0] > 3.) 53 | valid_indices = valid_indices | (pc[:, 1] < -3.) 54 | valid_indices = valid_indices | (pc[:, 1] > 3.) 55 | pc = pc[valid_indices].copy() 56 | pc_org = torch.from_numpy(pc.astype(np.float32)) 57 | 58 | # 读取标定参数 59 | RT_cam02 = calib['RT2'].astype(np.float32) 60 | # camera intrinsic parameter 61 | calib_cam02 = calib['K2'] # 3x3 62 | E_RT = RT_cam02 63 | calib_cal = torch.tensor(calib_cam02, dtype = torch.float) 64 | 65 | if pc_org.shape[1] == 4 or pc_org.shape[1] == 3: 66 | pc_org = pc_org.t() 67 | if pc_org.shape[0] == 3: 68 | homogeneous = torch.ones(pc_org.shape[1]).unsqueeze(0) 69 | pc_org = torch.cat((pc_org, homogeneous), 0) 70 | elif pc_org.shape[0] == 4: 71 | if not torch.all(pc_org[3, :] == 1.): 72 | pc_org[3, :] = 1. 73 | else: 74 | raise TypeError("Wrong PointCloud shape") 75 | 76 | pc_rot = np.matmul(E_RT, pc_org.numpy()) 77 | pc_rot = pc_rot.astype(np.float32).copy() 78 | pc_in = torch.from_numpy(pc_rot) 79 | 80 | transforms = torchvision.transforms.Compose([Resampler(100000)]) #采样多少个点 81 | pc_temp = {'points': pc_in} 82 | pc_temp['points'] = pc_temp['points'].transpose(0, 1) 83 | ds_pc = transforms(pc_temp)['points'].transpose(0, 1) 84 | 85 | depth_img, uv = lidar_project_depth(ds_pc, calib_cal, real_shape) 86 | 87 | assert depth_img.shape[1] == real_shape[0] and depth_img.shape[2] == real_shape[1] 88 | 89 | # 存储原图以及深度图到当前路径 90 | img = img.numpy() 91 | img = np.transpose(img, (1, 2, 0)) * 255 92 | 93 | depth_img = depth_img.numpy() 94 | depth_img = np.transpose(depth_img, (1, 2, 0)) * 255 95 | depth_img = np.concatenate((depth_img, depth_img, depth_img), axis=2) 96 | depth_img = cv2.cvtColor(depth_img, cv2.COLOR_RGB2BGR) 97 | # print(img.shape) 98 | # print(depth_img.shape) 99 | 100 | # 存储图片看投影效果; 101 | cv2.imwrite('/home/zhujt/code_calib/CalibDepth/test/test.png', img) 102 | cv2.imwrite('/home/zhujt/code_calib/CalibDepth/test/test_depth.png', depth_img) 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == 'main': 109 | test_data_init() -------------------------------------------------------------------------------- /test/test_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/test/test_depth.png -------------------------------------------------------------------------------- /test/test_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/zhujt/code_calib/CalibDepth') 3 | import pytest 4 | import torch 5 | from models.model import Agent 6 | 7 | @pytest.fixture 8 | def sample_input(): 9 | batch_size = 2 10 | channels = 3 11 | height = 256 12 | width = 512 13 | h_n = torch.zeros(2, batch_size, 256) 14 | c_n = torch.zeros(2, batch_size, 256) 15 | 16 | rgb_img = torch.rand(batch_size, channels, height, width) 17 | depth_img = torch.rand(batch_size, 1, height, width) 18 | 19 | return rgb_img, depth_img, h_n, c_n 20 | 21 | 22 | def test_agent_forward(sample_input): 23 | rgb_img, depth_img, h_n, c_n = sample_input 24 | agent = Agent() 25 | 26 | action_mean, predict_depth, hc = agent(rgb_img, depth_img, h_n, c_n) 27 | 28 | assert isinstance(action_mean, list) 29 | assert len(action_mean) == 2 30 | assert action_mean[0].shape == (rgb_img.shape[0], 3) 31 | assert action_mean[1].shape == (rgb_img.shape[0], 3) 32 | 33 | assert predict_depth.shape == depth_img.shape 34 | 35 | assert hc[0].shape == h_n.shape 36 | assert hc[1].shape == c_n.shape 37 | 38 | if __name__ == '__main__': 39 | pytest.main() 40 | 41 | 42 | # 主要用于测试模型的前向推理过程,测试模型的输入输出是否符合预期 -------------------------------------------------------------------------------- /test/test_pose_rotate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/zhujt/code_calib/CalibDepth') 3 | import pytest 4 | import torch 5 | import mathutils 6 | import numpy as np 7 | 8 | from environment import environment as env 9 | from utility.utils import ( invert_pose, quaternion_from_matrix, rotate_back, rotate_forward, 10 | quaternion_from_matrix, rotate_back ) 11 | 12 | def test_rotate(): 13 | """测试点云的扰动+恢复 14 | """ 15 | # 创建一个测试用的点云 PC,假设是 4xN 的形状 16 | pc_target = torch.tensor([[1.0, 2.0, 3.0, 1.0], 17 | [4.0, 5.0, 6.0, 1.0], 18 | [7.0, 8.0, 9.0, 1.0], 19 | [10.0, 11.0, 12.0, 1.0], 20 | [10.0, 11.0, 12.0, 1.0]]) 21 | 22 | # 添加扰动 23 | max_angle = 20 24 | max_t = 1.5 25 | rotz = np.random.uniform(-max_angle, max_angle) * (np.pi / 180.0) 26 | roty = np.random.uniform(-max_angle, max_angle) * (np.pi / 180.0) 27 | rotx = np.random.uniform(-max_angle, max_angle) * (np.pi / 180.0) 28 | transl_x = np.random.uniform(-max_t, max_t) 29 | transl_y = np.random.uniform(-max_t, max_t) 30 | transl_z = np.random.uniform(-max_t, max_t) 31 | initial_RT = 0.0 32 | 33 | R = mathutils.Euler((rotx, roty, rotz), 'XYZ') 34 | T = mathutils.Vector((transl_x, transl_y, transl_z)) 35 | 36 | R_m = mathutils.Quaternion(R).to_matrix() 37 | R_m.resize_4x4() 38 | T_m = mathutils.Matrix.Translation(T) 39 | RT_m = T_m * R_m 40 | 41 | pc_rotated = rotate_back(pc_target, RT_m) # Pc’ = RT * Pc 42 | pc_source = pc_rotated 43 | 44 | # 位姿数据(扰动的逆作为点云位姿) 45 | i_pose_target = np.array(RT_m, dtype=np.float32) 46 | pose_target = i_pose_target.copy() 47 | pose_target[:3, :3] = pose_target[:3, :3].T 48 | pose_target[:3, 3] = -np.matmul(pose_target[:3, :3], pose_target[:3, 3]) 49 | pose_target = torch.from_numpy(pose_target) 50 | pose_source = torch.eye(4) 51 | 52 | # 计算专家动作 53 | expert_action = env.expert_step_real(pose_source.unsqueeze(0), pose_target.unsqueeze(0), False) 54 | new_pc_source, pos_src = env.step_continous(pc_source.unsqueeze(0), expert_action, pose_source.unsqueeze(0)) 55 | 56 | 57 | print("pc_target: ", pc_target) 58 | print("pc_source: ", pc_source) 59 | print("new_pc_source: ", new_pc_source) 60 | assert torch.allclose(pc_target.to("cuda"), new_pc_source) # 恢复后的点云与原始点云一致 61 | assert torch.allclose(pose_target.to("cuda"), pos_src) # 恢复后的位姿与原始位姿一致 62 | 63 | 64 | if __name__ == '__main__': 65 | test_rotate() -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/zhujt/code_calib/CalibDepth') 3 | import torch 4 | import mathutils 5 | import pytest 6 | # import utility.utils as utils 7 | from utility import utils 8 | 9 | import unittest 10 | from utility.utils import rotate_forward 11 | from mathutils import Matrix, Euler 12 | 13 | def test_rotate_forward(): 14 | # 创建一个测试用的点云 PC,假设是 4xN 的形状 15 | PC = torch.tensor([[1.0, 2.0, 3.0, 1.0], 16 | [4.0, 5.0, 6.0, 1.0], 17 | [7.0, 8.0, 9.0, 1.0], 18 | [10.0, 11.0, 12.0, 1.0]]) 19 | 20 | # 创建测试用的旋转矩阵 R,这里用单位矩阵作为测试 21 | R = torch.eye(4) 22 | 23 | # 创建测试用的平移向量 T,假设是 [1, 2, 3] 24 | # T = torch.tensor([1.0, 2.0, 3.0]) 25 | T = None 26 | 27 | # 调用 rotate_forward 函数进行旋转 28 | rotated_PC = rotate_forward(PC, R, T) 29 | 30 | # 进行断言,比较旋转后的点云 rotated_PC 是否与预期一致 31 | # expected_rotated_PC = torch.tensor([[ 6.0, 7.0, 8.0, 1.0], 32 | # [11.0, 12.0, 13.0, 1.0], 33 | # [16.0, 17.0, 18.0, 1.0], 34 | # [21.0, 22.0, 23.0, 1.0]]) 35 | expected_rotated_PC = PC 36 | assert torch.allclose(rotated_PC, expected_rotated_PC) 37 | 38 | # 添加更多测试用例,如测试不同的旋转角度、平移向量等情况 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | from tqdm import tqdm 9 | from prefetch_generator import BackgroundGenerator 10 | 11 | from utility.logger import Logger 12 | import utility.metrics as metrics 13 | from utility.quaternion_distances import quaternion_distance 14 | from models.model import Agent 15 | import models.model as util_model 16 | from dataset.DatasetLidarCam import DatasetKittiRawCalibNet 17 | from dataset.DatasetLidarCam import lidar_project_depth 18 | from dataset.data_utils import (merge_inputs, quaternion_from_matrix) 19 | 20 | from environment import environment as env 21 | from environment import transformations as tra 22 | from environment.buffer import Buffer 23 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | np.random.seed(42) 25 | 26 | 27 | def lidar_project_depth_batch(pc, calib, img_shape): 28 | depth_img_out = [] 29 | for idx in range(pc.shape[0]): 30 | depth_img, _ = lidar_project_depth(pc[idx].transpose(0, 1), calib[idx], img_shape) 31 | depth_img = depth_img.to(DEVICE) 32 | depth_img_out.append(depth_img) 33 | 34 | depth_img_out = torch.stack(depth_img_out) 35 | depth_img_out = F.interpolate(depth_img_out, size=[256, 512], mode = 'bilinear', align_corners=False) 36 | return depth_img_out 37 | 38 | 39 | def parse_args(): 40 | # fmt: off 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--data_folder", type=str, default='/home/zhujt/dataset_zjt/kitti_raw/', 43 | help="the data path") 44 | parser.add_argument("--dataset", type=str, default='raw', choices = ['raw', 'odometry', 'raw_calibNet'], 45 | help="the data type") # kitti数据的类型 46 | parser.add_argument("--use_reflectance", type=bool, default=False, 47 | help="use reflectance or not") # 是否使用反射率 48 | parser.add_argument("--val_sequence", type=list, default=['2011_09_26_drive_0005_sync', '2011_09_26_drive_0070_sync'], 49 | help="the data for valuation") # 验证集的序列 50 | parser.add_argument("--val_sequence_generalization", type=list, 51 | default=['2011_09_26_drive_0005_sync', '2011_09_26_drive_0070_sync'], 52 | help="the data for evaluate the generalization") 53 | parser.add_argument("--max_t", type=float, default=0.2, 54 | help="the translation decalibration range") # 平移扰动范围的最大值(cm) 55 | parser.add_argument("--max_r", type=float, default=10., 56 | help="the rotation decalibration range") # 旋转扰动范围的最大值(°) 57 | parser.add_argument("--save_id", type=int, default=1, 58 | help="the id of the model to be saved") # 模型保存的id 59 | parser.add_argument("--learning_rate", type=float, default=1e-4, 60 | help="the learning rate of the optimizer") # 学习率 61 | parser.add_argument("--learning_rate_step", type=int, default=8, 62 | help="the learning rate's scale step of the optimizer") 63 | parser.add_argument("--epoch", type=int, default=50, 64 | help="the epochs for training") 65 | parser.add_argument("--batch_size", type=int, default=8, 66 | help="the batch size for data collection") 67 | parser.add_argument("--update_batch_size", type=int, default=64, 68 | help="the batch size for training") 69 | parser.add_argument("--num_worker", type=int, default=5, 70 | help="the worker nums for training") 71 | parser.add_argument("--seed", type=int, default=42, 72 | help="seeds") 73 | parser.add_argument("--ITER_TRAIN", type=int, default=3, 74 | help="train iterations") 75 | parser.add_argument("--ITER_EVAL", type=int, default=3, 76 | help="value iterations") 77 | # parser.add_argument("--NUM_TRAJ", type=int, default=4, 78 | # help="trajectory numbers") 79 | 80 | # Algorithm specific arguments 81 | 82 | args = parser.parse_args() 83 | return args 84 | 85 | class GenerateSeq(nn.Module): 86 | """ 87 | 基于agent网络生成固定长度的标定动作序列 88 | """ 89 | def __init__(self, agent): 90 | super(GenerateSeq, self).__init__() 91 | self.agent = agent 92 | 93 | def forward(self, ds_pc_source, calib, depth, rgb, pos_src, pos_tgt, seq_len): 94 | batch_size = ds_pc_source.shape[0] 95 | trg_seqlen = seq_len 96 | 97 | # 初始化输出结果 98 | outputs_save_transl=torch.zeros(batch_size,trg_seqlen,3) 99 | outputs_save_rot=torch.zeros(batch_size,trg_seqlen,3) # agent动作 100 | exp_outputs_save_transl=torch.zeros(batch_size,trg_seqlen,3) 101 | exp_outputs_save_rot=torch.zeros(batch_size,trg_seqlen,3) # 专家监督动作 102 | h_last = torch.zeros(2, depth.shape[0], 256).to(DEVICE) 103 | c_last = torch.zeros(2, depth.shape[0], 256).to(DEVICE) # lstm的中间输出 104 | exp_pos_src = pos_src 105 | 106 | # 生成动作序列 107 | for i in range(0, trg_seqlen): 108 | # 专家动作 109 | expert_action = env.expert_step_real(exp_pos_src, pos_tgt) 110 | # agent动作 111 | actions, predict_depth, hc = agent(rgb, depth, h_last, c_last) 112 | h_last, c_last = hc[0], hc[1] 113 | action_t, action_r = actions[0].unsqueeze(1), actions[1].unsqueeze(1) 114 | action_tr = torch.cat([action_t, action_r], dim = 1) 115 | # 下一步状态 116 | new_source, pos_src = env.step_continous(ds_pc_source, action_tr, pos_src) # new_source只用来记录当前点云,不迭代更新输入点云(apply_trafo决定) 117 | exp_new_source, exp_pos_src = env.step_continous(ds_pc_source, expert_action, exp_pos_src) 118 | # 状态更新 119 | current_source = new_source 120 | depth = lidar_project_depth_batch(current_source, calib, (384, 1280)) # 更新后点云对应的一个batch的深度图 121 | depth /= 80 122 | # 保存 123 | exp_outputs_save_transl[:,i,:]=expert_action[:,0] 124 | exp_outputs_save_rot[:,i,:]=expert_action[:,1] 125 | outputs_save_transl[:,i,:]=actions[0].squeeze(1) 126 | outputs_save_rot[:,i,:]=actions[1].squeeze(1) 127 | return exp_outputs_save_transl, exp_outputs_save_rot, outputs_save_transl, outputs_save_rot, pos_src, current_source, predict_depth 128 | 129 | def train(calib_seq, agent, logger, datapath, max_t, max_r, epochs, batch_size, num_worker, 130 | lr, lr_step, model_path, val_sequence): 131 | args = parse_args() 132 | optimizer = torch.optim.Adam(agent.parameters(), lr=lr, amsgrad=True) 133 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_step, 0.5) 134 | dataset_train = DatasetKittiRawCalibNet(datapath, max_r=max_r, max_t=max_t, split='train', 135 | use_reflectance=False, val_sequence=val_sequence) 136 | dataset_val = DatasetKittiRawCalibNet(datapath, max_r=max_r, max_t=max_t, split='val', 137 | use_reflectance=False, val_sequence=args.val_sequence_generalization) 138 | TrainLoader = torch.utils.data.DataLoader(dataset_train, 139 | batch_size=batch_size, 140 | shuffle=True, 141 | num_workers=num_worker, 142 | pin_memory=True, 143 | drop_last=False, 144 | collate_fn=merge_inputs) 145 | ValLoader = torch.utils.data.DataLoader(dataset_val, 146 | batch_size=batch_size, 147 | shuffle=False, 148 | num_workers=num_worker, 149 | pin_memory=True, 150 | drop_last=False, 151 | collate_fn=merge_inputs) 152 | print(len(TrainLoader)) 153 | print(len(ValLoader)) 154 | 155 | # 初始化 156 | RANDOM_STATE = np.random.get_state() 157 | losses_bc, losses_q, losses_tme, losses_pd, loss_depth, losses_all = [], [], [], [], [], [] 158 | episode = 0 # for loss logging (not using epoch) 159 | best_chamfer = np.infty 160 | 161 | buffer = Buffer() 162 | buffer.start_trajectory() 163 | 164 | cal_loss = torch.nn.SmoothL1Loss(reduction='none') 165 | FEATURE_loss = util_model.berHuLoss() 166 | 167 | for epoch in range(epochs): 168 | print(f"Epoch {epoch}") 169 | 170 | # 训练阶段 171 | agent.train() 172 | np.random.set_state(RANDOM_STATE) 173 | progress = tqdm(BackgroundGenerator(TrainLoader), total=len(TrainLoader)) 174 | for data in progress: 175 | # 读取batch数据并且初始化 176 | rgb_input, depth_input, depth_target, pose_target, pose_source, ds_pc_target, ds_pc_source, calib = env.init(data) 177 | 178 | current_source = ds_pc_source 179 | current_depth = depth_input 180 | 181 | exp_transl_seq, exp_rot_seq, transl_seq, rot_seq, pos_final, current_source, predict_depth = calib_seq(ds_pc_source, 182 | calib, current_depth, rgb_input, pose_source, pose_target, args.ITER_TRAIN) 183 | # 每一步与专家动作之间的平均损失 184 | loss_translation = cal_loss(transl_seq, exp_transl_seq).sum(2).mean() 185 | loss_rotation = cal_loss(rot_seq, exp_rot_seq).sum(2).mean() 186 | clone_loss = loss_rotation + loss_translation 187 | 188 | # 计算最终位姿与真值之间的四元数损失 189 | R_composed_target = torch.stack([quaternion_from_matrix(pose_target[i, :]) for i in range(pose_target.shape[0])], dim = 0) 190 | R_composed = torch.stack([quaternion_from_matrix(pos_final[i, :]) for i in range(pos_final.shape[0])], dim = 0) 191 | qd_error = quaternion_distance(R_composed, 192 | R_composed_target, 193 | R_composed.device) 194 | qd_error = qd_error.abs() * (180.0/np.pi) 195 | 196 | # 计算最终位姿和目标之间的平移损失 197 | t_gt = pose_target[:, :3, 3] 198 | t_pred = pos_final[:, :3, 3] 199 | t_mae = torch.abs(t_gt - t_pred).mean(dim=1) 200 | 201 | # 点云距离损失 202 | rand_idxs = np.random.choice(current_source.shape[1], 1024, replace=False) 203 | src_transformed_samp = current_source[:, rand_idxs, :] 204 | ref_clean_samp = ds_pc_target[:, rand_idxs, :] 205 | dist = torch.min(tra.square_distance(src_transformed_samp, ref_clean_samp), dim=-1)[0] 206 | chamfer_dist = torch.mean(dist, dim=1).view(-1, 1, 1) 207 | geo_loss = chamfer_dist.mean() 208 | 209 | # 单目深度估计的损失 210 | mask = depth_target > 0 # 只用大于0的有效值部分进行监督 211 | depth_loss = FEATURE_loss(predict_depth[mask], depth_target[mask]) 212 | 213 | # 整体的损失函数 214 | loss = clone_loss*10 + qd_error.mean()*0.1 + t_mae.mean()*3 + geo_loss * 0.2 + depth_loss * 0.05 215 | 216 | # 优化 217 | optimizer.zero_grad() 218 | losses_bc.append(10*loss_translation.item()) # 暂时改一下 219 | losses_q.append(0.1*qd_error.mean().item()) 220 | losses_tme.append(3*t_mae.mean().item()) 221 | losses_pd.append(0.2*geo_loss.item()) 222 | loss_depth.append(depth_loss.item()) 223 | losses_all.append(loss.item()) 224 | 225 | loss.backward() 226 | optimizer.step() 227 | 228 | # 存到log 229 | logger.record("train/bc", np.mean(losses_bc)) 230 | logger.record("train/q", np.mean(losses_q)) 231 | logger.record("train/tme", np.mean(losses_tme)) 232 | logger.record("train/geo", np.mean(losses_pd)) 233 | logger.record("train/depth", np.mean(loss_depth)) 234 | logger.record("train/all", np.mean(losses_all)) 235 | logger.dump(step=episode) 236 | 237 | 238 | losses_bc, losses_q, losses_tme, losses_pd, loss_depth, losses_all = [], [], [], [], [], [] 239 | episode += 1 240 | 241 | scheduler.step() 242 | RANDOM_STATE = np.random.get_state() 243 | 244 | if ValLoader is not None: 245 | chamfer_val = evaluate(agent, logger, ValLoader, prefix='val') 246 | 247 | if chamfer_val <= best_chamfer: 248 | print(f"new best: {chamfer_val}") 249 | best_chamfer = chamfer_val 250 | infos = { 251 | 'epoch': epoch, 252 | 'optimizer_state_dict': optimizer.state_dict() 253 | } 254 | # 存储验证集上最优模型 255 | util_model.save(agent, f"{model_path}.zip", infos) 256 | logger.dump(step=epoch) 257 | 258 | 259 | def evaluate(agent, logger, loader, prefix='test'): 260 | agent.eval() 261 | args = parse_args() 262 | progress = tqdm(BackgroundGenerator(loader), total=len(loader)) 263 | predictions = [] 264 | 265 | with torch.no_grad(): 266 | for data in progress: 267 | rgb_input, depth_input, depth_target, pose_target, pose_source, ds_pc_target, ds_pc_source, calib = env.init(data) 268 | current_source = ds_pc_source 269 | current_depth = depth_input 270 | 271 | for step in range(args.ITER_EVAL): 272 | # 第一步迭代讲h_last和c_last初始化为0 273 | if(step == 0): 274 | actions, _, hc = agent(rgb_input, current_depth, 275 | torch.zeros(2, depth_input.shape[0], 256).to(DEVICE), 276 | torch.zeros(2, depth_input.shape[0], 256).to(DEVICE)) 277 | else: 278 | actions, _, hc = agent(rgb_input, current_depth, h_last, c_last) 279 | h_last, c_last = hc[0], hc[1] 280 | 281 | action_t, action_r = actions[0].unsqueeze(1), actions[1].unsqueeze(1) 282 | action_tr = torch.cat([action_t, action_r], dim = 1) 283 | 284 | new_source, pose_source = env.step_continous(ds_pc_source, action_tr, pose_source) 285 | current_source = new_source 286 | current_depth = lidar_project_depth_batch(current_source, calib, (384, 1280)) # 更新后点云对应的一个batch的深度图 287 | current_depth /= 80 288 | predictions.append(pose_source) 289 | 290 | predictions = torch.cat(predictions) 291 | _, summary_metrics = metrics.compute_stats(predictions, data_loader=loader) 292 | logger.record(f"{prefix}/mae-r", summary_metrics['r_mae']) 293 | logger.record(f"{prefix}/mae-t", summary_metrics['t_mae']) 294 | logger.record(f"{prefix}/iso-r", summary_metrics['r_iso']) 295 | logger.record(f"{prefix}/iso-t", summary_metrics['t_iso']) 296 | logger.record(f"{prefix}/chamfer", summary_metrics['chamfer_dist']) 297 | return summary_metrics['chamfer_dist'] 298 | 299 | if __name__ == '__main__': 300 | args = parse_args() 301 | dataset = args.dataset 302 | save_id = args.save_id 303 | code_path = os.path.dirname(os.path.abspath(__file__)) 304 | if not os.path.exists(os.path.join(code_path, "logs_lstm")): 305 | os.mkdir(os.path.join(code_path, "logs_lstm")) 306 | if not os.path.exists(os.path.join(code_path, "weights_lstm")): 307 | os.mkdir(os.path.join(code_path, "weights_lstm")) 308 | 309 | model_path = os.path.join(code_path, f"weights_lstm/{dataset}_{save_id}") 310 | logger = Logger(log_dir=os.path.join(code_path, f"logs_lstm/{dataset}/"), log_name=f"calibdepth_{save_id}", 311 | reset_num_timesteps=True) 312 | 313 | agent = Agent().to(DEVICE) 314 | calib_seq = GenerateSeq(agent).to(DEVICE) 315 | print(f"Training: dataset '{dataset}'") 316 | train(calib_seq, agent, logger, 317 | datapath=args.data_folder, max_t = args.max_t, max_r = args.max_r, 318 | epochs=args.epoch, batch_size=args.batch_size, num_worker=args.num_worker, 319 | lr=args.learning_rate, lr_step=args.learning_rate_step, 320 | model_path=model_path, val_sequence=args.val_sequence) 321 | 322 | -------------------------------------------------------------------------------- /utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brickzhuantou/CalibDepth/45dc0252f01353d4a897e414e7bf2a63d273e1f3/utility/__init__.py -------------------------------------------------------------------------------- /utility/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from collections import defaultdict 5 | import torch 6 | # from torch.utils.tensorboard import SummaryWriter 7 | # import tensorboard 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | 12 | class Logger: 13 | """Based off https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/logger.py""" 14 | def __init__(self, log_dir, log_name, reset_num_timesteps=True): 15 | self.name_to_value = defaultdict(float) # values this iteration 16 | self.name_to_count = defaultdict(int) 17 | self.name_to_excluded = defaultdict(str) 18 | 19 | latest_run_id = self.get_latest_run_id(log_dir, log_name) 20 | if not reset_num_timesteps: 21 | # Continue training in the same directory 22 | latest_run_id -= 1 23 | save_path = os.path.join(log_dir, f"{log_name}_{latest_run_id + 1}") 24 | os.makedirs(save_path, exist_ok=True) 25 | self.writer = SummaryWriter(log_dir=save_path) 26 | 27 | @staticmethod 28 | def get_latest_run_id(log_dir, log_name) -> int: 29 | """ 30 | Returns the latest run number for the given log name and log path, 31 | by finding the greatest number in the directories. 32 | :return: latest run number 33 | """ 34 | max_run_id = 0 35 | for path in glob.glob(f"{log_dir}/{log_name}_[0-9]*"): 36 | file_name = path.split(os.sep)[-1] 37 | ext = file_name.split("_")[-1] 38 | if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: 39 | max_run_id = int(ext) 40 | return max_run_id 41 | 42 | def record(self, key, value, exclude=None): 43 | """ 44 | Log a value of some diagnostic 45 | Call this once for each diagnostic quantity, each iteration 46 | If called many times, last value will be used. 47 | :param key: save to log this key 48 | :param value: save to log this value 49 | :param exclude: outputs to be excluded 50 | """ 51 | self.name_to_value[key] = value 52 | self.name_to_excluded[key] = exclude 53 | 54 | def dump(self, step=0): 55 | """Write all of the diagnostics from the current iteration""" 56 | self.write(self.name_to_value, self.name_to_excluded, step) 57 | 58 | self.name_to_value.clear() 59 | self.name_to_count.clear() 60 | self.name_to_excluded.clear() 61 | 62 | def write(self, key_values, key_excluded, step=0): 63 | 64 | for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): 65 | 66 | if isinstance(value, np.ScalarType): 67 | self.writer.add_scalar(key, value, step) 68 | 69 | if isinstance(value, torch.Tensor): 70 | self.writer.add_histogram(key, value, step) 71 | 72 | # Flush the output to the file 73 | self.writer.flush() 74 | 75 | def close(self) -> None: 76 | """Closes the file""" 77 | if self.writer: 78 | self.writer.close() 79 | self.writer = None 80 | -------------------------------------------------------------------------------- /utility/metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/zhujt/code_calib/CalibDepth') 3 | from collections import defaultdict 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | from tqdm import tqdm 7 | import torch 8 | import environment.transformations as tra 9 | from environment import environment as env 10 | 11 | from utility.utils import quaternion_from_matrix 12 | from utility.quaternion_distances import quaternion_distance 13 | 14 | 15 | def compute_metrics(pose_target, ds_pc_source, ds_pc_target, pred_transforms): 16 | gt_transforms = pose_target 17 | 18 | igt_transforms = torch.eye(4, device=pred_transforms.device).repeat(gt_transforms.shape[0], 1, 1) 19 | igt_transforms[:, :3, :3] = gt_transforms[:, :3, :3].transpose(2, 1) 20 | igt_transforms[:, :3, 3] = -(igt_transforms[:, :3, :3] @ gt_transforms[:, :3, 3].view(-1, 3, 1)).view(-1, 3) 21 | 22 | points_src = ds_pc_source[..., :3] 23 | points_ref = ds_pc_target[..., :3] 24 | points_raw = points_ref 25 | 26 | # 四元数评测指标 27 | R_composed_target = torch.stack([quaternion_from_matrix(gt_transforms[i, :]) 28 | for i in range(gt_transforms.shape[0])], dim = 0) 29 | R_composed = torch.stack([quaternion_from_matrix(pred_transforms[i, :]) 30 | for i in range(pred_transforms.shape[0])], dim = 0) 31 | 32 | qd_error = quaternion_distance(R_composed, 33 | R_composed_target, 34 | R_composed.device) 35 | qd_error = qd_error.abs() * (180.0/np.pi) 36 | 37 | # 欧拉角评测指标 38 | r_gt_euler_deg = np.stack([Rotation.from_matrix(r.cpu().numpy()).as_euler('xyz', degrees=True) 39 | for r in gt_transforms[:, :3, :3]]) 40 | r_pred_euler_deg = np.stack([Rotation.from_matrix(r.cpu().numpy()).as_euler('xyz', degrees=True) 41 | for r in pred_transforms[:, :3, :3]]) 42 | t_gt = gt_transforms[:, :3, 3] 43 | t_pred = pred_transforms[:, :3, 3] 44 | r_mae = np.abs(r_gt_euler_deg - r_pred_euler_deg).mean(axis=1) 45 | t_mae = torch.abs(t_gt - t_pred).mean(dim=1) # 分别计算旋转平移三个维度的平均误差 46 | 47 | # 计算iso误差指标 48 | concatenated = igt_transforms @ pred_transforms 49 | rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2] 50 | r_iso = torch.rad2deg(torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0))) 51 | t_iso = concatenated[:, :3, 3].norm(dim=-1) 52 | 53 | # 旋转平移六个维度各自的平均误差 54 | xyz_mae = torch.abs(t_gt - t_pred) 55 | x_mae, y_mae, z_mae = xyz_mae[:,0], xyz_mae[:,1], xyz_mae[:,2] 56 | rpy_mae = np.abs(np.stack([Rotation.from_matrix(r.cpu().numpy()).as_euler('xyz', degrees=True) 57 | for r in concatenated[:, :3, :3]])) 58 | rr_mae, pp_mae, yy_mae = rpy_mae[:,0], rpy_mae[:,1], rpy_mae[:,2] 59 | 60 | # 点云距离指标 61 | src_transformed = (pred_transforms[:, :3, :3] @ points_src.transpose(2, 1)).transpose(2, 1)\ 62 | + pred_transforms[:, :3, 3][:, None, :] # 用预测值校正后的点云 63 | 64 | rand_idxs = np.random.choice(src_transformed.shape[1], 1024, replace=False) # 随机采样1024个点 65 | src_transformed_samp = src_transformed[:, rand_idxs, :] 66 | points_ref_samp = points_ref[:, rand_idxs, :] # 分别对原始点云和参考真值点云进行采样 67 | 68 | dist_src = torch.min(tra.square_distance(src_transformed_samp, points_ref_samp), dim=-1)[0] 69 | dist_ref = torch.min(tra.square_distance(points_ref_samp, src_transformed_samp), dim=-1)[0] 70 | chamfer_dist = torch.mean(dist_src, dim=1) + torch.mean(dist_ref, dim=1) # 计算倒角距离 71 | 72 | metrics = { 73 | 'r_mae': r_mae, 74 | 'rr_mae': rr_mae, 75 | 'pp_mae': pp_mae, 76 | 'yy_mae': yy_mae, 77 | 'qd_error': qd_error.cpu().numpy(), 78 | 't_mae': t_mae.cpu().numpy(), 79 | 'x_mae': x_mae.cpu().numpy(), 80 | 'y_mae': y_mae.cpu().numpy(), 81 | 'z_mae': z_mae.cpu().numpy(), 82 | 'r_iso': r_iso.cpu().numpy(), 83 | 't_iso': t_iso.cpu().numpy(), 84 | 'chamfer_dist': chamfer_dist.cpu().numpy() 85 | } 86 | return metrics 87 | 88 | def summarize_metrics(metrics): 89 | summarized = {} 90 | for k in metrics: 91 | metrics[k] = np.hstack(metrics[k]) 92 | summarized[k] = np.mean(metrics[k]) 93 | return summarized 94 | 95 | def compute_stats(pred_transforms, data_loader): 96 | metrics_for_iter = defaultdict(list) 97 | num_processed = 0 98 | with torch.no_grad(): 99 | for data in tqdm(data_loader, leave=False): 100 | dict_all_to_device(data, pred_transforms.device) 101 | _,_,_, pose_target, pose_source, ds_pc_target, ds_pc_source, calib = env.init(data) 102 | batch_size = pose_source.shape[0] 103 | cur_pred_transforms = pred_transforms[num_processed:num_processed+batch_size] 104 | metrics = compute_metrics(pose_target, ds_pc_source, ds_pc_target, cur_pred_transforms) 105 | for k in metrics: 106 | metrics_for_iter[k].append(metrics[k]) 107 | num_processed += batch_size 108 | 109 | summary_metrics = summarize_metrics(metrics_for_iter) 110 | return metrics_for_iter, summary_metrics 111 | 112 | def dict_all_to_device(tensor_dict, device): 113 | """Sends everything into a certain device 114 | via RPMNet """ 115 | for k in tensor_dict: 116 | if isinstance(tensor_dict[k], torch.Tensor): 117 | tensor_dict[k] = tensor_dict[k].to(device) 118 | if tensor_dict[k].dtype == torch.double: 119 | tensor_dict[k] = tensor_dict[k].float() 120 | if isinstance(tensor_dict[k], dict): 121 | dict_all_to_device(tensor_dict[k], device) -------------------------------------------------------------------------------- /utility/quaternion_distances.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------- 2 | # Copyright (C) 2020 Università degli studi di Milano-Bicocca, iralab 3 | # Author: Daniele Cattaneo (d.cattaneo10@campus.unimib.it) 4 | # Released under Creative Commons 5 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | # ------------------------------------------------------------------- 8 | 9 | # Modified Author: Xudong Lv 10 | # based on github.com/cattaneod/CMRNet/blob/master/quaternion_distances.py 11 | 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def quatmultiply(q, r, device='cpu'): 17 | """ 18 | Batch quaternion multiplication 19 | Args: 20 | q (torch.Tensor/np.ndarray): shape=[Nx4] 21 | r (torch.Tensor/np.ndarray): shape=[Nx4] 22 | device (str): 'cuda' or 'cpu' 23 | 24 | Returns: 25 | torch.Tensor: shape=[Nx4] 26 | """ 27 | if isinstance(q, torch.Tensor): 28 | t = torch.zeros(q.shape[0], 4, device=device) 29 | elif isinstance(q, np.ndarray): 30 | t = np.zeros(q.shape[0], 4) 31 | else: 32 | raise TypeError("Type not supported") 33 | t[:, 0] = r[:, 0] * q[:, 0] - r[:, 1] * q[:, 1] - r[:, 2] * q[:, 2] - r[:, 3] * q[:, 3] 34 | t[:, 1] = r[:, 0] * q[:, 1] + r[:, 1] * q[:, 0] - r[:, 2] * q[:, 3] + r[:, 3] * q[:, 2] 35 | t[:, 2] = r[:, 0] * q[:, 2] + r[:, 1] * q[:, 3] + r[:, 2] * q[:, 0] - r[:, 3] * q[:, 1] 36 | t[:, 3] = r[:, 0] * q[:, 3] - r[:, 1] * q[:, 2] + r[:, 2] * q[:, 1] + r[:, 3] * q[:, 0] 37 | return t 38 | 39 | 40 | def quatinv(q): 41 | """ 42 | Batch quaternion inversion 43 | Args: 44 | q (torch.Tensor/np.ndarray): shape=[Nx4] 45 | 46 | Returns: 47 | torch.Tensor/np.ndarray: shape=[Nx4] 48 | """ 49 | if isinstance(q, torch.Tensor): 50 | t = q.clone() 51 | elif isinstance(q, np.ndarray): 52 | t = q.copy() 53 | else: 54 | raise TypeError("Type not supported") 55 | t *= -1 56 | t[:, 0] *= -1 57 | return t 58 | 59 | 60 | def quaternion_distance(q, r, device): 61 | """ 62 | Batch quaternion distances, used as loss 63 | Args: 64 | q (torch.Tensor): shape=[Nx4] 65 | r (torch.Tensor): shape=[Nx4] 66 | device (str): 'cuda' or 'cpu' 67 | 68 | Returns: 69 | torch.Tensor: shape=[N] 70 | """ 71 | t = quatmultiply(q, quatinv(r), device) 72 | return 2 * torch.atan2(torch.norm(t[:, 1:], dim=1), torch.abs(t[:, 0])) 73 | -------------------------------------------------------------------------------- /utility/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------- 2 | # Copyright (C) 2020 Università degli studi di Milano-Bicocca, iralab 3 | # Author: Daniele Cattaneo (d.cattaneo10@campus.unimib.it) 4 | # Released under Creative Commons 5 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | # ------------------------------------------------------------------- 8 | 9 | # Modified Author: Xudong Lv 10 | # based on github.com/cattaneod/CMRNet/blob/master/utils.py 11 | 12 | import math 13 | 14 | import mathutils 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from matplotlib import cm 19 | from torch.utils.data.dataloader import default_collate 20 | 21 | 22 | def rotate_points(PC, R, T=None, inverse=True): 23 | if T is not None: 24 | R = R.to_matrix() 25 | R.resize_4x4() 26 | T = mathutils.Matrix.Translation(T) 27 | RT = T*R 28 | else: 29 | RT=R.copy() 30 | if inverse: 31 | RT.invert_safe() 32 | RT = torch.tensor(RT, device=PC.device, dtype=torch.float) 33 | 34 | if PC.shape[0] == 4: 35 | PC = torch.mm(RT, PC) 36 | elif PC.shape[1] == 4: 37 | PC = torch.mm(RT, PC.t()) 38 | PC = PC.t() 39 | else: 40 | raise TypeError("Point cloud must have shape [Nx4] or [4xN] (homogeneous coordinates)") 41 | return PC 42 | 43 | 44 | def rotate_points_torch(PC, R, T=None, inverse=True): 45 | if T is not None: 46 | R = quat2mat(R) 47 | T = tvector2mat(T) 48 | RT = torch.mm(T, R) 49 | else: 50 | RT = R.clone() 51 | if inverse: 52 | RT = RT.inverse() 53 | 54 | if PC.shape[0] == 4: 55 | PC = torch.mm(RT, PC) 56 | elif PC.shape[1] == 4: 57 | PC = torch.mm(RT, PC.t()) 58 | PC = PC.t() 59 | else: 60 | raise TypeError("Point cloud must have shape [Nx4] or [4xN] (homogeneous coordinates)") 61 | return PC 62 | 63 | 64 | def rotate_forward(PC, R, T=None): 65 | """ 66 | Transform the point cloud PC, so to have the points 'as seen from' the new 67 | pose T*R 68 | Args: 69 | PC (torch.Tensor): Point Cloud to be transformed, shape [4xN] or [Nx4] 70 | R (torch.Tensor/mathutils.Euler): can be either: 71 | * (mathutils.Euler) euler angles of the rotation part, in this case T cannot be None 72 | * (torch.Tensor shape [4]) quaternion representation of the rotation part, in this case T cannot be None 73 | * (mathutils.Matrix shape [4x4]) Rotation matrix, 74 | in this case it should contains the translation part, and T should be None 75 | * (torch.Tensor shape [4x4]) Rotation matrix, 76 | in this case it should contains the translation part, and T should be None 77 | T (torch.Tensor/mathutils.Vector): Translation of the new pose, shape [3], or None (depending on R) 78 | 79 | Returns: 80 | torch.Tensor: Transformed Point Cloud 'as seen from' pose T*R 81 | """ 82 | if isinstance(R, torch.Tensor): 83 | return rotate_points_torch(PC, R, T, inverse=True) 84 | else: 85 | return rotate_points(PC, R, T, inverse=True) 86 | 87 | 88 | def rotate_back(PC_ROTATED, R, T=None): 89 | """ 90 | Inverse of :func:`~utils.rotate_forward`. 91 | """ 92 | if isinstance(R, torch.Tensor): 93 | return rotate_points_torch(PC_ROTATED, R, T, inverse=False) 94 | else: 95 | return rotate_points(PC_ROTATED, R, T, inverse=False) 96 | 97 | 98 | def invert_pose(R, T): 99 | """ 100 | Given the 'sampled pose' (aka H_init), we want CMRNet to predict inv(H_init). 101 | inv(T*R) will be used as ground truth for the network. 102 | Args: 103 | R (mathutils.Euler): Rotation of 'sampled pose' 104 | T (mathutils.Vector): Translation of 'sampled pose' 105 | 106 | Returns: 107 | (R_GT, T_GT) = (mathutils.Quaternion, mathutils.Vector) 108 | """ 109 | R = R.to_matrix() 110 | R.resize_4x4() 111 | T = mathutils.Matrix.Translation(T) 112 | RT = T * R 113 | RT.invert_safe() 114 | T_GT, R_GT, _ = RT.decompose() 115 | return R_GT.normalized(), T_GT 116 | 117 | 118 | def merge_inputs(queries): 119 | point_clouds = [] 120 | imgs = [] 121 | reflectances = [] 122 | returns = {key: default_collate([d[key] for d in queries]) for key in queries[0] 123 | if key != 'point_cloud' and key != 'rgb' and key != 'reflectance'} 124 | for input in queries: 125 | point_clouds.append(input['point_cloud']) 126 | imgs.append(input['rgb']) 127 | if 'reflectance' in input: 128 | reflectances.append(input['reflectance']) 129 | returns['point_cloud'] = point_clouds 130 | returns['rgb'] = imgs 131 | if len(reflectances) > 0: 132 | returns['reflectance'] = reflectances 133 | return returns 134 | 135 | 136 | def quaternion_from_matrix(matrix): 137 | """ 138 | Convert a rotation matrix to quaternion. 139 | Args: 140 | matrix (torch.Tensor): [4x4] transformation matrix or [3,3] rotation matrix. 141 | 142 | Returns: 143 | torch.Tensor: shape [4], normalized quaternion 144 | """ 145 | if matrix.shape == (4, 4): 146 | R = matrix[:-1, :-1] 147 | elif matrix.shape == (3, 3): 148 | R = matrix 149 | else: 150 | raise TypeError("Not a valid rotation matrix") 151 | tr = R[0, 0] + R[1, 1] + R[2, 2] 152 | q = torch.zeros(4, device=matrix.device) 153 | if tr > 0.: 154 | S = (tr+1.0).sqrt() * 2 155 | q[0] = 0.25 * S 156 | q[1] = (R[2, 1] - R[1, 2]) / S 157 | q[2] = (R[0, 2] - R[2, 0]) / S 158 | q[3] = (R[1, 0] - R[0, 1]) / S 159 | elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: 160 | S = (1.0 + R[0, 0] - R[1, 1] - R[2, 2]).sqrt() * 2 161 | q[0] = (R[2, 1] - R[1, 2]) / S 162 | q[1] = 0.25 * S 163 | q[2] = (R[0, 1] + R[1, 0]) / S 164 | q[3] = (R[0, 2] + R[2, 0]) / S 165 | elif R[1, 1] > R[2, 2]: 166 | S = (1.0 + R[1, 1] - R[0, 0] - R[2, 2]).sqrt() * 2 167 | q[0] = (R[0, 2] - R[2, 0]) / S 168 | q[1] = (R[0, 1] + R[1, 0]) / S 169 | q[2] = 0.25 * S 170 | q[3] = (R[1, 2] + R[2, 1]) / S 171 | else: 172 | S = (1.0 + R[2, 2] - R[0, 0] - R[1, 1]).sqrt() * 2 173 | q[0] = (R[1, 0] - R[0, 1]) / S 174 | q[1] = (R[0, 2] + R[2, 0]) / S 175 | q[2] = (R[1, 2] + R[2, 1]) / S 176 | q[3] = 0.25 * S 177 | return q / q.norm() 178 | 179 | 180 | def quatmultiply(q, r): 181 | """ 182 | Multiply two quaternions 183 | Args: 184 | q (torch.Tensor/nd.ndarray): shape=[4], first quaternion 185 | r (torch.Tensor/nd.ndarray): shape=[4], second quaternion 186 | 187 | Returns: 188 | torch.Tensor: shape=[4], normalized quaternion q*r 189 | """ 190 | t = torch.zeros(4, device=q.device) 191 | t[0] = r[0] * q[0] - r[1] * q[1] - r[2] * q[2] - r[3] * q[3] 192 | t[1] = r[0] * q[1] + r[1] * q[0] - r[2] * q[3] + r[3] * q[2] 193 | t[2] = r[0] * q[2] + r[1] * q[3] + r[2] * q[0] - r[3] * q[1] 194 | t[3] = r[0] * q[3] - r[1] * q[2] + r[2] * q[1] + r[3] * q[0] 195 | return t / t.norm() 196 | 197 | 198 | def quat2mat(q): 199 | """ 200 | Convert a quaternion to a rotation matrix 201 | Args: 202 | q (torch.Tensor): shape [4], input quaternion 203 | 204 | Returns: 205 | torch.Tensor: [4x4] homogeneous rotation matrix 206 | """ 207 | assert q.shape == torch.Size([4]), "Not a valid quaternion" 208 | if q.norm() != 1.: 209 | q = q / q.norm() 210 | mat = torch.zeros((4, 4), device=q.device) 211 | mat[0, 0] = 1 - 2*q[2]**2 - 2*q[3]**2 212 | mat[0, 1] = 2*q[1]*q[2] - 2*q[3]*q[0] 213 | mat[0, 2] = 2*q[1]*q[3] + 2*q[2]*q[0] 214 | mat[1, 0] = 2*q[1]*q[2] + 2*q[3]*q[0] 215 | mat[1, 1] = 1 - 2*q[1]**2 - 2*q[3]**2 216 | mat[1, 2] = 2*q[2]*q[3] - 2*q[1]*q[0] 217 | mat[2, 0] = 2*q[1]*q[3] - 2*q[2]*q[0] 218 | mat[2, 1] = 2*q[2]*q[3] + 2*q[1]*q[0] 219 | mat[2, 2] = 1 - 2*q[1]**2 - 2*q[2]**2 220 | mat[3, 3] = 1. 221 | return mat 222 | 223 | 224 | def tvector2mat(t): 225 | """ 226 | Translation vector to homogeneous transformation matrix with identity rotation 227 | Args: 228 | t (torch.Tensor): shape=[3], translation vector 229 | 230 | Returns: 231 | torch.Tensor: [4x4] homogeneous transformation matrix 232 | 233 | """ 234 | assert t.shape == torch.Size([3]), "Not a valid translation" 235 | mat = torch.eye(4, device=t.device) 236 | mat[0, 3] = t[0] 237 | mat[1, 3] = t[1] 238 | mat[2, 3] = t[2] 239 | return mat 240 | 241 | 242 | def mat2xyzrpy(rotmatrix): 243 | """ 244 | Decompose transformation matrix into components 245 | Args: 246 | rotmatrix (torch.Tensor/np.ndarray): [4x4] transformation matrix 247 | 248 | Returns: 249 | torch.Tensor: shape=[6], contains xyzrpy 250 | """ 251 | roll = math.atan2(-rotmatrix[1, 2], rotmatrix[2, 2]) 252 | pitch = math.asin ( rotmatrix[0, 2]) 253 | yaw = math.atan2(-rotmatrix[0, 1], rotmatrix[0, 0]) 254 | x = rotmatrix[:3, 3][0] 255 | y = rotmatrix[:3, 3][1] 256 | z = rotmatrix[:3, 3][2] 257 | 258 | return torch.tensor([x, y, z, roll, pitch, yaw], device=rotmatrix.device, dtype=rotmatrix.dtype) 259 | 260 | 261 | def to_rotation_matrix(R, T): 262 | R = quat2mat(R) 263 | T = tvector2mat(T) 264 | RT = torch.mm(T, R) 265 | return RT 266 | 267 | 268 | def overlay_imgs(rgb, lidar, idx=0): 269 | std = [0.229, 0.224, 0.225] 270 | mean = [0.485, 0.456, 0.406] 271 | 272 | rgb = rgb.clone().cpu().permute(1,2,0).numpy() 273 | rgb = rgb*std+mean 274 | lidar = lidar.clone() 275 | 276 | lidar[lidar == 0] = 1000. 277 | lidar = -lidar 278 | #lidar = F.max_pool2d(lidar, 3, 1, 1) 279 | lidar = F.max_pool2d(lidar, 3, 1, 1) 280 | lidar = -lidar 281 | lidar[lidar == 1000.] = 0. 282 | 283 | #lidar = lidar.squeeze() 284 | lidar = lidar[0][0] 285 | lidar = (lidar*255).int().cpu().numpy() 286 | lidar_color = cm.jet(lidar) 287 | lidar_color[:, :, 3] = 0.5 288 | lidar_color[lidar == 0] = [0, 0, 0, 0] 289 | blended_img = lidar_color[:, :, :3] * (np.expand_dims(lidar_color[:, :, 3], 2)) + \ 290 | rgb * (1. - np.expand_dims(lidar_color[:, :, 3], 2)) 291 | blended_img = blended_img.clip(min=0., max=1.) 292 | #io.imshow(blended_img) 293 | #io.show() 294 | #plt.figure() 295 | #plt.imshow(blended_img) 296 | #io.imsave(f'./IMGS/{idx:06d}.png', blended_img) 297 | return blended_img 298 | -------------------------------------------------------------------------------- /visual_test.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import numpy as np 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | from prefetch_generator import BackgroundGenerator 11 | 12 | from utility.logger import Logger 13 | import utility.metrics as metrics 14 | from utility.quaternion_distances import quaternion_distance 15 | from models.model import Agent 16 | import models.model as util_model 17 | from dataset.DatasetLidarCam import DatasetKittiRawCalibNet 18 | from dataset.DatasetLidarCam import lidar_project_depth, get_2D_lidar_projection 19 | from dataset.data_utils import (merge_inputs, quaternion_from_matrix) 20 | 21 | from environment import environment as env 22 | from environment import transformations as tra 23 | from environment.buffer import Buffer 24 | 25 | import ipcv_utils.utils as plt 26 | import cv2 27 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | #%% 30 | def lidar_project_depth_batch(pc, calib, img_shape): 31 | depth_img_out = [] 32 | for idx in range(pc.shape[0]): 33 | depth_img, _ = lidar_project_depth(pc[idx].transpose(0, 1), calib[idx], img_shape) 34 | depth_img = depth_img.to(DEVICE) 35 | depth_img_out.append(depth_img) 36 | 37 | depth_img_out = torch.stack(depth_img_out) 38 | depth_img_out = F.interpolate(depth_img_out, size=[256, 512], mode = 'bilinear', align_corners=False) 39 | return depth_img_out 40 | 41 | def get_projected_pts(pc_rotated, cam_calib, img_shape): 42 | pc_rotated = pc_rotated[:3, :].detach().cpu().numpy() 43 | cam_intrinsic = cam_calib.detach().cpu().numpy() 44 | pcl_uv, pcl_z = get_2D_lidar_projection(pc_rotated, cam_intrinsic) 45 | mask = (pcl_uv[:, 0] > 0) & (pcl_uv[:, 0] < img_shape[1]) & (pcl_uv[:, 1] > 0) & ( 46 | pcl_uv[:, 1] < img_shape[0]) & (pcl_z > 0) 47 | pcl_uv = pcl_uv[mask] 48 | pcl_z = pcl_z[mask] 49 | 50 | return pcl_uv, pcl_z 51 | 52 | def max_normalize_pts(pts): 53 | return (pts - np.min(pts)) / (np.max(pts) - np.min(pts) + 1e-10) 54 | 55 | def get_projected_img(pts, dist, img): 56 | hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 57 | 58 | dist_norm = max_normalize_pts(dist)*100 59 | # dist_norm = dist 60 | 61 | for i in range(pts.shape[0]): 62 | cv2.circle(hsv_img, (int(pts[i, 0]), int(pts[i, 1])), radius=1, color=(int(dist_norm[i]), 255, 255), thickness=-1) 63 | 64 | projection = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) 65 | projection = F.interpolate(torch.from_numpy(projection.astype(np.float32)).permute(2,0,1).unsqueeze(0), size=[256, 512], mode = 'bilinear', align_corners=False) 66 | return projection[0].permute(1,2,0).detach().cpu().numpy().astype(np.uint8) 67 | 68 | #%% 69 | # 数据读取 70 | dataset_class = DatasetKittiRawCalibNet 71 | dataset_val = dataset_class('/home/zhujt/dataset_zjt/kitti_raw/', max_r=10., max_t=0.25, split='val', 72 | use_reflectance=False, val_sequence=['2011_09_26_drive_0020_sync', '2011_09_26_drive_0034_sync']) 73 | 74 | ValImgLoader = torch.utils.data.DataLoader(dataset=dataset_val, 75 | shuffle=True, 76 | batch_size=1, 77 | num_workers=4, 78 | collate_fn=merge_inputs, 79 | drop_last=False, 80 | pin_memory=True) 81 | #%% 82 | # 模型读取 83 | agent = Agent().to(DEVICE) 84 | code_path = '/home/zhujt/code_calib/CalibDepth/' 85 | pretrain = os.path.join(code_path, 'weights_lstm/raw_1.zip') 86 | if os.path.exists(pretrain): 87 | util_model.load(agent, pretrain) 88 | progress = tqdm(BackgroundGenerator(ValImgLoader), total=len(ValImgLoader)) 89 | 90 | #%% 91 | target_num = 20 92 | num=0 93 | for data in progress: 94 | raw_img_path = data['img_path'] 95 | # 读取图片 96 | rgb_for_show = cv2.imread(raw_img_path[0]) 97 | # rgbforshow转为numpy格式 98 | rgb_for_show = cv2.cvtColor(rgb_for_show, cv2.COLOR_BGR2RGB) 99 | 100 | # rgb_for_show = (data['rgb'][0]*255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8) 101 | item = data['item'][0] 102 | rgb_input, depth_input, depth_target, pose_target, pose_source, ds_pc_target, ds_pc_source, calib = env.init(data) 103 | 104 | current_source = ds_pc_source 105 | current_depth = depth_input 106 | 107 | print(num) 108 | if num == target_num: 109 | # 初始的图片存储: 110 | # plt.imshow(rgb_input[0].permute(1,2,0).cpu()) 111 | # plt.imwrite(rgb_input[0].permute(1,2,0).cpu(), './save_fig/'+str(num)+'_rgb') 112 | plt.imwrite(rgb_for_show, './save_fig/'+str(num)+'_rgb') 113 | # plt.imshow((depth_target.expand(-1,3,-1,-1)[0]/80).permute(1,2,0).cpu()) 114 | plt.imwrite((depth_target.expand(-1,3,-1,-1)[0]/80).permute(1,2,0).cpu(), './save_fig/'+str(num)+'_dg') 115 | 116 | # plt.imshow((depth_input/depth_input.max()).expand(-1,3,-1,-1)[0].permute(1,2,0).cpu()) 117 | # plt.imwrite((depth_input/depth_input.max()).expand(-1,3,-1,-1)[0].permute(1,2,0).cpu(), './save_fig/gt'+str(num)) 118 | 119 | pcl_uv, pcl_z = get_projected_pts(current_source[0].transpose(0,1), calib[0], (384, 1280)) 120 | init_depth = get_projected_img(pcl_uv, pcl_z, np.zeros_like(rgb_for_show)) 121 | # plt.imshow(init_depth) 122 | 123 | init_project_img = get_projected_img(pcl_uv, pcl_z, rgb_for_show) 124 | # plt.imshow(init_project_img) 125 | plt.imwrite(init_project_img, './save_fig/'+str(num)+'_iter0') 126 | print(item) 127 | pcl_uv, pcl_z = get_projected_pts(ds_pc_target[0].transpose(0,1), calib[0], (384, 1280)) 128 | gt_project_img = get_projected_img(pcl_uv, pcl_z, rgb_for_show) 129 | # plt.imshow(gt_project_img) 130 | plt.imwrite(gt_project_img, './save_fig/'+str(num)+'_gt') 131 | 132 | for step in range(3): 133 | # actions, _, action_logprobs, _, value = agent(rgb_input, current_depth) # 如果是IL单独训练的话,效果应该是看均值的生成效果 134 | # _, actions, action_logprobs, _, value = agent(rgb_input, current_depth) # 如果是IL+RL联合进行训练的话,采样输出因为有RL对应损失的监督作用,也可以用来进行测试 135 | 136 | if(step == 0): 137 | actions, _, hc = agent(rgb_input, current_depth, torch.zeros(2, depth_input.shape[0], 256).to(DEVICE), torch.zeros(2, depth_input.shape[0], 256).to(DEVICE)) 138 | else: 139 | actions, depth_predict, hc = agent(rgb_input, current_depth, h_last, c_last) 140 | h_last, c_last = hc[0], hc[1] 141 | action_t, action_r = actions[0].unsqueeze(1), actions[1].unsqueeze(1) 142 | action_tr = torch.cat([action_t, action_r], dim = 1) 143 | 144 | new_source, pose_source = env.step_continous(ds_pc_source, action_tr, pose_source) 145 | current_source = new_source 146 | current_depth = lidar_project_depth_batch(current_source, calib, (384, 1280)) # 更新后点云对应的一个batch的深度图 147 | current_depth /= 80 148 | 149 | pcl_uv, pcl_z = get_projected_pts(current_source[0].transpose(0,1), calib[0], (384, 1280)) 150 | init_project_img = get_projected_img(pcl_uv, pcl_z, rgb_for_show) 151 | # plt.imshow(init_project_img) 152 | plt.imwrite(init_project_img, './save_fig/'+str(num)+'_iter'+str(step+1)) 153 | iter_img = get_projected_img(pcl_uv, pcl_z, np.zeros_like(rgb_for_show)) 154 | 155 | plt.imwrite((depth_predict.expand(-1,3,-1,-1)[0]/80).permute(1,2,0).detach().cpu().numpy(), './save_fig/'+str(num)+'_pred_d') 156 | break 157 | num += 1 158 | 159 | 160 | 161 | 162 | # %% 163 | --------------------------------------------------------------------------------