├── .gitignore ├── README.md ├── config ├── maicity │ └── maicity_incre.yaml ├── ncd │ └── ncd_incre.yaml └── rgbd │ ├── rgbd_inre.yaml │ └── rgbd_inre_thin.yaml ├── dataset ├── kitti_dataset.py └── rgbd_to_kitti_format.py ├── eval ├── crop_intersection.py ├── eval_utils.py └── evaluator.py ├── model ├── __init__.py ├── decoder.py └── feature_octree.py ├── run.py ├── scripts ├── convert_rgbd_to_kitti_format.sh ├── download_maicity.sh ├── download_ncd_example.sh └── download_neural_rgbd_data.sh └── utils ├── __init__.py ├── config.py ├── data_sampler.py ├── incre_learning.py ├── loss.py ├── mapper.py ├── mesher.py ├── pose.py ├── scan.py ├── semantic_kitti_utils.py ├── tools.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | *.key 4 | *.label 5 | *.txt 6 | *.ply 7 | *.bin 8 | *.pth 9 | *.png 10 | *.gif 11 | *.mp4 12 | *.xlsx 13 | *.csv 14 | TODO.md 15 | experiments/ 16 | data/ 17 | log/ 18 | visualize_density.py 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository represents the official implementation of the paper [N3-Mapping](https://ieeexplore.ieee.org/abstract/document/10518078/): 2 | ``` 3 | @article{song2024n3, 4 | title={N $\^{}$\{$3$\}$ $-Mapping: Normal Guided Neural Non-Projective Signed Distance Fields for Large-scale 3D Mapping}, 5 | author={Song, Shuangfu and Zhao, Junqiao and Huang, Kai and Lin, Jiaye and Ye, Chen and Feng, Tiantian}, 6 | journal={IEEE Robotics and Automation Letters}, 7 | year={2024}, 8 | publisher={IEEE} 9 | } 10 | ``` 11 | 12 | ## Installation 13 | #### 1. Clone the repository 14 | ``` 15 | git clone git@github.com:tiev-tongji/N3-Mapping.git 16 | cd N3-Mapping 17 | ``` 18 | #### 2. Set up conda environment 19 | ``` 20 | conda create --name n3 python=3.7 21 | conda activate n3 22 | ``` 23 | #### 3. Install the key requirement kaolin 24 | 25 | Kaolin depends on Pytorch (>= 1.8, <= 1.13.1), please install the corresponding Pytorch for your CUDA version (can be checked by ```nvcc --version```). You can find the installation commands [here](https://pytorch.org/get-started/previous-versions/). 26 | 27 | For example, for CUDA version >=11.6, you can use: 28 | ``` 29 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 30 | ``` 31 | 32 | Kaolin now supports installation with wheels. For example, to install kaolin 0.13.0 over torch 1.12.1 and cuda 11.6: 33 | ``` 34 | pip install kaolin==0.13.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.12.1_cu116.html 35 | ``` 36 | 37 | #### 4. Install the other requirements 38 | ``` 39 | pip install open3d scikit-image wandb tqdm natsort 40 | ``` 41 | 42 | ## Run 43 | Download the dataset the following script: 44 | ``` 45 | sh ./scripts/download_maicity.sh 46 | ``` 47 | Other datasets can also be downloaded in the same way: 48 | ``` 49 | sh ./scripts/download_ncd_example.sh 50 | sh ./scripts/download_neural_rgbd_data.sh 51 | ``` 52 | The data should follow the kitti odometry format from [here](https://www.cvlibs.net/datasets/kitti/eval_odometry.php). 53 | 54 | Therefore if you need to use [Neural RGBD dataset](https://github.com/dazinovic/neural-rgbd-surface-reconstruction), you can convert this dataset to the KITTI format by using for each sequence: 55 | ``` 56 | sh ./scripts/convert_rgbd_to_kitti_format.sh 57 | ``` 58 | Now we take the maicity as an example to show how to run the mapping system. 59 | First you need to check the config file such as `./config/maicity/maicity_incre.yaml` and set the correct path like `pc_path`, `pose_path` and `calib_path`. Then use: 60 | ``` 61 | python run.py config/maicity/maicity_incre.yaml 62 | ``` 63 | 64 | ## Evaluation 65 | Please prepare your reconstructed mesh and corresponding ground truth point cloud. Then set the right data path and evaluation set-up in `./eval/evaluator.py`. Now run: 66 | ``` 67 | python ./eval/evaluator.py 68 | ``` 69 | ## Contact 70 | Feel free to contact me if you have any questions :) 71 | - Song {[1911204@tongji.edu.cn]()} 72 | 73 | ## Acknowledgment 74 | Our work is mainly built on [SHINE-Mapping](https://github.com/PRBonn/SHINE_mapping). Many thanks to the authors of this excellent work! 75 | We also appreciate the following great open-source works: 76 | - [Voxfield](https://github.com/VIS4ROB-lab/voxfield) (comparison baseline, inspiration) 77 | - [Voxblox](https://github.com/ethz-asl/voxblox) (comparison baseline) 78 | - [NeRF-LOAM](https://github.com/JunyuanDeng/NeRF-LOAM) (comparison baseline) 79 | - [Loc-NDF](https://github.com/PRBonn/LocNDF)(inspiration) 80 | 81 | ## TODO 82 | Currently our implementation is more of a proof-of-concept and lacks optimization. We are working on improving this. A more efficient voxel-centric mapping design is on the way. 83 | -------------------------------------------------------------------------------- /config/maicity/maicity_incre.yaml: -------------------------------------------------------------------------------- 1 | setting: 2 | name: "maicity_incre" 3 | output_root: "./experiments/" 4 | pc_path: "/media/shy/Document/dataset/mai_city/ply/sequences/01/velodyne" 5 | pose_path: "/media/shy/Document/dataset/mai_city/ply/sequences/01/poses.txt" 6 | calib_path: "/media/shy/Document/dataset/mai_city/ply/sequences/01/calib.txt" 7 | label_path: "" 8 | load_model: False # load the pretrained decoder model (optional) 9 | model_path: "./pretrained/geo_decoder_8dim.pth" 10 | first_frame_ref: False 11 | begin_frame: 0 12 | end_frame: 100 13 | every_frame: 1 # 1 means does not skip 14 | device: "cuda" 15 | gpu_id: "0" 16 | process: 17 | min_range_m: 1.5 18 | pc_radius_m: 20.0 # distance filter for each frame 19 | min_z: -3.0 20 | max_z: 30.0 21 | rand_downsample: False # use random or voxel downsampling 22 | vox_down_m: 0.01 # 0.03 23 | rand_down_r: 1.0 24 | sampler: 25 | surface_sample_range_m: 0.5 # 0.3 26 | surface_sample_n: 3 27 | free_sample_begin_ratio: 0.5 28 | free_sample_end_dist: 0.5 29 | free_sample_n: 3 30 | normal_sampling_on: True 31 | gaussian_sampling_on: False 32 | #gaussian_sigma: 0.5 33 | sliding_window_on: True 34 | octree: 35 | leaf_vox_size: 0.2 36 | tree_level_world: 12 37 | tree_level_feat: 3 38 | feature_dim: 8 39 | poly_int_on: False # better false when using normal sampling 40 | octree_from_surface_samples: True 41 | decoder: 42 | mlp_level: 2 43 | mlp_hidden_dim: 32 44 | freeze_after_frame: 20 45 | predict_residual_sdf: False 46 | loss: 47 | ray_loss: False 48 | main_loss_type: sdf_bce # select from sdf_bce (our proposed), sdf_l1, sdf_l2, dr, dr_neus 49 | sigma_sigmoid_m: 0.05 # 0.05 50 | loss_weight_on: False 51 | behind_dropoff_on: False 52 | ekional_loss_on: True 53 | weight_e: 0.08 #0.1 54 | normal_loss_on: False 55 | weight_n: 0.1 56 | continual: 57 | continual_learning_reg: False 58 | lambda_forget: 0 59 | optimizer: 60 | iters: 50 # iterations per frame 61 | batch_size: 8192 #4096 62 | learning_rate: 0.01 # 0.01 63 | weight_decay: 0 # l2 regularization 64 | extra_training: False 65 | eval: 66 | wandb_vis_on: False # log to wandb or not 67 | o3d_vis_on: False # visualize the mapping or not 68 | vis_freq_iters: 0 69 | save_freq_iters: 0 # save the model and octree every x iterations 70 | mesh_freq_frame: 100 # reconstruct the mesh every x frames 71 | mc_res_m: 0.2 # reconstruction marching cubes resolution 72 | mc_with_octree: False # querying sdf in the map bbx 73 | mc_vis_level: 2 # 1 more accurate more complete (may with more artifacts) 74 | clean_mesh_on: True 75 | save_map: False # save the sdf map -------------------------------------------------------------------------------- /config/ncd/ncd_incre.yaml: -------------------------------------------------------------------------------- 1 | setting: 2 | name: "ncd_incre" 3 | output_root: "./experiments/" 4 | pc_path: "/media/shy/Document/dataset/ncd_example/quad/pcd" 5 | pose_path: "/media/shy/Document/dataset/ncd_example/quad/poses.txt" 6 | calib_path: "/media/shy/Document/dataset/ncd_example/quad/calib.txt" 7 | label_path: "" 8 | load_model: False 9 | model_path: "" 10 | first_frame_ref: False 11 | begin_frame: 0 12 | end_frame: 1300 13 | every_frame: 5 # 1 means does not skip 14 | device: "cuda" 15 | gpu_id: "0" 16 | process: 17 | min_range_m: 1.5 18 | pc_radius_m: 50.0 # distance filter for each frame 19 | min_z: -3.0 20 | max_z: 30.0 21 | rand_downsample: False # use random or voxel downsampling 22 | vox_down_m: 0.03 # may cause map offset 23 | rand_down_r: 0.2 24 | sampler: 25 | surface_sample_range_m: 0.3 26 | surface_sample_n: 3 27 | free_sample_begin_ratio: 0.3 28 | free_sample_end_dist: 0.8 # 29 | free_sample_n: 6 # more free space sampling can suppress dynamic objects 30 | normal_sampling_on: True 31 | gaussian_sampling_on: False 32 | sliding_window_on: False # not necessary in ncd 33 | octree: 34 | leaf_vox_size: 0.4 # 0.2 35 | tree_level_world: 12 36 | tree_level_feat: 3 37 | feature_dim: 8 38 | poly_int_on: False 39 | octree_from_surface_samples: True 40 | decoder: 41 | mlp_level: 2 42 | mlp_hidden_dim: 32 43 | freeze_after_frame: 30 44 | predict_residual_sdf: False 45 | loss: 46 | ray_loss: False 47 | main_loss_type: sdf_bce # select from sdf_bce (our proposed), sdf_l1, sdf_l2, dr, dr_neus 48 | sigma_sigmoid_m: 0.1 49 | loss_weight_on: False 50 | behind_dropoff_on: False 51 | ekional_loss_on: False 52 | weight_e: 0.08 53 | normal_loss_on: False 54 | weight_n: 0.1 55 | continual: 56 | continual_learning_reg: False 57 | lambda_forget: 0 # the larger this value, the model would be less likely to forget 58 | optimizer: 59 | iters: 50 # iterations per frame 60 | batch_size: 8192 61 | learning_rate: 0.01 62 | weight_decay: 1e-7 # l2 regularization 63 | extra_training: False 64 | eval: 65 | wandb_vis_on: False # log to wandb or not 66 | o3d_vis_on: False # visualize the mapping or not 67 | vis_freq_iters: 0 68 | save_freq_iters: 0 # save the model and octree every x iterations 69 | mesh_freq_frame: 1300 # reconstruct the mesh every x frames 70 | mc_res_m: 0.2 # reconstruction marching cubes resolution 71 | mc_with_octree: True # querying sdf only in certain levels of the octree 72 | mc_vis_level: 2 73 | clean_mesh_on: True 74 | save_map: False # save the sdf map or not 75 | -------------------------------------------------------------------------------- /config/rgbd/rgbd_inre.yaml: -------------------------------------------------------------------------------- 1 | # thin_geometry_kitti_gt 2 | # staircase_kitti_format 3 | # complete_kitchen_kitti_format 4 | # grey_white_room_kitti_format 5 | # morning_apartment_kitti_format 6 | 7 | setting: 8 | name: "rgbd_incre" 9 | output_root: "./experiments/" 10 | pc_path: "/media/shy/Document/dataset/neural_rgbd/staircase_kitti_format/rgbd_ply" 11 | pose_path: "/media/shy/Document/dataset/neural_rgbd/staircase_kitti_format/poses.txt" 12 | label_path: "" 13 | calib_path: "" 14 | load_model: False # load the pretrained decoder model (optional) 15 | model_path: "./pretrained/geo_decoder_8dim.pth" 16 | first_frame_ref: False 17 | begin_frame: 0 18 | end_frame: 2000 #1500 19 | every_frame: 5 # 1 means does not skip 20 | device: "cuda" 21 | gpu_id: "0" 22 | process: 23 | min_range_m: 0.1 24 | pc_radius_m: 8.0 # distance filter for each frame 25 | min_z: -8.0 26 | max_z: 8.0 27 | rand_downsample: False # use random or voxel downsampling 28 | vox_down_m: 0.01 29 | rand_down_r: 0.5 30 | sampler: 31 | surface_sample_range_m: 0.05 #0.05 32 | surface_sample_n: 3 33 | free_sample_begin_ratio: 0.5 34 | free_sample_end_dist: 0.2 35 | free_sample_n: 3 36 | normal_sampling_on: False 37 | gaussian_sampling_on: False 38 | sliding_window_on: False 39 | octree: 40 | leaf_vox_size: 0.04 41 | tree_level_world: 12 42 | tree_level_feat: 4 43 | feature_dim: 8 44 | poly_int_on: False 45 | octree_from_surface_samples: True 46 | decoder: 47 | mlp_level: 2 48 | mlp_hidden_dim: 32 49 | freeze_after_frame: 20 50 | predict_residual_sdf: False 51 | loss: 52 | ray_loss: False 53 | main_loss_type: sdf_bce # select from sdf_bce (our proposed), sdf_l1, sdf_l2, dr, dr_neus 54 | sigma_sigmoid_m: 0.02 55 | loss_weight_on: False 56 | behind_dropoff_on: False 57 | ekional_loss_on: False 58 | weight_e: 0.08 59 | normal_loss_on: False 60 | weight_n: 0.1 61 | continual: 62 | continual_learning_reg: False 63 | lambda_forget: 0 64 | optimizer: 65 | iters: 30 66 | batch_size: 8192 67 | learning_rate: 0.01 68 | weight_decay: 0 69 | extra_training: False 70 | eval: 71 | wandb_vis_on: False # log to wandb or not 72 | o3d_vis_on: False # visualize the mapping or not 73 | vis_freq_iters: 0 74 | save_freq_iters: 0 # save the model and octree every x iterations 75 | mesh_freq_frame: 5000 # reconstruct the mesh every x frames 76 | mc_res_m: 0.04 # reconstruction marching cubes resolution 77 | mc_with_octree: False # querying sdf in the map bbx 78 | mc_vis_level: 2 # 2 may lead to fake wall 79 | clean_mesh_on: True 80 | save_map: False # save the sdf map or not -------------------------------------------------------------------------------- /config/rgbd/rgbd_inre_thin.yaml: -------------------------------------------------------------------------------- 1 | # thin_geometry_kitti_format 2 | # thin_geometry_kitti_gt 3 | 4 | setting: 5 | name: "rgbd_incre" 6 | output_root: "./experiments/" 7 | pc_path: "/media/shy/Document/dataset/neural_rgbd/thin_geometry_kitti_gt/rgbd_ply" 8 | pose_path: "/media/shy/Document/dataset/neural_rgbd/thin_geometry_kitti_gt/poses.txt" 9 | label_path: "" 10 | calib_path: "" 11 | load_model: False # load the pretrained decoder model (optional) 12 | model_path: "./pretrained/geo_decoder_8dim.pth" 13 | first_frame_ref: False 14 | begin_frame: 0 15 | end_frame: 2000 #1500 16 | every_frame: 2 # 1 means does not skip 17 | device: "cuda" 18 | gpu_id: "0" 19 | process: 20 | min_range_m: 0.1 21 | pc_radius_m: 8.0 # distance filter for each frame 22 | min_z: -8.0 23 | max_z: 8.0 24 | rand_downsample: False # use random or voxel downsampling 25 | vox_down_m: 0.01 26 | rand_down_r: 0.5 27 | sampler: 28 | surface_sample_range_m: 0.02 29 | surface_sample_n: 3 30 | free_sample_begin_ratio: 0.5 31 | free_sample_end_dist: 0.03 32 | free_sample_n: 3 33 | normal_sampling_on: True 34 | gaussian_sampling_on: False 35 | sliding_window_on: False 36 | octree: 37 | leaf_vox_size: 0.02 38 | tree_level_world: 12 39 | tree_level_feat: 4 40 | feature_dim: 8 41 | poly_int_on: False 42 | octree_from_surface_samples: True 43 | decoder: 44 | mlp_level: 2 45 | mlp_hidden_dim: 32 46 | freeze_after_frame: 20 47 | predict_residual_sdf: False 48 | loss: 49 | ray_loss: False 50 | main_loss_type: sdf_bce # select from sdf_bce (our proposed), sdf_l1, sdf_l2, dr, dr_neus 51 | sigma_sigmoid_m: 0.01 52 | loss_weight_on: False 53 | behind_dropoff_on: False 54 | ekional_loss_on: False 55 | weight_e: 0.08 56 | normal_loss_on: False 57 | weight_n: 0.1 58 | continual: 59 | continual_learning_reg: False 60 | lambda_forget: 0 61 | optimizer: 62 | iters: 30 63 | batch_size: 8192 64 | learning_rate: 0.01 65 | weight_decay: 0 66 | extra_training: False 67 | eval: 68 | wandb_vis_on: False # log to wandb or not 69 | o3d_vis_on: False # visualize the mapping or not 70 | vis_freq_iters: 0 71 | save_freq_iters: 0 # save the model and octree every x iterations 72 | mesh_freq_frame: 5000 # reconstruct the mesh every x frames 73 | mc_res_m: 0.02 # 0.04 reconstruction marching cubes resolution 74 | mc_with_octree: False # querying sdf in the map bbx 75 | mc_vis_level: 2 76 | clean_mesh_on: True 77 | save_map: False # save the sdf map or not -------------------------------------------------------------------------------- /dataset/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | sys.path.append(osp.abspath('./')) 5 | 6 | import copy 7 | import numpy as np 8 | from numpy.linalg import inv, norm 9 | import torch 10 | from torch.utils.data import Dataset 11 | import open3d as o3d 12 | from natsort import natsorted 13 | 14 | from utils.config import SHINEConfig 15 | from utils.pose import * 16 | from utils.semantic_kitti_utils import * 17 | from utils.tools import voxel_down_sample_torch 18 | 19 | 20 | class KITTIDataset(Dataset): 21 | def __init__(self, config: SHINEConfig) -> None: 22 | 23 | super().__init__() 24 | 25 | self.config = config 26 | self.dtype = config.dtype 27 | torch.set_default_dtype(self.dtype) 28 | self.device = config.device 29 | self.pool_device = config.device 30 | 31 | self.calib = {} 32 | if config.calib_path != '': 33 | self.calib = read_calib_file(config.calib_path) 34 | else: 35 | self.calib['Tr'] = np.eye(4) 36 | if config.pose_path.endswith('txt'): 37 | self.poses_w = read_poses_file(config.pose_path, self.calib) 38 | elif config.pose_path.endswith('csv'): 39 | self.poses_w = csv_odom_to_transforms(config.pose_path) 40 | else: 41 | sys.exit( 42 | "Wrong pose file format. Please use either *.txt (KITTI format) or *.csv (xyz+quat format)" 43 | ) 44 | 45 | # pose in the reference frame (might be the first frame used) 46 | self.poses_ref = self.poses_w # initialize size 47 | 48 | # point cloud files 49 | self.pc_filenames = natsorted(os.listdir(config.pc_path)) # sort files as 1, 2,… 9, 10 not 1, 10, 100 with natsort 50 | self.total_pc_count = len(self.pc_filenames) 51 | 52 | # local map pc 53 | self.cur_frame_pc = o3d.geometry.PointCloud() 54 | # merged downsampled point cloud 55 | self.map_down_pc = o3d.geometry.PointCloud() 56 | # map bounding box in the world coordinate system 57 | self.map_bbx = o3d.geometry.AxisAlignedBoundingBox() 58 | 59 | # get the pose in the reference frame 60 | self.used_pc_count = 0 61 | begin_flag = False 62 | self.begin_pose_inv = np.eye(4) 63 | for frame_id in range(self.total_pc_count): 64 | if ( 65 | frame_id < config.begin_frame 66 | or frame_id > config.end_frame 67 | or frame_id % config.every_frame != 0 68 | ): 69 | continue 70 | if not begin_flag: # the first frame used 71 | begin_flag = True 72 | if config.first_frame_ref: 73 | self.begin_pose_inv = inv(self.poses_w[frame_id]) # T_rw 74 | else: 75 | # just a random number to avoid octree boudnary marching cubes problems on synthetic dataset such as MaiCity(TO FIX) 76 | self.begin_pose_inv[2,3] += config.global_shift_default 77 | # use the first frame as the reference (identity) 78 | self.poses_ref[frame_id] = np.matmul( 79 | self.begin_pose_inv, self.poses_w[frame_id] 80 | ) 81 | self.used_pc_count += 1 82 | # or we directly use the world frame as reference 83 | 84 | def process_frame(self, frame_id): 85 | 86 | pc_radius = self.config.pc_radius 87 | min_z = self.config.min_z 88 | max_z = self.config.max_z 89 | normal_radius_m = self.config.normal_radius_m 90 | normal_max_nn = self.config.normal_max_nn 91 | rand_down_r = self.config.rand_down_r 92 | vox_down_m = self.config.vox_down_m 93 | sor_nn = self.config.sor_nn 94 | sor_std = self.config.sor_std 95 | 96 | self.cur_pose_ref = self.poses_ref[frame_id] 97 | 98 | # step 0. load point cloud (support *pcd, *ply and kitti *bin format) 99 | frame_filename = os.path.join(self.config.pc_path, self.pc_filenames[frame_id]) 100 | 101 | if not self.config.semantic_on: 102 | frame_pc = self.read_point_cloud(frame_filename) 103 | # label_filename = os.path.join(self.config.label_path, self.pc_filenames[frame_id].replace('bin','label')) 104 | # frame_pc = self.read_semantic_point_label(frame_filename, label_filename) 105 | else: 106 | label_filename = os.path.join(self.config.label_path, self.pc_filenames[frame_id].replace('bin','label')) 107 | frame_pc = self.read_semantic_point_label(frame_filename, label_filename) 108 | 109 | #step 1. block filter: crop the point clouds into a cube 110 | bbx_min = o3d.core.Tensor([-pc_radius, -pc_radius, min_z], dtype = o3d.core.float32) 111 | bbx_max = o3d.core.Tensor([pc_radius, pc_radius, max_z], dtype = o3d.core.float32) 112 | bbx = o3d.t.geometry.AxisAlignedBoundingBox(bbx_min, bbx_max) 113 | frame_pc = frame_pc.crop(bbx) 114 | 115 | # surface normal estimation 116 | if self.config.estimate_normal: 117 | frame_pc.estimate_normals(max_nn = normal_max_nn) 118 | #frame_pc.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(normal_max_nn)) 119 | #frame_pc.estimate_normals(radius=normal_radius_m) 120 | frame_pc.orient_normals_towards_camera_location() # orient normals towards the default origin(0,0,0). 121 | 122 | 123 | #step2. point cloud downsampling 124 | if self.config.rand_downsample: 125 | # random downsampling 126 | frame_pc = frame_pc.random_down_sample(sampling_ratio=rand_down_r) 127 | else: 128 | # voxel downsampling 129 | frame_pc = frame_pc.voxel_down_sample(voxel_size=vox_down_m) 130 | 131 | # apply filter (optional) 132 | if self.config.filter_noise: 133 | frame_pc = frame_pc.remove_statistical_outlier( 134 | sor_nn, sor_std, print_progress=False 135 | )[0] 136 | 137 | frame_origin = self.cur_pose_ref[:3, 3] * self.config.scale # translation part 138 | frame_origin_torch = torch.tensor(frame_origin, dtype=self.dtype, device=self.pool_device) 139 | 140 | # step 3. transform to reference frame 141 | frame_pc = frame_pc.transform(self.cur_pose_ref) 142 | 143 | # step 3.5 make a backup of global point cloud map. 144 | frame_pc_clone = copy.deepcopy(frame_pc.to_legacy()) 145 | #frame_pc_clone = frame_pc_clone.voxel_down_sample(voxel_size=self.config.map_vox_down_m) # for smaller memory cost 146 | self.map_down_pc += frame_pc_clone # for marching cube filtering. 147 | self.cur_frame_pc = frame_pc_clone # for visualization 148 | self.map_bbx = self.map_down_pc.get_axis_aligned_bounding_box() 149 | if frame_id % 400 == 0: 150 | self.map_down_pc = self.map_down_pc.voxel_down_sample(0.5*self.config.mc_res_m) # to avoid out of memory for large map 151 | 152 | # step 4. and scale to [-1,1] coordinate system (important!) 153 | frame_pc_s = frame_pc.scale(self.config.scale, center=o3d.core.Tensor([0,0,0], dtype = o3d.core.float32)) 154 | 155 | # step 5 turn into torch format. 156 | frame_pc_s_torch = torch.tensor(frame_pc_s.point.positions.numpy(), dtype=self.dtype, device=self.pool_device) 157 | frame_normal_torch = None 158 | if self.config.estimate_normal: 159 | frame_normal_torch = torch.tensor(frame_pc_s.point.normals.numpy(), dtype=self.dtype, device=self.pool_device) 160 | frame_label_torch = None 161 | if self.config.semantic_on: 162 | frame_label_torch = torch.tensor(frame_pc_s.point.labels.numpy(), dtype=self.dtype, device=self.pool_device) 163 | 164 | return frame_id, frame_origin_torch, frame_pc_s_torch, frame_normal_torch, frame_label_torch 165 | 166 | def read_point_cloud(self, filename: str): 167 | # read point cloud from either (*.ply, *.pcd) or (kitti *.bin) format 168 | if ".bin" in filename: 169 | points = np.fromfile(filename, dtype=np.float32).reshape((-1, 4))[:, :3] 170 | elif ".ply" in filename or ".pcd" in filename: 171 | pc_load = o3d.io.read_point_cloud(filename) 172 | points = np.asarray(pc_load.points) 173 | else: 174 | sys.exit( 175 | "The format of the imported point cloud is wrong (support only *pcd, *ply and *bin)" 176 | ) 177 | preprocessed_points = self.preprocess_kitti(points, self.config.min_z, self.config.min_range) 178 | #preprocessed_points = points 179 | pcd_t = o3d.t.geometry.PointCloud() 180 | pcd_t.point.positions = o3d.core.Tensor(preprocessed_points, o3d.core.float32) 181 | return pcd_t 182 | 183 | def read_semantic_point_label(self, filename: str, label_filename: str): 184 | 185 | # read point cloud (kitti *.bin format) 186 | if ".bin" in filename: 187 | points = np.fromfile(filename, dtype=np.float32).reshape((-1, 4))[:, :3] 188 | elif ".ply" in filename or ".pcd" in filename: 189 | pc_load = o3d.io.read_point_cloud(filename) 190 | points = np.asarray(pc_load.points) 191 | else: 192 | sys.exit( 193 | "The format of the imported point cloud is wrong (support only *bin)" 194 | ) 195 | 196 | # read point cloud labels (*.label format) 197 | if ".label" in label_filename: 198 | labels = np.fromfile(label_filename, dtype=np.uint32).reshape((-1)) 199 | else: 200 | sys.exit( 201 | "The format of the imported point labels is wrong (support only *label)" 202 | ) 203 | 204 | points, sem_labels = self.preprocess_sem_kitti( 205 | points, labels, self.config.min_z, self.config.min_range, filter_moving=self.config.filter_moving_object 206 | ) 207 | pcd_t = o3d.t.geometry.PointCloud() 208 | pcd_t.point.positions = o3d.core.Tensor(points, o3d.core.float32) 209 | pcd_t.point.labels = o3d.core.Tensor(sem_labels, o3d.core.int32) #.reshape(-1) 210 | return pcd_t 211 | 212 | def preprocess_kitti(self, points, z_th=-3.0, min_range=2.5): 213 | # filter the outliers 214 | z = points[:, 2] 215 | points = points[z > z_th] 216 | points = points[np.linalg.norm(points, axis=1) >= min_range] 217 | return points 218 | 219 | def preprocess_sem_kitti(self, points, labels, min_range=2.75, filter_outlier = True, filter_moving = True): 220 | # TODO: speed up 221 | sem_labels = np.array(labels & 0xFFFF) 222 | 223 | range_filtered_idx = np.linalg.norm(points, axis=1) >= min_range 224 | points = points[range_filtered_idx] 225 | sem_labels = sem_labels[range_filtered_idx] 226 | 227 | # filter the outliers according to semantic labels 228 | if filter_moving: 229 | filtered_idx = sem_labels < 100 230 | points = points[filtered_idx] 231 | sem_labels = sem_labels[filtered_idx] 232 | 233 | if filter_outlier: 234 | filtered_idx = (sem_labels != 1) # not outlier 235 | points = points[filtered_idx] 236 | sem_labels = sem_labels[filtered_idx] 237 | 238 | sem_labels_main_class = np.array([sem_kitti_learning_map[sem_label] for sem_label in sem_labels]) # get the reduced label [0-20] 239 | 240 | return points, sem_labels_main_class 241 | 242 | def write_merged_pc(self, out_path): 243 | #map_down_pc_out = copy.deepcopy(self.map_down_pc) 244 | map_down_pc_out = self.map_down_pc 245 | map_down_pc_out.transform(inv(self.begin_pose_inv)) # back to world coordinate (if taking the first frame as reference) 246 | o3d.io.write_point_cloud(out_path, map_down_pc_out) 247 | print("save the merged point cloud map to %s\n" % (out_path)) 248 | 249 | def __len__(self) -> int: 250 | return len(self.pc_filenames) 251 | 252 | def __getitem__(self, index: int): 253 | return self.process_frame(index) 254 | 255 | if __name__ == '__main__': 256 | config = SHINEConfig() 257 | if len(sys.argv) > 1: 258 | config.load(sys.argv[1]) 259 | else: 260 | sys.exit( 261 | "Please provide the path to the config file.\nTry: python shine_incre.py xxx/xxx_config.yaml" 262 | ) 263 | loader = KITTIDataset(config) 264 | seq_size = len(loader) 265 | print("the sequence has {0} frames in total.".format(seq_size)) 266 | for frame_id, origin, points, normals, labels in loader: 267 | print(frame_id) 268 | print(origin) 269 | print(points.shape) 270 | print(normals.shape) 271 | if(config.semantic_on): 272 | print(labels.shape) 273 | 274 | 275 | 276 | 277 | 278 | 279 | -------------------------------------------------------------------------------- /dataset/rgbd_to_kitti_format.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | from tqdm import tqdm 3 | import argparse 4 | import re 5 | import os 6 | import numpy as np 7 | import json 8 | import shutil 9 | 10 | def rgbd_to_kitti_format(args): 11 | 12 | ply_path = os.path.join(args.output_root, "rgbd_ply") 13 | os.makedirs(ply_path, exist_ok=True) 14 | 15 | # get pose 16 | pose_kitti_format_path = os.path.join(args.output_root, "poses.txt") 17 | if args.already_kitti_format_pose: 18 | shutil.copyfile(args.pose_file, pose_kitti_format_path) # don't directly copy, may have some issues 19 | else: 20 | poses_mat = load_poses(args.pose_file, with_head = False) # with_head = True for open3d provided Redwood dataset 21 | write_poses_kitti_format(poses_mat, pose_kitti_format_path) 22 | 23 | # get an example image 24 | depth_img_files = sorted(os.listdir(args.depth_img_folder), key=alphanum_key) 25 | rgb_img_files = sorted(os.listdir(args.rgb_img_folder), key=alphanum_key) 26 | 27 | im_depth_example_path = os.path.join(args.depth_img_folder, depth_img_files[0]) 28 | # print(im_depth_example_path) 29 | im_depth_example = o3d.io.read_image(im_depth_example_path) 30 | H, W = np.array(im_depth_example).shape[:2] 31 | print("Image size:", H, "x", W) 32 | 33 | # load the camera intrinsic parameters 34 | intrinsic = o3d.camera.PinholeCameraIntrinsic() 35 | depth_scale = 1000. 36 | if args.intrinsic_file == "": 37 | # use the default parameter 38 | # W=640, H=480, fx=fy=525.0, cx=319.5, cy=239.5 39 | print("Default intrinsic for PrimeSense used") 40 | intrinsic = o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault) 41 | # use this extrinsic matrix to rotate the image since frames captured with RealSense camera are upside down 42 | extrinsic = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 43 | else: 44 | if args.is_focal_file: # load the focal length only txt file # This is used for NeuralRGBD dataset 45 | focal = load_focal_length(args.intrinsic_file) 46 | print("Focal length:", focal) 47 | intrinsic.set_intrinsics(height=H, 48 | width=W, 49 | fx=focal, 50 | fy=focal, 51 | cx=(W-1.)/2., 52 | cy=(H-1.)/2.) 53 | depth_scale = 1000. 54 | # use this extrinsic matrix to rotate the image since frames captured with RealSense camera are upside down 55 | extrinsic = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 56 | 57 | else: 58 | with open(args.intrinsic_file, 'r') as infile: # load intrinsic json file 59 | cam = json.load(infile)["camera"] 60 | intrinsic.set_intrinsics(height=cam["h"], 61 | width=cam["w"], 62 | fx=cam["fx"], 63 | fy=cam["fy"], 64 | cx=cam["cx"], 65 | cy=cam["cy"]) 66 | depth_scale = cam["scale"] 67 | extrinsic = np.eye(4) # this is used for Replica dataset 68 | 69 | 70 | # get point cloud 71 | frame_count = 0 72 | for color_path, depth_path in tqdm(zip(rgb_img_files, depth_img_files)): 73 | color_path = os.path.join(args.rgb_img_folder, color_path) 74 | depth_path = os.path.join(args.depth_img_folder, depth_path) 75 | print(color_path) 76 | print(depth_path) 77 | 78 | im_color = o3d.io.read_image(color_path) 79 | im_depth = o3d.io.read_image(depth_path) 80 | im_rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(im_color, im_depth, depth_scale = depth_scale, depth_trunc = args.max_depth_m, convert_rgb_to_intensity=False) # not just gray 81 | im_pcd = o3d.geometry.PointCloud.create_from_rgbd_image(im_rgbd, intrinsic, extrinsic) 82 | # remove rgb 83 | im_pcd.colors = o3d.utility.Vector3dVector([]) 84 | 85 | if args.vis_on: 86 | o3d.visualization.draw_geometries([im_pcd]) 87 | frame_id_str = f'{frame_count:06d}' 88 | cur_filename = frame_id_str+".ply" 89 | cur_path = os.path.join(ply_path, cur_filename) 90 | o3d.io.write_point_cloud(cur_path, im_pcd) 91 | 92 | frame_count+=1 93 | 94 | print("The rgbd dataset in KITTI format has been saved at %s", args.output_root) 95 | 96 | def str2bool(v): 97 | if isinstance(v, bool): 98 | return v 99 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 100 | return True 101 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 102 | return False 103 | else: 104 | raise argparse.ArgumentTypeError('Boolean value expected.') 105 | 106 | def alphanum_key(s): 107 | """ Turn a string into a list of string and number chunks. 108 | "z23a" -> ["z", 23, "a"] 109 | """ 110 | return [int(x) if x.isdigit() else x for x in re.split('([0-9]+)', s)] 111 | 112 | def load_from_json(filename): 113 | """Load a dictionary from a JSON filename. 114 | Args: 115 | filename: The filename to load from. 116 | """ 117 | assert filename.suffix == ".json" 118 | with open(filename, encoding="UTF-8") as file: 119 | return json.load(file) 120 | 121 | def load_focal_length(filepath): 122 | file = open(filepath, "r") 123 | return float(file.readline()) 124 | 125 | def load_poses(posefile, with_head = False): 126 | file = open(posefile, "r") 127 | lines = file.readlines() 128 | file.close() 129 | poses = [] 130 | if not with_head: 131 | lines_per_matrix = 4 132 | skip_line = 0 133 | else: 134 | lines_per_matrix = 5 135 | skip_line = 1 136 | for i in range(0, len(lines), lines_per_matrix): 137 | pose_floats = np.array([[float(x) for x in line.split()] for line in lines[i+skip_line:i+lines_per_matrix]]) 138 | # print(pose_floats) 139 | poses.append(pose_floats) 140 | 141 | return poses 142 | 143 | def write_poses_kitti_format(poses_mat, posefile): 144 | poses_vec = [] 145 | for pose_mat in poses_mat: 146 | pose_vec= pose_mat.flatten()[0:12] 147 | poses_vec.append(pose_vec) 148 | np.savetxt(posefile, poses_vec, delimiter=' ') 149 | 150 | def parser_json_sdf_studio_format(json_file): 151 | meta_data = load_from_json(json_file) 152 | 153 | 154 | 155 | if __name__ == "__main__": 156 | 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument('--depth_img_folder', help="folder containing the depth images") 159 | parser.add_argument('--rgb_img_folder', help="folder containing the rgb images") 160 | parser.add_argument('--intrinsic_file', default="", help="path to the json file containing the camera intrinsic parameters") 161 | parser.add_argument('--pose_file', help="path to the txt file containing the camera pose at each frame") 162 | parser.add_argument('--output_root', help="path for outputing the kitti format data") 163 | parser.add_argument('--max_depth_m', type=float, default=5.0, help="maximum depth to be used") 164 | parser.add_argument('--is_focal_file', type=str2bool, nargs='?', default=True, \ 165 | help="is the input intrinsic file a txt file containing only the focus length (as the Neural RGBD data format)\ 166 | or the json file containing all the intrinsic parameters (as the Replica format)") 167 | parser.add_argument('--already_kitti_format_pose', type=str2bool, nargs='?', default=False, \ 168 | help="is the input pose file already in KITTI pose format (also as the Replica format)\ 169 | or the input pose file is in a 4dim transformation form (as the Neural RGBD data format)") 170 | parser.add_argument('--vis_on', type=str2bool, nargs='?', default=False) 171 | args = parser.parse_args() 172 | 173 | rgbd_to_kitti_format(args) -------------------------------------------------------------------------------- /eval/crop_intersection.py: -------------------------------------------------------------------------------- 1 | from eval_utils import crop_intersection 2 | 3 | # This file presents an example to crop the ground truth point cloud to the intersection part of all 4 | # the compared method's mesh reconstruction 5 | 6 | gt_pcd_path = "xxx/mai_city/01/gt_map_pc_mai.ply" 7 | 8 | pred_vdb_path = "xxx/mai_city/01/baseline/vdb_fusion/mesh_vdb_10cm.ply" 9 | 10 | pred_puma_path = "xxx/mai_city/01/baseline/puma/mesh_puma_l10.ply" 11 | 12 | pred_voxblox_path = "xxx/mai_city/01/baseline/voxblox/mesh_voxblox_10cm.ply" 13 | 14 | pred_shine_path = "xxx/mai_city/01/mesh_shine_10cm.ply" 15 | 16 | preds_path = [pred_vdb_path, pred_puma_path, pred_voxblox_path, pred_shine_path] 17 | 18 | crop_gt_pcd_path = "xxx/mai_city/01/gt_map_pc_mai_crop_intersection.ply" 19 | 20 | crop_intersection(gt_pcd_path, preds_path, crop_gt_pcd_path, dist_thre=0.2) -------------------------------------------------------------------------------- /eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | # This file is derived from [Atlas](https://github.com/magicleap/Atlas). 2 | # Originating Author: Zak Murez (zak.murez.com) 3 | # Modified for [SHINEMapping] by Yue Pan. 4 | 5 | # Original header: 6 | # Copyright 2020 Magic Leap, Inc. 7 | 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | 21 | import open3d as o3d 22 | import numpy as np 23 | 24 | def eval_mesh(file_pred, file_trgt, down_sample_res=0.02, threshold=0.05, truncation_acc=0.50, truncation_com=0.50, gt_bbx_mask_on= True, 25 | mesh_sample_point=10000000, generate_error_map=False): 26 | """ Compute Mesh metrics between prediction and target. 27 | Opens the Meshs and runs the metrics 28 | Args: 29 | file_pred: file path of prediction (should be mesh) 30 | file_trgt: file path of target (shoud be point cloud) 31 | down_sample_res: use voxel_downsample to uniformly sample mesh points 32 | threshold: distance threshold used to compute precision/recall 33 | truncation_acc: points whose nearest neighbor is farther than the distance would not be taken into account (take pred as reference) 34 | truncation_com: points whose nearest neighbor is farther than the distance would not be taken into account (take trgt as reference) 35 | gt_bbx_mask_on: use the bounding box of the trgt as a mask of the pred mesh 36 | mesh_sample_point: number of the sampling points from the mesh 37 | possion_sample_init_factor: used for possion uniform sampling, check open3d for more details (deprecated) 38 | Returns: 39 | 40 | Returns: 41 | Dict of mesh metrics (chamfer distance, precision, recall, f1 score, etc.) 42 | """ 43 | 44 | mesh_pred = o3d.io.read_triangle_mesh(file_pred) 45 | 46 | pcd_trgt = o3d.io.read_point_cloud(file_trgt) 47 | 48 | # (optional) filter the prediction outside the gt bounding box (since gt sometimes is not complete enough) 49 | if gt_bbx_mask_on: 50 | trgt_bbx = pcd_trgt.get_axis_aligned_bounding_box() 51 | min_bound = trgt_bbx.get_min_bound() 52 | min_bound[2]-=down_sample_res 53 | max_bound = trgt_bbx.get_max_bound() 54 | max_bound[2]+=down_sample_res 55 | trgt_bbx = o3d.geometry.AxisAlignedBoundingBox(min_bound, max_bound) 56 | mesh_pred = mesh_pred.crop(trgt_bbx) 57 | 58 | # mesh uniform sampling 59 | pcd_sample_pred = mesh_pred.sample_points_uniformly(number_of_points=mesh_sample_point) 60 | 61 | if down_sample_res > 0: 62 | pred_pt_count_before = len(pcd_sample_pred.points) 63 | pcd_pred = pcd_sample_pred.voxel_down_sample(down_sample_res) 64 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample_res) 65 | pred_pt_count_after = len(pcd_pred.points) 66 | print("Predicted mesh unifrom sample: ", pred_pt_count_before, " --> ", pred_pt_count_after, " (", down_sample_res, "m)") 67 | 68 | verts_pred = np.asarray(pcd_pred.points) 69 | verts_trgt = np.asarray(pcd_trgt.points) 70 | 71 | _, dist_p = nn_correspondance(verts_trgt, verts_pred, truncation_acc, True) # find nn in ground truth samples for each predict sample -> precision related 72 | _, dist_r = nn_correspondance(verts_pred, verts_trgt, truncation_com, False) # find nn in predict samples for each ground truth sample -> recall related 73 | 74 | dist_p = np.array(dist_p) 75 | # dist_r = np.array(dist_r) 76 | all_dist_r = np.array(dist_r) # for error map 77 | dist_r = all_dist_r[all_dist_r < truncation_com] 78 | 79 | error_map = o3d.geometry.PointCloud() 80 | if generate_error_map: 81 | error_map = generate_save_error_map(verts_trgt, all_dist_r) 82 | 83 | dist_p_s = np.square(dist_p) 84 | dist_r_s = np.square(dist_r) 85 | 86 | dist_p_mean = np.mean(dist_p) 87 | dist_r_mean = np.mean(dist_r) 88 | 89 | dist_p_s_mean = np.mean(dist_p_s) 90 | dist_r_s_mean = np.mean(dist_r_s) 91 | 92 | chamfer_l1 = 0.5 * (dist_p_mean + dist_r_mean) 93 | chamfer_l2 = np.sqrt(0.5 * (dist_p_s_mean + dist_r_s_mean)) 94 | 95 | precision = np.mean((dist_p < threshold).astype('float')) * 100.0 # % 96 | recall = np.mean((dist_r < threshold).astype('float')) * 100.0 # % 97 | fscore = 2 * precision * recall / (precision + recall) # % 98 | 99 | metrics = {'MAE_accuracy (cm)': dist_p_mean*100, 100 | 'MAE_completeness (cm)': dist_r_mean*100, 101 | 'Chamfer_L1 (cm)': chamfer_l1*100, 102 | 'Precision [Accuracy] (%)': precision, 103 | 'Recall [Completeness] (%)': recall, 104 | 'F-score (%)': fscore, 105 | 'Inlier_threshold (m)': threshold, # evlaution setup 106 | 'Outlier_truncation_acc (m)': truncation_acc, # evlaution setup 107 | 'Outlier_truncation_com (m)': truncation_com # evlaution setup 108 | } 109 | return metrics, error_map 110 | 111 | def generate_save_error_map(points, errors): 112 | #errors = np.clip(errors, 0, 0.20)/0.2 113 | errors = np.clip(errors, 0, 0.05)/0.05 114 | colors = colormap(errors) 115 | 116 | pcd = o3d.geometry.PointCloud() 117 | 118 | pcd.points = o3d.utility.Vector3dVector(points) 119 | pcd.colors = o3d.utility.Vector3dVector(colors) 120 | return pcd 121 | 122 | def generate_mesh_error_map(file_pred, file_trgt, tr=0.50): 123 | mesh_pred = o3d.io.read_triangle_mesh(file_pred) 124 | pcd_trgt = o3d.io.read_point_cloud(file_trgt) 125 | 126 | mesh_verts_pred = np.asarray(mesh_pred.vertices) 127 | verts_trgt = np.asarray(pcd_trgt.points) 128 | _, acc_dist = nn_correspondance(verts_trgt, mesh_verts_pred, tr, False) # find nn in ground truth samples for each predict sample -> precision related 129 | normal_errors = np.clip(acc_dist, 0, 0.1)/0.1 # set error interval 130 | colors = colormap(normal_errors) 131 | mesh_pred.vertex_colors = o3d.utility.Vector3dVector(colors) 132 | return mesh_pred 133 | 134 | def colormap(errors): 135 | colors = np.zeros((len(errors), 3)) 136 | colors[:, 0] = 1.0 137 | colors[:, 1] = 1 - errors 138 | colors[:, 2] = 1 - errors 139 | 140 | return colors 141 | 142 | def nn_correspondance(verts1, verts2, truncation_dist, ignore_outlier=True): 143 | """ for each vertex in verts2 find the nearest vertex in verts1 144 | Args: 145 | nx3 np.array's 146 | scalar truncation_dist: points whose nearest neighbor is farther than the distance would not be taken into account 147 | Returns: 148 | ([indices], [distances]) 149 | """ 150 | 151 | indices = [] 152 | distances = [] 153 | if len(verts1) == 0 or len(verts2) == 0: 154 | return indices, distances 155 | 156 | pcd = o3d.geometry.PointCloud() 157 | pcd.points = o3d.utility.Vector3dVector(verts1) 158 | kdtree = o3d.geometry.KDTreeFlann(pcd) 159 | 160 | truncation_dist_square = truncation_dist**2 161 | 162 | for vert in verts2: 163 | _, inds, dist_square = kdtree.search_knn_vector_3d(vert, 1) 164 | 165 | if dist_square[0] < truncation_dist_square: 166 | indices.append(inds[0]) 167 | distances.append(np.sqrt(dist_square[0])) 168 | else: 169 | if not ignore_outlier: 170 | indices.append(inds[0]) 171 | distances.append(truncation_dist) 172 | 173 | return indices, distances 174 | 175 | 176 | def eval_depth(depth_pred, depth_trgt): 177 | """ Computes 2d metrics between two depth maps 178 | Args: 179 | depth_pred: mxn np.array containing prediction 180 | depth_trgt: mxn np.array containing ground truth 181 | Returns: 182 | Dict of metrics 183 | """ 184 | mask1 = depth_pred > 0 # ignore values where prediction is 0 (% complete) 185 | mask = (depth_trgt < 10) * (depth_trgt > 0) * mask1 186 | 187 | depth_pred = depth_pred[mask] 188 | depth_trgt = depth_trgt[mask] 189 | abs_diff = np.abs(depth_pred - depth_trgt) 190 | abs_rel = abs_diff / depth_trgt 191 | sq_diff = abs_diff ** 2 192 | sq_rel = sq_diff / depth_trgt 193 | sq_log_diff = (np.log(depth_pred) - np.log(depth_trgt)) ** 2 194 | thresh = np.maximum((depth_trgt / depth_pred), (depth_pred / depth_trgt)) 195 | r1 = (thresh < 1.25).astype('float') 196 | r2 = (thresh < 1.25 ** 2).astype('float') 197 | r3 = (thresh < 1.25 ** 3).astype('float') 198 | 199 | metrics = {} 200 | metrics['AbsRel'] = np.mean(abs_rel) 201 | metrics['AbsDiff'] = np.mean(abs_diff) 202 | metrics['SqRel'] = np.mean(sq_rel) 203 | metrics['RMSE'] = np.sqrt(np.mean(sq_diff)) 204 | metrics['LogRMSE'] = np.sqrt(np.mean(sq_log_diff)) 205 | metrics['r1'] = np.mean(r1) 206 | metrics['r2'] = np.mean(r2) 207 | metrics['r3'] = np.mean(r3) 208 | metrics['complete'] = np.mean(mask1.astype('float')) 209 | 210 | return metrics 211 | 212 | def crop_intersection(file_gt, files_pred, out_file_crop, dist_thre=0.1, mesh_sample_point=1000000): 213 | """ Get the cropped ground truth point cloud according to the intersection of the predicted 214 | mesh by different methods 215 | Args: 216 | file_gt: file path of the ground truth (shoud be point cloud) 217 | files_pred: a list of the paths of different methods's reconstruction (shoud be mesh) 218 | out_file_crop: output path of the cropped ground truth point cloud 219 | dist_thre: nearest neighbor distance threshold in meter 220 | mesh_sample_point: number of the sampling points from the mesh 221 | """ 222 | print("Load the original ground truth point cloud from:", file_gt) 223 | pcd_gt = o3d.io.read_point_cloud(file_gt) 224 | pcd_gt_pts = np.asarray(pcd_gt.points) 225 | dist_square_thre = dist_thre**2 226 | for i in range(len(files_pred)): 227 | cur_file_pred = files_pred[i] 228 | print("Process", cur_file_pred) 229 | cur_mesh_pred = o3d.io.read_triangle_mesh(cur_file_pred) 230 | 231 | cur_sample_pred = cur_mesh_pred.sample_points_uniformly(number_of_points=mesh_sample_point) 232 | 233 | cur_kdtree = o3d.geometry.KDTreeFlann(cur_sample_pred) 234 | 235 | crop_pcd_gt_pts = [] 236 | for pt in pcd_gt_pts: 237 | _, _, dist_square = cur_kdtree.search_knn_vector_3d(pt, 1) 238 | 239 | if dist_square[0] < dist_square_thre: 240 | crop_pcd_gt_pts.append(pt) 241 | 242 | pcd_gt_pts = np.asarray(crop_pcd_gt_pts) 243 | 244 | crop_pcd_gt = o3d.geometry.PointCloud() 245 | crop_pcd_gt.points = o3d.utility.Vector3dVector(pcd_gt_pts) 246 | 247 | print("Output the croped ground truth to:", out_file_crop) 248 | o3d.io.write_point_cloud(out_file_crop, crop_pcd_gt) 249 | 250 | 251 | -------------------------------------------------------------------------------- /eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from eval_utils import * 3 | import open3d as o3d 4 | 5 | dataset_name = "maicity_01_" 6 | # ground truth point cloud (or mesh) file 7 | # (optional masked by the intersection part of all the compared method) 8 | gt_pcd_path = "xxx/dataset/mai_city/gt_map_pc_mai.ply" 9 | 10 | pred_mesh_path = "xx/mai_xx.ply" 11 | method_name = "maixx" 12 | 13 | # pred_mesh_path = "xxx/baseline/vdb_fusion_xxx.ply" 14 | # method_name = "vdb_fusion_xxx" 15 | 16 | # pred_mesh_path = "xxx/baseline/puma_xxx.ply" 17 | # method_name = "puma_xxx" 18 | 19 | # evaluation results output file 20 | base_output_folder = "./eval/eval_results/" 21 | output_csv_path = base_output_folder + dataset_name + method_name + "_eval.csv" 22 | 23 | # For MaiCity 24 | down_sample_vox = 0.02 25 | dist_thre = 0.1 26 | truncation_dist_acc = 0.2 27 | truncation_dist_com = 2.0 28 | 29 | # For NCD 30 | # down_sample_vox = 0.02 31 | # dist_thre = 0.2 32 | # truncation_dist_acc = 0.4 33 | # truncation_dist_com = 2.0 34 | 35 | # For NRGBD 36 | # down_sample_vox = 0.004 37 | # dist_thre = 0.04 #0.04? 38 | # truncation_dist_acc = 0.08 39 | # truncation_dist_com = 0.4 # 0.5? 40 | 41 | # evaluation 42 | eval_metric, error_map = eval_mesh(pred_mesh_path, gt_pcd_path, down_sample_res=down_sample_vox, threshold=dist_thre, 43 | truncation_acc = truncation_dist_acc, truncation_com = truncation_dist_com, gt_bbx_mask_on = True, 44 | generate_error_map=False) 45 | 46 | print(eval_metric) 47 | 48 | if not error_map.is_empty(): 49 | o3d.io.write_point_cloud("./eval/eval_results/" + method_name + ".ply", error_map) 50 | # mesh_error_map = generate_mesh_error_map(pred_mesh_path, gt_pcd_path) 51 | # if not mesh_error_map.is_empty(): 52 | # o3d.io.write_triangle_mesh("./eval/eval_results/" + method_name + "_mesh.ply", mesh_error_map) 53 | 54 | evals = [eval_metric] 55 | 56 | csv_columns = ['MAE_accuracy (cm)', 'MAE_completeness (cm)', 'Chamfer_L1 (cm)', \ 57 | 'Precision [Accuracy] (%)', 'Recall [Completeness] (%)', 'F-score (%)', \ 58 | 'Inlier_threshold (m)', 'Outlier_truncation_acc (m)', 'Outlier_truncation_com (m)'] 59 | 60 | try: 61 | with open(output_csv_path, 'w') as csvfile: 62 | writer = csv.DictWriter(csvfile, fieldnames=csv_columns) 63 | writer.writeheader() 64 | for data in evals: 65 | writer.writerow(data) 66 | except IOError: 67 | print("I/O error") 68 | 69 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiev-tongji/N3-Mapping/9bc3ca76ba45cae7f5b224c81f4c59c9b048b7e2/model/__init__.py -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import grad 5 | 6 | from utils.config import SHINEConfig 7 | 8 | 9 | class Decoder(nn.Module): 10 | def __init__(self, config: SHINEConfig, is_geo_encoder = True): 11 | 12 | super().__init__() 13 | 14 | if is_geo_encoder: 15 | mlp_hidden_dim = config.geo_mlp_hidden_dim 16 | mlp_bias_on = config.geo_mlp_bias_on 17 | mlp_level = config.geo_mlp_level 18 | else: 19 | mlp_hidden_dim = config.sem_mlp_hidden_dim 20 | mlp_bias_on = config.sem_mlp_bias_on 21 | mlp_level = config.sem_mlp_level 22 | 23 | # predict sdf (now it anyway only predict sdf without further sigmoid 24 | # Initializa the structure of shared MLP 25 | layers = [] 26 | for i in range(mlp_level): 27 | if i == 0: 28 | layers.append(nn.Linear(config.feature_dim, mlp_hidden_dim, mlp_bias_on)) 29 | else: 30 | layers.append(nn.Linear(mlp_hidden_dim, mlp_hidden_dim, mlp_bias_on)) 31 | self.layers = nn.ModuleList(layers) 32 | self.lout = nn.Linear(mlp_hidden_dim, 1, mlp_bias_on) 33 | self.nclass_out = nn.Linear(mlp_hidden_dim, config.sem_class_count + 1, mlp_bias_on) # sem class + free space class 34 | # self.bn = nn.BatchNorm1d(self.hidden_dim, affine=False) 35 | 36 | # predict residual sdf. 37 | if config.predict_residual_sdf: 38 | layers_res = [] 39 | for i in range(mlp_level): 40 | if i == 0: 41 | layers_res.append(nn.Linear(config.feature_dim, mlp_hidden_dim, mlp_bias_on)) 42 | else: 43 | layers_res.append(nn.Linear(mlp_hidden_dim, mlp_hidden_dim, mlp_bias_on)) 44 | self.layers_res = nn.ModuleList(layers_res) 45 | self.lout_res = nn.Linear(mlp_hidden_dim, 1, mlp_bias_on) 46 | 47 | self.to(config.device) 48 | # torch.cuda.empty_cache() 49 | 50 | def forward(self, feature): 51 | # If we use BCEwithLogits loss, do not need to do sigmoid mannually 52 | output = self.sdf(feature) 53 | return output 54 | 55 | # predict the sdf (opposite sign to the actual sdf) 56 | def sdf(self, sum_features): 57 | for k, l in enumerate(self.layers): 58 | if k == 0: 59 | h = F.relu(l(sum_features)) 60 | else: 61 | h = F.relu(l(h)) 62 | 63 | out = self.lout(h).squeeze(1) 64 | # linear (feature_dim -> hidden_dim) 65 | # relu 66 | # linear (hidden_dim -> hidden_dim) 67 | # relu 68 | # linear (hidden_dim -> 1) 69 | 70 | return out 71 | 72 | def sum_sdf(self, fine_features, coarse_features): 73 | for k, l in enumerate(self.layers): 74 | if k == 0: 75 | h = F.relu(l(coarse_features)) 76 | else: 77 | h = F.relu(l(h)) 78 | coarse_sdf = self.lout(h).squeeze(1) 79 | 80 | for k, l in enumerate(self.layers_res): 81 | if k == 0: 82 | x = F.relu(l(fine_features)) 83 | else: 84 | x = F.relu(l(x)) 85 | res_sdf = self.lout_res(x).squeeze(1) 86 | 87 | sdf_out = coarse_sdf + res_sdf 88 | 89 | return sdf_out 90 | 91 | # predict the occupancy probability 92 | def occupancy(self, sum_features): 93 | out = torch.sigmoid(self.sdf(sum_features)) # to [0, 1] 94 | return out 95 | 96 | # predict the probabilty of each semantic label 97 | def sem_label_prob(self, sum_features): 98 | for k, l in enumerate(self.layers): 99 | if k == 0: 100 | h = F.relu(l(sum_features)) 101 | else: 102 | h = F.relu(l(h)) 103 | 104 | out = F.log_softmax(self.nclass_out(h), dim=1) 105 | return out 106 | 107 | def sem_label(self, sum_features): 108 | out = torch.argmax(self.sem_label_prob(sum_features), dim=1) 109 | return out 110 | -------------------------------------------------------------------------------- /model/feature_octree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import time 5 | from tqdm import tqdm 6 | import kaolin as kal 7 | import numpy as np 8 | 9 | from functools import partial 10 | from utils.tools import * 11 | 12 | from utils.config import SHINEConfig 13 | 14 | # TODO: polish the codes 15 | 16 | class FeatureOctree(nn.Module): 17 | 18 | def __init__(self, config: SHINEConfig): 19 | 20 | super().__init__() 21 | 22 | # [0 1 2 3 ... max_level-1 max_level], 0 level is the root, which have 8 corners. 23 | self.max_level = config.tree_level_world 24 | # the number of levels with feature (begin from bottom) 25 | self.leaf_vox_size = config.leaf_vox_size 26 | self.featured_level_num = config.tree_level_feat 27 | self.free_level_num = self.max_level - self.featured_level_num + 1 28 | self.feature_dim = config.feature_dim 29 | self.feature_std = config.feature_std 30 | self.polynomial_interpolation = config.poly_int_on 31 | self.device = config.device 32 | 33 | # Initialize the look up tables 34 | self.corners_lookup_tables = [] # from corner morton to corner index (top-down) 35 | self.nodes_lookup_tables = [] # from nodes morton to corner index (top-down) 36 | # Initialize the look up table for each level, each is a dictionary 37 | for l in range(self.max_level+1): 38 | self.corners_lookup_tables.append({}) 39 | self.nodes_lookup_tables.append({}) 40 | 41 | # local leaf node morton list for local mapping 42 | self.nodes_lookup_tables.append({}) 43 | 44 | # Initialize the hierarchical grid feature list 45 | if self.featured_level_num < 1: 46 | raise ValueError('No level with grid features!') 47 | # hierarchical grid features list 48 | # top-down: leaf node level is stored in the last row (dim feature_level_num-1) 49 | # but only for the featured levels 50 | self.hier_features = nn.ParameterList([]) 51 | 52 | # the temporal stuffs that can be cleared 53 | # hierachical feature grid indices for the input batch point 54 | self.hierarchical_indices = [] 55 | # bottom-up: stored from the leaf node level (dim 0) 56 | 57 | # used for incremental learning (mapping) 58 | self.importance_weight = [] # weight for each feature dimension 59 | self.features_last_frame = [] # hierarchical features for the last frame 60 | 61 | self.to(config.device) 62 | 63 | # the last element of the each level of the hier_features is the trashbin element 64 | # after the optimization, we need to set it back to zero vector 65 | def set_zero(self): 66 | with torch.no_grad(): 67 | for n in range(len(self.hier_features)): 68 | self.hier_features[n][-1] = torch.zeros(1,self.feature_dim) 69 | 70 | def forward(self, x): 71 | feature = self.query_feature(x) 72 | return feature 73 | 74 | def get_morton(self, sample_points, level): 75 | points = kal.ops.spc.quantize_points(sample_points, level) # quantize to interger coords 76 | points_morton = kal.ops.spc.points_to_morton(points) # to 1d morton code 77 | sample_points_with_morton = torch.hstack((sample_points, points_morton.view(-1, 1))) 78 | morton_set = set(points_morton.cpu().numpy()) 79 | return sample_points_with_morton, morton_set 80 | 81 | def get_octree_nodes(self, level): # top-down 82 | nodes_morton = list(self.nodes_lookup_tables[level].keys()) 83 | nodes_morton = torch.tensor(nodes_morton).to(self.device, torch.int64) 84 | nodes_spc = kal.ops.spc.morton_to_points(nodes_morton) 85 | nodes_spc_np = nodes_spc.cpu().numpy() 86 | node_size = 2**(1-level) # 2/2^{level}, node size in the -1 to 1 kaolin space, 87 | # nodes_spc_np * node_size is in[0,2], afer minus 1, is in [-1,1] 88 | # +0.5*node_size means return the center coor of node. 89 | nodes_coord_scaled = (nodes_spc_np * node_size) - 1. + 0.5 * node_size 90 | return nodes_coord_scaled 91 | 92 | def get_octree_nodes_spc(self, level): 93 | octree_data = {} 94 | nodes_morton = list(self.nodes_lookup_tables[level].keys()) 95 | nodes_morton = torch.tensor(nodes_morton).to(self.device, torch.int64) 96 | nodes_spc = kal.ops.spc.morton_to_points(nodes_morton) 97 | return nodes_spc 98 | 99 | def is_empty(self): 100 | return len(self.hier_features) == 0 101 | 102 | # clear the temp data (used for one batch) that is not needed 103 | def clear_temp(self): 104 | self.hierarchical_indices = [] 105 | self.importance_weight = [] 106 | self.features_last_frame = [] 107 | 108 | # update the octree according to new observations 109 | # if incremental_on = True, then we additional store the last frames' feature for regularization based incremental mapping 110 | def update(self, surface_points): 111 | # [0 1 2 3 ... max_level-1 max_level] 112 | spc = kal.ops.conversions.unbatched_pointcloud_to_spc(surface_points, self.max_level) 113 | pyramid = spc.pyramids[0].cpu() # represent the number of points in each level 114 | for i in range(self.max_level+1): # for each level (top-down) 115 | if i < self.free_level_num: # free levels (skip), only need to consider the featured levels 116 | continue 117 | # level storing features (i>=free_level_num) 118 | nodes = spc.point_hierarchies[pyramid[1, i]:pyramid[1, i+1]] 119 | nodes_morton = kal.ops.spc.points_to_morton(nodes).cpu().numpy().tolist() # nodes at certain level 120 | new_nodes_index = [] 121 | for idx in range(len(nodes_morton)): 122 | if nodes_morton[idx] not in self.nodes_lookup_tables[i]: 123 | new_nodes_index.append(idx) # nodes to corner dictionary: key is the morton code 124 | new_nodes = nodes[new_nodes_index] # get the newly added nodes 125 | if new_nodes.shape[0] == 0: 126 | continue 127 | corners = kal.ops.spc.points_to_corners(new_nodes).reshape(-1,3) 128 | corners_unique = torch.unique(corners, dim=0) 129 | # mortons of the coners from the new scan 130 | corners_morton = kal.ops.spc.points_to_morton(corners_unique).cpu().numpy().tolist() 131 | if len(self.corners_lookup_tables[i]) == 0: # for the first frame 132 | corners_dict = dict(zip(corners_morton, range(len(corners_morton)))) 133 | self.corners_lookup_tables[i] = corners_dict 134 | # initializa corner features 135 | fts = self.feature_std*torch.randn(len(corners_dict)+1, self.feature_dim, device=self.device) 136 | fts[-1] = torch.zeros(1,self.feature_dim) 137 | # Be careful, the size of the feature list equals to featured_level_num not max_level+1 138 | self.hier_features.append(nn.Parameter(fts)) 139 | else: # update for new frames 140 | pre_size = len(self.corners_lookup_tables[i]) 141 | for m in corners_morton: 142 | if m not in self.corners_lookup_tables[i]: # add new keys 143 | self.corners_lookup_tables[i][m] = len(self.corners_lookup_tables[i]) 144 | new_feature_num = len(self.corners_lookup_tables[i]) - pre_size 145 | new_fts = self.feature_std*torch.randn(new_feature_num+1, self.feature_dim, device=self.device) 146 | new_fts[-1] = torch.zeros(1,self.feature_dim) 147 | cur_featured_level = i-self.free_level_num 148 | self.hier_features[cur_featured_level] = nn.Parameter(torch.cat((self.hier_features[cur_featured_level][:-1],new_fts),0)) 149 | 150 | corners_m = kal.ops.spc.points_to_morton(corners).cpu().numpy().tolist() 151 | indexes = torch.tensor([self.corners_lookup_tables[i][x] for x in corners_m]).reshape(-1,8).numpy().tolist() 152 | new_nodes_morton = kal.ops.spc.points_to_morton(new_nodes).cpu().numpy().tolist() 153 | for k in range(len(new_nodes_morton)): 154 | self.nodes_lookup_tables[i][new_nodes_morton[k]] = indexes[k] 155 | 156 | # nodes_coord = self.get_octree_nodes(self.max_level) 157 | # print(nodes_coord) 158 | 159 | # tri-linear (or polynomial) interplation of feature at certain octree level at certain spatial point x 160 | def interpolat(self, x, level, polynomial_on = True): 161 | coords = ((2**level)*(x*0.5+0.5)) 162 | d_coords = torch.frac(coords) 163 | if polynomial_on: 164 | tx = 3*(d_coords[:,0]**2) - 2*(d_coords[:,0]**3) 165 | ty = 3*(d_coords[:,1]**2) - 2*(d_coords[:,1]**3) 166 | tz = 3*(d_coords[:,2]**2) - 2*(d_coords[:,2]**3) 167 | else: # linear 168 | tx = d_coords[:,0] 169 | ty = d_coords[:,1] 170 | tz = d_coords[:,2] 171 | _1_tx = 1-tx 172 | _1_ty = 1-ty 173 | _1_tz = 1-tz 174 | p0 = _1_tx*_1_ty*_1_tz 175 | p1 = _1_tx*_1_ty*tz 176 | p2 = _1_tx*ty*_1_tz 177 | p3 = _1_tx*ty*tz 178 | p4 = tx*_1_ty*_1_tz 179 | p5 = tx*_1_ty*tz 180 | p6 = tx*ty*_1_tz 181 | p7 = tx*ty*tz 182 | 183 | p = torch.stack((p0,p1,p2,p3,p4,p5,p6,p7),0).T.unsqueeze(2) 184 | return p 185 | 186 | # get the unique indices of the feature node at spatial points x for each level 187 | # TODO: speed up !!! 188 | def get_indices(self, coord): 189 | self.hierarchical_indices = [] # initialize the hierarchical indices list for the batch points x 190 | for i in range(self.featured_level_num): # bottom-up, for each level 191 | current_level = self.max_level - i 192 | points = kal.ops.spc.quantize_points(coord,current_level) # quantize to interger coords 193 | points_morton = kal.ops.spc.points_to_morton(points).cpu().numpy().tolist() # convert to 1d morton code for the voxel center 194 | features_last_row = [-1 for t in range(8)] # if not in the look up table, then assign all -1 195 | # look up the 8 corner nodes' unique indices for each 1d morton code in the look up table [nx8], the most time-consuming part 196 | # [actually a kind of hashing realized by python dictionary] 197 | indices_list = [self.nodes_lookup_tables[current_level].get(p,features_last_row) for p in points_morton] 198 | # if p is not found in the key lists of cur_lookup_table, use features_last_row, 199 | # which is the all-zero trashbin vector of the level's feature 200 | indices_torch = torch.tensor(indices_list, device=self.device) 201 | self.hierarchical_indices.append(indices_torch) # l level {nx8} # bottom-up 202 | 203 | return self.hierarchical_indices 204 | 205 | # get the hierachical-sumed interpolated feature at spatial points x 206 | def query_feature_with_indices(self, coord, hierarchical_indices): 207 | sum_features = torch.zeros(coord.shape[0], self.feature_dim, device=self.device) 208 | for i in range(self.featured_level_num): # for each level 209 | current_level = self.max_level - i 210 | feature_level = self.featured_level_num-i-1 211 | # Interpolating 212 | # get the interpolation coefficients for the 8 neighboring corners, corresponding to the order of the hierarchical_indices 213 | coeffs = self.interpolat(coord,current_level,self.polynomial_interpolation) 214 | sum_features += (self.hier_features[feature_level][hierarchical_indices[i]]*coeffs).sum(1) 215 | # corner index -1 means the queried voxel is not in the leaf node. If so, we will get the trashbin row of the feature grid, 216 | # and get the value 0, the feature for this level will then be 0 217 | return sum_features 218 | 219 | # get the hierachical-sumed interpolated feature at spatial points x 220 | def query_split_feature_with_indices(self, coord, hierarchical_indices): 221 | 222 | coarse_features = torch.zeros(coord.shape[0], self.feature_dim, device=self.device) 223 | fine_features = torch.zeros(coord.shape[0], self.feature_dim, device=self.device) 224 | for i in range(self.featured_level_num): # for each level 225 | current_level = self.max_level - i 226 | feature_level = self.featured_level_num-i-1 227 | # Interpolating 228 | # get the interpolation coefficients for the 8 neighboring corners, corresponding to the order of the hierarchical_indices 229 | coeffs = self.interpolat(coord,current_level,self.polynomial_interpolation) 230 | if i == 0: # leaf node features 231 | fine_features = (self.hier_features[feature_level][hierarchical_indices[i]]*coeffs).sum(1) 232 | else: 233 | coarse_features += (self.hier_features[feature_level][hierarchical_indices[i]]*coeffs).sum(1) 234 | # corner index -1 means the queried voxel is not in the leaf node. If so, we will get the trashbin row of the feature grid, 235 | # and get the value 0, the feature for this level will then be 0 236 | return fine_features, coarse_features 237 | 238 | # all-in-one function to get the octree features for a batch of points 239 | def query_feature(self, coord, faster = False): 240 | self.set_zero() # set the trashbin feature vector back to 0 after the feature update 241 | if faster: 242 | indices = self.get_indices_fast(coord) 243 | else: 244 | indices = self.get_indices(coord) 245 | features = self.query_feature_with_indices(coord, indices) 246 | return features 247 | 248 | def query_split_feature(self, coord, faster = False): 249 | self.set_zero() # set the trashbin feature vector back to 0 after the feature update 250 | if faster: 251 | indices = self.get_indices_fast(coord) 252 | else: 253 | indices = self.get_indices(coord) 254 | fine_features, coarse_features = self.query_split_feature_with_indices(coord, indices) 255 | return fine_features, coarse_features 256 | 257 | def cal_regularization(self): 258 | regularization = 0. 259 | for i in range(self.featured_level_num): # for each level 260 | feature_level = self.featured_level_num-i-1 261 | unique_indices = self.hierarchical_indices[i].flatten().unique() 262 | # feature change between current and last frame 263 | difference = self.hier_features[feature_level][unique_indices] - self.features_last_frame[feature_level][unique_indices] 264 | # regularization for continous learning weighted by the feature importance and the change magnitude 265 | regularization += (self.importance_weight[feature_level][unique_indices]*(difference**2)).sum() 266 | return regularization 267 | 268 | # def list_duplicates(self, seq): 269 | # dd = defaultdict(list) 270 | # for i,item in enumerate(seq): 271 | # dd[item].append(i) 272 | # return [(key,locs) for key,locs in dd.items() if len(locs)>=1] 273 | 274 | # speed up for the batch sdf inferencing during meshing 275 | # points in the same voxel would be grouped and getting indices together 276 | # more efficient only when there are lots of samples from the same voxel in the batch (the case when conducting meshing) 277 | # This function contains some problem which would make the mesh worse, check it later (solved) 278 | def get_indices_fast(self, coord): 279 | self.hierarchical_indices = [] 280 | for i in range(self.featured_level_num): # bottom-up 281 | current_level = self.max_level - i 282 | points = kal.ops.spc.quantize_points(coord,current_level) # quantize to interger coords 283 | points_morton = kal.ops.spc.points_to_morton(points).cpu().numpy().tolist() # convert to 1d morton code for the voxel center 284 | features_last_row = [-1 for t in range(8)] # if not in the look up table, then assign -1 285 | 286 | dups_in_mortons = dict(list_duplicates(points_morton)) # list the x with the same morton code (samples inside the same voxel) 287 | dups_indices = np.zeros((len(points_morton), 8)) 288 | # print(len(dups_in_mortons.keys()), len(points_morton)) 289 | for p in dups_in_mortons.keys(): 290 | idx = dups_in_mortons[p] # indices, p is the point morton 291 | # get indices only once for these samples sharing the same voxel 292 | corner_indices = self.nodes_lookup_tables[current_level].get(p,features_last_row) 293 | dups_indices[idx,:] = corner_indices 294 | indices = torch.tensor(dups_indices, device=self.device).long() 295 | self.hierarchical_indices.append(indices) # l level {nx8} 296 | 297 | return self.hierarchical_indices 298 | 299 | def print_detail(self): 300 | print("Current Octomap:") 301 | total_vox_count = 0 302 | for level in range(self.featured_level_num): 303 | level_vox_size = self.leaf_vox_size*(2**(self.featured_level_num-1-level)) 304 | level_vox_count = self.hier_features[level].shape[0] 305 | print("%.2f m: %d voxel corners" %(level_vox_size, level_vox_count)) 306 | total_vox_count += level_vox_count 307 | total_map_memory = total_vox_count * self.feature_dim * 4 / 1024 / 1024 # unit: MB 308 | print("memory: %d x %d x 4 = %.3f MB" %(total_vox_count, self.feature_dim, total_map_memory)) 309 | print("--------------------------------") 310 | 311 | # with open('./log/memory-07.txt', 'a') as f: 312 | # f.write(f"{total_map_memory}\n") 313 | # return -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | #os.environ['CUDA_VISIBLE_DEVICES'] = "0" #set here before import torch 4 | 5 | from utils.config import SHINEConfig 6 | from utils.mapper import Mapper 7 | 8 | def run(): 9 | 10 | config = SHINEConfig() 11 | 12 | if len(sys.argv) > 1: 13 | config.load(sys.argv[1]) 14 | else: 15 | sys.exit( 16 | "Please provide the path to the config file.\nTry: python shine_incre.py xxx/xxx_config.yaml" 17 | ) 18 | 19 | mapper = Mapper(config) 20 | mapper.mapping() 21 | 22 | if __name__ == "__main__": 23 | run() -------------------------------------------------------------------------------- /scripts/convert_rgbd_to_kitti_format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | root_path=/media/shy/Document/dataset/neural_rgbd 4 | sequence_name=thin_geometry 5 | base_path=${root_path}/${sequence_name} 6 | 7 | # For NeuralRGBD dataset, set is_focal_file to True, and already_kitti_format_pose to False 8 | # For Replica dataset, set is_focal_file to False, and already_kitti_format_pose to True 9 | 10 | command="python3 ./dataset/rgbd_to_kitti_format.py 11 | --output_root ${base_path}_kitti_gt 12 | --depth_img_folder ${base_path}/depth_gt/ 13 | --rgb_img_folder ${base_path}/images/ 14 | --intrinsic_file ${base_path}/focal.txt 15 | --pose_file ${base_path}/poses.txt 16 | --is_focal_file True 17 | --already_kitti_format_pose False 18 | --vis_on False 19 | --max_depth_m 5" 20 | 21 | echo "Convert RGBD dataset to KITTI format" 22 | eval $command 23 | echo "done" -------------------------------------------------------------------------------- /scripts/download_maicity.sh: -------------------------------------------------------------------------------- 1 | echo Creating the dataset path... 2 | 3 | mkdir -p data 4 | cd data 5 | 6 | echo Downloading MaiCity dataset... 7 | wget https://www.ipb.uni-bonn.de/html/projects/mai_city/mai_city.tar.gz 8 | 9 | echo Extracting dataset... 10 | tar -xvf mai_city.tar.gz 11 | 12 | echo Downloading MaiCity ground truth point cloud generated from sequence 02 and the ground truth model ... 13 | cd mai_city 14 | wget -O gt_map_pc_mai.ply -c https://uni-bonn.sciebo.de/s/DAMWVCC1Kxkfkyz/download 15 | cd .. 16 | 17 | rm mai_city.tar.gz 18 | 19 | cd ../.. -------------------------------------------------------------------------------- /scripts/download_ncd_example.sh: -------------------------------------------------------------------------------- 1 | echo Creating the dataset path... 2 | 3 | mkdir -p data 4 | cd data 5 | 6 | echo Downloading Newer College dataset, Quad example subset... 7 | wget -O ncd_example.tar.gz -c https://uni-bonn.sciebo.de/s/ZKTMubNY9mqbfwN/download 8 | 9 | echo Extracting dataset... 10 | tar -xvf ncd_example.tar.gz 11 | 12 | rm ncd_example.tar.gz 13 | 14 | cd ../.. -------------------------------------------------------------------------------- /scripts/download_neural_rgbd_data.sh: -------------------------------------------------------------------------------- 1 | echo Creating the dataset path... 2 | 3 | mkdir -p data 4 | cd data 5 | 6 | echo Downloading Neural RGBD dataset... 7 | wget http://kaldir.vc.in.tum.de/neural_rgbd/neural_rgbd_data.zip 8 | 9 | echo Extracting dataset... 10 | unzip neural_rgbd_data.zip 11 | 12 | rm neural_rgbd_data.zip 13 | 14 | cd ../.. -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiev-tongji/N3-Mapping/9bc3ca76ba45cae7f5b224c81f4c59c9b048b7e2/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import torch 4 | from typing import List 5 | 6 | class SHINEConfig: 7 | def __init__(self): 8 | 9 | # Default values 10 | # settings 11 | self.name: str = "dummy" # experiment name 12 | 13 | self.output_root: str = "" # output root folder 14 | self.pc_path: str = "" # input point cloud folder 15 | self.pose_path: str = "" # input pose file 16 | self.calib_path: str = "" # input calib file (to sensor frame) 17 | self.label_path: str = "" # input point-wise label path, for semantic shine mapping 18 | self.load_model: bool = False # load the pre-trained model or not 19 | self.model_path: str = "/" # pre-trained model path 20 | 21 | self.first_frame_ref: bool = True # if false, we directly use the world 22 | # frame as the reference frame 23 | self.begin_frame: int = 0 # begin from this frame 24 | self.end_frame: int = 0 # end at this frame 25 | self.every_frame: int = 1 # process every x frame 26 | 27 | self.num_workers: int = 12 # number of worker for the dataloader 28 | self.device: str = "cuda" # use "cuda" or "cpu" 29 | self.gpu_id: str = "0" # used GPU id 30 | self.dtype = torch.float32 # default torch tensor data type 31 | self.pc_count_gpu_limit: int = 500 # maximum used frame number to be stored in the gpu 32 | 33 | # just a ramdom number for the global shift of the input on z axis (used to avoid octree boundary marching cubes issues) 34 | self.global_shift_default: float = 0.17241 35 | 36 | # process 37 | self.min_range: float = 2.75 # filter too-close points (and 0 artifacts) 38 | self.pc_radius: float = 20.0 # keep only the point cloud inside the 39 | # block with such radius (unit: m) 40 | self.min_z: float = -3.0 # filter for z coordinates (unit: m) 41 | self.max_z: float = 30.0 42 | 43 | self.rand_downsample: bool = ( 44 | False # apply random or voxel downsampling to input original point clcoud 45 | ) 46 | self.vox_down_m: float = ( 47 | 0.03 # the voxel size if using voxel downsampling (unit: m) 48 | ) 49 | self.rand_down_r: float = ( 50 | 1.0 # the decimation ratio if using random downsampling (0-1) 51 | ) 52 | 53 | self.filter_noise: bool = False # use SOR to remove the noise or not 54 | self.sor_nn: int = 5 # SOR neighborhood size 55 | self.sor_std: float = 2.5 # SOR std threshold 56 | 57 | self.estimate_normal: bool = True # estimate surface normal or not 58 | self.normal_radius_m: float = 0.5 # supporting radius for estimating the normal 59 | self.normal_max_nn: int = 20 # supporting neighbor count for estimating the normal 60 | 61 | # semantic related 62 | self.semantic_on: bool = False # semantic shine mapping on [semantic] 63 | self.sem_class_count: int = 20 # semantic class count: 20 for semantic kitti 64 | self.sem_label_decimation: int = 1 # use only 1/${sem_label_decimation} of the available semantic labels for training (fitting) 65 | self.filter_moving_object: bool = True 66 | 67 | # frame-wise downsampling voxel size for the merged map point cloud (unit: m) 68 | self.map_vox_down_m: float = 0.2 69 | 70 | # use distance based keyframe (only for KITTI) 71 | self.use_keyframe: bool = False 72 | 73 | # octree 74 | self.tree_level_world: int = ( 75 | 10 # the total octree level, allocated for the whole space 76 | ) 77 | self.tree_level_feat: int = 4 # the octree levels with optimizable feature grid 78 | # start from the leaf level 79 | self.leaf_vox_size: float = 0.5 # voxel size of the octree leaf nodes (unit: m) 80 | self.feature_dim: int = 8 # length of the feature for each grid feature 81 | self.feature_std: float = 0.05 # grid feature initialization standard deviation 82 | self.poly_int_on: bool = ( 83 | True # use polynomial interpolation or linear interpolation 84 | ) 85 | self.octree_from_surface_samples: bool = True # Use all the surface samples or just the exact measurements to build the octree. If True may lead to larger memory, but is more robust while the reconstruction. 86 | 87 | # sampler 88 | self.surface_sample_range_m: float = 0.5 # 89 | self.surface_sample_n: int = 5 90 | self.free_sample_begin_ratio: float = 0.3 91 | # self.free_sample_end_ratio: float = 1.0 # deprecated 92 | self.free_sample_end_dist: float = 0.5 # maximum distance after the surface (unit: m) 93 | self.free_sample_n: int = 2 94 | self.normal_sampling_on: bool = False 95 | self.gaussian_sampling_on: bool = False 96 | self.sliding_window_on: bool = False 97 | 98 | # space carving sampling related (deprecated) 99 | # self.carving_on = False 100 | # self.tree_level_carving = self.tree_level_world 101 | # self.carving_stop_depth_m = 0.5 102 | # self.carving_inte_thre_m = 0.1 103 | 104 | # continuous learning 105 | self.continual_learning_reg: bool = True 106 | # regularization based 107 | self.lambda_forget: float = 1e5 108 | self.cal_importance_weight_down_rate: int = 5 # set it larger to save the consuming time 109 | 110 | # decoder 111 | self.geo_mlp_level: int = 2 112 | self.geo_mlp_hidden_dim: int = 32 113 | self.geo_mlp_bias_on: bool = True 114 | 115 | self.sem_mlp_level: int = 2 116 | self.sem_mlp_hidden_dim: int = 32 117 | self.sem_mlp_bias_on: bool = True 118 | 119 | self.freeze_after_frame: int = 20 # For incremental mode only, if the decoder model is not loaded , it would be trained and freezed after such frame number 120 | self.predict_residual_sdf: bool = False 121 | 122 | # loss 123 | self.ray_loss: bool = False # one loss on a whole ray (including depth estimation loss or the differentiable rendering loss) 124 | # the main loss type, select from the sample sdf loss ('sdf_bce', 'sdf_l1', 'sdf_l2') and the ray rendering loss ('dr', 'dr_neus') 125 | self.main_loss_type: str = 'sdf_bce' 126 | 127 | self.loss_reduction: str = 'mean' # select from 'mean' and 'sum' (for incremental mapping) 128 | 129 | self.sigma_sigmoid_m: float = 0.1 130 | self.logistic_gaussian_ratio: float = 0.55 131 | 132 | self.neus_loss_on: bool = False # use the unbiased and occlusion-aware weights for differentiable rendering as introduced in NEUS 133 | self.loss_weight_on: bool = False # if True, the weight would be given to the loss, if False, the weight would be used to change the sigmoid's shape 134 | self.behind_dropoff_on: bool = False # behind surface drop off weight 135 | self.dropoff_min_sigma: float = 1.0 136 | self.dropoff_max_sigma: float = 5.0 137 | self.normal_loss_on: bool = False 138 | self.weight_n: float = 0.01 139 | self.ekional_loss_on: bool = False 140 | self.weight_e: float = 1e-4 141 | self.weight_s: float = 1.0 # weight for semantic classification loss 142 | 143 | # optimizer 144 | self.iters: int = 200 145 | self.batch_iters: int = 2000 # for global optimization at last frame. 146 | self.opt_adam: bool = True # use adam or sgd 147 | self.bs: int = 4096 148 | self.lr: float = 1e-3 149 | self.weight_decay: float = 0 150 | self.adam_eps: float = 1e-15 151 | self.lr_level_reduce_ratio: float = 1.0 152 | self.lr_iters_reduce_ratio: float = 0.1 153 | self.dropout: float = 0 154 | self.extra_training: bool = False 155 | 156 | # eval 157 | self.wandb_vis_on: bool = False 158 | self.o3d_vis_on: bool = True # visualize the mesh in-the-fly using o3d visualzier or not [press space to pasue/resume] 159 | self.eval_on: bool = False 160 | self.eval_outlier_thre = 0.5 # unit:m 161 | self.eval_freq_iters: int = 100 162 | self.vis_freq_iters: int = 100 163 | self.save_freq_iters: int = 100 164 | self.mesh_freq_frame: int = 1 # do the reconstruction per x frames 165 | 166 | # marching cubes related 167 | self.mc_res_m: float = 0.1 168 | self.pad_voxel: int = 0 169 | self.mc_with_octree: bool = True # conducting marching cubes reconstruction within a certain level of the octree or within the axis-aligned bounding box of the whole map 170 | self.mc_query_level: int = 8 171 | self.mc_vis_level: int = 1 # masked the marching cubes for level higher than this 172 | self.mc_mask_on: bool = True # use mask for marching cubes to avoid the artifacts 173 | self.clean_mesh_on: bool = False # clean mesh vertex not close to point cloud 174 | self.infer_bs: int = 4096 175 | 176 | self.save_map: bool = False # save the sdf map or not, the sdf would be saved in the intensity channel 177 | 178 | # initialization 179 | self.scale: float = 1.0 # then will scale world size to [-1,1] for kaolin octotree 180 | self.world_size: float = 1.0 181 | 182 | def load(self, config_file): 183 | config_args = yaml.safe_load(open(os.path.abspath(config_file))) 184 | 185 | # common 186 | self.name = config_args["setting"]["name"] 187 | 188 | self.output_root = config_args["setting"]["output_root"] 189 | self.pc_path = config_args["setting"]["pc_path"] 190 | self.pose_path = config_args["setting"]["pose_path"] 191 | self.calib_path = config_args["setting"]["calib_path"] 192 | # optional, when semantic shine mapping is on [semantic] 193 | self.label_path = config_args["setting"]["label_path"] 194 | 195 | self.load_model = config_args["setting"]["load_model"] 196 | self.model_path = config_args["setting"]["model_path"] 197 | 198 | self.first_frame_ref = config_args["setting"]["first_frame_ref"] 199 | self.begin_frame = config_args["setting"]["begin_frame"] 200 | self.end_frame = config_args["setting"]["end_frame"] 201 | self.every_frame = config_args["setting"]["every_frame"] 202 | 203 | self.device = config_args["setting"]["device"] 204 | self.gpu_id = config_args["setting"]["gpu_id"] 205 | 206 | # process 207 | self.min_range = config_args["process"]["min_range_m"] 208 | self.pc_radius = config_args["process"]["pc_radius_m"] 209 | self.min_z = config_args["process"]["min_z"] 210 | self.max_z = config_args["process"]["max_z"] 211 | self.rand_downsample = config_args["process"]["rand_downsample"] 212 | self.vox_down_m = config_args["process"]["vox_down_m"] 213 | self.rand_down_r = config_args["process"]["rand_down_r"] 214 | # self.estimate_normal = config_args["process"]["estimate_normal"] 215 | # self.filter_noise = config_args["process"]["filter_noise"] 216 | # self.semantic_on = config_args["process"]["semantic_on"] 217 | 218 | # sampler 219 | self.surface_sample_range_m = config_args["sampler"]["surface_sample_range_m"] 220 | self.surface_sample_n = config_args["sampler"]["surface_sample_n"] 221 | self.free_sample_begin_ratio = config_args["sampler"]["free_sample_begin_ratio"] 222 | self.free_sample_end_dist = config_args["sampler"]["free_sample_end_dist"] 223 | self.free_sample_n = config_args["sampler"]["free_sample_n"] 224 | self.normal_sampling_on = config_args["sampler"]["normal_sampling_on"] 225 | self.gaussian_sampling_on = config_args["sampler"]["gaussian_sampling_on"] 226 | self.sliding_window_on = config_args["sampler"]["sliding_window_on"] 227 | # label 228 | # self.occu_update_on = config_args["label"]["occu_update_on"] 229 | # use bayersian update of the occupancy prob. as the new label 230 | 231 | # octree 232 | self.tree_level_world = config_args["octree"]["tree_level_world"] 233 | # the number of the total octree level (defining the world scale) 234 | self.tree_level_feat = config_args["octree"][ 235 | "tree_level_feat" 236 | ] # the number of the octree level used for storing feature grid 237 | self.leaf_vox_size = config_args["octree"][ 238 | "leaf_vox_size" 239 | ] # the size of the grid on octree's leaf level (unit: m) 240 | self.feature_dim = config_args["octree"][ 241 | "feature_dim" 242 | ] # feature vector's dimension 243 | # self.feature_std = config_args["octree"][ 244 | # "feature_std" 245 | # ] # feature vector's initialization sigma (a zero mean, sigma standard deviation gaussian distribution) 246 | self.poly_int_on = config_args["octree"][ 247 | "poly_int_on" 248 | ] # use polynomial or linear interpolation of feature grids 249 | self.octree_from_surface_samples = config_args["octree"][ 250 | "octree_from_surface_samples" 251 | ] # build the octree from the surface samples or only the measurement points 252 | 253 | # decoder 254 | self.geo_mlp_level = config_args["decoder"][ 255 | "mlp_level" 256 | ] # number of the level of the mlp decoder 257 | self.geo_mlp_hidden_dim = config_args["decoder"][ 258 | "mlp_hidden_dim" 259 | ] # dimension of the mlp's hidden layer 260 | # freeze the decoder after runing for x frames (used for incremental mapping to avoid forgeting) 261 | self.freeze_after_frame = config_args["decoder"]["freeze_after_frame"] 262 | self.predict_residual_sdf = config_args["decoder"]["predict_residual_sdf"] 263 | 264 | # loss 265 | self.ray_loss = config_args["loss"]["ray_loss"] 266 | self.main_loss_type = config_args["loss"]["main_loss_type"] 267 | self.sigma_sigmoid_m = config_args["loss"]["sigma_sigmoid_m"] 268 | 269 | self.loss_weight_on = config_args["loss"]["loss_weight_on"] 270 | 271 | self.behind_dropoff_on = config_args["loss"][ 272 | "behind_dropoff_on" 273 | ] # apply "behind the surface" loss weight drop-off or not 274 | 275 | self.normal_loss_on = config_args["loss"][ 276 | "normal_loss_on" 277 | ] # use normal consistency loss [deprecated] 278 | self.weight_n = float(config_args["loss"]["weight_n"]) 279 | 280 | self.ekional_loss_on = config_args["loss"][ 281 | "ekional_loss_on" 282 | ] # use ekional loss (gradient = 1 loss) 283 | self.weight_e = float(config_args["loss"]["weight_e"]) 284 | 285 | 286 | 287 | # continual learning 288 | # using the regularization based continuous learning or the rehersal based continuous learning 289 | self.continual_learning_reg = config_args["continual"]["continual_learning_reg"] 290 | # the forgeting lambda for regularization based continual learning 291 | self.lambda_forget = float( 292 | config_args["continual"]["lambda_forget"] 293 | ) 294 | 295 | # # regularization based method 296 | # # rehersal based method 297 | # self.history_sample_ratio = float( 298 | # config_args["continuous"]["history_sample_ratio"] 299 | # ) # sample the history samples by a scale of the number of current samples 300 | # self.history_sample_res = config_args["continuous"][ 301 | # "history_sample_res" 302 | # ] # the resolution of the kept history samples (unit: m) 303 | 304 | # optimizer 305 | self.iters = config_args["optimizer"][ 306 | "iters" 307 | ] # maximum iters (in our implementation, iters means iteration actually) 308 | #self.batch_iters = config_args["optimizer"]["batch_iters"] 309 | self.bs = config_args["optimizer"]["batch_size"] 310 | # self.adam_eps = float(config_args["optimizer"]["adam_eps"]) 311 | self.lr = float(config_args["optimizer"]["learning_rate"]) 312 | # self.lr_level_reduce_ratio = config_args["optimizer"][ 313 | # "lr_level_reduce_ratio" 314 | # ] # decay the learning rate for higher level of feature grids by such ratio 315 | # self.lr_iters_reduce_ratio = config_args["optimizer"][ 316 | # "lr_iters_reduce_ratio" 317 | # ] # decay the learning rate after certain iterss by such ratio 318 | self.weight_decay = float( 319 | config_args["optimizer"]["weight_decay"] 320 | ) # coefficient for L2 regularization 321 | self.extra_training = config_args["optimizer"]["extra_training"] 322 | 323 | # vis and eval 324 | self.wandb_vis_on = config_args["eval"][ 325 | "wandb_vis_on" 326 | ] # use weight and bias to monitor the experiment or not 327 | self.o3d_vis_on = config_args["eval"][ 328 | "o3d_vis_on" 329 | ] # turn on the open3d visualizer to visualize the mapping progress or not 330 | self.vis_freq_iters = config_args["eval"][ 331 | "vis_freq_iters" 332 | ] # frequency for mesh reconstruction for batch mode (per x iters) 333 | self.save_freq_iters = config_args["eval"][ 334 | "save_freq_iters" 335 | ] # frequency for model saving for batch mode (per x iters) 336 | self.mesh_freq_frame = config_args["eval"][ 337 | "mesh_freq_frame" 338 | ] # frequency for mesh reconstruction for incremental mode (per x frame) 339 | self.mc_with_octree = config_args["eval"][ 340 | "mc_with_octree" 341 | ] # using octree to narrow down the region that needs the sdf query so as to boost the efficieny 342 | # if false, we query all the positions within the map bounding box 343 | self.mc_res_m = config_args["eval"][ 344 | "mc_res_m" 345 | ] # marching cubes grid sampling interval (unit: m) 346 | self.mc_vis_level = config_args["eval"][ 347 | "mc_vis_level" 348 | ] 349 | # self.mc_mask_on = config_args["eval"]["mc_mask_on"] # using masked marching cubes according to the octree or not, default true 350 | self.clean_mesh_on = config_args["eval"]["clean_mesh_on"] 351 | self.save_map = config_args["eval"][ 352 | "save_map" 353 | ] 354 | # tree level starting for reconstruction and visualization, the larger of this value, 355 | # the larger holes would be filled (better completion), but at the same time more artifacts 356 | # would appear at the boundary of the map 357 | # it's a trading-off of the compeltion and the artifacts 358 | 359 | self.calculate_world_scale() 360 | self.infer_bs = self.bs * 16 361 | self.mc_query_level = self.tree_level_world - self.tree_level_feat + 1 362 | 363 | # calculate the scale for compressing the world into a [-1,1] kaolin cube 364 | def calculate_world_scale(self): 365 | self.world_size = self.leaf_vox_size*(2**(self.tree_level_world-1)) 366 | self.scale = 1.0 / self.world_size 367 | -------------------------------------------------------------------------------- /utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import inv, norm 3 | import kaolin as kal 4 | import kaolin.render.spc as spc_render 5 | import torch 6 | 7 | from utils.config import SHINEConfig 8 | from utils.tools import * 9 | 10 | class dataSampler(): 11 | 12 | def __init__(self, config: SHINEConfig): 13 | 14 | self.config = config 15 | self.dev = config.device 16 | 17 | def sampling_rectified_sdf(self, points_torch, 18 | sensor_origin_torch, 19 | normal_torch, 20 | sem_label_torch): 21 | 22 | dev = self.config.device 23 | 24 | surface_sample_range_scaled = self.config.surface_sample_range_m * self.config.scale 25 | surface_sample_n = self.config.surface_sample_n 26 | freespace_sample_n = self.config.free_sample_n 27 | all_sample_n = surface_sample_n+freespace_sample_n 28 | free_min_ratio = self.config.free_sample_begin_ratio 29 | free_sample_end_dist_scaled = self.config.free_sample_end_dist * self.config.scale #meter 30 | 31 | # get sample points 32 | shift_points = points_torch - sensor_origin_torch 33 | point_num = shift_points.shape[0] 34 | distances = torch.linalg.norm(shift_points, dim=1, keepdim=True) # ray distances (scaled) 35 | ray_direction = shift_points/distances # normalized ray direction 36 | 37 | # Part 1. close-to-surface sampling along with normals. 38 | # uniform sample in the close-to-surface range (+- range) (-1,1) 39 | surface_sample_ratio_uniform = (torch.rand(point_num*surface_sample_n, 1, device=dev)-0.5)*2 40 | 41 | # gaussian sampling (TODO: gaussian should provide both near surface samples and free space samples) 42 | if self.config.gaussian_sampling_on: 43 | surface_sample_ratio_gaussian = torch.randn(point_num*surface_sample_n,1,device=dev)*0.3 44 | condition = torch.logical_and(surface_sample_ratio_gaussian > -1, surface_sample_ratio_gaussian < 1) 45 | surface_sample_ratio = torch.where(condition, surface_sample_ratio_gaussian, surface_sample_ratio_uniform) 46 | #print(surface_sample_ratio) 47 | else: 48 | surface_sample_ratio = surface_sample_ratio_uniform 49 | 50 | surface_sample_displacement = surface_sample_ratio * surface_sample_range_scaled 51 | repeated_dist = distances.repeat(surface_sample_n,1) 52 | surface_sample_dist_ratio = surface_sample_displacement/repeated_dist + 1.0 # 1.0 means on the surface 53 | 54 | surface_repeated_points = shift_points.repeat(surface_sample_n,1) 55 | surface_sample_points = sensor_origin_torch + surface_repeated_points*surface_sample_dist_ratio 56 | 57 | # only near surface samples are assigned to semantic labels. 58 | if sem_label_torch is not None: 59 | surface_sem_label_tensor = sem_label_torch.repeat(1, surface_sample_n).transpose(0,1) 60 | 61 | # Part 2. free space uniform sampling 62 | repeated_dist = distances.repeat(freespace_sample_n,1) 63 | if sem_label_torch is not None: 64 | free_sem_label_tensor = torch.zeros_like(repeated_dist) 65 | 66 | free_max_ratio = free_sample_end_dist_scaled / repeated_dist + 1.0 67 | free_diff_ratio = free_max_ratio - free_min_ratio 68 | free_sample_dist_ratio = torch.rand(point_num*freespace_sample_n, 1, device=dev)*free_diff_ratio + free_min_ratio 69 | free_sample_displacement = (free_sample_dist_ratio - 1.0) * repeated_dist 70 | free_repeated_points = shift_points.repeat(freespace_sample_n,1) 71 | free_sample_points = free_repeated_points*free_sample_dist_ratio + sensor_origin_torch 72 | 73 | all_sample_points = torch.cat((surface_sample_points,free_sample_points),0) 74 | all_sample_displacement = torch.cat((surface_sample_displacement, free_sample_displacement),0) 75 | 76 | weight_tensor = torch.ones_like(all_sample_displacement) 77 | if self.config.behind_dropoff_on: 78 | dropoff_min = 0.2 * self.config.scale 79 | dropoff_max = 0.8 * self.config.scale 80 | dropoff_diff = dropoff_max - dropoff_min 81 | dropoff_weight = (dropoff_max - all_sample_displacement) / dropoff_diff 82 | dropoff_weight = torch.clamp(dropoff_weight, min = 0.1, max = 1.0) 83 | #print(dropoff_weight) 84 | 85 | # give a flag indicating the type of the sample [negative: freespace, positive: surface] 86 | weight_tensor[point_num*surface_sample_n:] *= -1.0 87 | 88 | # assign sdf labels to the samples 89 | # projective distance as the label: behind -, in-front + 90 | sdf_label_tensor = - all_sample_displacement.squeeze(1) 91 | 92 | # assign the normal label to the samples 93 | normal_label_tensor = None 94 | if normal_torch is not None: 95 | normal_label_tensor = normal_torch.repeat(all_sample_n,1).reshape(-1, 3) 96 | 97 | # rectify sdf label by normals 98 | ray_direction_tensor = ray_direction.repeat(all_sample_n,1) 99 | correct_ratio = (normal_label_tensor * ray_direction_tensor).sum(dim=1).abs() 100 | sdf_label_tensor = sdf_label_tensor * correct_ratio 101 | 102 | # assign the semantic label to the samples (including free space as the 0 label) 103 | sem_label_tensor = None 104 | if sem_label_torch is not None: 105 | sem_label_tensor = torch.cat((surface_sem_label_tensor, free_sem_label_tensor),0).int().reshape(-1, 3) 106 | 107 | # samples to voxel int coords 108 | all_sample_voxels = kal.ops.spc.quantize_points(all_sample_points, self.config.tree_level_world) 109 | all_sample_morton = kal.ops.spc.points_to_morton(all_sample_voxels) 110 | 111 | samples = {} 112 | samples["count"] = sdf_label_tensor.shape[0] 113 | #samples["point_morton_count"] = point_morton_count 114 | samples["pcd_count"] = point_num 115 | samples["coord"] = all_sample_points.reshape(-1,3) 116 | samples["voxel_coord"] = all_sample_voxels.reshape(-1, 3) 117 | samples["morton"] = all_sample_morton.reshape(-1) 118 | samples["sdf"] = sdf_label_tensor.reshape(-1) 119 | samples["normal"] = normal_label_tensor 120 | samples["sem"] = sem_label_tensor 121 | samples["weight"] = weight_tensor.reshape(-1) 122 | 123 | return samples 124 | 125 | # free space sampling jump near surface 126 | def sampling(self, points_torch, 127 | sensor_origin_torch, 128 | normal_torch, 129 | sem_label_torch, 130 | normal_guided_sampling = False): 131 | 132 | dev = self.config.device 133 | 134 | surface_sample_range_scaled = self.config.surface_sample_range_m * self.config.scale 135 | surface_sample_n = self.config.surface_sample_n 136 | freespace_sample_n = self.config.free_sample_n 137 | all_sample_n = surface_sample_n+freespace_sample_n 138 | free_min_ratio = self.config.free_sample_begin_ratio 139 | free_sample_end_dist_scaled = self.config.free_sample_end_dist * self.config.scale #meter 140 | 141 | # get sample points 142 | shift_points = points_torch - sensor_origin_torch 143 | point_num = shift_points.shape[0] 144 | distances = torch.linalg.norm(shift_points, dim=1, keepdim=True) # ray distances (scaled) 145 | 146 | # Part 1. close-to-surface sampling 147 | # uniform sample in the close-to-surface range (+- range) (-1,1) 148 | surface_sample_ratio_uniform = (torch.rand(point_num*surface_sample_n, 1, device=dev)-0.5)*2 149 | 150 | # gaussian sampling (gaussian should provide both near surface samples and free space samples) 151 | if self.config.gaussian_sampling_on: 152 | surface_sample_ratio_gaussian = torch.randn(point_num*surface_sample_n,1,device=dev)*0.3 153 | condition = torch.logical_and(surface_sample_ratio_gaussian > -1, surface_sample_ratio_gaussian < 1) 154 | surface_sample_ratio = torch.where(condition, surface_sample_ratio_gaussian, surface_sample_ratio_uniform) 155 | #print(surface_sample_ratio) 156 | else: 157 | surface_sample_ratio = surface_sample_ratio_uniform 158 | 159 | surface_sample_displacement = surface_sample_ratio * surface_sample_range_scaled 160 | repeated_dist = distances.repeat(surface_sample_n,1) 161 | surface_sample_dist_ratio = surface_sample_displacement/repeated_dist + 1.0 # 1.0 means on the surface 162 | 163 | surface_repeated_points = shift_points.repeat(surface_sample_n,1) 164 | if normal_guided_sampling: 165 | normal_direction = normal_torch.repeat(surface_sample_n,1) # normals are oriented towards sensors. 166 | #note that normals are oriented towards origin (inwards) 167 | surface_sample_points = sensor_origin_torch + surface_repeated_points + surface_sample_displacement * (-normal_direction) 168 | else: 169 | surface_sample_points = sensor_origin_torch+ surface_repeated_points*surface_sample_dist_ratio 170 | 171 | # only near surface samples are assigned to semantic labels. 172 | if sem_label_torch is not None: 173 | surface_sem_label_tensor = sem_label_torch.repeat(1, surface_sample_n).transpose(0,1) 174 | 175 | # Part 2. free space uniform sampling 176 | repeated_dist = distances.repeat(freespace_sample_n,1) 177 | if sem_label_torch is not None: 178 | free_sem_label_tensor = torch.zeros_like(repeated_dist) 179 | 180 | free_max_ratio = free_sample_end_dist_scaled / repeated_dist + 1.0 181 | free_diff_ratio = free_max_ratio - free_min_ratio 182 | free_sample_dist_ratio = torch.rand(point_num*freespace_sample_n, 1, device=dev)*free_diff_ratio + free_min_ratio 183 | free_sample_displacement = (free_sample_dist_ratio - 1.0) * repeated_dist 184 | free_repeated_points = shift_points.repeat(freespace_sample_n,1) 185 | free_sample_points = free_repeated_points*free_sample_dist_ratio + sensor_origin_torch 186 | 187 | # remove near-surface samples from free-space samples 188 | tr = surface_sample_range_scaled*1.33 189 | valid_mask = torch.logical_or(free_sample_displacement < -tr, free_sample_displacement > tr).reshape(-1) 190 | free_sample_displacement = free_sample_displacement[valid_mask] 191 | free_sample_points = free_sample_points[valid_mask] 192 | 193 | all_sample_points = torch.cat((surface_sample_points,free_sample_points),0) 194 | all_sample_displacement = torch.cat((surface_sample_displacement, free_sample_displacement),0) 195 | 196 | weight_tensor = torch.ones_like(all_sample_displacement) 197 | 198 | # give a flag indicating the type of the sample [negative: freespace, positive: surface] 199 | weight_tensor[point_num*surface_sample_n:] *= -1.0 200 | 201 | # assign sdf labels to the samples 202 | # projective distance as the label: behind -, in-front + 203 | sdf_label_tensor = - all_sample_displacement.squeeze(1) 204 | 205 | # assign the normal label to the samples 206 | normal_label_tensor = None 207 | if normal_torch is not None: 208 | surface_normal = normal_torch.repeat(surface_sample_n,1) 209 | free_normal = normal_torch.repeat(freespace_sample_n,1) 210 | free_normal = free_normal[valid_mask] 211 | normal_label_tensor = torch.cat((surface_normal,free_normal),0) 212 | 213 | # assign the semantic label to the samples (including free space as the 0 label) 214 | sem_label_tensor = None 215 | if sem_label_torch is not None: 216 | free_sem_label_tensor = free_sem_label_tensor[valid_mask] 217 | sem_label_tensor = torch.cat((surface_sem_label_tensor, free_sem_label_tensor),0).int().reshape(-1, 3) 218 | 219 | # samples to voxel int coords 220 | all_sample_voxels = kal.ops.spc.quantize_points(all_sample_points, self.config.tree_level_world) 221 | all_sample_morton = kal.ops.spc.points_to_morton(all_sample_voxels) 222 | 223 | samples = {} 224 | samples["count"] = sdf_label_tensor.shape[0] 225 | #samples["point_morton_count"] = point_morton_count 226 | samples["pcd_count"] = point_num 227 | samples["coord"] = all_sample_points.reshape(-1,3) 228 | samples["voxel_coord"] = all_sample_voxels.reshape(-1, 3) 229 | samples["morton"] = all_sample_morton.reshape(-1) 230 | samples["sdf"] = sdf_label_tensor.reshape(-1) 231 | samples["normal"] = normal_label_tensor.reshape(-1, 3) 232 | samples["sem"] = sem_label_tensor 233 | samples["weight"] = weight_tensor.reshape(-1) 234 | 235 | return samples -------------------------------------------------------------------------------- /utils/incre_learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | from tqdm import tqdm 3 | from model.feature_octree import FeatureOctree 4 | from model.decoder import Decoder 5 | from dataset.lidar_dataset import LiDARDataset 6 | from utils.loss import * 7 | 8 | def cal_feature_importance(data: LiDARDataset, octree: FeatureOctree, mlp: Decoder, 9 | sigma, bs, down_rate=1, loss_reduction='mean', loss_weight_on = False): 10 | 11 | # shuffle_indice = torch.randperm(data.coord_pool.shape[0]) 12 | # shuffle_coord = data.coord_pool[shuffle_indice] 13 | # shuffle_label = data.sdf_label_pool[shuffle_indice] 14 | 15 | sample_count = data.coord_pool.shape[0] 16 | batch_interval = bs*down_rate 17 | iter_n = math.ceil(sample_count/batch_interval) 18 | for n in tqdm(range(iter_n)): 19 | head = n*batch_interval 20 | tail = min((n+1)*batch_interval, sample_count) 21 | # batch_coord = data.coord_pool[head:tail:down_rate] 22 | # batch_label = data.sdf_label_pool[head:tail:down_rate] 23 | 24 | batch_coord = data.coord_pool[head:tail:down_rate] 25 | batch_label = data.sdf_label_pool[head:tail:down_rate] 26 | # batch_weight = data.weight_pool[head:tail:down_rate] 27 | count = batch_label.shape[0] 28 | 29 | octree.get_indices(batch_coord) 30 | features = octree.query_feature(batch_coord) 31 | pred = mlp(features) # before sigmoid 32 | # add options for other losses here 33 | sdf_loss = sdf_bce_loss(pred, batch_label, sigma, None, loss_weight_on, loss_reduction) 34 | sdf_loss.backward() 35 | 36 | for i in range(len(octree.importance_weight)): # for each level 37 | octree.importance_weight[i] += octree.hier_features[i].grad.abs() 38 | octree.hier_features[i].grad.zero_() 39 | 40 | octree.importance_weight[i][-1] *= 0 # reseting the trashbin feature weight to 0 -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | def sdf_diff_loss(pred, label, weight, scale, l2_loss=True): 7 | count = pred.shape[0] 8 | diff = pred - label 9 | diff_m = diff / scale 10 | if l2_loss: 11 | loss = (weight * (diff_m**2)).sum() / count # l2 loss 12 | else: 13 | loss = (weight * abs(diff_m)).sum() / count # l1 loss 14 | return loss 15 | 16 | # TODO: add drop off weighting. 17 | def sdf_bce_loss(pred, label, sigma, weight, weighted=False, bce_reduction = "mean"): 18 | if weighted: 19 | loss_bce = nn.BCEWithLogitsLoss(reduction=bce_reduction, weight=weight) 20 | else: 21 | loss_bce = nn.BCEWithLogitsLoss(reduction=bce_reduction) 22 | label_op = torch.sigmoid(label / sigma) # occupancy prob 23 | loss = loss_bce(pred, label_op) 24 | return loss 25 | 26 | def normal_diff_loss(pred, label, mask): 27 | # diff1 = (pred - label)[mask].abs().norm(2, dim=1) 28 | # diff2 = (pred + label)[mask].abs().norm(2, dim=1) 29 | # diff = torch.min(diff1,diff2) 30 | diff = (pred - label)[mask].abs().norm(2, dim=1) 31 | return diff.mean() 32 | 33 | 34 | def ray_estimation_loss(x, y, d_meas): # for each ray 35 | # x as depth 36 | # y as sdf prediction 37 | # d_meas as measured depth 38 | 39 | # print(x.shape, y.shape, d_meas.shape) 40 | 41 | # regard each sample as a ray 42 | mat_A = torch.vstack((x, torch.ones_like(x))).transpose(0, 1) 43 | vec_b = y.view(-1, 1) 44 | 45 | # print(mat_A.shape, vec_b.shape) 46 | 47 | least_square_estimate = torch.linalg.lstsq(mat_A, vec_b).solution 48 | 49 | a = least_square_estimate[0] # -> -1 (added in ekional loss term) 50 | b = least_square_estimate[1] 51 | 52 | # d_estimate = -b/a 53 | d_estimate = torch.clamp(-b / a, min=1.0, max=40.0) # -> d 54 | 55 | # d_error = (d_estimate-d_meas)**2 56 | 57 | d_error = torch.abs(d_estimate - d_meas) 58 | 59 | # print(mat_A.shape, vec_b.shape, least_square_estimate.shape) 60 | # print(d_estimate.item(), d_meas.item(), d_error.item()) 61 | 62 | return d_error 63 | 64 | 65 | def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch] 66 | # x as depth 67 | # y as occ.prob. prediction 68 | x = x.squeeze(1) 69 | sort_x, indices = torch.sort(x) 70 | sort_y = y[indices] 71 | 72 | w = torch.ones_like(y) 73 | for i in range(sort_x.shape[0]): 74 | w[i] = sort_y[i] 75 | for j in range(i): 76 | w[i] *= 1.0 - sort_y[j] 77 | 78 | d_render = (w * x).sum() 79 | 80 | d_error = torch.abs(d_render - d_meas) 81 | 82 | # print(x.shape, y.shape, d_meas.shape) 83 | # print(mat_A.shape, vec_b.shape, least_square_estimate.shape) 84 | # print(d_render.item(), d_meas.item(), d_error.item()) 85 | 86 | return d_error 87 | 88 | 89 | def batch_ray_rendering_loss(x, y, d_meas, neus_on=True): # for all rays in a batch 90 | # x as depth [ray number * sample number] 91 | # y as prediction (the alpha in volume rendering) [ray number * sample number] 92 | # d_meas as measured depth [ray number] 93 | # w as the raywise weight [ray number] 94 | # neus_on determine if using the loss defined in NEUS 95 | 96 | # print(x.shape, y.shape, d_meas.shape, w.shape) 97 | 98 | sort_x, indices = torch.sort(x, 1) # for each row 99 | sort_y = torch.gather(y, 1, indices) # for each row 100 | 101 | if neus_on: 102 | neus_alpha = (sort_y[:, 1:] - sort_y[:, 0:-1]) / ( 1. - sort_y[:, 0:-1] + 1e-10) 103 | # avoid dividing by 0 (nan) 104 | # print(neus_alpha) 105 | alpha = torch.clamp(neus_alpha, min=0.0, max=1.0) 106 | else: 107 | alpha = sort_y 108 | 109 | one_minus_alpha = torch.ones_like(alpha) - alpha + 1e-10 110 | 111 | cum_mat = torch.cumprod(one_minus_alpha, 1) 112 | 113 | weights = cum_mat / one_minus_alpha * alpha 114 | 115 | weights_x = weights * sort_x[:, 0 : alpha.shape[1]] 116 | 117 | d_render = torch.sum(weights_x, 1) 118 | 119 | d_error = torch.abs(d_render - d_meas) 120 | 121 | # d_error = torch.abs(d_render - d_meas) * w # times ray-wise weight 122 | 123 | d_error_mean = torch.mean(d_error) 124 | 125 | return d_error_mean 126 | -------------------------------------------------------------------------------- /utils/mapper.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | from collections import deque 4 | from numpy.linalg import inv, norm 5 | from tqdm import tqdm 6 | import open3d as o3d 7 | import kaolin as kal 8 | import wandb 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.functional as F 13 | import matplotlib.pyplot as plt 14 | 15 | from utils.config import SHINEConfig 16 | from utils.tools import * 17 | from utils.loss import * 18 | from utils.data_sampler import dataSampler 19 | from utils.mesher import Mesher 20 | from utils.visualizer import MapVisualizer, random_color_table 21 | from model.feature_octree import FeatureOctree 22 | from model.decoder import Decoder 23 | from dataset.kitti_dataset import KITTIDataset 24 | 25 | class Mapper(): 26 | def __init__(self, config: SHINEConfig): 27 | 28 | self.config = config 29 | self.device = config.device 30 | self.dtype = config.dtype 31 | 32 | self.run_path = setup_experiment(config) 33 | 34 | # initialize the feature octree 35 | self.octree = FeatureOctree(config) 36 | # initialize the mlp decoder 37 | self.geo_mlp = Decoder(config, is_geo_encoder=True) 38 | self.sem_mlp = Decoder(config, is_geo_encoder=False) 39 | 40 | # Load the decoder model 41 | if config.load_model: 42 | loaded_model = torch.load(config.model_path) 43 | self.geo_mlp.load_state_dict(loaded_model["geo_decoder"]) 44 | print("Pretrained decoder loaded") 45 | freeze_model(self.geo_mlp) # fixed the decoder 46 | if config.semantic_on: 47 | self.sem_mlp.load_state_dict(loaded_model["sem_decoder"]) 48 | freeze_model(self.sem_mlp) # fixed the decoder 49 | if 'feature_octree' in loaded_model.keys(): # also load the feature octree 50 | self.octree = loaded_model["feature_octree"] 51 | self.octree.print_detail() 52 | 53 | # dataset 54 | self.dataset = KITTIDataset(config) 55 | 56 | # sampler 57 | self.sampler = dataSampler(config) 58 | 59 | # mesh reconstructor 60 | self.mesher = Mesher(config, self.octree, self.geo_mlp, self.sem_mlp) 61 | self.mesher.global_transform = inv(self.dataset.begin_pose_inv) 62 | 63 | # Non-blocking visualizer 64 | if config.o3d_vis_on: 65 | self.vis = MapVisualizer() 66 | 67 | # learnable parameters 68 | self.geo_mlp_param = list(self.geo_mlp.parameters()) 69 | # learnable sigma for differentiable rendering 70 | self.sigma_size = torch.nn.Parameter(torch.ones(1, device=self.device)*1.0) 71 | # fixed sigma for sdf prediction supervised with BCE loss 72 | self.sigma_sigmoid = config.logistic_gaussian_ratio*config.sigma_sigmoid_m*config.scale 73 | 74 | # # key scan list 75 | # self.keyframelist = [] 76 | # # key frames local window 77 | # self.local_frames_list = [] 78 | 79 | self.last_frame_origin = np.zeros(3) # record the last frame origin. 80 | 81 | self.window_size = 10 # frames size 82 | self.window_traj_gap = 50 # meters 83 | # samples of each frame 84 | self.frame_samples_list = deque(maxlen = self.window_size) 85 | 86 | # samples count of each frame 87 | self.frame_samples_count = deque(maxlen = self.window_size) 88 | 89 | self.coord_pool = torch.empty((0, 3), device=self.device, dtype=self.dtype) 90 | self.voxel_pool = torch.empty((0, 3), device=self.device, dtype=int) 91 | self.morton_pool = torch.empty((0), device=self.device, dtype=torch.int64) 92 | self.sdf_label_pool = torch.empty((0), device=self.device, dtype=self.dtype) 93 | self.normal_label_pool = torch.empty((0, 3), device=self.device, dtype=self.dtype) 94 | self.sem_label_pool = torch.empty((0), device=self.device, dtype=int) 95 | self.weight_pool = torch.empty((0), device=self.device, dtype=self.dtype) 96 | 97 | self.extra_sample_pool_index = [] # for extra local sampling 98 | self.extra_index_pool = torch.empty((0), dtype=torch.int64) 99 | # update 100 | def update_samples_pool(self, samples, min_xyz, max_xyz, use_sliding_window = False): 101 | 102 | frame_idx = samples["index"] 103 | # concatenate new samples with samples pool 104 | self.coord_pool = torch.cat((self.coord_pool, samples["coord"]), 0) 105 | self.sdf_label_pool = torch.cat((self.sdf_label_pool, samples["sdf"]), 0) 106 | self.weight_pool = torch.cat((self.weight_pool, samples["weight"]), 0) 107 | self.voxel_pool = torch.cat((self.voxel_pool, samples["voxel_coord"]), 0) 108 | self.morton_pool = torch.cat((self.morton_pool, samples["morton"]), 0) 109 | if samples["normal"] is not None: 110 | self.normal_label_pool = torch.cat((self.normal_label_pool, samples["normal"]), 0) 111 | else: 112 | self.normal_label_pool = None 113 | if samples["sem"] is not None: 114 | self.sem_label_pool = torch.cat((self.sem_label_pool, samples["sem"]), 0) 115 | else: 116 | self.sem_label_pool = None 117 | 118 | if use_sliding_window: 119 | maskx = torch.logical_and(self.voxel_pool[:,0] > min_xyz[0], 120 | self.voxel_pool[:,0] < max_xyz[0]) 121 | 122 | masky = torch.logical_and(self.voxel_pool[:,1] > min_xyz[1], 123 | self.voxel_pool[:,1] < max_xyz[1]) 124 | 125 | maskz = torch.logical_and(self.voxel_pool[:,2] > min_xyz[2], 126 | self.voxel_pool[:,2] < max_xyz[2]) 127 | 128 | mask = torch.logical_and(torch.logical_and(maskx, masky),maskz) 129 | 130 | self.coord_pool = self.coord_pool[mask] 131 | self.sdf_label_pool = self.sdf_label_pool[mask] 132 | self.weight_pool = self.weight_pool[mask] 133 | self.voxel_pool = self.voxel_pool[mask] 134 | self.morton_pool = self.morton_pool[mask] 135 | if self.normal_label_pool is not None: 136 | self.normal_label_pool = self.normal_label_pool[mask] 137 | if self.sem_label_pool is not None: 138 | self.sem_label_pool = self.sem_label_pool[mask] 139 | 140 | # downsampling samples or extra local sampling 141 | # pay more attention to observation-less region. 142 | if self.config.extra_training: 143 | if frame_idx > 10: 144 | dups_in_mortons = dict(list_duplicates(self.morton_pool.cpu().numpy().tolist())) 145 | extra_sample_pool_index = [] 146 | for m, indexes in dups_in_mortons.items(): 147 | if len(indexes) < 10: # 10 148 | extra_sample_pool_index.extend(indexes) 149 | self.extra_index_pool = torch.tensor(extra_sample_pool_index, device=self.config.device, dtype=torch.int64) 150 | 151 | # TODO: better do it in data sampling 152 | # filter out all free samples out of voxels. 153 | def filter_samples(self, samples): 154 | #t1 = get_time() 155 | 156 | pcd_count = samples["pcd_count"] 157 | surface_sample_num = self.config.surface_sample_n * pcd_count 158 | #free_sample_num = self.config.free_sample_n * pcd_count 159 | 160 | # unpack samples 161 | count = samples["count"] 162 | coord = samples["coord"] 163 | voxel_coord = samples["voxel_coord"] 164 | morton = samples["morton"] 165 | sdf = samples["sdf"] 166 | normal = samples["normal"] 167 | sem = samples["sem"] 168 | weight = samples["weight"] 169 | 170 | valid_index = list(range(surface_sample_num)) 171 | voxel_morton = morton.cpu().numpy().tolist() # nodes at certain level 172 | 173 | for idx in range(surface_sample_num, len(voxel_morton)): 174 | if voxel_morton[idx] in self.octree.nodes_lookup_tables[self.config.tree_level_world]: 175 | valid_index.append(idx) # nodes to corner dictionary: key is the morton code 176 | 177 | samples["coord"] = coord[valid_index] 178 | samples["voxel_coord"] = voxel_coord[valid_index] 179 | samples["morton"] = morton[valid_index] 180 | samples["sdf"] = sdf[valid_index] 181 | samples["weight"] = weight[valid_index] 182 | if sem is not None: 183 | samples["sem"] = sem[valid_index] 184 | if normal is not None: 185 | samples["normal"] = normal[valid_index] 186 | 187 | samples["count"] = len(valid_index) 188 | 189 | #t2 = get_time() 190 | #print("filter {:d} samples cost {:.1f}ms".format(count-len(valid_index),1000*(t2-t1))) 191 | 192 | return samples 193 | 194 | # sampling training pairs from label pool. 195 | def get_batch(self): 196 | train_sample_count = self.sdf_label_pool.shape[0] 197 | if not self.config.extra_training: 198 | index = torch.randint(0, train_sample_count, (self.config.bs,), device=self.config.device) 199 | else: 200 | extra_batch_size = round(self.config.bs/3) 201 | bacth_size = self.config.bs - extra_batch_size 202 | #bacth_size = self.config.bs 203 | extra_sample_count = self.extra_index_pool.shape[0] 204 | index = torch.randint(0, train_sample_count, (bacth_size,), device=self.config.device) 205 | if extra_sample_count > bacth_size: 206 | extra_index_index = torch.randint(0, extra_sample_count, (extra_batch_size,), device=self.config.device) 207 | extra_index = self.extra_index_pool[extra_index_index] 208 | index = torch.cat((index, extra_index), dim=0) 209 | 210 | coord = self.coord_pool[index, :] 211 | sdf_label = self.sdf_label_pool[index] 212 | weight = self.weight_pool[index] 213 | 214 | if self.normal_label_pool is not None: 215 | normal_label = self.normal_label_pool[index, :] 216 | else: 217 | normal_label = None 218 | 219 | if self.sem_label_pool is not None: 220 | sem_label = self.sem_label_pool[index] 221 | else: 222 | sem_label = None 223 | 224 | return coord, sdf_label, normal_label, sem_label, weight 225 | 226 | def mapping(self): 227 | processed_frame = 0 228 | total_iter = 0 229 | 230 | for frame_id in tqdm(range(self.dataset.total_pc_count)): 231 | if (frame_id < self.config.begin_frame or frame_id > self.config.end_frame or \ 232 | frame_id % self.config.every_frame != 0): 233 | continue 234 | 235 | # # for kitti 07, skip the pose jump 236 | # if (frame_id > 659 and frame_id < 730): 237 | # processed_frame += 1 238 | # continue 239 | 240 | vis_mesh = False 241 | if processed_frame == self.config.freeze_after_frame: # freeze the decoder after certain frame 242 | print("Freeze the decoder") 243 | freeze_model( self.geo_mlp) # fixed the decoder 244 | 245 | T0 = get_time() 246 | # sampling 247 | _, frame_origin_torch, frame_pc_s_torch, frame_normal_torch, frame_label_torch = self.dataset[frame_id] 248 | samples = self.sampler.sampling(frame_pc_s_torch, 249 | frame_origin_torch, 250 | frame_normal_torch, 251 | frame_label_torch, 252 | normal_guided_sampling=self.config.normal_sampling_on) 253 | # samples = self.sampler.sampling_rectified_sdf(frame_pc_s_torch, 254 | # frame_origin_torch, 255 | # frame_normal_torch, 256 | # frame_label_torch) 257 | 258 | # avoid duplicated samples due to the slow motion 259 | samples["index"] = frame_id 260 | current_frame_origin = frame_origin_torch.cpu().numpy().reshape(-1) 261 | relative_dist = np.mean(norm(current_frame_origin - self.last_frame_origin)) 262 | if processed_frame > 5 and relative_dist < 0.5*self.config.scale and self.config.use_keyframe: 263 | print("slow motion! jump frame") 264 | with open('./log/jump.txt', 'a') as f: 265 | f.write(f"{frame_id}, {relative_dist/self.config.scale}\n") 266 | processed_frame += 1 267 | if (frame_id+1) % self.config.mesh_freq_frame != 0: 268 | continue 269 | self.last_frame_origin = current_frame_origin 270 | 271 | # update feature octree 272 | if self.config.octree_from_surface_samples: 273 | # update with the sampled surface points 274 | self.octree.update(samples["coord"][samples["weight"] > 0, :]) 275 | else: 276 | # update with the original points 277 | self.octree.update(frame_pc_s_torch.to("cuda")) 278 | 279 | # calculate local boundary 280 | frame_origin_voxel = kal.ops.spc.quantize_points(frame_origin_torch, self.config.tree_level_world) 281 | radius_vox_count = round(self.config.pc_radius/self.config.leaf_vox_size) 282 | min_xyz = frame_origin_voxel[:3] - radius_vox_count #lower bound 283 | max_xyz = frame_origin_voxel[:3] + radius_vox_count # upper bound 284 | 285 | # update samples pool 286 | samples = self.filter_samples(samples) 287 | self.update_samples_pool(samples, min_xyz, max_xyz, use_sliding_window = self.config.sliding_window_on) 288 | 289 | octree_feat = list(self.octree.parameters()) 290 | opt = setup_optimizer(self.config, octree_feat, self.geo_mlp_param, None, self.sigma_size) 291 | self.octree.print_detail() 292 | 293 | T1 = get_time() 294 | for iter in tqdm(range(self.config.iters)): 295 | # load batch data (avoid using dataloader because the data are already in gpu, memory vs speed) 296 | coord, sdf_label, normal_label, sem_label, weight = self.get_batch() 297 | if self.config.normal_loss_on or self.config.ekional_loss_on: 298 | coord.requires_grad_(True) 299 | 300 | # interpolate and concat the hierachical grid features 301 | # predict the scaled sdf with the feature 302 | if self.config.predict_residual_sdf: 303 | feature, coarse_features = self.octree.query_split_feature(coord) 304 | sdf_pred = self.geo_mlp.sum_sdf(feature, coarse_features) 305 | else: 306 | feature = self.octree.query_feature(coord) 307 | sdf_pred = self.geo_mlp.sdf(feature) 308 | 309 | if self.config.semantic_on: 310 | sem_pred = self.sem_mlp.sem_label_prob(feature) 311 | 312 | # calculate the gradients 313 | if self.config.normal_loss_on or self.config.ekional_loss_on: 314 | g = get_gradient(coord, sdf_pred)*self.sigma_sigmoid 315 | g_normalized = F.normalize(g, p=2, dim=1) 316 | 317 | # calculate the loss 318 | surface_mask = weight > 0 319 | 320 | cur_loss = 0. 321 | weight = torch.abs(weight) # weight's sign indicate the sample is around the surface or in the free space 322 | sdf_loss = sdf_bce_loss(sdf_pred, sdf_label, self.sigma_sigmoid, weight, self.config.loss_weight_on, self.config.loss_reduction) 323 | cur_loss += sdf_loss 324 | 325 | # incremental learning regularization loss (useless in this work) 326 | reg_loss = 0. 327 | if self.config.continual_learning_reg: 328 | reg_loss = self.octree.cal_regularization() 329 | cur_loss += self.config.lambda_forget * reg_loss 330 | 331 | # optional ekional loss 332 | eikonal_loss = 0. 333 | if self.config.ekional_loss_on: 334 | #eikonal_loss = ((g[~surface_mask].norm(2, dim=-1) - 1.0) ** 2).mean() # MSE with regards to 1 335 | eikonal_loss = ((g.norm(2, dim=-1) - 1.0) ** 2).mean() # MSE with regards to 1 336 | cur_loss += self.config.weight_e * eikonal_loss 337 | 338 | normal_loss = 0. 339 | if self.config.normal_loss_on: 340 | normal_loss = normal_diff_loss(g_normalized, normal_label, surface_mask) 341 | cur_loss += self.config.weight_n * normal_loss 342 | 343 | # semantic classification loss 344 | # sem_loss = 0. 345 | if self.config.semantic_on: 346 | loss_nll = nn.NLLLoss(reduction='mean') 347 | sem_loss = loss_nll(sem_pred[::self.config.sem_label_decimation,:], sem_label[::self.config.sem_label_decimation]) 348 | cur_loss += self.config.weight_s * sem_loss 349 | 350 | opt.zero_grad(set_to_none=True) 351 | cur_loss.backward() # this is the slowest part (about 10x the forward time) 352 | opt.step() 353 | 354 | total_iter += 1 355 | 356 | T2 = get_time() 357 | 358 | # reconstruction by marching cubes 359 | if frame_id == 0 or (processed_frame) % self.config.mesh_freq_frame == 0: 360 | print("Begin mesh reconstruction from the implicit map") 361 | vis_mesh = True 362 | mesh_path = self.run_path + '/mesh/mesh_frame_' + str(frame_id+1) + ".ply" 363 | map_path = self.run_path + '/map/sdf_map_frame_' + str(frame_id+1) + ".ply" 364 | if self.config.mc_with_octree: # default 365 | cur_mesh = self.mesher.recon_octree_mesh(self.config.mc_query_level, self.dataset.map_down_pc, self.config.mc_res_m, 366 | mesh_path, map_path, self.config.save_map, self.config.semantic_on, 367 | filter_free_space_vertices=self.config.clean_mesh_on) 368 | else: 369 | cur_mesh = self.mesher.recon_bbx_mesh(self.dataset.map_bbx, self.dataset.map_down_pc, self.config.mc_res_m, 370 | mesh_path, map_path, self.config.save_map,self. config.semantic_on, 371 | filter_free_space_vertices=self.config.clean_mesh_on) 372 | # save raw point cloud 373 | # pc_map_path = self.run_path + '/map/pc_frame_' + str(frame_id+1) + ".ply" 374 | # self.dataset.write_merged_pc(pc_map_path) 375 | 376 | T3 = get_time() 377 | 378 | if self.config.o3d_vis_on: 379 | if vis_mesh: 380 | cur_mesh.transform(self.dataset.begin_pose_inv) # back to the globally shifted frame for vis 381 | self.vis.update(self.dataset.cur_frame_pc, self.dataset.cur_pose_ref, cur_mesh) 382 | else: # only show frame and current point cloud 383 | self.vis.update(self.dataset.cur_frame_pc, self.dataset.cur_pose_ref) 384 | 385 | processed_frame += 1 386 | 387 | print("Begin mesh reconstruction from the implicit map") 388 | mesh_path = self.run_path + '/mesh/final_mesh.ply' 389 | map_path = self.run_path + '/map/final_sdf.ply' 390 | if self.config.mc_with_octree: # default 391 | cur_mesh = self.mesher.recon_octree_mesh(self.config.mc_query_level, self.dataset.map_down_pc, self.config.mc_res_m, 392 | mesh_path, map_path, self.config.save_map, self.config.semantic_on, 393 | filter_free_space_vertices=self.config.clean_mesh_on) 394 | else: 395 | cur_mesh = self.mesher.recon_bbx_mesh(self.dataset.map_bbx, self.dataset.map_down_pc, self.config.mc_res_m, 396 | mesh_path, map_path, self.config.save_map,self. config.semantic_on, 397 | filter_free_space_vertices=self.config.clean_mesh_on) 398 | if self.config.o3d_vis_on: 399 | cur_mesh.transform(self.dataset.begin_pose_inv) # back to the globally shifted frame for vis 400 | self.vis.update(self.dataset.cur_frame_pc, self.dataset.cur_pose_ref, cur_mesh) 401 | self.vis.stop() 402 | 403 | 404 | -------------------------------------------------------------------------------- /utils/mesher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import skimage.measure 4 | from scipy.spatial import cKDTree 5 | import torch 6 | import math 7 | import open3d as o3d 8 | from utils.config import SHINEConfig 9 | from utils.semantic_kitti_utils import * 10 | from model.feature_octree import FeatureOctree 11 | from model.decoder import Decoder 12 | 13 | class Mesher(): 14 | 15 | def __init__(self, config: SHINEConfig, octree: FeatureOctree, \ 16 | geo_decoder: Decoder, sem_decoder: Decoder): 17 | 18 | self.config = config 19 | 20 | self.octree = octree 21 | self.geo_decoder = geo_decoder 22 | self.sem_decoder = sem_decoder 23 | self.device = config.device 24 | self.cur_device = self.device 25 | self.dtype = config.dtype 26 | self.world_scale = config.scale 27 | 28 | self.global_transform = np.eye(4) 29 | 30 | def query_points(self, coord, bs, query_sdf = True, query_sem = False, query_mask = True): 31 | """ query the sdf value, semantic label and marching cubes mask for points 32 | Args: 33 | coord: Nx3 torch tensor, the coordinates of all N (axbxc) query points in the scaled 34 | kaolin coordinate system [-1,1] 35 | bs: batch size for the inference 36 | Returns: 37 | sdf_pred: Ndim numpy array, signed distance value (scaled) at each query point 38 | sem_pred: Ndim numpy array, semantic label prediction at each query point 39 | mc_mask: Ndim bool numpy array, marching cubes mask at each query point 40 | """ 41 | # the coord torch tensor is already scaled in the [-1,1] coordinate system 42 | sample_count = coord.shape[0] 43 | iter_n = math.ceil(sample_count/bs) 44 | check_level = min(self.octree.featured_level_num, self.config.mc_vis_level)-1 45 | if query_sdf: 46 | sdf_pred = np.zeros(sample_count) 47 | else: 48 | sdf_pred = None 49 | if query_sem: 50 | sem_pred = np.zeros(sample_count) 51 | else: 52 | sem_pred = None 53 | if query_mask: 54 | mc_mask = np.zeros(sample_count) 55 | else: 56 | mc_mask = None 57 | 58 | with torch.no_grad(): # eval step 59 | if iter_n > 1: 60 | for n in tqdm(range(iter_n)): 61 | head = n*bs 62 | tail = min((n+1)*bs, sample_count) 63 | batch_coord = coord[head:tail, :] 64 | if self.cur_device == "cpu" and self.device == "cuda": 65 | batch_coord = batch_coord.cuda() 66 | 67 | if self.config.predict_residual_sdf: 68 | batch_feature, batch_coarse_feature = self.octree.query_split_feature(batch_coord, True) # query features 69 | batch_sdf = self.geo_decoder.sum_sdf(batch_feature,batch_coarse_feature) 70 | else: 71 | batch_feature = self.octree.query_feature(batch_coord, True) # query features 72 | batch_sdf = self.geo_decoder.sdf(batch_feature) 73 | sdf_pred[head:tail] = batch_sdf.detach().cpu().numpy() 74 | 75 | if query_sem: 76 | batch_sem = self.sem_decoder.sem_label(batch_feature) 77 | sem_pred[head:tail] = batch_sem.detach().cpu().numpy() 78 | if query_mask: 79 | # get the marching cubes mask 80 | # hierarchical_indices: bottom-up 81 | check_level_indices = self.octree.hierarchical_indices[check_level] 82 | # print(check_level_indices) 83 | # if index is -1 for the level, then means the point is not valid under this level 84 | mask_mc = check_level_indices >= 0 85 | # print(mask_mc.shape) 86 | # all should be true (all the corner should be valid) 87 | mask_mc = torch.all(mask_mc, dim=1) 88 | mc_mask[head:tail] = mask_mc.detach().cpu().numpy() 89 | # but for scimage's marching cubes, the top right corner's mask should also be true to conduct marching cubes 90 | else: 91 | feature = self.octree.query_feature(coord, True) 92 | if query_sdf: 93 | sdf_pred = self.geo_decoder.sdf(feature).detach().cpu().numpy() 94 | if query_sem: 95 | sem_pred = self.sem_decoder.sem_label(feature).detach().cpu().numpy() 96 | if query_mask: 97 | # get the marching cubes mask 98 | check_level_indices = self.octree.hierarchical_indices[check_level] 99 | # if index is -1 for the level, then means the point is not valid under this level 100 | mask_mc = check_level_indices >= 0 101 | # all should be true (all the corner should be valid) 102 | mc_mask = torch.all(mask_mc, dim=1).detach().cpu().numpy() 103 | 104 | return sdf_pred, sem_pred, mc_mask 105 | 106 | def get_query_from_bbx(self, bbx, voxel_size): 107 | """ get grid query points inside a given bounding box (bbx) 108 | Args: 109 | bbx: open3d bounding box, in world coordinate system, with unit m 110 | voxel_size: scalar, marching cubes voxel size with unit m 111 | Returns: 112 | coord: Nx3 torch tensor, the coordinates of all N (axbxc) query points in the scaled 113 | kaolin coordinate system [-1,1] 114 | voxel_num_xyz: 3dim numpy array, the number of voxels on each axis for the bbx 115 | voxel_origin: 3dim numpy array the coordinate of the bottom-left corner of the 3d grids 116 | for marching cubes, in world coordinate system with unit m 117 | """ 118 | # bbx and voxel_size are all in the world coordinate system 119 | min_bound = bbx.get_min_bound() 120 | max_bound = bbx.get_max_bound() 121 | len_xyz = max_bound - min_bound 122 | voxel_num_xyz = (np.ceil(len_xyz/voxel_size)+self.config.pad_voxel*2).astype(np.int_) 123 | voxel_origin = min_bound-self.config.pad_voxel*voxel_size 124 | # pad an additional voxel underground to gurantee the reconstruction of ground 125 | voxel_origin[2]-=voxel_size 126 | voxel_num_xyz[2]+=1 127 | 128 | voxel_count_total = voxel_num_xyz[0] * voxel_num_xyz[1] * voxel_num_xyz[2] 129 | if voxel_count_total > 1e8: # TODO: avoid gpu memory issue, dirty fix 130 | self.cur_device = "cpu" # firstly save in cpu memory (which would be larger than gpu's) 131 | print("too much query points, use cpu memory") 132 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device) 133 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device) 134 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device) 135 | 136 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ... 137 | x, y, z = torch.meshgrid(x, y, z, indexing='ij') 138 | # get the vector of all the grid point's 3D coordinates 139 | coord = torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float() 140 | # transform to world coordinate system 141 | coord *= voxel_size 142 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device) 143 | # scaling to the [-1, 1] coordinate system 144 | coord *= self.world_scale 145 | 146 | return coord, voxel_num_xyz, voxel_origin 147 | 148 | def generate_sdf_map(self, coord, sdf_pred, mc_mask, map_path): 149 | device = o3d.core.Device("CPU:0") 150 | dtype = o3d.core.float32 151 | sdf_map_pc = o3d.t.geometry.PointCloud(device) 152 | 153 | # scaling back to the world coordinate system 154 | coord /= self.world_scale 155 | coord_np = coord.detach().cpu().numpy() 156 | 157 | sdf_pred_world = sdf_pred * self.config.logistic_gaussian_ratio*self.config.sigma_sigmoid_m # convert to unit: m 158 | 159 | # the sdf (unit: m) would be saved in the intensity channel 160 | sdf_map_pc.point['positions'] = o3d.core.Tensor(coord_np, dtype, device) 161 | sdf_map_pc.point['intensities'] = o3d.core.Tensor(np.expand_dims(sdf_pred_world, axis=1), dtype, device) # scaled sdf prediction 162 | if mc_mask is not None: 163 | # the marching cubes mask would be saved in the labels channel (indicating the hierarchical position in the octree) 164 | sdf_map_pc.point['labels'] = o3d.core.Tensor(np.expand_dims(mc_mask, axis=1), o3d.core.int32, device) # mask 165 | 166 | # global transform (to world coordinate system) before output 167 | sdf_map_pc.transform(self.global_transform) 168 | o3d.t.io.write_point_cloud(map_path, sdf_map_pc, print_progress=False) 169 | print("save the sdf map to %s" % (map_path)) 170 | 171 | def assign_to_bbx(self, sdf_pred, sem_pred, mc_mask, voxel_num_xyz): 172 | """ assign the queried sdf, semantic label and marching cubes mask back to the 3D grids in the specified bounding box 173 | Args: 174 | sdf_pred: Ndim np.array 175 | sem_pred: Ndim np.array 176 | mc_mask: Ndim bool np.array 177 | voxel_num_xyz: 3dim numpy array, the number of voxels on each axis for the bbx 178 | Returns: 179 | sdf_pred: a*b*c np.array, 3d grids of sign distance values 180 | sem_pred: a*b*c np.array, 3d grids of semantic labels 181 | mc_mask: a*b*c np.array, 3d grids of marching cube masks, marching cubes only on where 182 | the mask is true 183 | """ 184 | if sdf_pred is not None: 185 | sdf_pred = sdf_pred.reshape(voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]) 186 | 187 | if sem_pred is not None: 188 | sem_pred = sem_pred.reshape(voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]) 189 | 190 | if mc_mask is not None: 191 | mc_mask = mc_mask.reshape(voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]).astype(dtype=bool) 192 | # mc_mask[:,:,0:1] = True # TODO: dirty fix for the ground issue 193 | 194 | return sdf_pred, sem_pred, mc_mask 195 | 196 | def mc_mesh(self, mc_sdf, mc_mask, voxel_size, mc_origin): 197 | """ use the marching cubes algorithm to get mesh vertices and faces 198 | Args: 199 | mc_sdf: a*b*c np.array, 3d grids of sign distance values 200 | mc_mask: a*b*c np.array, 3d grids of marching cube masks, marching cubes only on where 201 | the mask is true 202 | voxel_size: scalar, marching cubes voxel size with unit m 203 | mc_origin: 3*1 np.array, the coordinate of the bottom-left corner of the 3d grids for 204 | marching cubes, in world coordinate system with unit m 205 | Returns: 206 | ([verts], [faces]), mesh vertices and triangle faces 207 | """ 208 | print("Marching cubes ...") 209 | # the input are all already numpy arraies 210 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0) 211 | try: 212 | verts, faces, normals, values = skimage.measure.marching_cubes( 213 | mc_sdf, level=0.0, allow_degenerate=False, mask=mc_mask) 214 | except: 215 | pass 216 | 217 | verts = mc_origin + verts * voxel_size 218 | return verts, faces 219 | 220 | def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices = True): 221 | print("predict semantic labels of the vertices") 222 | verts_scaled = torch.tensor(verts * self.world_scale, dtype=self.dtype, device=self.device) 223 | _, verts_sem, _ = self.query_points(verts_scaled, self.config.infer_bs, False, True, False) 224 | verts_sem_list = list(verts_sem) 225 | verts_sem_rgb = [sem_kitti_color_map[sem_label] for sem_label in verts_sem_list] 226 | verts_sem_rgb = np.asarray(verts_sem_rgb)/255.0 227 | mesh.vertex_colors = o3d.utility.Vector3dVector(verts_sem_rgb) 228 | 229 | # filter the freespace vertices 230 | if filter_free_space_vertices: 231 | non_freespace_idx = verts_sem <= 0 232 | mesh.remove_vertices_by_mask(non_freespace_idx) 233 | 234 | return mesh 235 | 236 | def filter_free_space_vertices(self, mesh, verts, map_down_pc): 237 | 238 | print("******** begin to clean mesh ********") 239 | map_down_pc = map_down_pc.voxel_down_sample(0.2 * self.config.mc_res_m) 240 | print("******** build kd-tree ********") 241 | points_kd_tree = cKDTree(np.asarray(map_down_pc.points)) 242 | print("******** query neighbors ********") 243 | radius = min(0.2, self.config.mc_res_m) # 0.15 244 | verts_find_neighbors_num = points_kd_tree.query_ball_point(verts, 0.80*radius, workers=12, return_length=True) 245 | mesh.remove_vertices_by_mask(verts_find_neighbors_num < 1) 246 | print("******** finished ********") 247 | return mesh 248 | 249 | def filter_isolated_vertices(self, mesh, filter_cluster_min_tri = 100): 250 | # print("Cluster connected triangles") 251 | triangle_clusters, cluster_n_triangles, _ = (mesh.cluster_connected_triangles()) 252 | triangle_clusters = np.asarray(triangle_clusters) 253 | cluster_n_triangles = np.asarray(cluster_n_triangles) 254 | # cluster_area = np.asarray(cluster_area) 255 | # print("Remove the small clusters") 256 | # mesh_0 = copy.deepcopy(mesh) 257 | triangles_to_remove = cluster_n_triangles[triangle_clusters] < filter_cluster_min_tri 258 | mesh.remove_triangles_by_mask(triangles_to_remove) 259 | # mesh = mesh_0 260 | return mesh 261 | 262 | def recon_bbx_mesh(self, bbx, map_down_pc, voxel_size, mesh_path, map_path, \ 263 | save_map = False, estimate_sem = False, estimate_normal = True, \ 264 | filter_isolated_mesh = True, filter_free_space_vertices = False): 265 | 266 | # reconstruct and save the (semantic) mesh from the feature octree the decoders within a 267 | # given bounding box. 268 | # bbx and voxel_size all with unit m, in world coordinate system 269 | 270 | coord, voxel_num_xyz, voxel_origin = self.get_query_from_bbx(bbx, voxel_size) 271 | sdf_pred, _, mc_mask = self.query_points(coord, self.config.infer_bs, True, False, self.config.mc_mask_on) 272 | if save_map: 273 | self.generate_sdf_map(coord, sdf_pred, mc_mask, map_path) 274 | mc_sdf, _, mc_mask = self.assign_to_bbx(sdf_pred, None, mc_mask, voxel_num_xyz) 275 | verts, faces = self.mc_mesh(mc_sdf, mc_mask, voxel_size, voxel_origin) 276 | 277 | # directly use open3d to get mesh 278 | mesh = o3d.geometry.TriangleMesh( 279 | o3d.utility.Vector3dVector(verts), 280 | o3d.utility.Vector3iVector(faces) 281 | ) 282 | 283 | if estimate_sem: 284 | mesh = self.estimate_vertices_sem(mesh, verts, filter_free_space_vertices) 285 | 286 | if filter_free_space_vertices: 287 | mesh = self.filter_free_space_vertices(mesh, verts, map_down_pc) 288 | 289 | if estimate_normal: 290 | mesh.compute_vertex_normals() 291 | #self.visualize_normals_with_rgb(mesh) 292 | 293 | if filter_isolated_mesh: 294 | mesh = self.filter_isolated_vertices(mesh) 295 | 296 | # global transform (to world coordinate system) before output 297 | mesh.transform(self.global_transform) 298 | 299 | # write the mesh to ply file 300 | o3d.io.write_triangle_mesh(mesh_path, mesh) 301 | print("save the mesh to %s\n" % (mesh_path)) 302 | 303 | return mesh 304 | 305 | # reconstruct the map sparsely using the octree, only query the sdf at certain level ($query_level) of the octree 306 | # much faster and also memory-wise more efficient 307 | def recon_octree_mesh(self, query_level, map_down_pc, mc_res_m, mesh_path, map_path, \ 308 | save_map = False, estimate_sem = False, estimate_normal = True, \ 309 | filter_isolated_mesh = True, filter_free_space_vertices = False): 310 | 311 | nodes_coord_scaled = self.octree.get_octree_nodes(query_level) # query level top-down 312 | nodes_count = nodes_coord_scaled.shape[0] 313 | min_nodes = np.min(nodes_coord_scaled, 0) 314 | max_nodes = np.max(nodes_coord_scaled, 0) 315 | 316 | node_res_scaled = 2**(1-query_level) # voxel size for queried octree node in [-1,1] coordinate system 317 | # marching cube's voxel size should be evenly divisible by the queried octree node's size 318 | # if query_level=tree_level, voxel_count_per_side_node = feature_voxel_size/mc_res_m 319 | voxel_count_per_side_node = np.ceil(node_res_scaled / self.world_scale / mc_res_m).astype(dtype=int) 320 | # assign coordinates for the queried octree node 321 | x = torch.arange(voxel_count_per_side_node, dtype=torch.int16, device=self.device) 322 | y = torch.arange(voxel_count_per_side_node, dtype=torch.int16, device=self.device) 323 | z = torch.arange(voxel_count_per_side_node, dtype=torch.int16, device=self.device) 324 | node_box_size = (np.ones(3)*voxel_count_per_side_node).astype(dtype=int) 325 | 326 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ... 327 | x, y, z = torch.meshgrid(x, y, z, indexing='ij') 328 | # get the vector of all the grid point's 3D coordinates 329 | coord = torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float() 330 | mc_res_scaled = node_res_scaled / voxel_count_per_side_node # voxel size for marching cubes in [-1,1] coordinate system 331 | # transform to [-1,1] coordinate system 332 | coord *= mc_res_scaled 333 | 334 | # the voxel count for the whole map 335 | voxel_count_per_side = ((max_nodes - min_nodes)/mc_res_scaled+voxel_count_per_side_node).astype(int) 336 | # initialize the whole map 337 | query_grid_sdf = np.zeros((voxel_count_per_side[0], voxel_count_per_side[1], voxel_count_per_side[2]), dtype=np.float16) # use float16 to save memory 338 | query_grid_mask = np.zeros((voxel_count_per_side[0], voxel_count_per_side[1], voxel_count_per_side[2]), dtype=bool) # mask off 339 | 340 | for node_idx in tqdm(range(nodes_count)): 341 | node_coord_scaled = nodes_coord_scaled[node_idx, :] 342 | cur_origin = torch.tensor(node_coord_scaled - 0.5 * (node_res_scaled - mc_res_scaled), device=self.device) 343 | cur_coord = coord.clone() 344 | cur_coord += cur_origin 345 | cur_sdf_pred, _, cur_mc_mask = self.query_points(cur_coord, self.config.infer_bs, True, False, self.config.mc_mask_on) 346 | cur_sdf_pred, _, cur_mc_mask = self.assign_to_bbx(cur_sdf_pred, None, cur_mc_mask, node_box_size) 347 | shift_coord = (node_coord_scaled - min_nodes)/node_res_scaled 348 | shift_coord = (shift_coord*voxel_count_per_side_node).astype(int) 349 | query_grid_sdf[shift_coord[0]:shift_coord[0]+voxel_count_per_side_node, shift_coord[1]:shift_coord[1]+voxel_count_per_side_node, shift_coord[2]:shift_coord[2]+voxel_count_per_side_node] = cur_sdf_pred 350 | query_grid_mask[shift_coord[0]:shift_coord[0]+voxel_count_per_side_node, shift_coord[1]:shift_coord[1]+voxel_count_per_side_node, shift_coord[2]:shift_coord[2]+voxel_count_per_side_node] = cur_mc_mask 351 | 352 | mc_voxel_size = mc_res_scaled / self.world_scale 353 | mc_voxel_origin = (min_nodes - 0.5 * (node_res_scaled - mc_res_scaled)) / self.world_scale 354 | 355 | # if save_map: # ignore it now, too much for the memory 356 | # # query_grid_coord 357 | # self.generate_sdf_map(query_grid_coord, query_grid_sdf, query_grid_mask, map_path) 358 | 359 | verts, faces = self.mc_mesh(query_grid_sdf, query_grid_mask, mc_voxel_size, mc_voxel_origin) 360 | # directly use open3d to get mesh 361 | mesh = o3d.geometry.TriangleMesh( 362 | o3d.utility.Vector3dVector(verts), 363 | o3d.utility.Vector3iVector(faces) 364 | ) 365 | 366 | if estimate_sem: 367 | mesh = self.estimate_vertices_sem(mesh, verts, filter_free_space_vertices) 368 | 369 | if filter_free_space_vertices: 370 | mesh = self.filter_free_space_vertices(mesh, verts, map_down_pc) 371 | 372 | if estimate_normal: 373 | mesh.compute_vertex_normals() 374 | #self.visualize_normals_with_rgb(mesh) 375 | 376 | if filter_isolated_mesh: 377 | mesh = self.filter_isolated_vertices(mesh) 378 | 379 | # global transform (to world coordinate system) before output 380 | mesh.transform(self.global_transform) 381 | 382 | # write the mesh to ply file 383 | o3d.io.write_triangle_mesh(mesh_path, mesh) 384 | print("save the mesh to %s\n" % (mesh_path)) 385 | 386 | return mesh -------------------------------------------------------------------------------- /utils/pose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import inv 3 | import csv 4 | from pyquaternion import Quaternion 5 | 6 | 7 | def read_calib_file(filename): 8 | """ 9 | read calibration file (with the kitti format) 10 | returns -> dict calibration matrices as 4*4 numpy arrays 11 | """ 12 | calib = {} 13 | calib_file = open(filename) 14 | key_num = 0 15 | 16 | for line in calib_file: 17 | # print(line) 18 | key, content = line.strip().split(":") 19 | values = [float(v) for v in content.strip().split()] 20 | pose = np.zeros((4, 4)) 21 | 22 | pose[0, 0:4] = values[0:4] 23 | pose[1, 0:4] = values[4:8] 24 | pose[2, 0:4] = values[8:12] 25 | pose[3, 3] = 1.0 26 | 27 | calib[key] = pose 28 | 29 | calib_file.close() 30 | return calib 31 | 32 | 33 | def read_poses_file(filename, calibration): 34 | """ 35 | read pose file (with the kitti format) 36 | """ 37 | pose_file = open(filename) 38 | 39 | poses = [] 40 | 41 | Tr = calibration["Tr"] 42 | Tr_inv = inv(Tr) 43 | 44 | for line in pose_file: 45 | values = [float(v) for v in line.strip().split()] 46 | 47 | pose = np.zeros((4, 4)) 48 | pose[0, 0:4] = values[0:4] 49 | pose[1, 0:4] = values[4:8] 50 | pose[2, 0:4] = values[8:12] 51 | pose[3, 3] = 1.0 52 | 53 | poses.append( 54 | np.matmul(Tr_inv, np.matmul(pose, Tr)) 55 | ) # lidar pose in world frame 56 | 57 | pose_file.close() 58 | return poses 59 | 60 | 61 | def csv_odom_to_transforms(path): 62 | 63 | # odom_tfs = {} 64 | poses = [] 65 | with open(path, mode="r") as f: 66 | reader = csv.reader(f) 67 | # get header and change timestamp label name 68 | header = next(reader) 69 | header[0] = "ts" 70 | # Convert string odometry to numpy transfor matrices 71 | for row in reader: 72 | odom = {l: row[i] for i, l in enumerate(header)} 73 | # Translarion and rotation quaternion as numpy arrays 74 | trans = np.array([float(odom[l]) for l in ["tx", "ty", "tz"]]) 75 | quat = Quaternion( 76 | np.array([float(odom[l]) for l in ["qx", "qy", "qz", "qw"]]) 77 | ) 78 | rot = quat.rotation_matrix 79 | # Build numpy transform matrix 80 | odom_tf = np.eye(4) 81 | odom_tf[0:3, 3] = trans 82 | odom_tf[0:3, 0:3] = rot 83 | # Add transform to timestamp indexed dictionary 84 | # odom_tfs[odom["ts"]] = odom_tf 85 | poses.append(odom_tf) 86 | 87 | return poses 88 | -------------------------------------------------------------------------------- /utils/scan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d.core as o3c 3 | # deprecated now 4 | class Frame(): 5 | def __init__(self, frame_id, origin, points, normals, labels) -> None: 6 | 7 | self.frame_id = frame_id 8 | self.origin = origin 9 | self.points = points 10 | self.normals = normals 11 | self.labels = labels 12 | #self.ranges = torch.linalg.norm(self.points, dim=1, keepdim=True) # nx1 13 | 14 | def get_rays(self): 15 | return self.points - self.origin 16 | 17 | def get_points(self): 18 | return self.points 19 | 20 | def sample_data(self, N_rays): 21 | 22 | return 23 | 24 | def get_sample_data(self, N_rays): 25 | 26 | return 27 | 28 | -------------------------------------------------------------------------------- /utils/semantic_kitti_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LabelDataConverter: 5 | """Convert .label binary data to instance id and rgb""" 6 | 7 | def __init__(self, labelscan): 8 | 9 | self.convertdata(labelscan) 10 | 11 | def convertdata(self, labelscan): 12 | 13 | self.semantic_id = [] 14 | self.rgb_id = [] 15 | self.instance_id = [] 16 | self.rgb_arr_id = [] 17 | 18 | for counting in range(len(labelscan)): 19 | 20 | sem_id = int(labelscan[counting]) & 0xFFFF # lower 16 bit 21 | rgb, rgb_arr = self.get_sem_rgb(sem_id) 22 | instance_id = int(labelscan[counting]) >> 16 # higher 16 bit 23 | # rgb = self.get_random_rgb(instance_id) 24 | 25 | # print("Sem label:", sem_id, "Ins label:", instance_id, "Color:", hex(rgb)) 26 | # print(hex(rgb)) 27 | # instance label is given in each semantic label 28 | 29 | self.semantic_id.append(sem_id) 30 | self.rgb_id.append(rgb) 31 | self.rgb_arr_id.append(rgb_arr) 32 | self.instance_id.append(instance_id) 33 | 34 | 35 | def get_random_rgb(n): 36 | n = ((n ^ n >> 15) * 2246822519) & 0xFFFFFFFF 37 | n = ((n ^ n >> 13) * 3266489917) & 0xFFFFFFFF 38 | n = (n ^ n >> 16) >> 8 39 | print(n) 40 | return tuple(n.to_bytes(3, "big")) 41 | 42 | 43 | sem_kitti_learning_map = { 44 | 0 : 0, # "unlabeled" 45 | 1 : 0, # "outlier" mapped to "unlabeled" --------------------------mapped 46 | 10: 1, # "car" 47 | 11: 2, # "bicycle" 48 | 13: 5, # "bus" mapped to "other-vehicle" --------------------------mapped 49 | 15: 3, # "motorcycle" 50 | 16: 5, # "on-rails" mapped to "other-vehicle" ---------------------mapped 51 | 18: 4, # "truck" 52 | 20: 5, # "other-vehicle" 53 | 30: 6, # "person" 54 | 31: 7, # "bicyclist" 55 | 32: 8, # "motorcyclist" 56 | 40: 9, # "road" 57 | 44: 10, # "parking" 58 | 48: 11, # "sidewalk" 59 | 49: 12, # "other-ground" 60 | 50: 13, # "building" 61 | 51: 14, # "fence" 62 | 52: 20, # "other-structure" mapped to "unlabeled" ------------------mapped 63 | 60: 9, # "lane-marking" to "road" ---------------------------------mapped 64 | 70: 15, # "vegetation" 65 | 71: 16, # "trunk" 66 | 72: 17, # "terrain" 67 | 80: 18, # "pole" 68 | 81: 19, # "traffic-sign" 69 | 99: 20, # "other-object" to "unlabeled" ----------------------------mapped 70 | 252: 1, # "moving-car" 71 | 253: 7, # "moving-bicyclist" 72 | 254: 6, # "moving-person" 73 | 255: 8, # "moving-motorcyclist" 74 | 256: 5, # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 75 | 257: 5, # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 76 | 258: 4, # "moving-truck" 77 | 259: 5, # "moving-other-vehicle" 78 | } 79 | 80 | sem_kitti_labels = { 81 | 0: "unlabeled", 82 | 1: "car", 83 | 2: "bicycle", 84 | 3: "motorcycle", 85 | 4: "truck", 86 | 5: "other-vehicle", 87 | 6: "person", 88 | 7: "bicyclist", 89 | 8: "motorcyclist", 90 | 9: "road", 91 | 10: "parking", 92 | 11: "sidewalk", 93 | 12: "other-ground", 94 | 13: "building", 95 | 14: "fence", 96 | 15: "vegetation", 97 | 16: "trunk", 98 | 17: "terrain", 99 | 18: "pole", 100 | 19: "traffic-sign", 101 | 20: "others", 102 | } 103 | 104 | sem_kitti_color_map = { # rgb 105 | 0: [255, 255, 255], 106 | 1: [100, 150, 245], 107 | 2: [100, 230, 245], 108 | 3: [30, 60, 150], 109 | 4: [80, 30, 180], 110 | 5: [0, 0, 255], 111 | 6: [255, 30, 30], 112 | 7: [255, 40, 200], 113 | 8: [150, 30, 90], 114 | 9: [255, 0, 255], 115 | 10: [255, 150, 255], 116 | 11: [75, 0, 75], 117 | 12: [175, 0, 75], 118 | 13: [255, 200, 0], 119 | 14: [255, 120, 50], 120 | 15: [0, 175, 0], 121 | 16: [135, 60, 0], 122 | 17: [150, 240, 80], 123 | 18: [255, 240, 150], 124 | 19: [255, 0, 0], 125 | 20: [30, 30, 30] 126 | } -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import sys 3 | import os 4 | import multiprocessing 5 | import getpass 6 | import time 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from datetime import datetime 10 | from torch import optim 11 | from torch.optim.optimizer import Optimizer 12 | from torch.autograd import grad 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | import wandb 17 | import json 18 | import open3d as o3d 19 | 20 | from utils.config import SHINEConfig 21 | 22 | # setup this run 23 | def setup_experiment(config: SHINEConfig): 24 | 25 | os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count()) 26 | os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id) 27 | ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # begining timestamp 28 | run_name = config.name + "_" + ts # modified to a name that is easier to index 29 | 30 | run_path = os.path.join(config.output_root, run_name) 31 | access = 0o755 32 | os.makedirs(run_path, access, exist_ok=True) 33 | assert os.access(run_path, os.W_OK) 34 | print(f"Start {run_path}") 35 | 36 | mesh_path = os.path.join(run_path, "mesh") 37 | map_path = os.path.join(run_path, "map") 38 | model_path = os.path.join(run_path, "model") 39 | os.makedirs(mesh_path, access, exist_ok=True) 40 | os.makedirs(map_path, access, exist_ok=True) 41 | os.makedirs(model_path, access, exist_ok=True) 42 | 43 | if config.wandb_vis_on: 44 | # set up wandb 45 | setup_wandb() 46 | wandb.init(project="SHINEMapping", config=vars(config), dir=run_path) # your own worksapce 47 | wandb.run.name = run_name 48 | 49 | # o3d.utility.random.seed(42) 50 | 51 | return run_path 52 | 53 | 54 | def setup_optimizer(config: SHINEConfig, octree_feat, mlp_geo_param, mlp_sem_param, sigma_size) -> Optimizer: 55 | lr_cur = config.lr 56 | opt_setting = [] 57 | # weight_decay is for L2 regularization, only applied to MLP 58 | if mlp_geo_param is not None: 59 | mlp_geo_param_opt_dict = {'params': mlp_geo_param, 'lr': lr_cur, 'weight_decay': config.weight_decay} 60 | opt_setting.append(mlp_geo_param_opt_dict) 61 | if config.semantic_on and mlp_sem_param is not None: 62 | mlp_sem_param_opt_dict = {'params': mlp_sem_param, 'lr': lr_cur, 'weight_decay': config.weight_decay} 63 | opt_setting.append(mlp_sem_param_opt_dict) 64 | # feature octree 65 | for i in range(config.tree_level_feat): 66 | # try to also add L2 regularization on the feature octree (results not quite good) 67 | feat_opt_dict = {'params': octree_feat[config.tree_level_feat-i-1], 'lr': lr_cur} 68 | lr_cur *= config.lr_level_reduce_ratio 69 | opt_setting.append(feat_opt_dict) 70 | # make sigma also learnable for differentiable rendering (but not for our method) 71 | if config.ray_loss: 72 | sigma_opt_dict = {'params': sigma_size, 'lr': config.lr} 73 | opt_setting.append(sigma_opt_dict) 74 | 75 | if config.opt_adam: 76 | opt = optim.Adam(opt_setting, betas=(0.9,0.99), eps = config.adam_eps) 77 | else: 78 | opt = optim.SGD(opt_setting, momentum=0.9) 79 | 80 | return opt 81 | 82 | 83 | # set up weight and bias 84 | def setup_wandb(): 85 | print("Weight & Bias logging option is on. Disable it by setting wandb_vis_on: False in the config file.") 86 | username = getpass.getuser() 87 | # print(username) 88 | wandb_key_path = username + "_wandb.key" 89 | if not os.path.exists(wandb_key_path): 90 | wandb_key = input( 91 | "[You need to firstly setup and login wandb] Please enter your wandb key (https://wandb.ai/authorize):" 92 | ) 93 | with open(wandb_key_path, "w") as fh: 94 | fh.write(wandb_key) 95 | else: 96 | print("wandb key already set") 97 | os.system('export WANDB_API_KEY=$(cat "' + wandb_key_path + '")') 98 | 99 | def step_lr_decay( 100 | optimizer: Optimizer, 101 | learning_rate: float, 102 | iteration_number: int, 103 | steps: List, 104 | reduce: float = 1.0): 105 | 106 | if reduce > 1.0 or reduce <= 0.0: 107 | sys.exit( 108 | "The decay reta should be between 0 and 1." 109 | ) 110 | 111 | if iteration_number in steps: 112 | steps.remove(iteration_number) 113 | learning_rate *= reduce 114 | print("Reduce base learning rate to {}".format(learning_rate)) 115 | 116 | for param in optimizer.param_groups: 117 | param["lr"] *= reduce 118 | 119 | return learning_rate 120 | 121 | 122 | def num_model_weights(model: nn.Module) -> int: 123 | num_weights = int( 124 | sum( 125 | [ 126 | np.prod(p.size()) 127 | for p in filter(lambda p: p.requires_grad, model.parameters()) 128 | ] 129 | ) 130 | ) 131 | return num_weights 132 | 133 | 134 | def print_model_summary(model: nn.Module): 135 | for child in model.children(): 136 | print(child) 137 | 138 | 139 | def get_gradient(inputs, outputs): 140 | d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) 141 | points_grad = grad( 142 | outputs=outputs, 143 | inputs=inputs, 144 | grad_outputs=d_points, 145 | create_graph=True, 146 | retain_graph=True, 147 | only_inputs=True, 148 | )[0] 149 | return points_grad 150 | 151 | def correct_sdf(sdf_label, normal_label, g_normalized, ray_direction, trunc_dist, surface_mask): 152 | # all in global coordinate frame. 153 | cos_theta = (ray_direction * g_normalized).sum(dim=1).abs() 154 | cos_alpha = (normal_label * g_normalized).sum(dim=1).abs() 155 | sin_theta = (1.0 - cos_theta*cos_theta).sqrt() 156 | sin_alpha = (1.0 - cos_alpha*cos_alpha).sqrt() 157 | 158 | a = (ray_direction * g_normalized).sum(dim=1) # TODO: ray-wise condition. 159 | b = (ray_direction * normal_label).sum(dim=1) 160 | # convex_mask = a-b > 0 # convex 161 | correct_ratio = torch.ones_like(sdf_label, device=sdf_label.device) 162 | convex_ratio = torch.abs(cos_theta - sin_theta*(1-cos_alpha)/sin_alpha) # convex 163 | concave_ratio = torch.abs(cos_theta + sin_theta*(1-cos_alpha)/sin_alpha) # concave 164 | correct_ratio = convex_ratio 165 | correct_ratio[a 0.9] = cos_theta # as plane. 167 | correct_ratio_masked = torch.ones_like(correct_ratio) 168 | correct_ratio_masked[surface_mask] = correct_ratio[surface_mask] 169 | 170 | np_sdf = sdf_label * correct_ratio_masked 171 | 172 | #np_sdf[np_sdf > trunc_dist] = trunc_dist # Tr should be scaled 173 | #np_sdf[np_sdf < -trunc_dist] = -trunc_dist 174 | 175 | # plane_correc_ratio = (normal_label * ray_direction).sum(dim=1).abs() 176 | # np_sdf = sdf_label * plane_correc_ratio 177 | 178 | return np_sdf 179 | 180 | # pytorch version < 2.0 it is not feasiable now 181 | def voxel_down_sample_torch(points: torch.tensor, voxel_size: float): 182 | """ 183 | voxel based downsampling. Returns the indices of the points which are closest to the voxel centers. 184 | Args: 185 | points (torch.Tensor): [N,3] point coordinates 186 | voxel_size (float): grid resolution 187 | 188 | Returns: 189 | indices (torch.Tensor): [M] indices of the original point cloud, downsampled point cloud would be `points[indices]` 190 | 191 | Reference: Louis Wiesmann 192 | """ 193 | _quantization = 1000 # if change to 1, then it would be random sample 194 | 195 | offset = torch.floor(points.min(dim=0)[0]/voxel_size).long() 196 | grid = torch.floor(points / voxel_size) 197 | center = (grid + 0.5) * voxel_size 198 | dist = ((points - center) ** 2).sum(dim=1)**0.5 199 | dist = dist / dist.max() * (_quantization - 1) # for speed up 200 | 201 | grid = grid.long() - offset 202 | #v_size = grid.max().ceil() 203 | v_size = grid.max() 204 | grid_idx = grid[:, 0] + grid[:, 1] * v_size + grid[:, 2] * v_size * v_size 205 | 206 | unique, inverse = torch.unique(grid_idx, return_inverse=True) 207 | idx_d = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) 208 | 209 | offset = 10**len(str(idx_d.max().item())) 210 | 211 | idx_d = idx_d + dist.long() * offset 212 | idx = torch.empty(unique.shape, dtype=inverse.dtype, 213 | device=inverse.device).scatter_reduce_(dim=0, index=inverse, src=idx_d, reduce="amin", include_self=False) 214 | idx = idx % offset 215 | return idx.long() 216 | 217 | def list_duplicates(seq): 218 | dd = defaultdict(list) 219 | for i,item in enumerate(seq): 220 | dd[item].append(i) 221 | return [(key,locs) for key,locs in dd.items() if len(locs)>=1] 222 | 223 | def list_count(seq): 224 | dd = {} 225 | for m in seq: 226 | if m in dd: 227 | dd[m] += 1 228 | else: 229 | dd[m] = 1 230 | return dd 231 | 232 | def freeze_model(model: nn.Module): 233 | for child in model.children(): 234 | for param in child.parameters(): 235 | param.requires_grad = False 236 | 237 | 238 | def unfreeze_model(model: nn.Module): 239 | for child in model.children(): 240 | for param in child.parameters(): 241 | param.requires_grad = True 242 | 243 | 244 | def save_checkpoint( 245 | feature_octree, geo_decoder, sem_decoder, optimizer, run_path, checkpoint_name, iters 246 | ): 247 | torch.save( 248 | { 249 | "iters": iters, 250 | "feature_octree": feature_octree, # save the whole NN module (the hierachical features and the indexing structure) 251 | "geo_decoder": geo_decoder.state_dict(), 252 | "sem_decoder": sem_decoder.state_dict(), 253 | "optimizer": optimizer.state_dict(), 254 | }, 255 | os.path.join(run_path, f"{checkpoint_name}.pth"), 256 | ) 257 | print(f"save the model to {run_path}/{checkpoint_name}.pth") 258 | 259 | 260 | def save_decoder(geo_decoder, sem_decoder, run_path, checkpoint_name): 261 | torch.save({"geo_decoder": geo_decoder.state_dict(), 262 | "sem_decoder": sem_decoder.state_dict()}, 263 | os.path.join(run_path, f"{checkpoint_name}_decoders.pth"), 264 | ) 265 | 266 | def save_geo_decoder(geo_decoder, run_path, checkpoint_name): 267 | torch.save({"geo_decoder": geo_decoder.state_dict()}, 268 | os.path.join(run_path, f"{checkpoint_name}_geo_decoder.pth"), 269 | ) 270 | 271 | def save_sem_decoder(sem_decoder, run_path, checkpoint_name): 272 | torch.save({"sem_decoder": sem_decoder.state_dict()}, 273 | os.path.join(run_path, f"{checkpoint_name}_sem_decoder.pth"), 274 | ) 275 | 276 | def get_time(): 277 | """ 278 | :return: get timing statistics 279 | """ 280 | torch.cuda.synchronize() 281 | return time.time() 282 | 283 | def load_from_json(filename: Path): 284 | """Load a dictionary from a JSON filename. 285 | Args: 286 | filename: The filename to load from. 287 | """ 288 | assert filename.suffix == ".json" 289 | with open(filename, encoding="UTF-8") as file: 290 | return json.load(file) 291 | 292 | 293 | def write_to_json(filename: Path, content: dict): 294 | """Write data to a JSON file. 295 | Args: 296 | filename: The filename to write to. 297 | content: The dictionary data to write. 298 | """ 299 | assert filename.suffix == ".json" 300 | with open(filename, "w", encoding="UTF-8") as file: 301 | json.dump(content, file) 302 | -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | # Partially borrowed from Nacho's lidar odometry (KISS-ICP) 2 | 3 | import copy 4 | from functools import partial 5 | import os 6 | from typing import Callable, List 7 | import cv2 8 | 9 | import numpy as np 10 | import open3d as o3d 11 | 12 | YELLOW = np.array([1, 0.706, 0]) 13 | RED = np.array([128, 0, 0]) / 255.0 14 | BLACK = np.array([0, 0, 0]) / 255.0 15 | GOLDEN = np.array([1.0, 0.843, 0.0]) 16 | 17 | random_color_table = [[230. / 255., 0., 0.], # red 18 | [60. / 255., 180. / 255., 75. / 255.], # green 19 | [0., 0., 255. / 255.], # blue 20 | [255. / 255., 0, 255. / 255.], 21 | [255. / 255., 165. / 255., 0.], 22 | [128. / 255., 0, 128. / 255.], 23 | [0., 255. / 255., 255. / 255.], 24 | [210. / 255., 245. / 255., 60. / 255.], 25 | [250. / 255., 190. / 255., 190. / 255.], 26 | [0., 128. / 255., 128. / 255.] 27 | ] 28 | 29 | class MapVisualizer(): 30 | # Public Interaface ---------------------------------------------------------------------------- 31 | def __init__(self): 32 | # Initialize GUI controls 33 | self.block_vis = True 34 | self.play_crun = True 35 | self.reset_bounding_box = True 36 | 37 | # Create data 38 | self.scan = o3d.geometry.PointCloud() 39 | self.frame_axis_len = 0.5 40 | self.frame = o3d.geometry.TriangleMesh() 41 | self.mesh = o3d.geometry.TriangleMesh() 42 | self.sample_points = o3d.geometry.PointCloud() 43 | 44 | # Initialize visualizer 45 | self.vis = o3d.visualization.VisualizerWithKeyCallback() 46 | self._register_key_callbacks() 47 | self._initialize_visualizer() 48 | 49 | # Visualization options 50 | self.render_map = True 51 | self.render_normal = True 52 | self.render_frame = True 53 | 54 | self.global_view = False 55 | self.view_control = self.vis.get_view_control() 56 | self.camera_params = self.view_control.convert_to_pinhole_camera_parameters() 57 | 58 | # self.global_transform = np.eye(4) 59 | self.frame_num = 0 60 | 61 | def update_view(self): 62 | self.vis.poll_events() 63 | self.vis.update_renderer() 64 | 65 | def pause_view(self): 66 | while self.block_vis: 67 | self.update_view() 68 | if self.play_crun: 69 | break 70 | 71 | def update(self, scan, pose, mesh = None): 72 | if mesh is not None: 73 | if self.render_normal: 74 | mesh = self.visualize_normals_with_rgb(mesh) 75 | self._update_geometries(scan, pose, mesh) 76 | self.update_view() 77 | self.pause_view() 78 | # self.vis.capture_screen_image(f"./record/kitti/frame{self.frame_num}.png", do_render=True) 79 | # self.frame_num = self.frame_num+1 80 | 81 | def update_mesh(self, mesh): 82 | if mesh is not None: 83 | if self.render_normal: 84 | mesh = self.visualize_normals_with_rgb(mesh) 85 | self._update_mesh(mesh) 86 | self.update_view() 87 | self.pause_view() 88 | 89 | def update_point_cloud(self, sample_points): 90 | self._update_point_cloud(sample_points) 91 | self.update_view() 92 | self.pause_view() 93 | 94 | def destroy_window(self): 95 | self.vis.destroy_window() 96 | 97 | def stop(self): 98 | self.play_crun = not self.play_crun 99 | while self.block_vis: 100 | self.vis.poll_events() 101 | self.vis.update_renderer() 102 | if self.play_crun: 103 | break 104 | 105 | def visualize_normals_with_rgb(self, mesh): 106 | normals = np.asarray(mesh.vertex_normals) 107 | colors = (normals+1.0)*0.5 108 | mesh.vertex_colors = o3d.utility.Vector3dVector(colors) 109 | return mesh 110 | 111 | def save_video(self, output_path): 112 | # save record images to videof file. 113 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 114 | video_writer = cv2.VideoWriter(output_path, fourcc, 10.0, (1920, 1080)) 115 | print("begin to generate video...") 116 | print(len(self.record_images)) 117 | for image in self.record_images: 118 | #cv2.imshow("image", image) 119 | video_writer.write(image) 120 | print("finished!") 121 | video_writer.release() 122 | 123 | # Private Interaface --------------------------------------------------------------------------- 124 | def _initialize_visualizer(self): 125 | w_name = self.__class__.__name__ 126 | self.vis.create_window(window_name=w_name, width=1920, height=1080) 127 | self.vis.add_geometry(self.scan) 128 | self.vis.add_geometry(self.frame) 129 | self.vis.add_geometry(self.mesh) 130 | self._set_white_background(self.vis) 131 | self.vis.get_render_option().point_size = 4 132 | self.vis.get_render_option().light_on = False 133 | self.vis.get_render_option().mesh_show_back_face = True 134 | print(100 * "*") 135 | print(f"{w_name} initialized. Press [SPACE] to pause/start, [N] to step, [ESC] to exit.") 136 | 137 | def _register_key_callback(self, keys: List, callback: Callable): 138 | for key in keys: 139 | self.vis.register_key_callback(ord(str(key)), partial(callback)) 140 | 141 | def _register_key_callbacks(self): 142 | self._register_key_callback(["Q", "\x1b"], self._quit) 143 | self._register_key_callback([" "], self._start_stop) 144 | self._register_key_callback(["W"], self._toggle_view) 145 | self._register_key_callback(["F"], self._toggle_frame) 146 | self._register_key_callback(["M"], self._toggle_map) 147 | # self._register_key_callback(["B"], self._set_black_background) 148 | # self._register_key_callback(["W"], self._set_white_background) 149 | 150 | def _set_black_background(self, vis): 151 | vis.get_render_option().background_color = [0.0, 0.0, 0.0] 152 | 153 | def _set_white_background(self, vis): 154 | vis.get_render_option().background_color = [1.0, 1.0, 1.0] 155 | 156 | def _quit(self, vis): 157 | print("Destroying Visualizer") 158 | vis.destroy_window() 159 | os._exit(0) 160 | 161 | def _next_frame(self, vis): 162 | self.block_vis = not self.block_vis 163 | 164 | def _start_stop(self, vis): 165 | self.play_crun = not self.play_crun 166 | 167 | def _toggle_frame(self, vis): 168 | self.render_frame = not self.render_frame 169 | return False 170 | 171 | def _toggle_map(self, vis): 172 | self.render_map = not self.render_map 173 | return False 174 | 175 | def _update_mesh(self, mesh): 176 | if mesh is not None: 177 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 178 | self.mesh = mesh 179 | self.vis.add_geometry(self.mesh, self.reset_bounding_box) 180 | 181 | def _update_point_cloud(self, pcd): 182 | if pcd is not None: 183 | self.vis.remove_geometry(self.sample_points, self.reset_bounding_box) 184 | self.sample_points = pcd 185 | self.vis.add_geometry(self.sample_points, self.reset_bounding_box) 186 | print("update point cloud") 187 | 188 | def _update_geometries(self, scan, pose, mesh = None): 189 | # Scan (toggled by "F") 190 | if self.render_frame: 191 | self.scan.points = o3d.utility.Vector3dVector(scan.points) 192 | self.scan.paint_uniform_color(GOLDEN) 193 | else: 194 | self.scan.points = o3d.utility.Vector3dVector() 195 | 196 | # Always visualize the coordinate frame 197 | self.frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=self.frame_axis_len, origin=np.zeros(3)) 198 | self.frame = self.frame.transform(pose) 199 | 200 | # Mesh Map (toggled by "M") 201 | # mesh already got global shifted 202 | if self.render_map: 203 | if mesh is not None: 204 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) # if comment, then we keep the previous reconstructed mesh (for the case we use local map reconstruction) 205 | self.mesh = mesh 206 | self.vis.add_geometry(self.mesh, self.reset_bounding_box) 207 | else: 208 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 209 | 210 | self.vis.update_geometry(self.scan) 211 | self.vis.add_geometry(self.frame, self.reset_bounding_box) 212 | 213 | if self.reset_bounding_box: 214 | self.vis.reset_view_point(True) 215 | self.reset_bounding_box = False 216 | 217 | def _toggle_view(self, vis): 218 | self.global_view = not self.global_view 219 | vis.update_renderer() 220 | vis.reset_view_point(True) 221 | current_camera = self.view_control.convert_to_pinhole_camera_parameters() 222 | if self.camera_params and not self.global_view: 223 | self.view_control.convert_from_pinhole_camera_parameters(self.camera_params) 224 | self.camera_params = current_camera 225 | --------------------------------------------------------------------------------