├── .gitignore ├── .idea ├── .gitignore ├── CLID-SLAM.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── assets ├── GUI_Mesh.png └── GUI_Neural_Points.png ├── cad ├── camera.ply ├── drone.ply ├── ipb_car.ply └── kitti_car.ply ├── config ├── run_SubT_MRS.yaml └── run_ncd128.yaml ├── dataset └── converter │ ├── config │ └── rosbag2dataset.yaml │ └── rosbag2dataset_parallel.py ├── experiment └── .gitkeep ├── gui ├── __pycache__ │ ├── gui_utils.cpython-312.pyc │ └── slam_gui.cpython-312.pyc ├── gui_utils.py └── slam_gui.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-312.pyc │ ├── decoder.cpython-312.pyc │ ├── local_point_cloud_map.cpython-312.pyc │ └── neural_points.cpython-312.pyc ├── decoder.py ├── local_point_cloud_map.py └── neural_points.py ├── requirements.txt ├── slam.py ├── tools.ipynb ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-312.pyc │ ├── config.cpython-312.pyc │ ├── data_sampler.cpython-312.pyc │ ├── dataset_indexing.cpython-312.pyc │ ├── error_state_iekf.cpython-312.pyc │ ├── eval_traj_utils.cpython-312.pyc │ ├── loss.cpython-312.pyc │ ├── mapper.cpython-312.pyc │ ├── mesher.cpython-312.pyc │ ├── semantic_kitti_utils.cpython-312.pyc │ ├── slam_dataset.cpython-312.pyc │ ├── so3_math.cpython-312.pyc │ └── tools.cpython-312.pyc ├── config.py ├── data_sampler.py ├── dataset_indexing.py ├── error_state_iekf.py ├── eval_traj_utils.py ├── loss.py ├── mapper.py ├── mesher.py ├── point_cloud2.py ├── semantic_kitti_utils.py ├── slam_dataset.py ├── so3_math.py ├── tools.py └── visualizer.py └── vis_pin_map.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/CLID-SLAM.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Photogrammetry & Robotics Bonn 4 | Copyright (c) 2025 Dalian University of Technology Intelligent Robotics Laboratory 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

⚔️CLID-SLAM: A Coupled LiDAR-Inertial Neural Implicit Dense SLAM with Region-Specific SDF Estimation

3 |

4 | 5 | 6 | 7 | 8 | 9 | 10 | FORK 11 | Issues 12 |

13 |

14 | 15 | | Mesh | Neural Points | 16 | |-------------------------------|-----------------------------------| 17 | | ![Mesh](./assets/GUI_Mesh.png) | ![Neural Points](./assets/GUI_Neural_Points.png) | 18 | 19 | ## TODO 📝 20 | 21 | - [x] Release the source code 22 | - [ ] Enhance the README.md 23 | - [ ] Include the theory derivations 24 | 25 | ## Installation 26 | 27 | ### Platform Requirements 28 | - Ubuntu 20.04 29 | - GPU (tested on RTX 4090) 30 | 31 | ### Steps 32 | 1. **Clone the repository** 33 | ```bash 34 | git clone git@github.com:DUTRobot/CLID-SLAM.git 35 | cd CLID-SLAM 36 | ``` 37 | 38 | 2. **Create Conda Environment** 39 | ```bash 40 | conda create -n slam python=3.12 41 | conda activate slam 42 | ``` 43 | 44 | 3. **Install PyTorch** 45 | ```bash 46 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 47 | ``` 48 | 49 | 4. **Install ROS Dependencies** 50 | ```bash 51 | sudo apt install ros-noetic-rosbag ros-noetic-sensor-msgs 52 | ``` 53 | 54 | 5. **Install Other Dependencies** 55 | ```bash 56 | pip3 install -r requirements.txt 57 | ``` 58 | 59 | ## Data Preparation 60 | 61 | ### Download ROSbag Files 62 | Download these essential ROSbag datasets: 63 | - [**Newer College Dataset**](https://ori-drs.github.io/newer-college-dataset/) 64 | - [**SubT-MRS Dataset**](https://superodometry.com/iccv23_challenge_LiI) 65 | 66 | ### Convert to Sequences 67 | 1. Edit `./dataset/converter/config/rosbag2dataset.yaml`. 68 | 2. Run: 69 | ```bash 70 | python3 ./dataset/converter/rosbag2dataset_parallel.py 71 | 72 | ## Run CLID-SLAM 73 | ```bash 74 | python3 slam.py ./config/run_ncd128.yaml 75 | ``` 76 | ## Acknowledgements 🙏 77 | 78 | This project is built upon the open-source project [**PIN-SLAM**](https://github.com/PRBonn/PIN_SLAM), developed by [**PRBonn/YuePanEdward**](https://github.com/YuePanEdward). A huge thanks to the contributors of **PIN-SLAM** for their outstanding work and dedication! 79 | -------------------------------------------------------------------------------- /assets/GUI_Mesh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/assets/GUI_Mesh.png -------------------------------------------------------------------------------- /assets/GUI_Neural_Points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/assets/GUI_Neural_Points.png -------------------------------------------------------------------------------- /cad/camera.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/camera.ply -------------------------------------------------------------------------------- /cad/drone.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/drone.ply -------------------------------------------------------------------------------- /cad/ipb_car.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/ipb_car.ply -------------------------------------------------------------------------------- /cad/kitti_car.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/kitti_car.ply -------------------------------------------------------------------------------- /config/run_SubT_MRS.yaml: -------------------------------------------------------------------------------- 1 | setting: 2 | name: "SubT_MRS" 3 | output_root: "./experiment" 4 | imu_path: "./dataset/SubT_MRS/Final_Challenge_UGV1/sequences/imu" 5 | pc_path: "./dataset/SubT_MRS/Final_Challenge_UGV1/sequences/lidar" 6 | pose_ts_path: "./dataset/SubT_MRS/Final_Challenge_UGV1/sequences/pose_ts.txt" 7 | deskew: True 8 | 9 | process: 10 | min_range_m: 1.0 11 | max_range_m: 60.0 12 | min_z_m: -10.0 13 | vox_down_m: 0.1 14 | sampler: 15 | local_voxel_size_m: 0.2 16 | surface_sample_range_m: 0.25 17 | surface_sample_n: 4 18 | free_sample_begin_ratio: 0.8 19 | free_front_sample_n: 2 20 | neuralpoints: 21 | voxel_size_m: 0.4 22 | num_nei_cells: 2 23 | search_alpha: 0.5 24 | weighted_first: True 25 | layer_norm_on: True 26 | loss: 27 | sigma_sigmoid_m: 0.1 28 | loss_weight_on: True 29 | dist_weight_scale: 0.8 30 | continual: 31 | batch_size_new_sample: 1000 32 | pool_capacity: 1e7 33 | tracker: 34 | measurement_noise_covariance: 0.01 35 | bias_noise_covariance: 0.0001 36 | source_vox_down_m: 0.6 37 | iter_n: 50 38 | T_imu_lidar: 39 | - [ 1.0, 0, 0, 0] 40 | - [ 0, 1.0, 0, 0] 41 | - [ 0, 0, 1.0, 0] 42 | - [ 0, 0, 0, 1.0] 43 | optimizer: 44 | iters: 10 45 | batch_size: 16384 46 | learning_rate: 0.01 47 | adaptive_iters: True 48 | eval: 49 | wandb_vis_on: False 50 | o3d_vis_on: False 51 | silence_log: True 52 | mesh_freq_frame: 50 53 | mesh_min_nn: 15 54 | save_map: True -------------------------------------------------------------------------------- /config/run_ncd128.yaml: -------------------------------------------------------------------------------- 1 | setting: 2 | name: "ncd128" 3 | output_root: "./experiment" 4 | imu_path: "./dataset/ncd128/collection1/quad_easy/sequences/imu" 5 | pc_path: "./dataset/ncd128/collection1/quad_easy/sequences/lidar" 6 | pose_ts_path: "./dataset/ncd128/collection1/quad_easy/sequences/pose_ts.txt" 7 | deskew: True 8 | 9 | process: 10 | min_range_m: 1.0 11 | max_range_m: 60.0 12 | min_z_m: -10.0 13 | vox_down_m: 0.1 14 | sampler: 15 | local_voxel_size_m: 0.2 16 | surface_sample_range_m: 0.25 17 | surface_sample_n: 4 18 | free_sample_begin_ratio: 0.5 19 | free_sample_end_dist_m: 1.2 20 | free_front_sample_n: 2 21 | neuralpoints: 22 | voxel_size_m: 0.4 23 | num_nei_cells: 2 24 | search_alpha: 0.5 25 | weighted_first: True 26 | loss: 27 | sigma_sigmoid_m: 0.1 28 | loss_weight_on: True 29 | dist_weight_scale: 0.8 30 | continual: 31 | batch_size_new_sample: 1000 32 | pool_capacity: 1e7 33 | tracker: 34 | measurement_noise_covariance: 0.01 35 | bias_noise_covariance: 0.0001 36 | source_vox_down_m: 0.6 37 | iter_n: 50 38 | T_imu_lidar: 39 | - [ 1.0, 0, 0, 0.014 ] 40 | - [ 0, 1.0, 0, -0.012 ] 41 | - [ 0, 0, 1.0, -0.015 ] 42 | - [ 0, 0, 0, 1.0 ] 43 | optimizer: 44 | iters: 10 45 | batch_size: 16384 46 | learning_rate: 0.01 47 | adaptive_iters: True 48 | eval: 49 | wandb_vis_on: False 50 | o3d_vis_on: True 51 | silence_log: True 52 | mesh_freq_frame: 50 53 | mesh_min_nn: 15 54 | save_map: True -------------------------------------------------------------------------------- /dataset/converter/config/rosbag2dataset.yaml: -------------------------------------------------------------------------------- 1 | input_bag: './dataset/ncd128/collection1/quad-easy.bag' 2 | output_folder: './dataset/ncd128/collection1/quad_easy' 3 | imu_topic: '/os_cloud_node/imu' 4 | lidar_topic: '/os_cloud_node/points' 5 | #imu_topic: '/imu/data' 6 | #lidar_topic: '/velodyne_points' 7 | batch_size: 100 # Number of messages per batch 8 | end_frame: -1 # -1 means process the entire bag file -------------------------------------------------------------------------------- /dataset/converter/rosbag2dataset_parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file rosbag2dataset_parallel.py 3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn] 4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved 5 | import csv 6 | import os 7 | import yaml 8 | from multiprocessing import Process, Queue 9 | from typing import List, Tuple 10 | 11 | import numpy as np 12 | import rosbag 13 | import sensor_msgs.point_cloud2 as pc2 14 | from plyfile import PlyData, PlyElement 15 | 16 | G_M_S2 = 9.81 # Gravitational constant in m/s^2 17 | 18 | def load_config(path: str) -> dict: 19 | """Load configuration from a YAML file.""" 20 | with open(path, 'r') as file: 21 | return yaml.safe_load(file) 22 | 23 | 24 | def write_ply(filename: str, data: tuple) -> bool: 25 | """ 26 | Writes point cloud data along with timestamps to a PLY file. 27 | 28 | Args: 29 | filename (str): Path to the output .ply file. 30 | data (list): List containing two elements, the point data and the timestamps. 31 | data[0] should be a 2D NumPy array of shape (n, 4) where n is the number of points. 32 | data[1] should be a 1D NumPy array of length n representing timestamps. 33 | field_names (list): List of strings representing the names of the fields for PLY file. 34 | 35 | Returns: 36 | bool: True if the file was written successfully, False otherwise. 37 | """ 38 | # Ensure timestamp data is a 2D array with one column 39 | points, timestamps = data 40 | combined_data = np.hstack([points, timestamps.reshape(-1, 1)]) 41 | structured_array = np.core.records.fromarrays(combined_data.transpose(), 42 | names=['x', 'y', 'z', 'intensity', 'timestamp']) 43 | PlyData([PlyElement.describe(structured_array, 'vertex')], text=False).write(filename) 44 | return True 45 | 46 | # import pandas as pd 47 | # from pyntcloud import PyntCloud 48 | # def write_ply_pyntcloud(filename, data): 49 | # """ 50 | # Writes point cloud data along with timestamps to a PLY file using PyntCloud. 51 | # 52 | # Args: 53 | # filename (str): Path to the output .ply file. 54 | # data (list): List containing two elements, the point data and the timestamps. 55 | # data[0] should be a 2D NumPy array of shape (n, 4) where n is the number of points. 56 | # data[1] should be a 1D NumPy array of length n representing timestamps. 57 | # 58 | # Returns: 59 | # bool: True if the file was written successfully, False otherwise. 60 | # """ 61 | # # Create a DataFrame from the provided data 62 | # points, timestamps = data 63 | # df = pd.DataFrame(points, columns=['x', 'y', 'z', 'intensity']) 64 | # df['timestamp'] = timestamps 65 | # 66 | # # Convert DataFrame to a PyntCloud object 67 | # cloud = PyntCloud(df) 68 | # 69 | # # Save to a PLY file 70 | # try: 71 | # cloud.to_file(filename) 72 | # return True 73 | # except Exception as e: 74 | # print(f"Error writing PLY file with PyntCloud: {e}") 75 | # return False 76 | 77 | def write_csv(filename: str, imu_data_pool: List[Tuple[float, float, float, float, float, float, float]]) -> None: 78 | """Write IMU data to a CSV file.""" 79 | with open(filename, 'w', newline='') as file: 80 | writer = csv.writer(file) 81 | writer.writerow(['timestamp', 'acc_x', 'acc_y', 'acc_z', 'gyro_x', 'gyro_y', 'gyro_z']) 82 | for imu_data in imu_data_pool: 83 | writer.writerow(imu_data) 84 | 85 | 86 | def extract_lidar_data(msg) -> Tuple[np.ndarray, np.ndarray]: 87 | """Extract point cloud data and timestamps from a LiDAR message.""" 88 | pc_data = list(pc2.read_points(msg, skip_nans=True)) 89 | pc_array = np.array(pc_data) 90 | timestamps = pc_array[:, 4] * 1e-9 # Convert to seconds 91 | return pc_array[:, :4], timestamps 92 | 93 | 94 | def process_lidar_data(batch_data: List[Tuple[str, Tuple[np.ndarray, np.ndarray]]]) -> None: 95 | """Process a batch of LiDAR data and save as PLY files.""" 96 | for i, (ply_file_path, data) in enumerate(batch_data): 97 | if write_ply(ply_file_path, data): 98 | print(f"Exported LiDAR point cloud PLY file: {ply_file_path}") 99 | 100 | 101 | def sync_and_save(config: dict) -> None: 102 | """Synchronize and save LiDAR and IMU data from a ROS bag file.""" 103 | os.makedirs(config["output_folder"], exist_ok=True) 104 | os.makedirs(os.path.join(config["output_folder"], "lidar"), exist_ok=True) 105 | os.makedirs(os.path.join(config["output_folder"], "imu"), exist_ok=True) 106 | 107 | in_bag = rosbag.Bag(config["input_bag"]) 108 | 109 | frame_index = 0 110 | start_flag = False 111 | imu_last_timestamp = None 112 | imu_data_pool = [] 113 | lidar_timestamp_queue = Queue() 114 | 115 | processes = [] 116 | batch_size = config["batch_size"] # Number of messages per batch 117 | batch_lidar_data = [] 118 | 119 | for topic, msg, t in in_bag.read_messages(topics=[config["imu_topic"], config["lidar_topic"]]): 120 | current_timestamp = t.to_sec() 121 | 122 | if topic == config["lidar_topic"]: 123 | if not start_flag: 124 | start_flag = True 125 | else: 126 | csv_file_path = os.path.join(config["output_folder"], "imu", f"{frame_index}.csv") 127 | write_csv(csv_file_path, imu_data_pool) 128 | imu_data_pool = [] 129 | print(f"Exported IMU measurement CSV file: {csv_file_path}") 130 | 131 | if len(batch_lidar_data) >= batch_size: 132 | p = Process(target=process_lidar_data, args=(batch_lidar_data,)) 133 | p.start() 134 | processes.append(p) 135 | batch_lidar_data = [] 136 | 137 | lidar_frame_timestamp = msg.header.stamp.to_sec() 138 | lidar_timestamp_queue.put(lidar_frame_timestamp) 139 | 140 | ply_file_path = os.path.join(config["output_folder"], "lidar", f"{frame_index}.ply") 141 | point_cloud_data = extract_lidar_data(msg) 142 | batch_lidar_data.append((ply_file_path, point_cloud_data)) 143 | 144 | imu_last_timestamp = current_timestamp 145 | frame_index += 1 146 | 147 | if 0 < config["end_frame"] <= frame_index: 148 | break 149 | 150 | elif topic == config["imu_topic"]: 151 | if start_flag: 152 | time_delta = current_timestamp - imu_last_timestamp 153 | imu_last_timestamp = current_timestamp 154 | imu_data = ( 155 | time_delta, 156 | msg.linear_acceleration.x, 157 | msg.linear_acceleration.y, 158 | msg.linear_acceleration.z, 159 | msg.angular_velocity.x, 160 | msg.angular_velocity.y, 161 | msg.angular_velocity.z 162 | ) 163 | imu_data_pool.append(imu_data) 164 | 165 | if batch_lidar_data: 166 | p = Process(target=process_lidar_data, args=(batch_lidar_data,)) 167 | p.start() 168 | processes.append(p) 169 | 170 | for p in processes: 171 | p.join() 172 | 173 | with open(os.path.join(config["output_folder"], "pose_timestamps.txt"), 'w', newline='') as file: 174 | print("Writing pose timestamps...") 175 | writer = csv.writer(file) 176 | writer.writerow(['timestamp']) 177 | while not lidar_timestamp_queue.empty(): 178 | lidar_timestamp = lidar_timestamp_queue.get() 179 | writer.writerow([lidar_timestamp]) 180 | print("Pose timestamps written successfully.") 181 | 182 | 183 | if __name__ == "__main__": 184 | config = load_config('./dataset/converter/config/rosbag2dataset.yaml') 185 | sync_and_save(config) -------------------------------------------------------------------------------- /experiment/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/experiment/.gitkeep -------------------------------------------------------------------------------- /gui/__pycache__/gui_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/gui/__pycache__/gui_utils.cpython-312.pyc -------------------------------------------------------------------------------- /gui/__pycache__/slam_gui.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/gui/__pycache__/slam_gui.cpython-312.pyc -------------------------------------------------------------------------------- /gui/gui_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file gui_utils.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | 5 | # This GUI is built on top of the great work of MonoGS (https://github.com/muskie82/MonoGS/blob/main/gui/gui_utils.py) 6 | 7 | import queue 8 | 9 | from utils.tools import feature_pca_torch 10 | from model.neural_points import NeuralPoints 11 | 12 | 13 | class VisPacket: 14 | def __init__( 15 | self, 16 | frame_id = None, 17 | finish=False, 18 | current_pointcloud_xyz=None, 19 | current_pointcloud_rgb=None, 20 | mesh_verts=None, 21 | mesh_faces=None, 22 | mesh_verts_rgb=None, 23 | odom_poses=None, 24 | gt_poses=None, 25 | slam_poses=None, 26 | travel_dist=None, 27 | slam_finished=False, 28 | ): 29 | self.has_neural_points = False 30 | 31 | self.neural_points_data = None 32 | 33 | self.frame_id = frame_id 34 | 35 | self.add_scan(current_pointcloud_xyz, current_pointcloud_rgb) 36 | 37 | self.add_mesh(mesh_verts, mesh_faces, mesh_verts_rgb) 38 | 39 | self.add_traj(odom_poses, gt_poses, slam_poses) 40 | 41 | self.sdf_slice_xyz = None 42 | self.sdf_slice_rgb = None 43 | 44 | self.sdf_pool_xyz = None 45 | self.sdf_pool_rgb = None 46 | 47 | self.travel_dist = travel_dist 48 | self.slam_finished = slam_finished 49 | 50 | self.finish = finish 51 | 52 | # the sorrounding map is also added here 53 | def add_neural_points_data(self, neural_points: NeuralPoints, only_local_map: bool = True, 54 | pca_color_on: bool = True): 55 | 56 | if neural_points is not None: 57 | self.has_neural_points = True 58 | self.neural_points_data = {} 59 | self.neural_points_data["count"] = neural_points.count() 60 | self.neural_points_data["local_count"] = neural_points.local_count() 61 | self.neural_points_data["map_memory_mb"] = neural_points.cur_memory_mb 62 | self.neural_points_data["resolution"] = neural_points.resolution 63 | 64 | if only_local_map: 65 | self.neural_points_data["position"] = neural_points.local_neural_points 66 | self.neural_points_data["orientation"] = neural_points.local_point_orientations 67 | self.neural_points_data["geo_feature"] = neural_points.local_geo_features.detach() 68 | if neural_points.color_on: 69 | self.neural_points_data["color_feature"] = neural_points.local_color_features.detach() 70 | self.neural_points_data["ts"] = neural_points.local_point_ts_update 71 | self.neural_points_data["stability"] = neural_points.local_point_certainties 72 | 73 | if pca_color_on: 74 | local_geo_feature_3d, _ = feature_pca_torch((self.neural_points_data["geo_feature"])[:-1], principal_components=neural_points.geo_feature_pca, down_rate=17) 75 | self.neural_points_data["color_pca_geo"] = local_geo_feature_3d 76 | 77 | if neural_points.color_on: 78 | local_color_feature_3d, _ = feature_pca_torch((self.neural_points_data["color_feature"])[:-1], principal_components=neural_points.color_feature_pca, down_rate=17) 79 | self.neural_points_data["color_pca_color"] = local_color_feature_3d 80 | 81 | else: 82 | self.neural_points_data["position"] = neural_points.neural_points 83 | self.neural_points_data["orientation"] = neural_points.point_orientations 84 | self.neural_points_data["geo_feature"] = neural_points.geo_features 85 | if neural_points.color_on: 86 | self.neural_points_data["color_feature"] = neural_points.color_features 87 | self.neural_points_data["ts"] = neural_points.point_ts_update 88 | self.neural_points_data["stability"] = neural_points.point_certainties 89 | if neural_points.local_mask is not None: 90 | self.neural_points_data["local_mask"] = neural_points.local_mask[:-1] 91 | 92 | if pca_color_on: 93 | geo_feature_3d, _ = feature_pca_torch(neural_points.geo_features[:-1], principal_components=neural_points.geo_feature_pca, down_rate=97) 94 | self.neural_points_data["color_pca_geo"] = geo_feature_3d 95 | 96 | if neural_points.color_on: 97 | color_feature_3d, _ = feature_pca_torch(neural_points.color_features[:-1], principal_components=neural_points.color_feature_pca, down_rate=97) 98 | self.neural_points_data["color_pca_color"] = color_feature_3d 99 | 100 | 101 | def add_scan(self, current_pointcloud_xyz=None, current_pointcloud_rgb=None): 102 | self.current_pointcloud_xyz = current_pointcloud_xyz 103 | self.current_pointcloud_rgb = current_pointcloud_rgb 104 | 105 | # TODO: add normal later 106 | 107 | def add_sdf_slice(self, sdf_slice_xyz=None, sdf_slice_rgb=None): 108 | self.sdf_slice_xyz = sdf_slice_xyz 109 | self.sdf_slice_rgb = sdf_slice_rgb 110 | 111 | def add_sdf_training_pool(self, sdf_pool_xyz=None, sdf_pool_rgb=None): 112 | self.sdf_pool_xyz = sdf_pool_xyz 113 | self.sdf_pool_rgb = sdf_pool_rgb 114 | 115 | def add_mesh(self, mesh_verts=None, mesh_faces=None, mesh_verts_rgb=None): 116 | self.mesh_verts = mesh_verts 117 | self.mesh_faces = mesh_faces 118 | self.mesh_verts_rgb = mesh_verts_rgb 119 | 120 | def add_traj(self, odom_poses=None, gt_poses=None, slam_poses=None, loop_edges=None): 121 | 122 | self.odom_poses = odom_poses 123 | self.gt_poses = gt_poses 124 | self.slam_poses = slam_poses 125 | 126 | if slam_poses is None: 127 | self.slam_poses = odom_poses 128 | 129 | self.loop_edges = loop_edges 130 | 131 | 132 | def get_latest_queue(q): 133 | message = None 134 | while True: 135 | try: 136 | message_latest = q.get_nowait() 137 | if message is not None: 138 | del message 139 | message = message_latest 140 | except queue.Empty: 141 | if q.empty(): 142 | break 143 | return message 144 | 145 | 146 | class ControlPacket: 147 | flag_pause = False 148 | flag_vis = True 149 | flag_mesh = False 150 | flag_sdf = False 151 | flag_global = False 152 | flag_source = False 153 | mc_res_m = 0.2 154 | mesh_min_nn = 10 155 | mesh_freq_frame = 50 156 | sdf_freq_frame = 50 157 | sdf_slice_height = 0.2 158 | sdf_res_m = 0.2 159 | cur_frame_id = 0 160 | 161 | class ParamsGUI: 162 | def __init__( 163 | self, 164 | q_main2vis=None, 165 | q_vis2main=None, 166 | config=None, 167 | local_map_default_on: bool = True, 168 | robot_default_on: bool = True, 169 | mesh_default_on: bool = False, 170 | sdf_default_on: bool = False, 171 | neural_point_map_default_on: bool = False, 172 | neural_point_color_default_mode: int = 1, # 1: geo feature pca, 2: photo feature pca, 3: time, 4: height 173 | neural_point_vis_down_rate: int = 1, 174 | ): 175 | self.q_main2vis = q_main2vis 176 | self.q_vis2main = q_vis2main 177 | self.config = config 178 | 179 | self.robot_default_on = robot_default_on 180 | self.neural_point_map_default_on = neural_point_map_default_on 181 | self.mesh_default_on = mesh_default_on 182 | self.sdf_default_on = sdf_default_on 183 | self.local_map_default_on = local_map_default_on 184 | self.neural_point_color_default_mode = neural_point_color_default_mode 185 | self.neural_point_vis_down_rate = neural_point_vis_down_rate 186 | 187 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model/__pycache__/decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/decoder.cpython-312.pyc -------------------------------------------------------------------------------- /model/__pycache__/local_point_cloud_map.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/local_point_cloud_map.cpython-312.pyc -------------------------------------------------------------------------------- /model/__pycache__/neural_points.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/neural_points.cpython-312.pyc -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file decoder.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.config import Config 12 | 13 | 14 | class Decoder(nn.Module): 15 | def __init__( 16 | self, 17 | config: Config, 18 | hidden_dim, 19 | hidden_level, 20 | out_dim, 21 | is_time_conditioned=False, 22 | ): 23 | 24 | super().__init__() 25 | 26 | self.out_dim = out_dim 27 | self.use_leaky_relu = config.mlp_leaky_relu 28 | bias_on = config.mlp_bias_on 29 | 30 | # default not used 31 | if config.use_gaussian_pe: 32 | position_dim = config.pos_input_dim + 2 * config.pos_encoding_band 33 | else: 34 | position_dim = config.pos_input_dim * (2 * config.pos_encoding_band + 1) 35 | 36 | feature_dim = config.feature_dim 37 | input_dim = feature_dim + position_dim 38 | 39 | # default not used 40 | if is_time_conditioned: 41 | input_layer_count += 1 42 | 43 | # predict sdf (now it anyway only predict sdf without further sigmoid 44 | # Initializa the structure of shared MLP 45 | layers = [] 46 | for i in range(hidden_level): 47 | if i == 0: 48 | layers.append(nn.Linear(input_dim, hidden_dim, bias_on)) 49 | else: 50 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias_on)) 51 | self.layers = nn.ModuleList(layers) 52 | self.lout = nn.Linear(hidden_dim, out_dim, bias_on) 53 | 54 | self.sdf_scale = 1.0 55 | if config.main_loss_type == "bce": 56 | self.sdf_scale = config.logistic_gaussian_ratio * config.sigma_sigmoid_m 57 | 58 | self.to(config.device) 59 | # torch.cuda.empty_cache() 60 | 61 | def mlp(self, features): 62 | # linear (feature_dim -> hidden_dim) 63 | # relu 64 | # linear (hidden_dim -> hidden_dim) 65 | # relu 66 | # linear (hidden_dim -> 1) 67 | for k, l in enumerate(self.layers): 68 | if k == 0: 69 | if self.use_leaky_relu: 70 | h = F.leaky_relu(l(features)) 71 | else: 72 | h = F.relu(l(features)) 73 | else: 74 | if self.use_leaky_relu: 75 | h = F.leaky_relu(l(h)) 76 | else: 77 | h = F.relu(l(h)) 78 | out = self.lout(h) 79 | return out 80 | 81 | # predict the sdf (opposite sign to the actual sdf) 82 | # unit is already m 83 | def sdf(self, features): 84 | out = self.mlp(features).squeeze(1) * self.sdf_scale 85 | return out 86 | 87 | def time_conditionded_sdf(self, features, ts): 88 | nn_k = features.shape[1] 89 | ts_nn_k = ts.repeat(nn_k).view(-1, nn_k, 1) 90 | time_conditioned_feature = torch.cat((features, ts_nn_k), dim=-1) 91 | out = self.sdf(time_conditioned_feature) 92 | return out 93 | 94 | # predict the occupancy probability 95 | def occupancy(self, features): 96 | out = torch.sigmoid(self.sdf(features) / -self.sdf_scale) # to [0, 1] 97 | return out 98 | 99 | # predict the probabilty of each semantic label 100 | def sem_label_prob(self, features): 101 | out = F.log_softmax(self.mlp(features), dim=-1) 102 | return out 103 | 104 | def sem_label(self, features): 105 | out = torch.argmax(self.sem_label_prob(features), dim=1) 106 | return out 107 | 108 | # def regress_color(self, features): 109 | # out = torch.clamp(self.mlp(features), 0.0, 1.0) 110 | # return out 111 | 112 | def regress_color(self, features): 113 | out = torch.sigmoid(self.mlp(features)) # sigmoid map to [0,1] 114 | return out -------------------------------------------------------------------------------- /model/local_point_cloud_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file local_point_cloud_map.py 3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn] 4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved 5 | import math 6 | import torch 7 | from utils.config import Config 8 | from utils.tools import voxel_down_sample_torch 9 | 10 | 11 | class LocalPointCloudMap: 12 | def __init__(self, config: Config) -> None: 13 | self.config = config 14 | self.idx_dtype = torch.int64 15 | self.dtype = config.dtype 16 | self.device = config.device 17 | self.resolution = config.local_voxel_size_m 18 | self.buffer_size = config.local_buffer_size 19 | 20 | self.buffer_pt_index = -torch.ones(self.buffer_size, dtype=self.idx_dtype, device=self.device) # 哈希表 21 | self.local_point_cloud_map = torch.empty((0, 3), dtype=torch.float32, device=self.device) 22 | 23 | self.primes = torch.tensor([73856093, 19349663, 83492791], dtype=self.idx_dtype, device=self.device) 24 | self.neighbor_idx = None 25 | self.max_valid_range = None 26 | self.set_search_neighborhood() 27 | self.map_size = config.local_map_size 28 | 29 | def voxel_hash(self, points): 30 | grid_coords = (points / self.resolution).floor().to(self.primes) 31 | hash_values = torch.fmod((grid_coords * self.primes).sum(-1), self.buffer_size) 32 | return hash_values 33 | 34 | def insert_points(self, points): 35 | sample_idx = voxel_down_sample_torch(points, self.resolution) 36 | sample_points = points[sample_idx] 37 | hash_values = self.voxel_hash(sample_points) 38 | hash_idx = self.buffer_pt_index[hash_values] 39 | 40 | update_mask = (hash_idx == -1) 41 | new_points = sample_points[update_mask] 42 | 43 | cur_pt_count = self.local_point_cloud_map.shape[0] 44 | self.buffer_pt_index[hash_values[update_mask]] = torch.arange(new_points.shape[0], device=self.device) + cur_pt_count 45 | 46 | self.local_point_cloud_map = torch.cat(( self.local_point_cloud_map, new_points), 0) 47 | 48 | def update_map(self, sensor_position, points): 49 | self.insert_points(points) 50 | distances = torch.norm(self.local_point_cloud_map - sensor_position, dim=-1) 51 | keep_mask = distances < self.map_size 52 | self.local_point_cloud_map = self.local_point_cloud_map[keep_mask] 53 | 54 | new_buffer_pt_index = -torch.ones(self.buffer_size, dtype=self.idx_dtype, device=self.device) 55 | new_hash_values = self.voxel_hash(self.local_point_cloud_map) 56 | new_buffer_pt_index[new_hash_values] = torch.arange(self.local_point_cloud_map.shape[0], device=self.device) 57 | 58 | self.buffer_pt_index = new_buffer_pt_index 59 | 60 | def set_search_neighborhood( 61 | self, num_nei_cells: int = 1, search_alpha: float = 0.2 62 | ): 63 | dx = torch.arange(-num_nei_cells, num_nei_cells + 1, device=self.primes.device, dtype=self.primes.dtype,) 64 | 65 | coords = torch.meshgrid(dx, dx, dx, indexing="ij") 66 | dx = torch.stack(coords, dim=-1).reshape(-1, 3) # [K,3] 67 | 68 | dx2 = torch.sum(dx**2, dim=-1) 69 | self.neighbor_idx = dx[dx2 < (num_nei_cells + search_alpha) ** 2] 70 | self.max_valid_range = 1.732*(num_nei_cells + 1) * self.resolution 71 | # in the sphere --> smaller K --> faster training 72 | # when num_cells = 2 when num_cells = 3 73 | # alpha 0.2, K = 33 alpha 0.2, K = 147 74 | # alpha 0.3, K = 57 alpha 0.5, K = 179 75 | # alpha 0.5, K = 81 alpha 1.0, K = 251 76 | # alpha 1.0, K = 93 77 | # alpha 2.0, K = 125 78 | 79 | def region_specific_sdf_estimation(self, points: torch.Tensor): 80 | point_num = points.shape[0] 81 | sdf_abs = torch.ones(point_num, device=points.device)*self.max_valid_range 82 | surface_mask = torch.ones(point_num, dtype=torch.bool, device=self.config.device) 83 | 84 | bs = 262144 # 256 × 1024 85 | iter_n = math.ceil(point_num / bs) 86 | # 为了避免爆显存,采用分批处理的办法 87 | for n in range(iter_n): 88 | head, tail = n * bs, min((n + 1) * bs, point_num) 89 | batch_points = points[head:tail, :] 90 | batch_coords = (batch_points / self.resolution).floor().to(self.primes) 91 | batch_neighbord_cells = (batch_coords[..., None, :] + self.neighbor_idx) 92 | batch_hash = torch.fmod((batch_neighbord_cells * self.primes).sum(-1), self.buffer_size) 93 | batch_neighb_idx = self.buffer_pt_index[batch_hash] 94 | batch_neighb_pts = self.local_point_cloud_map[batch_neighb_idx] 95 | batch_dist = torch.norm(batch_neighb_pts - batch_points.view(-1, 1, 3), dim=-1) 96 | batch_dist = torch.where(batch_neighb_idx == -1, self.max_valid_range, batch_dist) 97 | 98 | # k nearst neighbors neural points 99 | batch_sdf_abs, batch_min_idx = torch.topk(batch_dist, 4, largest=False, dim=1) 100 | batch_min_idx_expanded = batch_min_idx.unsqueeze(-1).expand(-1, -1, 3) 101 | batch_knn_points = torch.gather(batch_neighb_pts, 1, batch_min_idx_expanded) 102 | valid_fit_mask = batch_sdf_abs[:, 3] < self.max_valid_range 103 | valid_batch_knn_points = batch_knn_points[valid_fit_mask] 104 | unit_normal_vector = torch.zeros_like(batch_points) 105 | plane_constant = torch.zeros(batch_points.size(0), device=batch_points.device) 106 | fit_success = torch.zeros(batch_points.size(0), dtype=torch.bool, device=batch_points.device) 107 | 108 | valid_unit_normal_vector, valid_plane_constant, valid_fit_success = estimate_plane(valid_batch_knn_points) 109 | unit_normal_vector[valid_fit_mask] = valid_unit_normal_vector 110 | plane_constant[valid_fit_mask] = valid_plane_constant 111 | fit_success[valid_fit_mask] = valid_fit_success 112 | 113 | fit_success &= batch_sdf_abs[:, 3] < self.max_valid_range # 平面拟合失败 114 | surface_mask[head:tail] &= (batch_sdf_abs[:, 0] < self.max_valid_range) 115 | distance = torch.abs(torch.sum(unit_normal_vector * batch_points, dim=1) + plane_constant) 116 | sdf_abs[head:tail][fit_success] = distance[fit_success] 117 | sdf_abs[head:tail][~fit_success] = batch_sdf_abs[:, 0][~fit_success] 118 | 119 | if not self.config.silence: 120 | print(surface_mask.sum().item() / surface_mask.numel()) 121 | return sdf_abs, surface_mask 122 | 123 | def estimate_plane(points: torch.Tensor, eta_threshold: float = 0.2, threshold: float = 0.1): 124 | """Estimates planes from a given set of 3D points using Singular Value Decomposition (SVD)""" 125 | def fit_planes(points: torch.Tensor): 126 | """Fits multiple planes using SVD""" 127 | centroid = points.mean(dim=1, keepdim=True) 128 | centered_points = points - centroid 129 | U, S, Vh = torch.linalg.svd(centered_points, full_matrices=False) # Perform SVD. 130 | 131 | # The normal vector of the plane is the last row of Vh (since Vh is the transpose of V). 132 | normals = Vh[:, -1, :] 133 | return normals, centroid.squeeze(1), S 134 | 135 | def is_valid_planes(singular_values: torch.Tensor, eta_threshold: float): 136 | """Determines whether the fitted planes are valid based on the η value.""" 137 | lambda_min = singular_values[:, -1] # The smallest singular value. 138 | lambda_mid = singular_values[:, 1] # The middle singular value. 139 | 140 | eta = lambda_min / (lambda_mid + 1e-6) 141 | return eta <= eta_threshold 142 | 143 | m, num_points, _ = points.shape 144 | normal_vector, centroids, singular_values = fit_planes(points) # Fit planes to the input points. 145 | 146 | # Initialize normal vectors with zeros. 147 | unit_normal_vector = torch.zeros((m, 3), dtype=points.dtype, device=points.device) 148 | 149 | valid_mask = is_valid_planes(singular_values, eta_threshold) 150 | normal_vector = normal_vector[valid_mask] 151 | unit_normal_vector[valid_mask] = normal_vector 152 | plane_constant = -1.0 * torch.sum(unit_normal_vector * centroids, dim=1) 153 | 154 | # Compute the distance of each point to its respective plane. 155 | # distances = torch.abs((points @ unit_normal_vector.unsqueeze(-1)).squeeze() + plane_constant.unsqueeze(-1)) 156 | distances = torch.abs(torch.bmm(points, unit_normal_vector.unsqueeze(-1)).squeeze() + plane_constant.unsqueeze(-1)) 157 | fit_success = torch.max(distances, dim=1).values <= threshold 158 | mask = fit_success & valid_mask 159 | 160 | return unit_normal_vector, plane_constant, mask -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | rospkg==1.5.1 3 | catkin_pkg==1.0.0 4 | opencv-python==4.11.0.86 5 | scikit-image==0.25.2 6 | plyfile==1.1 7 | evo==1.28.0 8 | gnupg==2.3.1 9 | laspy==2.5.3 10 | natsort==8.1.0 11 | open3d==0.19.0 12 | pycryptodomex==3.20.0 13 | pypose==0.6.8 14 | pyquaternion==0.9.9 15 | rerun-sdk==0.17.0 16 | rich==12.5.1 17 | roma==1.5.0 18 | rospkg==1.5.1 19 | wandb==0.19.8 -------------------------------------------------------------------------------- /slam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file pin_slam.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | # Modifications by: 6 | # Junlong Jiang [jiangjunlong@mail.dlut.edu.cn] 7 | # Copyright (c) 2025 Junlong Jiang, all rights reserved. 8 | 9 | import os 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # by default 0, change it here if you want to use other GPU 11 | import sys 12 | import time 13 | 14 | import numpy as np 15 | import open3d as o3d 16 | import torch 17 | import torch.multiprocessing as mp 18 | import wandb 19 | from rich import print 20 | from tqdm import tqdm 21 | 22 | from gui import slam_gui 23 | from gui.gui_utils import ParamsGUI, VisPacket, ControlPacket, get_latest_queue 24 | from model.decoder import Decoder 25 | from model.local_point_cloud_map import LocalPointCloudMap 26 | from model.neural_points import NeuralPoints 27 | from utils.config import Config 28 | from utils.dataset_indexing import set_dataset_path 29 | from utils.error_state_iekf import IEKFOM 30 | from utils.mapper import Mapper 31 | from utils.mesher import Mesher 32 | from utils.slam_dataset import SLAMDataset 33 | from utils.tools import freeze_model, get_time, save_implicit_map, setup_experiment, split_chunks, remove_gpu_cache, \ 34 | create_bbx_o3d 35 | 36 | 37 | def run_slam(config_path=None, dataset_name=None, sequence_name=None, seed=None): 38 | torch.set_num_threads(16) # 设置为16个线程,限制使用的线程数,使用太多的线程会导致电脑卡死 39 | config = Config() 40 | if config_path is not None: 41 | config.load(config_path) 42 | set_dataset_path(config, dataset_name, sequence_name) 43 | if seed is not None: 44 | config.seed = seed 45 | argv = ['slam.py', config_path, dataset_name, sequence_name, str(seed)] 46 | run_path = setup_experiment(config, argv) 47 | else: 48 | if len(sys.argv) > 1: 49 | config.load(sys.argv[1]) 50 | else: 51 | sys.exit("Please provide the path to the config file.\nTry: \ 52 | python3 slam.py path_to_config.yaml [dataset_name] [sequence_name] [random_seed]") 53 | # specific dataset [optional] 54 | if len(sys.argv) == 3: 55 | set_dataset_path(config, sys.argv[2]) 56 | if len(sys.argv) > 3: 57 | set_dataset_path(config, sys.argv[2], sys.argv[3]) 58 | if len(sys.argv) > 4: # random seed [optional] 59 | config.seed = int(sys.argv[4]) 60 | run_path = setup_experiment(config, sys.argv) 61 | print("⚔️", "[bold green]CLID-SLAM starts[/bold green]") 62 | 63 | if config.o3d_vis_on: 64 | mp.set_start_method("spawn") # don't forget this 65 | 66 | # 初始化MLP解码器 67 | geo_mlp = Decoder(config, config.geo_mlp_hidden_dim, config.geo_mlp_level, 1) 68 | mlp_dict = {"sdf": geo_mlp, "semantic": None, "color": None} 69 | 70 | # 初始化神经点云地图 71 | neural_points = NeuralPoints(config) 72 | local_point_cloud_map = LocalPointCloudMap(config) 73 | 74 | # 初始化数据集 75 | dataset = SLAMDataset(config) 76 | 77 | # 里程计跟踪模块 78 | iekfom = IEKFOM(config, neural_points, geo_mlp) 79 | dataset.tracker = iekfom 80 | 81 | # 建图模块 82 | mapper = Mapper(config, dataset, neural_points, local_point_cloud_map, geo_mlp) 83 | 84 | # 网格重建 85 | mesher = Mesher(config, neural_points, mlp_dict) 86 | 87 | last_frame = dataset.total_pc_count - 1 88 | 89 | # 可视化 90 | q_main2vis = q_vis2main = None 91 | if config.o3d_vis_on: 92 | # communicator between the processes 93 | q_main2vis = mp.Queue() 94 | q_vis2main = mp.Queue() 95 | 96 | params_gui = ParamsGUI( 97 | q_main2vis=q_main2vis, 98 | q_vis2main=q_vis2main, 99 | config=config, 100 | local_map_default_on=config.local_map_default_on, 101 | mesh_default_on=config.mesh_default_on, 102 | sdf_default_on=config.sdf_default_on, 103 | neural_point_map_default_on=config.neural_point_map_default_on, 104 | ) 105 | gui_process = mp.Process(target=slam_gui.run, args=(params_gui,)) 106 | gui_process.start() 107 | time.sleep(3) # second 108 | 109 | # visualizer configs 110 | vis_visualize_on = True 111 | vis_source_pc_weight = False 112 | vis_global_on = not config.local_map_default_on 113 | vis_mesh_on = config.mesh_default_on 114 | vis_mesh_freq_frame = config.mesh_freq_frame 115 | vis_mesh_mc_res_m = config.mc_res_m 116 | vis_mesh_min_nn = config.mesh_min_nn 117 | vis_sdf_on = config.sdf_default_on 118 | vis_sdf_freq_frame = config.sdfslice_freq_frame 119 | vis_sdf_slice_height = config.sdf_slice_height 120 | vis_sdf_res_m = config.vis_sdf_res_m 121 | 122 | cur_mesh = None 123 | cur_sdf_slice = None 124 | 125 | for frame_id in tqdm(range(dataset.total_pc_count)): 126 | 127 | # I. 加载数据和预处理 128 | T0 = get_time() 129 | dataset.read_frame(frame_id) 130 | 131 | T1 = get_time() 132 | valid_frame = dataset.preprocess_frame() 133 | if not valid_frame: 134 | dataset.processed_frame += 1 135 | continue 136 | 137 | T2 = get_time() 138 | 139 | # II. 里程计定位 140 | if frame_id > 0: 141 | if config.track_on: 142 | cur_pose_torch, valid_flag = iekfom.update_iterated(dataset.cur_source_points) 143 | dataset.lose_track = not valid_flag 144 | dataset.update_odom_pose(cur_pose_torch) # update dataset.cur_pose_torch 145 | 146 | travel_dist = dataset.travel_dist[:frame_id+1] 147 | neural_points.travel_dist = torch.tensor(travel_dist, device=config.device, dtype=config.dtype) # always update this 148 | valid_mapping_flag = (not dataset.lose_track) and (not dataset.stop_status) 149 | 150 | T3 = get_time() 151 | # III: 建图和光束平差优化 152 | # if lose track, we will not update the map and data pool (don't let the wrong pose to corrupt the map) 153 | # if the robot stop, also don't process this frame, since there's no new oberservations 154 | if not dataset.lose_track and valid_mapping_flag: 155 | mapper.process_frame(dataset.cur_point_cloud_torch, dataset.cur_sem_labels_torch, 156 | dataset.cur_pose_torch, frame_id, (config.dynamic_filter_on and frame_id > 0)) 157 | else: 158 | mapper.determine_used_pose() 159 | neural_points.reset_local_map(dataset.cur_pose_torch[:3, 3], None, frame_id) # not efficient for large map 160 | 161 | T4 = get_time() 162 | 163 | # for the first frame, we need more iterations to do the initialization (warm-up) 164 | # 计算当前帧建图的迭代轮数 165 | cur_iter_num = config.iters * config.init_iter_ratio if frame_id == 0 else config.iters 166 | if dataset.stop_status: 167 | cur_iter_num = max(1, cur_iter_num - 10) 168 | # 在某一帧后固定解码器的参数 169 | if frame_id == config.freeze_after_frame: # freeze the decoder after certain frame 170 | freeze_model(geo_mlp) 171 | 172 | # mapping with fixed poses (every frame) 173 | if frame_id % config.mapping_freq_frame == 0: 174 | mapper.mapping(cur_iter_num) 175 | 176 | T5 = get_time() 177 | 178 | # regular saving logs 179 | if config.log_freq_frame > 0 and (frame_id + 1) % config.log_freq_frame == 0: 180 | dataset.write_results_log() 181 | 182 | remove_gpu_cache() 183 | 184 | # IV: 网格重建和可视化 185 | if config.o3d_vis_on: 186 | 187 | if not q_vis2main.empty(): 188 | control_packet: ControlPacket = get_latest_queue(q_vis2main) 189 | 190 | vis_visualize_on = control_packet.flag_vis 191 | vis_global_on = control_packet.flag_global 192 | vis_mesh_on = control_packet.flag_mesh 193 | vis_sdf_on = control_packet.flag_sdf 194 | vis_source_pc_weight = control_packet.flag_source 195 | vis_mesh_mc_res_m = control_packet.mc_res_m 196 | vis_mesh_min_nn = control_packet.mesh_min_nn 197 | vis_mesh_freq_frame = control_packet.mesh_freq_frame 198 | vis_sdf_slice_height = control_packet.sdf_slice_height 199 | vis_sdf_freq_frame = control_packet.sdf_freq_frame 200 | vis_sdf_res_m = control_packet.sdf_res_m 201 | 202 | while control_packet.flag_pause: 203 | time.sleep(0.1) 204 | if not q_vis2main.empty(): 205 | control_packet = get_latest_queue(q_vis2main) 206 | if not control_packet.flag_pause: 207 | break 208 | 209 | if vis_visualize_on: 210 | 211 | dataset.update_o3d_map() 212 | # Only PIN-SLAM has 213 | # if config.track_on and frame_id > 0 and vis_source_pc_weight and (weight_pc_o3d is not None): 214 | # dataset.cur_frame_o3d = weight_pc_o3d 215 | 216 | # T7 = get_time() 217 | T6 = get_time() 218 | 219 | # reconstruction by marching cubes 220 | # Only PIN-SLAM has 221 | # if vis_mesh_on and (frame_id == 0 or frame_id == last_frame or ( 222 | # frame_id + 1) % vis_mesh_freq_frame == 0 or pgm.last_loop_idx == frame_id): 223 | if vis_mesh_on and (frame_id == 0 or frame_id == last_frame or ( 224 | frame_id + 1) % vis_mesh_freq_frame == 0): 225 | # update map bbx 226 | global_neural_pcd_down = neural_points.get_neural_points_o3d(query_global=True, 227 | random_down_ratio=37) # prime number 228 | dataset.map_bbx = global_neural_pcd_down.get_axis_aligned_bounding_box() 229 | 230 | if not vis_global_on: # only build the local mesh 231 | chunks_aabb = split_chunks(global_neural_pcd_down, dataset.cur_bbx, 232 | vis_mesh_mc_res_m * 100) # reconstruct in chunks 233 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, vis_mesh_mc_res_m, None, True, 234 | config.semantic_on, config.color_on, 235 | filter_isolated_mesh=True, 236 | mesh_min_nn=vis_mesh_min_nn) 237 | else: 238 | aabb = global_neural_pcd_down.get_axis_aligned_bounding_box() 239 | chunks_aabb = split_chunks(global_neural_pcd_down, aabb, 240 | vis_mesh_mc_res_m * 300) # reconstruct in chunks 241 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, vis_mesh_mc_res_m, None, False, 242 | config.semantic_on, config.color_on, 243 | filter_isolated_mesh=True, 244 | mesh_min_nn=vis_mesh_min_nn) 245 | # cur_sdf_slice = None 246 | 247 | if vis_sdf_on and (frame_id == 0 or frame_id == last_frame or (frame_id + 1) % vis_sdf_freq_frame == 0): 248 | sdf_bound = config.surface_sample_range_m * 4.0 249 | vis_sdf_bbx = create_bbx_o3d(dataset.cur_pose_ref[:3, 3], config.max_range / 2) 250 | cur_sdf_slice_h = mesher.generate_bbx_sdf_hor_slice(vis_sdf_bbx, dataset.cur_pose_ref[ 251 | 2, 3] + vis_sdf_slice_height, vis_sdf_res_m, True, -sdf_bound, 252 | sdf_bound) # horizontal slice (local) 253 | if config.vis_sdf_slice_v: 254 | cur_sdf_slice_v = mesher.generate_bbx_sdf_ver_slice(dataset.cur_bbx, dataset.cur_pose_ref[0, 3], 255 | vis_sdf_res_m, True, -sdf_bound, 256 | sdf_bound) # vertical slice (local) 257 | cur_sdf_slice = cur_sdf_slice_h + cur_sdf_slice_v 258 | else: 259 | cur_sdf_slice = cur_sdf_slice_h 260 | 261 | pool_pcd = mapper.get_data_pool_o3d(down_rate=37) 262 | odom_poses, gt_poses, pgo_poses = dataset.get_poses_np_for_vis() 263 | loop_edges = None 264 | # Only PIN-SLAM has 265 | # loop_edges = pgm.loop_edges_vis if config.pgo_on else None 266 | 267 | packet_to_vis: VisPacket = VisPacket(frame_id=frame_id, travel_dist=travel_dist[-1]) 268 | 269 | if not neural_points.is_empty(): 270 | packet_to_vis.add_neural_points_data(neural_points, only_local_map=(not vis_global_on), 271 | pca_color_on=config.decoder_freezed) 272 | 273 | if dataset.cur_frame_o3d is not None: 274 | packet_to_vis.add_scan(np.array(dataset.cur_frame_o3d.points, dtype=np.float64), 275 | np.array(dataset.cur_frame_o3d.colors, dtype=np.float64)) 276 | 277 | if cur_mesh is not None: 278 | packet_to_vis.add_mesh(np.array(cur_mesh.vertices, dtype=np.float64), np.array(cur_mesh.triangles), 279 | np.array(cur_mesh.vertex_colors, dtype=np.float64)) 280 | 281 | if cur_sdf_slice is not None: 282 | packet_to_vis.add_sdf_slice(np.array(cur_sdf_slice.points, dtype=np.float64), 283 | np.array(cur_sdf_slice.colors, dtype=np.float64)) 284 | 285 | if pool_pcd is not None: 286 | packet_to_vis.add_sdf_training_pool(np.array(pool_pcd.points, dtype=np.float64), 287 | np.array(pool_pcd.colors, dtype=np.float64)) 288 | 289 | packet_to_vis.add_traj(odom_poses, gt_poses, pgo_poses, loop_edges) 290 | 291 | q_main2vis.put(packet_to_vis) 292 | 293 | T8 = get_time() 294 | 295 | # if not config.silence: 296 | # print("time for o3d update (ms): {:.2f}".format((T7 - T6) * 1e3)) 297 | # print("time for visualization (ms): {:.2f}".format((T8 - T7) * 1e3)) 298 | 299 | cur_frame_process_time = np.array([T2 - T1, T3 - T2, T4 - T3, T5 - T4, 0]) 300 | # cur_frame_process_time = np.array([T2 - T1, T3 - T2, T5 - T4, T6 - T5, T4 - T3]) # loop & pgo in the end, visualization and I/O time excluded 301 | dataset.time_table.append(cur_frame_process_time) # in s 302 | 303 | if config.wandb_vis_on: 304 | wandb_log_content = {'frame': frame_id, 'timing(s)/preprocess': T2 - T1, 'timing(s)/tracking': T3 - T2, 305 | 'timing(s)/pgo': T4 - T3, 'timing(s)/mapping': T6 - T4} 306 | wandb.log(wandb_log_content) 307 | 308 | dataset.processed_frame += 1 309 | 310 | # V. 保存结果 311 | mapper.free_pool() 312 | pose_eval_results = dataset.write_results() 313 | 314 | neural_points.prune_map(config.max_prune_certainty, 0, True) # prune uncertain points for the final output 315 | neural_points.recreate_hash(None, None, False, False) # merge the final neural point map 316 | neural_pcd = neural_points.get_neural_points_o3d(query_global=True, color_mode=0) 317 | if config.save_map: 318 | o3d.io.write_point_cloud(os.path.join(run_path, "map", "neural_points.ply"), neural_pcd) # write the neural point cloud 319 | neural_points.clear_temp() # clear temp data for output 320 | 321 | output_mc_res_m = config.mc_res_m*0.6 322 | mc_cm_str = str(round(output_mc_res_m*1e2)) 323 | if config.save_mesh: 324 | chunks_aabb = split_chunks(neural_pcd, neural_pcd.get_axis_aligned_bounding_box(), output_mc_res_m * 300) # reconstruct in chunks 325 | mesh_path = os.path.join(run_path, "mesh", "mesh_" + mc_cm_str + "cm.ply") 326 | print("Reconstructing the global mesh with resolution {} cm".format(mc_cm_str)) 327 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, output_mc_res_m, mesh_path, False, config.semantic_on, config.color_on, filter_isolated_mesh=True, mesh_min_nn=config.mesh_min_nn) 328 | print("Reconstructing the global mesh done") 329 | neural_points.clear_temp() # clear temp data for output 330 | if config.save_map: 331 | save_implicit_map(run_path, neural_points, mlp_dict) 332 | # lcd_npmc.save_context_dict(mapper.used_poses, run_path) 333 | print("Use 'python vis_pin_map.py {} -m {} -o mesh_out_{}cm.ply' to inspect the map offline.".format(run_path, output_mc_res_m, mc_cm_str)) 334 | 335 | if config.save_merged_pc: 336 | dataset.write_merged_point_cloud() # replay: save merged point cloud map 337 | 338 | remove_gpu_cache() 339 | 340 | if config.o3d_vis_on: 341 | 342 | while True: 343 | if not q_vis2main.empty(): 344 | q_vis2main.get() 345 | 346 | packet_to_vis: VisPacket = VisPacket(frame_id=frame_id, travel_dist=travel_dist[-1], slam_finished=True) 347 | 348 | if not neural_points.is_empty(): 349 | packet_to_vis.add_neural_points_data(neural_points, only_local_map=False, 350 | pca_color_on=config.decoder_freezed) 351 | 352 | if cur_mesh is not None: 353 | packet_to_vis.add_mesh(np.array(cur_mesh.vertices, dtype=np.float64), np.array(cur_mesh.triangles), 354 | np.array(cur_mesh.vertex_colors, dtype=np.float64)) 355 | cur_mesh = None 356 | 357 | packet_to_vis.add_traj(odom_poses, gt_poses, pgo_poses, loop_edges) 358 | 359 | q_main2vis.put(packet_to_vis) 360 | time.sleep(1.0) 361 | 362 | return pose_eval_results 363 | 364 | 365 | if __name__ == "__main__": 366 | run_slam() 367 | -------------------------------------------------------------------------------- /tools.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "metadata": {}, 5 | "cell_type": "code", 6 | "outputs": [], 7 | "execution_count": null, 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "import csv\n", 12 | "import numpy as np\n", 13 | "import open3d as o3d\n", 14 | "from rosbag import Bag" 15 | ], 16 | "id": "2ea90e0351ac99dd" 17 | }, 18 | { 19 | "metadata": {}, 20 | "cell_type": "markdown", 21 | "source": "# Merge Multiple ROSbag Files", 22 | "id": "7bc6c26b767ca971" 23 | }, 24 | { 25 | "metadata": {}, 26 | "cell_type": "code", 27 | "outputs": [], 28 | "execution_count": null, 29 | "source": [ 30 | "def merge():\n", 31 | " input_folder = \"xxxxxx\"\n", 32 | " output_folder = \"xxxxxxx\"\n", 33 | " outbag_name = \"xxxxx\"\n", 34 | " input_bags = os.listdir(input_folder)\n", 35 | " input_bags.sort() # 根据文件名进行排序\n", 36 | " print(\"Writing bag file: \" + outbag_name)\n", 37 | "\n", 38 | " with Bag(os.path.join(output_folder, outbag_name), 'w') as ob:\n", 39 | " for ifile in input_bags:\n", 40 | " print(\"> Reading bag file: \" + ifile)\n", 41 | " with Bag(os.path.join(input_folder, ifile), 'r') as ib:\n", 42 | " for topic, msg, t in ib:\n", 43 | " ob.write(topic, msg, t)\n", 44 | "merge()" 45 | ], 46 | "id": "7de5937f52f49bc" 47 | }, 48 | { 49 | "metadata": {}, 50 | "cell_type": "markdown", 51 | "source": [ 52 | "\n", 53 | "# Convert Pose Format" 54 | ], 55 | "id": "dbb04396689dba20" 56 | }, 57 | { 58 | "metadata": {}, 59 | "cell_type": "code", 60 | "outputs": [], 61 | "execution_count": null, 62 | "source": [ 63 | "# 该文件的作用是将csv格式转换为标准的tum位姿格式\n", 64 | "\n", 65 | "input_file = './dataset/SubT_MRS/SubT_MRS_Urban_Challenge_UGV2/poses/ground_truth_path.csv'\n", 66 | "output_file = './dataset/SubT_MRS/SubT_MRS_Urban_Challenge_UGV2/poses/gt_poses_tum.txt'\n", 67 | "\n", 68 | "with open(input_file, 'r') as file:\n", 69 | " reader = csv.reader(file)\n", 70 | " header = next(reader) # Skip the header\n", 71 | " with open(output_file, 'w') as outfile:\n", 72 | " for row in reader:\n", 73 | " nsec, x, y, z, qx, qy, qz, qw = map(float, row)\n", 74 | " sec = nsec * 1e-9 # Convert nanoseconds to seconds\n", 75 | " output_line = f\"{sec} {x} {y} {z} {qx} {qy} {qz} {qw}\\n\"\n", 76 | " outfile.write(output_line)\n", 77 | "\n", 78 | "print(\"Conversion completed, file saved as\", output_file)\n" 79 | ], 80 | "id": "38bfd4ef6b5b3eea" 81 | }, 82 | { 83 | "metadata": {}, 84 | "cell_type": "markdown", 85 | "source": [ 86 | "\n", 87 | "# Mapping Performance Evaluation\n" 88 | ], 89 | "id": "d7bda514d936edca" 90 | }, 91 | { 92 | "metadata": {}, 93 | "cell_type": "code", 94 | "outputs": [], 95 | "execution_count": null, 96 | "source": [ 97 | "def quaternion_to_rotation_matrix(qx: float, qy: float, qz: float, qw: float) -> np.ndarray:\n", 98 | " \"\"\"\n", 99 | " Converts a quaternion into a 3x3 rotation matrix.\n", 100 | "\n", 101 | " Parameters:\n", 102 | " - qx (float): X component of the quaternion.\n", 103 | " - qy (float): Y component of the quaternion.\n", 104 | " - qz (float): Z component of the quaternion.\n", 105 | " - qw (float): W (scalar) component of the quaternion.\n", 106 | "\n", 107 | " Returns:\n", 108 | " - np.ndarray: A 3x3 NumPy array representing the rotation matrix.\n", 109 | " \"\"\"\n", 110 | " # Normalize the quaternion to ensure a valid rotation\n", 111 | " norm = np.sqrt(qx**2 + qy**2 + qz**2 + qw**2)\n", 112 | " qx /= norm\n", 113 | " qy /= norm\n", 114 | " qz /= norm\n", 115 | " qw /= norm\n", 116 | "\n", 117 | " # Compute the rotation matrix using the normalized quaternion\n", 118 | " rotation_matrix = np.array([\n", 119 | " [1 - 2 * (qy**2 + qz**2), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],\n", 120 | " [2 * (qx * qy + qz * qw), 1 - 2 * (qx**2 + qz**2), 2 * (qy * qz - qx * qw)],\n", 121 | " [2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx**2 + qy**2)]\n", 122 | " ])\n", 123 | "\n", 124 | " return rotation_matrix" 125 | ], 126 | "id": "5431eea935e8ad5e" 127 | }, 128 | { 129 | "metadata": {}, 130 | "cell_type": "markdown", 131 | "source": "## Newer College Dataset Ncd Sequence", 132 | "id": "6d77fde250a16151" 133 | }, 134 | { 135 | "metadata": {}, 136 | "cell_type": "code", 137 | "outputs": [], 138 | "execution_count": null, 139 | "source": [ 140 | "# ImMesh\n", 141 | "matrix_values = [\n", 142 | " 5.304626993818075675e-01, -8.474622417882305969e-01, 2.042261633276983360e-02, -8.377865843928848644e-02,\n", 143 | " 8.463450710843981595e-01, 5.308216667107832354e-01, 4.391332211528171242e-02, 3.370663104058911230e+00,\n", 144 | " -4.805564502133059107e-02, -6.009799083416286596e-03, 9.988265833638158009e-01, 7.037440120229881968e-01\n", 145 | "]\n", 146 | "T_lidar = np.vstack([np.array(matrix_values).reshape(3, 4), [0, 0, 0, 1]])\n", 147 | "\n", 148 | "T_lidar_imu = np.array([[-1.0, 0, 0, -0.006253],\n", 149 | " [0, -1.0, 0, 0.011775],\n", 150 | " [0, 0, 1.0, -0.028535],\n", 151 | " [0, 0, 0, 1]])" 152 | ], 153 | "id": "fb8b78b5c4feb986" 154 | }, 155 | { 156 | "metadata": {}, 157 | "cell_type": "markdown", 158 | "source": "## Newer College Dataset Extension Math Easy Sequence", 159 | "id": "c864e629b6013c7c" 160 | }, 161 | { 162 | "metadata": {}, 163 | "cell_type": "code", 164 | "outputs": [], 165 | "execution_count": null, 166 | "source": [ 167 | "T_lidar = np.eye(4)\n", 168 | "\n", 169 | "# SLAMesh\n", 170 | "T_lidar[:3, :3] = quaternion_to_rotation_matrix(-0.00987445, 0.00774057, 0.842868, 0.537974)\n", 171 | "T_lidar[:3, 3] = np.array([-23.7176, -31.2646, 1.03258])\n", 172 | "\n", 173 | "# ImMesh\n", 174 | "# T_lidar[:3, :3] = quaternion_to_rotation_matrix(-0.00248205, 0.00444627, 0.842838, 0.538143)\n", 175 | "# T_lidar[:3, 3] = np.array([-23.7202, -31.2861, 1.04326])\n", 176 | "\n", 177 | "T_lidar_imu = np.array([\n", 178 | " [1.0, 0, 0, 0.014],\n", 179 | " [0, 1.0, 0, -0.012],\n", 180 | " [0, 0, 1.0, -0.015],\n", 181 | " [0, 0, 0, 1.0]])\n", 182 | "\n", 183 | "T = T_lidar @ T_lidar_imu" 184 | ], 185 | "id": "e9a22013e913de27" 186 | }, 187 | { 188 | "metadata": {}, 189 | "cell_type": "markdown", 190 | "source": "## Transform Mesh To Align The Ground Truth", 191 | "id": "61456f82338bb7e0" 192 | }, 193 | { 194 | "metadata": {}, 195 | "cell_type": "code", 196 | "outputs": [], 197 | "execution_count": null, 198 | "source": [ 199 | "def mesh_transform():\n", 200 | " mesh_file = \"./math_easy.ply\"\n", 201 | " output_file = \"./math_easy_transformed.ply\"\n", 202 | "\n", 203 | " if not os.path.exists(mesh_file):\n", 204 | " sys.exit(f\"Mesh file {mesh_file} does not exist.\")\n", 205 | " print(\"Loading Mesh file: \", mesh_file)\n", 206 | "\n", 207 | " # Load Mesh\n", 208 | " mesh = o3d.io.read_triangle_mesh(mesh_file)\n", 209 | " mesh.compute_vertex_normals()\n", 210 | " transformation_matrix = np.array([[-4.20791735e-01, -9.07157072e-01, 6.01526210e-04, -2.37152142e+01],\n", 211 | " [ 9.07112929e-01, -4.20764518e-01, 1.01663692e-02, -3.12685037e+01],\n", 212 | " [-8.96939285e-03, 4.82357635e-03, 9.99948140e-01, 1.02807732e+00],\n", 213 | " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]])\n", 214 | "\n", 215 | " mesh.transform(transformation_matrix)\n", 216 | " o3d.io.write_triangle_mesh(output_file, mesh)\n", 217 | "\n", 218 | "mesh_transform()" 219 | ], 220 | "id": "532df31dd565ed8b" 221 | }, 222 | { 223 | "metadata": {}, 224 | "cell_type": "markdown", 225 | "source": "# Visualize Mesh\n", 226 | "id": "52c9c86e3ad5e34f" 227 | }, 228 | { 229 | "metadata": {}, 230 | "cell_type": "code", 231 | "outputs": [], 232 | "execution_count": null, 233 | "source": [ 234 | "def vis_mesh():\n", 235 | " mesh_file = \"./ours_mesh_20cm.ply\" # the path of the mesh which you want to visualize\n", 236 | " if not os.path.exists(mesh_file):\n", 237 | " sys.exit(f\"Mesh file {mesh_file} does not exist.\")\n", 238 | " print(\"Loading Mesh file: \", mesh_file)\n", 239 | "\n", 240 | " # Load the mesh\n", 241 | " mesh = o3d.io.read_triangle_mesh(mesh_file)\n", 242 | " mesh.compute_vertex_normals()\n", 243 | " # Check if the mesh was loaded successfully\n", 244 | " if not mesh.has_vertices():\n", 245 | " sys.exit(\"Failed to load the mesh. No vertices found.\")\n", 246 | " vis = o3d.visualization.Visualizer()\n", 247 | " vis.create_window(window_name=\"Mesh Visualization\")\n", 248 | " vis.add_geometry(mesh)\n", 249 | " opt = vis.get_render_option()\n", 250 | " opt.light_on = True # Enable lighting to show the mesh color\n", 251 | " opt.mesh_show_back_face = True\n", 252 | " # Enable shortcuts in the console (e.g., Ctrl+9)\n", 253 | " vis.run()\n", 254 | " vis.destroy_window()\n", 255 | "\n", 256 | "vis_mesh()" 257 | ], 258 | "id": "abe2bf68d78b341e" 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "Python 3", 264 | "language": "python", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "codemirror_mode": { 269 | "name": "ipython", 270 | "version": 2 271 | }, 272 | "file_extension": ".py", 273 | "mimetype": "text/x-python", 274 | "name": "python", 275 | "nbconvert_exporter": "python", 276 | "pygments_lexer": "ipython2", 277 | "version": "2.7.6" 278 | } 279 | }, 280 | "nbformat": 4, 281 | "nbformat_minor": 5 282 | } 283 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/config.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_sampler.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/data_sampler.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_indexing.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/dataset_indexing.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/error_state_iekf.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/error_state_iekf.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_traj_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/eval_traj_utils.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/loss.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mapper.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/mapper.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mesher.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/mesher.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/semantic_kitti_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/semantic_kitti_utils.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/slam_dataset.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/slam_dataset.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/so3_math.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/so3_math.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/tools.cpython-312.pyc -------------------------------------------------------------------------------- /utils/data_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file data_sampler.py 3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn] 4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved 5 | 6 | import torch 7 | 8 | from model.local_point_cloud_map import LocalPointCloudMap 9 | from utils.config import Config 10 | from utils.tools import transform_torch 11 | 12 | 13 | class DataSampler: 14 | def __init__(self, config: Config): 15 | self.config = config 16 | self.dev = config.device 17 | 18 | def sample(self, points_torch, local_point_cloud_map: LocalPointCloudMap, cur_pose_torch): 19 | dev = self.dev 20 | surface_sample_range = self.config.surface_sample_range_m 21 | surface_sample_n = self.config.surface_sample_n 22 | freespace_behind_sample_n = self.config.free_behind_n 23 | freespace_front_sample_n = self.config.free_front_n 24 | all_sample_n = ( 25 | surface_sample_n + freespace_behind_sample_n + freespace_front_sample_n + 1 26 | ) # 1 as the exact measurement 27 | free_front_min_ratio = self.config.free_sample_begin_ratio 28 | free_sample_end_dist = self.config.free_sample_end_dist_m 29 | 30 | # get sample points 31 | point_num = points_torch.shape[0] 32 | distances = torch.linalg.norm( 33 | points_torch, dim=1, keepdim=True 34 | ) # ray distances (scaled) 35 | 36 | # Part 0. the exact measured point 37 | measured_sample_displacement = torch.zeros_like(distances) 38 | measured_sample_dist_ratio = torch.ones_like(distances) 39 | 40 | # Part 1. close-to-surface uniform sampling 41 | # uniform sample in the close-to-surface range (+- range) 42 | surface_sample_displacement = ( 43 | torch.randn(point_num * surface_sample_n, 1, device=dev) 44 | * surface_sample_range 45 | ) 46 | 47 | repeated_dist = distances.repeat(surface_sample_n, 1) 48 | surface_sample_dist_ratio = ( 49 | surface_sample_displacement / repeated_dist + 1.0 50 | ) # 1.0 means on the surface 51 | 52 | 53 | # Part 2. free space (in front of surface) uniform sampling 54 | # if you want to reconstruct the thin objects (like poles, tree branches) well, you need more freespace samples to have 55 | # a space carving effect 56 | 57 | sigma_ratio = 2.0 58 | repeated_dist = distances.repeat(freespace_front_sample_n, 1) 59 | free_max_ratio = 1.0 - sigma_ratio * surface_sample_range / repeated_dist 60 | free_diff_ratio = free_max_ratio - free_front_min_ratio 61 | free_sample_front_dist_ratio = ( 62 | torch.rand(point_num * freespace_front_sample_n, 1, device=dev) 63 | * free_diff_ratio 64 | + free_front_min_ratio 65 | ) 66 | free_sample_front_displacement = ( 67 | free_sample_front_dist_ratio - 1.0 68 | ) * repeated_dist 69 | 70 | # Part 3. free space (behind surface) uniform sampling 71 | repeated_dist = distances.repeat(freespace_behind_sample_n, 1) 72 | free_max_ratio = free_sample_end_dist / repeated_dist + 1.0 73 | free_behind_min_ratio = 1.0 + sigma_ratio * surface_sample_range / repeated_dist 74 | free_diff_ratio = free_max_ratio - free_behind_min_ratio 75 | 76 | free_sample_behind_dist_ratio = ( 77 | torch.rand(point_num * freespace_behind_sample_n, 1, device=dev) 78 | * free_diff_ratio 79 | + free_behind_min_ratio 80 | ) 81 | 82 | free_sample_behind_displacement = ( 83 | free_sample_behind_dist_ratio - 1.0 84 | ) * repeated_dist 85 | 86 | 87 | # all together 88 | all_sample_displacement = torch.cat( 89 | ( 90 | measured_sample_displacement, 91 | surface_sample_displacement, 92 | free_sample_front_displacement, 93 | free_sample_behind_displacement, 94 | ), 95 | 0, 96 | ) 97 | all_sample_dist_ratio = torch.cat( 98 | ( 99 | measured_sample_dist_ratio, 100 | surface_sample_dist_ratio, 101 | free_sample_front_dist_ratio, 102 | free_sample_behind_dist_ratio, 103 | ), 104 | 0, 105 | ) 106 | 107 | repeated_points = points_torch.repeat(all_sample_n, 1) 108 | repeated_dist = distances.repeat(all_sample_n, 1) 109 | all_sample_points = repeated_points * all_sample_dist_ratio 110 | ####################################### Added By Jiang Junlong ################################################# 111 | # 根据表面采样平移量计算符号 112 | sdf_sign = torch.where(surface_sample_displacement.squeeze(1) < 0, 1, -1) 113 | mask = torch.ones(point_num * all_sample_n, dtype=torch.bool, device=self.config.device) 114 | 115 | sdf_label_tensor = -1 * all_sample_displacement.squeeze(1) 116 | surface_sample_count = point_num * (surface_sample_n + 1) 117 | surface_sample_points = all_sample_points[point_num: surface_sample_count] 118 | surface_sample_points_G = transform_torch(surface_sample_points, cur_pose_torch) 119 | dist, valid_mask = local_point_cloud_map.region_specific_sdf_estimation(surface_sample_points_G) 120 | mask[point_num:surface_sample_count] = valid_mask 121 | sdf_label_tensor[point_num:surface_sample_count] = sdf_sign * dist 122 | # sdf_label = torch.clamp(sdf_label, -0.4, 0.4) 123 | 124 | # depth tensor of all the samples 125 | depths_tensor = repeated_dist * all_sample_dist_ratio 126 | # get the weight vector as the inverse of sigma 127 | weight_tensor = torch.ones_like(depths_tensor) 128 | if self.config.dist_weight_on: # far away surface samples would have lower weight 129 | weight_tensor[:surface_sample_count] = ( 130 | 1 131 | + self.config.dist_weight_scale * 0.5 132 | - (repeated_dist[:surface_sample_count] / self.config.max_range) 133 | * self.config.dist_weight_scale 134 | ) # [0.6, 1.4] 135 | 136 | weight_tensor[surface_sample_count:] *= -1.0 137 | 138 | # Convert from the all ray surface + all ray free order to the ray-wise (surface + free) order 139 | 140 | all_sample_points = ( 141 | all_sample_points.reshape(all_sample_n, -1, 3) 142 | .transpose(0, 1) 143 | .reshape(-1, 3) 144 | ) 145 | sdf_label_tensor = ( 146 | sdf_label_tensor.reshape(all_sample_n, -1).transpose(0, 1).reshape(-1) 147 | ) 148 | 149 | weight_tensor = ( 150 | weight_tensor.reshape(all_sample_n, -1).transpose(0, 1).reshape(-1) 151 | ) 152 | 153 | mask = ( 154 | mask.reshape(all_sample_n, -1).transpose(0, 1).reshape(-1) 155 | ) 156 | return all_sample_points[mask], sdf_label_tensor[mask], weight_tensor[mask] 157 | -------------------------------------------------------------------------------- /utils/dataset_indexing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file dataset_indexing.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | import os 7 | 8 | from utils.config import Config 9 | 10 | def set_dataset_path(config: Config, dataset_name: str = '', seq: str = ''): 11 | 12 | config.name = config.name + '_' + dataset_name + '_' + seq.replace("/", "") 13 | 14 | if config.use_kiss_dataloader: 15 | config.data_loader_name = dataset_name 16 | config.data_loader_seq = seq 17 | print('KISS-ICP data loaders used') 18 | from kiss_icp.datasets import available_dataloaders 19 | print('Available dataloaders:', available_dataloaders()) 20 | 21 | else: 22 | if dataset_name == "kitti": 23 | base_path = config.pc_path.rsplit('/', 3)[0] 24 | config.pc_path = os.path.join(base_path, 'sequences', seq, "velodyne") # input point cloud folder 25 | pose_file_name = seq + '.txt' 26 | config.pose_path = os.path.join(base_path, 'poses', pose_file_name) # input pose file 27 | config.calib_path = os.path.join(base_path, 'sequences', seq, "calib.txt") # input calib file (to sensor frame) 28 | config.label_path = os.path.join(base_path, 'sequences', seq, "labels") # input point-wise label path, for semantic mapping (optional) 29 | config.kitti_correction_on = True 30 | config.correction_deg = 0.195 31 | elif dataset_name == "mulran": 32 | config.name = config.name + "_mulran_" + seq 33 | base_path = config.pc_path.rsplit('/', 2)[0] 34 | config.pc_path = os.path.join(base_path, seq, "Ouster") # input point cloud folder 35 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 36 | elif dataset_name == "kitti_carla": 37 | config.name = config.name + "_kitti_carla_" + seq 38 | base_path = config.pc_path.rsplit('/', 3)[0] 39 | config.pc_path = os.path.join(base_path, seq, "generated", "frames") # input point cloud folder 40 | config.pose_path = os.path.join(base_path, seq, "generated", "poses.txt") # input pose file 41 | config.calib_path = os.path.join(base_path, seq, "generated", "calib.txt") # input calib file (to sensor frame) 42 | elif dataset_name == "ncd": 43 | config.name = config.name + "_ncd_" + seq 44 | base_path = config.pc_path.rsplit('/', 2)[0] 45 | config.pc_path = os.path.join(base_path, seq, "bin") # input point cloud folder 46 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 47 | config.calib_path = os.path.join(base_path, seq, "calib.txt") # input calib file (to sensor frame) 48 | elif dataset_name == "ncd128": 49 | config.name = config.name + "_ncd128_" + seq 50 | base_path = config.pc_path.rsplit('/', 2)[0] 51 | config.pc_path = os.path.join(base_path, seq, "ply") # input point cloud folder 52 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 53 | elif dataset_name == "ipbcar": 54 | config.name = config.name + "_ipbcar_" + seq 55 | base_path = config.pc_path.rsplit('/', 2)[0] 56 | config.pc_path = os.path.join(base_path, seq, "ouster") # input point cloud folder 57 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 58 | config.calib_path = os.path.join(base_path, seq, "calib.txt") # input calib file (to sensor frame) 59 | elif dataset_name == "hilti": 60 | config.name = config.name + "_hilti_" + seq 61 | base_path = config.pc_path.rsplit('/', 2)[0] 62 | config.pc_path = os.path.join(base_path, seq, "ply") # input point cloud folder 63 | # config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 64 | elif dataset_name == "m2dgr": 65 | config.name = config.name + "_m2dgr_" + seq 66 | base_path = config.pc_path.rsplit('/', 2)[0] 67 | config.pc_path = os.path.join(base_path, seq, "points") # input point cloud folder 68 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 69 | elif dataset_name == "replica": 70 | config.name = config.name + "_replica_" + seq 71 | base_path = config.pc_path.rsplit('/', 2)[0] 72 | config.pc_path = os.path.join(base_path, seq, "rgbd_down_ply") # input point cloud folder 73 | #config.pc_path = os.path.join(base_path, seq, "rgbd_ply") # input point cloud folder 74 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file 75 | -------------------------------------------------------------------------------- /utils/error_state_iekf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file error_state_iekf.py 3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn] 4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved 5 | 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from model.decoder import Decoder 12 | from model.neural_points import NeuralPoints 13 | from utils.config import Config 14 | from utils.so3_math import vec2skew, so3Exp, SO3Log, vectors_to_skew_symmetric 15 | from utils.tools import get_gradient, transform_torch 16 | 17 | G_m_s2 = 9.81 # 定义全局重力加速度 18 | 19 | 20 | class StateIkfom: 21 | """18维的状态量x定义: 对应顺序为旋转(3), 位置(3), 速度(3), 角速度偏置(3), 加速度偏置(3), 重力向量(3)""" 22 | def __init__( 23 | self, 24 | dtype, 25 | pos=None, 26 | rot=None, 27 | vel=None, bg=None, 28 | ba=None, 29 | grav=None 30 | ): 31 | self.dtype = dtype 32 | self.rot = torch.eye(3, dtype=self.dtype) if rot is None else rot 33 | self.pos = torch.zeros(3, dtype=self.dtype) if pos is None else pos 34 | self.vel = torch.zeros(3, dtype=self.dtype) if vel is None else vel 35 | self.bg = torch.zeros(3, dtype=self.dtype) if bg is None else bg 36 | self.ba = torch.zeros(3, dtype=self.dtype) if ba is None else ba 37 | self.grav = torch.tensor([0.0, 0.0, -G_m_s2], dtype=self.dtype) if grav is None else grav 38 | 39 | def cpu(self): 40 | """将所有张量转移到CPU""" 41 | self.rot = self.rot.cpu() 42 | self.pos = self.pos.cpu() 43 | self.vel = self.vel.cpu() 44 | self.bg = self.bg.cpu() 45 | self.ba = self.ba.cpu() 46 | self.grav = self.grav.cpu() 47 | 48 | def cuda(self): 49 | """将所有张量转移到GPU""" 50 | self.rot = self.rot.cuda() 51 | self.pos = self.pos.cuda() 52 | self.vel = self.vel.cuda() 53 | self.bg = self.bg.cuda() 54 | self.ba = self.ba.cuda() 55 | self.grav = self.grav.cuda() 56 | 57 | 58 | class InputIkfom: 59 | """输入向量类定义,用于表示陀螺仪和加速度计的测量值。""" 60 | 61 | def __init__( 62 | self, 63 | dtype, 64 | acc=np.array, 65 | gyro=np.array 66 | ): 67 | self.dtype = dtype 68 | self.acc = torch.tensor(acc, dtype=self.dtype) 69 | self.gyro = torch.tensor(gyro, dtype=self.dtype) 70 | 71 | 72 | def boxplus(state: StateIkfom, delta: torch.tensor): 73 | """广义加法操作""" 74 | new_state = StateIkfom(state.dtype) 75 | new_state.rot = state.rot @ so3Exp(delta[0:3]) 76 | new_state.pos = state.pos + delta[3:6] 77 | new_state.vel = state.vel + delta[6:9] 78 | new_state.bg = state.bg + delta[9:12] 79 | new_state.ba = state.ba + delta[12:15] 80 | new_state.grav = state.grav + delta[15:18] 81 | return new_state 82 | 83 | 84 | def boxminus(x1: StateIkfom, x2: StateIkfom): 85 | """广义减法操作,计算两个状态之间的差""" 86 | delta_rot = SO3Log(x2.rot.T @ x1.rot) 87 | delta_pos = x1.pos - x2.pos 88 | delta_vel = x1.vel - x2.vel 89 | delta_bg = x1.bg - x2.bg 90 | delta_ba = x1.ba - x2.ba 91 | delta_grav = x1.grav - x2.grav 92 | delta = torch.concatenate([delta_rot, delta_pos, delta_vel, delta_bg, delta_ba, delta_grav]) 93 | return delta 94 | 95 | 96 | class IEKFOM: 97 | """迭代扩展卡尔曼滤波器类""" 98 | 99 | def __init__( 100 | self, 101 | config: Config, 102 | neural_points: NeuralPoints, 103 | geo_decoder: Decoder, 104 | ): 105 | self.config = config 106 | self.silence = config.silence 107 | self.neural_points = neural_points 108 | self.geo_decoder = geo_decoder 109 | self.device = self.config.device 110 | self.dtype = config.dtype 111 | self.tran_dtype = config.tran_dtype 112 | 113 | self.x = StateIkfom(self.tran_dtype) # 初始化状态 114 | self.P = torch.eye(18, dtype=self.tran_dtype) # 初始化状态协方差矩阵 115 | self.P[9:12, 9:12] = self.P[9:12, 9:12] * 1e-4 # 初始陀螺仪偏置协方差 116 | self.P[12:15, 12:15] = self.P[12:15, 12:15] * 1e-3 # 初始加速度计协方差 117 | self.P[15:18, 15:18] = self.P[15:18, 15:18] * 1e-4 # 初始重力协方差 118 | self.Q = self.process_noise_covariance() # 前向传播白噪声协方差 119 | self.R_inv = None # 测量噪声协方差 120 | self.eps = 0.001 # 收敛阈值 121 | self.max_iteration = self.config.reg_iter_n # 最大迭代轮数 122 | 123 | def process_noise_covariance(self): 124 | """噪声协方差Q的初始化""" 125 | Q = torch.zeros((12, 12), dtype=self.config.tran_dtype) 126 | Q[:3, :3] = self.config.measurement_noise_covariance * torch.eye(3) 127 | Q[3:6, 3:6] = self.config.measurement_noise_covariance * torch.eye(3) 128 | Q[6:9, 6:9] = self.config.bias_noise_covariance * torch.eye(3) 129 | Q[9:12, 9:12] = self.config.bias_noise_covariance * torch.eye(3) 130 | return Q 131 | 132 | def df_dx(self, s: StateIkfom, in_: InputIkfom, dt: float): 133 | """计算状态转移函数的雅可比矩阵""" 134 | # omega_ = in_.gyro - s.bg 135 | acc_ = in_.acc - s.ba 136 | 137 | df_dx = torch.eye(18, dtype=self.tran_dtype) 138 | I_dt = torch.eye(3, dtype=self.tran_dtype) * dt 139 | # cov[0:3, 0:3] = so3Exp(-omega_ * dt) 140 | df_dx[0:3, 0:3] = torch.eye(3, dtype=self.tran_dtype) # so3Exp(-omega_ * dt) 可以近似为I 141 | df_dx[0:3, 9:12] = -I_dt 142 | df_dx[3:6, 6:9] = I_dt 143 | df_dx[6:9, 0:3] = -s.rot @ vec2skew(acc_) * dt 144 | df_dx[6:9, 12:15] = -s.rot * dt 145 | df_dx[6:9, 15:18] = I_dt 146 | 147 | return df_dx 148 | 149 | def df_dw(self, s: StateIkfom, in_: InputIkfom, dt: float): 150 | """计算过程噪声的雅可比矩阵""" 151 | # omega_ = in_.gyro - s.bg 152 | I = torch.eye(3, dtype=self.tran_dtype) 153 | cov = torch.zeros((18, 12), dtype=self.tran_dtype) 154 | # cov[0:3, 0:3] = -A_T(omega_ * dt) 155 | cov[0:3, 0:3] = -I # -A(w dt)可以简化为-I 156 | cov[6:9, 3:6] = -s.rot # -R 157 | cov[9:12, 6:9] = I # I 158 | cov[12:15, 9:12] = I # I 159 | cov = cov * dt 160 | 161 | return cov 162 | 163 | def predict(self, i_in: InputIkfom, dt: float): 164 | """前向传播,在cpu上执行前向传播效率高得多""" 165 | f = self.f_model(self.x, i_in) 166 | df_dx = self.df_dx(self.x, i_in, dt) 167 | df_dw = self.df_dw(self.x, i_in, dt) 168 | 169 | self.x = boxplus(self.x, f * dt) 170 | self.P = df_dx @ self.P @ df_dx.T + df_dw @ self.Q @ df_dw.T 171 | 172 | def f_model(self, s: StateIkfom, in_: InputIkfom): 173 | """获取运动方程,用于描述状态如何随时间演变""" 174 | res = torch.zeros(18, dtype=self.tran_dtype) 175 | a_inertial = s.rot @ (in_.acc - s.ba) + s.grav 176 | res[:3] = in_.gyro - s.bg 177 | res[3:6] = s.vel 178 | res[6:9] = a_inertial 179 | return res 180 | 181 | def h_model(self, point_cloud_imu: torch.tensor): 182 | bs = self.config.infer_bs 183 | mask_min_nn_count = self.config.track_mask_query_nn_k 184 | min_grad_norm = self.config.reg_min_grad_norm 185 | max_grad_norm = self.config.reg_max_grad_norm 186 | 187 | pose = torch.eye(4) 188 | pose[:3, :3] = self.x.rot 189 | pose[:3, 3] = self.x.pos 190 | 191 | point_cloud_global = transform_torch(point_cloud_imu, pose) 192 | sample_count = point_cloud_global.shape[0] 193 | iter_n = math.ceil(sample_count / bs) 194 | 195 | sdf_pred = torch.zeros(sample_count, device=point_cloud_global.device) 196 | sdf_std = torch.zeros(sample_count, device=point_cloud_global.device) 197 | mc_mask = torch.zeros(sample_count, device=point_cloud_global.device, dtype=torch.bool) 198 | sdf_grad = torch.zeros((sample_count, 3), device=point_cloud_global.device) 199 | certainty = torch.zeros(sample_count, device=point_cloud_global.device) 200 | 201 | # 为了避免爆显存,采用分批处理的办法 202 | # 计算点云的SDF预测值 203 | for n in range(iter_n): 204 | head = n * bs 205 | tail = min((n + 1) * bs, sample_count) 206 | batch_coord = point_cloud_global[head:tail, :] 207 | batch_coord.requires_grad_(True) 208 | 209 | ( 210 | batch_geo_feature, 211 | _, 212 | weight_knn, 213 | nn_count, 214 | batch_certainty, 215 | ) = self.neural_points.query_feature( 216 | batch_coord, 217 | training_mode=False, 218 | query_locally=True, 219 | query_color_feature=False, 220 | ) # inference mode 221 | 222 | batch_sdf = self.geo_decoder.sdf(batch_geo_feature) 223 | if not self.config.weighted_first: 224 | batch_sdf_mean = torch.sum(batch_sdf * weight_knn, dim=1) 225 | batch_sdf_var = torch.sum((weight_knn * (batch_sdf - batch_sdf_mean.unsqueeze(-1)) ** 2), dim=1) 226 | batch_sdf_std = torch.sqrt(batch_sdf_var).squeeze(1) 227 | batch_sdf = batch_sdf_mean.squeeze(1) 228 | sdf_std[head:tail] = batch_sdf_std.detach() 229 | 230 | batch_sdf_grad = get_gradient(batch_coord, batch_sdf) 231 | sdf_grad[head:tail, :] = batch_sdf_grad.detach() 232 | sdf_pred[head:tail] = batch_sdf.detach() 233 | mc_mask[head:tail] = nn_count >= mask_min_nn_count 234 | certainty[head:tail] = batch_certainty.detach() 235 | 236 | # 剔除异常预测(也是滤波器的观测) 237 | grad_norm = sdf_grad.norm(dim=-1, keepdim=True).squeeze() 238 | max_sdf_std = self.config.surface_sample_range_m * self.config.max_sdf_std_ratio 239 | valid_idx = ( 240 | mc_mask 241 | & (grad_norm < max_grad_norm) 242 | & (grad_norm > min_grad_norm) 243 | & (sdf_std < max_sdf_std) 244 | ) 245 | valid_points = point_cloud_global[valid_idx] 246 | valid_point_count = valid_points.shape[0] 247 | point_cloud_imu = point_cloud_imu[valid_idx] 248 | grad_norm = grad_norm[valid_idx] 249 | sdf_pred = sdf_pred[valid_idx] 250 | sdf_grad = sdf_grad[valid_idx] 251 | 252 | # 计算雅可比矩阵 253 | H = torch.zeros((valid_point_count, 18), device=self.device, dtype=self.tran_dtype) 254 | pc_imu_hat = vectors_to_skew_symmetric(point_cloud_imu) 255 | rotation = self.x.rot.to(dtype=self.dtype).unsqueeze(0) 256 | A = torch.bmm(rotation.repeat(valid_point_count, 1, 1), pc_imu_hat) 257 | H[:, 0: 3] = -torch.bmm(sdf_grad.unsqueeze(1), A).squeeze(1) 258 | H[:, 3: 6] = sdf_grad 259 | 260 | # 计算不确定度(对精度有一个轻微的提升) 261 | sdf_residual = sdf_pred.to(dtype=self.tran_dtype) 262 | grad_anomaly = (grad_norm - 1.0).to(dtype=self.tran_dtype) 263 | w_grad = 1 / (1 + grad_anomaly ** 2) 264 | w_res = 0.4 / (0.4 + sdf_residual ** 2) 265 | self.R_inv = w_grad * w_res * 1000 266 | 267 | return sdf_residual, H, valid_points 268 | 269 | def update_iterated(self, point_cloud_imu: torch.tensor): 270 | """ 271 | 使用迭代方法更新状态估计。 272 | 273 | Args: 274 | source_points (np.array): 测量点云,假定为 Nx3 矩阵。 275 | maximum_iter (int): 最大迭代次数。 276 | """ 277 | # 将状态量和协方差矩阵转移到GPU 278 | self.x.cuda() 279 | self.P = self.P.cuda() 280 | valid_flag = True 281 | converged = False 282 | 283 | x_propagated = self.x 284 | P_inv = torch.linalg.inv(self.P) 285 | I = torch.eye(18, device=self.device, dtype=self.tran_dtype) 286 | term_thre_deg = self.config.reg_term_thre_deg 287 | term_thre_m = self.config.reg_term_thre_m 288 | 289 | for i in range(self.max_iteration): 290 | dx_new = boxminus(self.x, x_propagated) 291 | z, H, valid_points = self.h_model(point_cloud_imu) 292 | valid_point_count = valid_points.shape[0] 293 | source_point_count = point_cloud_imu.shape[0] 294 | 295 | if valid_point_count / source_point_count < 0.2 and i == self.max_iteration - 1: 296 | if not self.config.silence: 297 | print("[bold yellow](Warning) registration failed: not enough valid points[/bold yellow]") 298 | valid_flag = False 299 | 300 | H_T_R_inv = H.T * self.R_inv 301 | S = H_T_R_inv @ H 302 | 303 | K_front = torch.linalg.inv(S + P_inv) 304 | K = K_front @ H_T_R_inv 305 | 306 | dx_ = -K @ z + (K @ H - I) @ dx_new 307 | self.x = boxplus(self.x, dx_) 308 | tran_m = dx_[3:6].norm() 309 | rot_angle_deg = dx_[0:3].norm() * 180.0 / np.pi 310 | 311 | # 第一种迭代终止判定方式(有一定的物理含义) 312 | if rot_angle_deg < term_thre_deg and tran_m < term_thre_m and torch.all(torch.abs(dx_[6:]) < self.eps): 313 | if not self.config.silence: 314 | print("Converged after", i, "iterations") 315 | converged = True 316 | 317 | # 第二种迭代终止判定方式 318 | # if torch.all(torch.abs(dx_) < self.eps): 319 | # if not self.config.silence: 320 | # print("Converged after", i, "iterations") 321 | # converged = True 322 | 323 | if not valid_flag or converged: 324 | break 325 | 326 | self.P = (I - K @ H) @ self.P 327 | updated_pose = torch.eye(4, dtype=self.dtype, device=self.device) 328 | updated_pose[:3, :3] = self.x.rot.to(self.dtype) 329 | updated_pose[:3, 3] = self.x.pos.to(self.dtype) 330 | 331 | # 将状态量和协方差矩阵转移到CPU 332 | self.x.cpu() 333 | self.P = self.P.cpu() 334 | return updated_pose, valid_flag -------------------------------------------------------------------------------- /utils/eval_traj_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file eval_traj_utils.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | from collections import defaultdict 7 | from typing import Dict, List, Tuple 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | 13 | def absolute_error( 14 | poses_gt: np.ndarray, poses_result: np.ndarray, align_on: bool = True 15 | ): 16 | assert poses_gt.shape[0] == poses_result.shape[0], "poses length should be identical" 17 | align_mat = np.eye(4) 18 | if align_on: 19 | align_rot, align_tran, _ = align_traj(poses_result, poses_gt) 20 | align_mat[:3, :3] = align_rot 21 | align_mat[:3, 3] = np.squeeze(align_tran) 22 | 23 | frame_count = poses_gt.shape[0] 24 | 25 | rot_errors = [] 26 | tran_errors = [] 27 | 28 | for i in range(frame_count): 29 | cur_results_pose_aligned = align_mat @ poses_result[i] 30 | cur_gt_pose = poses_gt[i] 31 | delta_rot = ( 32 | np.linalg.inv(cur_gt_pose[:3, :3]) @ cur_results_pose_aligned[:3, :3] 33 | ) 34 | delta_tran = cur_gt_pose[:3, 3] - cur_results_pose_aligned[:3, 3] 35 | 36 | # the one used for kiss-icp 37 | # delta_tran = cur_gt_pose[:3,3] - delta_rot @ cur_results_pose_aligned[:3,3] 38 | 39 | delta_rot_theta = rotation_error(delta_rot) 40 | delta_t = np.linalg.norm(delta_tran) 41 | 42 | rot_errors.append(delta_rot_theta) 43 | tran_errors.append(delta_t) 44 | 45 | rot_errors = np.array(rot_errors) 46 | tran_errors = np.array(tran_errors) 47 | 48 | rot_rmse = np.sqrt(np.dot(rot_errors, rot_errors) / frame_count) * 180.0 / np.pi # this seems to have some problem 49 | tran_rmse = np.sqrt(np.dot(tran_errors, tran_errors) / frame_count) 50 | 51 | # rot_mean = np.mean(rot_errors) 52 | # tran_mean = np.mean(tran_errors) 53 | 54 | # rot_median = np.median(rot_errors) 55 | # tran_median = np.median(tran_errors) 56 | 57 | # rot_std = np.std(rot_errors) 58 | # tran_std = np.std(tran_errors) 59 | 60 | return rot_rmse, tran_rmse, align_mat 61 | 62 | 63 | def align_traj(poses_np_1, poses_np_2): 64 | 65 | traj_1 = poses_np_1[:,:3,3].squeeze().T 66 | traj_2 = poses_np_2[:,:3,3].squeeze().T 67 | 68 | return align(traj_1, traj_2) 69 | 70 | 71 | def align(model, data): 72 | """Align two trajectories using the method of Horn (closed-form). 73 | 74 | Input: 75 | model -- first trajectory (3xn) 76 | data -- second trajectory (3xn) 77 | 78 | Output: 79 | rot -- rotation matrix (3x3) 80 | trans -- translation vector (3x1) 81 | trans_error -- translational error per point (1xn) 82 | 83 | Borrowed from NICE-SLAM 84 | """ 85 | model_zerocentered = model - model.mean(1, keepdims=True) 86 | data_zerocentered = data - data.mean(1, keepdims=True) 87 | 88 | W = np.zeros((3, 3)) 89 | for column in range(model.shape[1]): 90 | W += np.outer(model_zerocentered[:, column], data_zerocentered[:, column]) 91 | U, d, Vh = np.linalg.linalg.svd(W.transpose()) 92 | S = np.matrix(np.identity(3)) 93 | if np.linalg.det(U) * np.linalg.det(Vh) < 0: 94 | S[2, 2] = -1 95 | rot = U * S * Vh 96 | trans = data.mean(1, keepdims=True) - rot * model.mean(1, keepdims=True) 97 | 98 | model_aligned = rot * model + trans 99 | 100 | alignment_error = model_aligned - data 101 | 102 | trans_error = np.sqrt(np.sum(np.multiply(alignment_error, alignment_error), 0)).A[ 103 | 0 104 | ] # as RMSE 105 | 106 | return rot, trans, trans_error 107 | 108 | 109 | def relative_error(poses_gt, poses_result): 110 | """calculate sequence error (kitti metric, relative drifting error) 111 | Args: 112 | poses_gt, kx4x4 np.array, ground truth poses 113 | poses_result, kx4x4 np.array, predicted poses 114 | Returns: 115 | err (list list): [first_frame, rotation error, translation error, length, speed] 116 | - first_frame: frist frame index 117 | - rotation error: rotation error per length 118 | - translation error: translation error per length 119 | - length: evaluation trajectory length 120 | - speed: car speed (#FIXME: 10FPS is assumed) 121 | """ 122 | assert poses_gt.shape[0] == poses_result.shape[0], "poses length should be identical" 123 | err = [] 124 | dist = trajectory_distances(poses_gt) 125 | step_size = 10 126 | 127 | lengths = [100, 200, 300, 400, 500, 600, 700, 800] # unit: m 128 | num_lengths = len(lengths) 129 | 130 | for first_frame in range(0, poses_gt.shape[0], step_size): 131 | for i in range(num_lengths): 132 | len_ = lengths[i] 133 | last_frame = last_frame_from_segment_length(dist, first_frame, len_) 134 | 135 | # Continue if sequence not long enough 136 | if last_frame == -1: 137 | continue 138 | 139 | # compute rotational and translational errors 140 | pose_delta_gt = np.linalg.inv(poses_gt[first_frame]) @ poses_gt[last_frame] 141 | pose_delta_result = ( 142 | np.linalg.inv(poses_result[first_frame]) @ poses_result[last_frame] 143 | ) 144 | 145 | pose_error = np.linalg.inv(pose_delta_result) @ pose_delta_gt 146 | 147 | r_err = rotation_error(pose_error) 148 | t_err = translation_error(pose_error) 149 | 150 | # compute speed 151 | num_frames = last_frame - first_frame + 1.0 152 | speed = len_ / (0.1 * num_frames) 153 | 154 | err.append([first_frame, r_err / len_, t_err / len_, len_, speed]) 155 | 156 | t_err = 0 157 | r_err = 0 158 | 159 | if len(err) == 0: # the case when the trajectory is not long enough 160 | return 0, 0 161 | 162 | for i in range(len(err)): 163 | r_err += err[i][1] 164 | t_err += err[i][2] 165 | 166 | r_err /= len(err) 167 | t_err /= len(err) 168 | drift_ate = t_err * 100.0 169 | drift_are = r_err / np.pi * 180.0 170 | 171 | return drift_ate, drift_are 172 | 173 | 174 | def trajectory_distances(poses_np): 175 | """Compute distance for each pose w.r.t frame-0 176 | Args: 177 | poses kx4x4 np.array 178 | Returns: 179 | dist (float list): distance of each pose w.r.t frame-0 180 | """ 181 | dist = [0] 182 | 183 | for i in range(poses_np.shape[0] - 1): 184 | rela_dist = np.linalg.norm(poses_np[i+1] - poses_np[i]) 185 | dist.append(dist[i] + rela_dist) 186 | 187 | return dist 188 | 189 | 190 | def rotation_error(pose_error): 191 | """Compute rotation error 192 | From a rotation matrix to the axis angle, use the angle as the result 193 | Args: 194 | pose_error (4x4 or 3x3 array): relative pose error 195 | Returns: 196 | rot_error (float): rotation error 197 | """ 198 | a = pose_error[0, 0] 199 | b = pose_error[1, 1] 200 | c = pose_error[2, 2] 201 | # 0.5 * (trace - 1) 202 | d = 0.5 * (a + b + c - 1.0) 203 | # make sure the rot_mat is valid (trace < 3, det = 1) 204 | rot_error = np.arccos(max(min(d, 1.0), -1.0)) # in rad 205 | return rot_error 206 | 207 | 208 | def translation_error(pose_error): 209 | """Compute translation error 210 | Args: 211 | pose_error (4x4 array): relative pose error 212 | Returns: 213 | trans_error (float): translation error 214 | """ 215 | dx = pose_error[0, 3] 216 | dy = pose_error[1, 3] 217 | dz = pose_error[2, 3] 218 | trans_error = np.sqrt(dx**2 + dy**2 + dz**2) 219 | return trans_error 220 | 221 | 222 | def last_frame_from_segment_length(dist, first_frame, length): 223 | """Find frame (index) that away from the first_frame with 224 | the required distance 225 | Args: 226 | dist (float list): distance of each pose w.r.t frame-0 227 | first_frame (int): start-frame index 228 | length (float): required distance 229 | Returns: 230 | i (int) / -1: end-frame index. if not found return -1 231 | """ 232 | for i in range(first_frame, len(dist), 1): 233 | if dist[i] > (dist[first_frame] + length): 234 | return i 235 | return -1 236 | 237 | 238 | def plot_trajectories( 239 | traj_plot_path: str, 240 | poses_est, 241 | poses_ref, 242 | poses_est_2=None, 243 | plot_3d: bool = True, 244 | grid_on: bool = True, 245 | plot_start_end_markers: bool = True, 246 | vis_now: bool = False, 247 | close_all: bool = True, 248 | ) -> None: 249 | # positions_est, positions_ref, positions_est_2 as list of numpy array 250 | 251 | from evo.core.trajectory import PosePath3D 252 | from evo.tools import plot as evoplot 253 | from evo.tools.settings import SETTINGS 254 | 255 | # without alignment 256 | 257 | if close_all: 258 | plt.close("all") 259 | 260 | poses = PosePath3D(poses_se3=poses_est) 261 | gt_poses = PosePath3D(poses_se3=poses_ref) 262 | if poses_est_2 is not None: 263 | poses_2 = PosePath3D(poses_se3=poses_est_2) 264 | 265 | if plot_3d: 266 | plot_mode = evoplot.PlotMode.xyz 267 | else: 268 | plot_mode = evoplot.PlotMode.xy 269 | 270 | fig = plt.figure(f"Trajectory results") 271 | ax = evoplot.prepare_axis(fig, plot_mode) 272 | evoplot.traj( 273 | ax=ax, 274 | plot_mode=plot_mode, 275 | traj=gt_poses, 276 | label="ground truth", 277 | style=SETTINGS.plot_reference_linestyle, 278 | color=SETTINGS.plot_reference_color, 279 | alpha=SETTINGS.plot_reference_alpha, 280 | plot_start_end_markers=False, 281 | ) 282 | evoplot.traj( 283 | ax=ax, 284 | plot_mode=plot_mode, 285 | traj=poses, 286 | label="PIN-SLAM", 287 | style=SETTINGS.plot_trajectory_linestyle, 288 | color="#4c72b0bf", 289 | alpha=SETTINGS.plot_trajectory_alpha, 290 | plot_start_end_markers=plot_start_end_markers, 291 | ) 292 | if poses_est_2 is not None: # better to change color (or the alpha) 293 | evoplot.traj( 294 | ax=ax, 295 | plot_mode=plot_mode, 296 | traj=poses_2, 297 | label="PIN-Odom", 298 | style=SETTINGS.plot_trajectory_linestyle, 299 | color="#FF940E", 300 | alpha=SETTINGS.plot_trajectory_alpha / 3.0, 301 | plot_start_end_markers=False, 302 | ) 303 | 304 | plt.tight_layout() 305 | ax.legend(frameon=grid_on) 306 | 307 | if traj_plot_path is not None: 308 | plt.savefig(traj_plot_path, dpi=600) 309 | 310 | if vis_now: 311 | plt.show() 312 | 313 | 314 | def read_kitti_format_calib(filename: str): 315 | """ 316 | read calibration file (with the kitti format) 317 | returns -> dict calibration matrices as 4*4 numpy arrays 318 | """ 319 | calib = {} 320 | calib_file = open(filename) 321 | 322 | for line in calib_file: 323 | key, content = line.strip().split(":") 324 | values = [float(v) for v in content.strip().split()] 325 | pose = np.zeros((4, 4)) 326 | 327 | pose[0, 0:4] = values[0:4] 328 | pose[1, 0:4] = values[4:8] 329 | pose[2, 0:4] = values[8:12] 330 | pose[3, 3] = 1.0 331 | 332 | calib[key] = pose 333 | 334 | calib_file.close() 335 | return calib 336 | 337 | 338 | def read_kitti_format_poses(filename: str) -> List[np.ndarray]: 339 | """ 340 | read pose file (with the kitti format) 341 | returns -> list, transformation before calibration transformation 342 | """ 343 | pose_file = open(filename) 344 | 345 | poses = [] 346 | 347 | for line in pose_file: 348 | values = [float(v) for v in line.strip().split()] 349 | 350 | pose = np.zeros((4, 4)) 351 | pose[0, 0:4] = values[0:4] 352 | pose[1, 0:4] = values[4:8] 353 | pose[2, 0:4] = values[8:12] 354 | pose[3, 3] = 1.0 355 | poses.append(pose) 356 | 357 | pose_file.close() 358 | return poses 359 | 360 | 361 | # copyright: Nacho et al. KISS-ICP 362 | def apply_kitti_format_calib(poses: List[np.ndarray], calib_T_cl) -> List[np.ndarray]: 363 | """Converts from Velodyne to Camera Frame (# T_camera<-lidar)""" 364 | poses_calib = [] 365 | for pose in poses: 366 | poses_calib.append(calib_T_cl @ pose @ np.linalg.inv(calib_T_cl)) 367 | return poses_calib 368 | 369 | 370 | # copyright: Nacho et al. KISS-ICP 371 | def write_kitti_format_poses(filename: str, poses: List[np.ndarray]): 372 | def _to_kitti_format(poses: np.ndarray) -> np.ndarray: 373 | return np.array([np.concatenate((pose[0], pose[1], pose[2])) for pose in poses]) 374 | 375 | np.savetxt(fname=f"{filename}_kitti.txt", X=_to_kitti_format(poses)) 376 | 377 | 378 | # for LiDAR dataset 379 | def get_metrics(seq_result: List[Dict]): 380 | odom_ate = (seq_result[0])["Average Translation Error [%]"] 381 | odom_are = (seq_result[0])["Average Rotational Error [deg/m]"] * 100.0 382 | slam_rmse = (seq_result[1])["Absoulte Trajectory Error [m]"] 383 | metrics_dict = { 384 | "Odometry ATE [%]": odom_ate, 385 | "Odometry ARE [deg/100m]": odom_are, 386 | "SLAM RMSE [m]": slam_rmse, 387 | } 388 | return metrics_dict 389 | 390 | 391 | def mean_metrics(seq_metrics: List[Dict]): 392 | sums = defaultdict(float) 393 | counts = defaultdict(int) 394 | 395 | for seq_metric in seq_metrics: 396 | for key, value in seq_metric.items(): 397 | sums[key] += value 398 | counts[key] += 1 399 | 400 | mean_metrics = {key: sums[key] / counts[key] for key in sums} 401 | return mean_metrics 402 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file loss.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def sdf_diff_loss(pred, label, weight, scale=1.0, l2_loss=True): 11 | count = pred.shape[0] 12 | diff = pred - label 13 | diff_m = diff / scale # so it's still in m unit 14 | if l2_loss: 15 | loss = (weight * (diff_m**2)).sum() / count # l2 loss 16 | else: 17 | loss = (weight * torch.abs(diff_m)).sum() / count # l1 loss 18 | return loss 19 | 20 | 21 | def sdf_l1_loss(pred, label): 22 | loss = torch.abs(pred - label) 23 | return loss.mean() 24 | 25 | 26 | def sdf_l2_loss(pred, label): 27 | loss = (pred - label) ** 2 28 | return loss.mean() 29 | 30 | 31 | def color_diff_loss(pred, label, weight, weighted=False, l2_loss=False): 32 | diff = pred - label 33 | if not weighted: 34 | weight = 1.0 35 | else: 36 | weight = weight.unsqueeze(1) 37 | if l2_loss: 38 | loss = (weight * (diff**2)).mean() 39 | else: 40 | loss = (weight * torch.abs(diff)).mean() 41 | return loss 42 | 43 | 44 | # used by our approach 45 | def sdf_bce_loss(pred, label, sigma, weight, weighted=False, bce_reduction="mean"): 46 | """ Calculate the binary cross entropy (BCE) loss for SDF supervision 47 | Args: 48 | pred (torch.tenosr): batch of predicted SDF values 49 | label (torch.tensor): batch of the target SDF values 50 | sigma (float): scale factor for the sigmoid function 51 | weight (torch.tenosr): batch of the per-sample weight 52 | weighted (bool, optional): apply the weight or not 53 | bce_reduction (string, optional): specifies the reduction to apply to the output 54 | Returns: 55 | loss (torch.tensor): BCE loss for the batch 56 | """ 57 | if weighted: 58 | loss_bce = nn.BCEWithLogitsLoss(reduction=bce_reduction, weight=weight) 59 | else: 60 | loss_bce = nn.BCEWithLogitsLoss(reduction=bce_reduction) 61 | label_op = torch.sigmoid(label / sigma) # occupancy prob 62 | loss = loss_bce(pred / sigma, label_op) 63 | return loss 64 | 65 | 66 | # the loss divised by Starry Zhong 67 | def sdf_zhong_loss(pred, label, trunc_dist=None, weight=None, weighted=False): 68 | if not weighted: 69 | weight = 1.0 70 | else: 71 | weight = weight 72 | loss = torch.zeros_like(label, dtype=label.dtype, device=label.device) 73 | middle_point = label / 2.0 74 | middle_point_abs = torch.abs(middle_point) 75 | shift_difference_abs = torch.abs(pred - middle_point) 76 | mask = shift_difference_abs > middle_point_abs 77 | loss[mask] = (shift_difference_abs - middle_point_abs)[ 78 | mask 79 | ] # not masked region simply has a loss of zero, masked region L1 loss 80 | if trunc_dist is not None: 81 | surface_mask = torch.abs(label) < trunc_dist 82 | loss[surface_mask] = torch.abs(pred - label)[surface_mask] 83 | loss *= weight 84 | return loss.mean() 85 | 86 | 87 | # not used 88 | def smooth_sdf_loss(pred, label, delta=20.0, weight=None, weighted=False): 89 | if not weighted: 90 | weight = 1.0 91 | else: 92 | weight = weight 93 | sign_factors = torch.ones_like(label, dtype=label.dtype, device=label.device) 94 | sign_factors[label < 0.0] = -1.0 95 | sign_loss = -sign_factors * delta * pred / 2.0 96 | no_loss = torch.zeros_like(pred, dtype=pred.dtype, device=pred.device) 97 | truncated_loss = sign_factors * delta * (pred / 2.0 - label) 98 | losses = torch.stack((sign_loss, no_loss, truncated_loss), dim=0) 99 | final_loss = torch.logsumexp(losses, dim=0) 100 | final_loss = ((2.0 / delta) * final_loss * weight).mean() 101 | return final_loss 102 | 103 | 104 | def ray_estimation_loss(x, y, d_meas): # for each ray 105 | # x as depth 106 | # y as sdf prediction 107 | # d_meas as measured depth 108 | 109 | # regard each sample as a ray 110 | mat_A = torch.vstack((x, torch.ones_like(x))).transpose(0, 1) 111 | vec_b = y.view(-1, 1) 112 | least_square_estimate = torch.linalg.lstsq(mat_A, vec_b).solution 113 | 114 | a = least_square_estimate[0] # -> -1 (added in ekional loss term) 115 | b = least_square_estimate[1] 116 | 117 | d_estimate = torch.clamp(-b / a, min=1.0, max=40.0) # -> d 118 | 119 | d_error = torch.abs(d_estimate - d_meas) 120 | 121 | return d_error 122 | 123 | 124 | def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch] 125 | # x as depth 126 | # y as occ.prob. prediction 127 | x = x.squeeze(1) 128 | sort_x, indices = torch.sort(x) 129 | sort_y = y[indices] 130 | 131 | w = torch.ones_like(y) 132 | for i in range(sort_x.shape[0]): 133 | w[i] = sort_y[i] 134 | for j in range(i): 135 | w[i] *= 1.0 - sort_y[j] 136 | 137 | d_render = (w * x).sum() 138 | 139 | d_error = torch.abs(d_render - d_meas) 140 | 141 | return d_error 142 | 143 | 144 | def batch_ray_rendering_loss(x, y, d_meas, neus_on=True): # for all rays in a batch 145 | # x as depth [ray number * sample number] 146 | # y as prediction (the alpha in volume rendering) [ray number * sample number] 147 | # d_meas as measured depth [ray number] 148 | # w as the raywise weight [ray number] 149 | # neus_on determine if using the loss defined in NEUS 150 | 151 | sort_x, indices = torch.sort(x, 1) # for each row 152 | sort_y = torch.gather(y, 1, indices) # for each row 153 | 154 | if neus_on: 155 | neus_alpha = (sort_y[:, 1:] - sort_y[:, 0:-1]) / (1.0 - sort_y[:, 0:-1] + 1e-10) 156 | # avoid dividing by 0 (nan) 157 | # print(neus_alpha) 158 | alpha = torch.clamp(neus_alpha, min=0.0, max=1.0) 159 | else: 160 | alpha = sort_y 161 | 162 | one_minus_alpha = torch.ones_like(alpha) - alpha + 1e-10 163 | 164 | cum_mat = torch.cumprod(one_minus_alpha, 1) 165 | 166 | weights = cum_mat / one_minus_alpha * alpha 167 | 168 | weights_x = weights * sort_x[:, 0 : alpha.shape[1]] 169 | 170 | d_render = torch.sum(weights_x, 1) 171 | 172 | d_error = torch.abs(d_render - d_meas) 173 | 174 | d_error_mean = torch.mean(d_error) 175 | 176 | return d_error_mean 177 | -------------------------------------------------------------------------------- /utils/mesher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file mesher.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | import math 7 | 8 | import matplotlib.cm as cm 9 | import numpy as np 10 | import open3d as o3d 11 | import skimage.measure 12 | import torch 13 | from tqdm import tqdm 14 | 15 | from model.decoder import Decoder 16 | from model.neural_points import NeuralPoints 17 | from utils.config import Config 18 | from utils.semantic_kitti_utils import sem_kitti_color_map 19 | from utils.tools import remove_gpu_cache 20 | 21 | class Mesher: 22 | def __init__( 23 | self, 24 | config: Config, 25 | neural_points: NeuralPoints, 26 | decoders: dict, 27 | ): 28 | 29 | self.config = config 30 | self.silence = config.silence 31 | self.neural_points = neural_points 32 | self.sdf_mlp = decoders["sdf"] 33 | self.sem_mlp = decoders["semantic"] 34 | self.color_mlp = decoders["color"] 35 | self.device = config.device 36 | self.cur_device = self.device 37 | self.dtype = config.dtype 38 | self.global_transform = np.eye(4) 39 | 40 | def query_points( 41 | self, 42 | coord, 43 | bs, 44 | query_sdf=True, 45 | query_sem=False, 46 | query_color=False, 47 | query_mask=True, 48 | query_locally=False, 49 | mask_min_nn_count: int = 4, 50 | out_torch: bool = False, 51 | ): 52 | """query the sdf value, semantic label and marching cubes mask for points 53 | Args: 54 | coord: Nx3 torch tensor, the coordinates of all N (axbxc) query points in the scaled 55 | kaolin coordinate system [-1,1] 56 | bs: batch size for the inference 57 | Returns: 58 | sdf_pred: Ndim numpy array or torch tensor, signed distance value (scaled) at each query point 59 | sem_pred: Ndim numpy array or torch tenosr, semantic label prediction at each query point 60 | mc_mask: Ndim bool numpy array or torch tensor, marching cubes mask at each query point 61 | """ 62 | # the coord torch tensor is already scaled in the [-1,1] coordinate system 63 | sample_count = coord.shape[0] 64 | iter_n = math.ceil(sample_count / bs) 65 | if query_sdf: 66 | if out_torch: 67 | sdf_pred = torch.zeros(sample_count) 68 | else: 69 | sdf_pred = np.zeros(sample_count) 70 | else: 71 | sdf_pred = None 72 | if query_sem: 73 | if out_torch: 74 | sem_pred = torch.zeros(sample_count) 75 | else: 76 | sem_pred = np.zeros(sample_count) 77 | else: 78 | sem_pred = None 79 | if query_color: 80 | if out_torch: 81 | color_pred = torch.zeros((sample_count, self.config.color_channel)) 82 | else: 83 | color_pred = np.zeros((sample_count, self.config.color_channel)) 84 | else: 85 | color_pred = None 86 | if query_mask: 87 | if out_torch: 88 | mc_mask = torch.zeros(sample_count) 89 | else: 90 | mc_mask = np.zeros(sample_count) 91 | else: 92 | mc_mask = None 93 | 94 | with torch.no_grad(): # eval step 95 | for n in tqdm(range(iter_n), disable=self.silence): 96 | head = n * bs 97 | tail = min((n + 1) * bs, sample_count) 98 | batch_coord = coord[head:tail, :] 99 | batch_size = batch_coord.shape[0] 100 | if self.cur_device == "cpu" and self.device == "cuda": 101 | batch_coord = batch_coord.cuda() 102 | ( 103 | batch_geo_feature, 104 | batch_color_feature, 105 | weight_knn, 106 | nn_count, 107 | _, 108 | ) = self.neural_points.query_feature( 109 | batch_coord, 110 | training_mode=False, 111 | query_locally=query_locally, # inference mode, query globally 112 | query_geo_feature=query_sdf or query_sem, 113 | query_color_feature=query_color, 114 | ) 115 | 116 | pred_mask = nn_count >= 1 # only query sdf here 117 | if query_sdf: 118 | if self.config.weighted_first: 119 | batch_sdf = torch.zeros(batch_size, device=self.device) 120 | else: 121 | batch_sdf = torch.zeros( 122 | batch_size, 123 | batch_geo_feature.shape[1], 124 | 1, 125 | device=self.device, 126 | ) 127 | # predict the sdf with the feature, only do for the unmasked part (not in the unknown freespace) 128 | batch_sdf[pred_mask] = self.sdf_mlp.sdf( 129 | batch_geo_feature[pred_mask] 130 | ) 131 | 132 | if not self.config.weighted_first: 133 | batch_sdf = torch.sum(batch_sdf * weight_knn, dim=1).squeeze(1) 134 | if out_torch: 135 | sdf_pred[head:tail] = batch_sdf.detach() 136 | else: 137 | sdf_pred[head:tail] = batch_sdf.detach().cpu().numpy() 138 | if query_sem: 139 | batch_sem_prob = self.sem_mlp.sem_label_prob(batch_geo_feature) 140 | if not self.config.weighted_first: 141 | batch_sem_prob = torch.sum(batch_sem_prob * weight_knn, dim=1) 142 | batch_sem = torch.argmax(batch_sem_prob, dim=1) 143 | if out_torch: 144 | sem_pred[head:tail] = batch_sem.detach() 145 | else: 146 | sem_pred[head:tail] = batch_sem.detach().cpu().numpy() 147 | if query_color: 148 | batch_color = self.color_mlp.regress_color(batch_color_feature) 149 | if not self.config.weighted_first: 150 | batch_color = torch.sum(batch_color * weight_knn, dim=1) # N, C 151 | if out_torch: 152 | color_pred[head:tail] = batch_color.detach() 153 | else: 154 | color_pred[head:tail] = ( 155 | batch_color.detach().cpu().numpy().astype(dtype=np.float64) 156 | ) 157 | if query_mask: 158 | # do marching cubes only when there are at least K near neural points 159 | mask_mc = nn_count >= mask_min_nn_count 160 | if out_torch: 161 | mc_mask[head:tail] = mask_mc.detach() 162 | else: 163 | mc_mask[head:tail] = mask_mc.detach().cpu().numpy() 164 | 165 | return sdf_pred, sem_pred, color_pred, mc_mask 166 | 167 | def get_query_from_bbx(self, bbx, voxel_size, pad_voxel=0, skip_top_voxel=0): 168 | """ 169 | get grid query points inside a given bounding box (bbx) 170 | Args: 171 | bbx: open3d bounding box, in world coordinate system, with unit m 172 | voxel_size: scalar, marching cubes voxel size with unit m 173 | Returns: 174 | coord: Nx3 torch tensor, the coordinates of all N (axbxc) query points in the scaled 175 | kaolin coordinate system [-1,1] 176 | voxel_num_xyz: 3dim numpy array, the number of voxels on each axis for the bbx 177 | voxel_origin: 3dim numpy array the coordinate of the bottom-left corner of the 3d grids 178 | for marching cubes, in world coordinate system with unit m 179 | """ 180 | # bbx and voxel_size are all in the world coordinate system 181 | min_bound = bbx.get_min_bound() 182 | max_bound = bbx.get_max_bound() 183 | len_xyz = max_bound - min_bound 184 | voxel_num_xyz = (np.ceil(len_xyz / voxel_size) + pad_voxel * 2).astype(np.int_) 185 | voxel_origin = min_bound - pad_voxel * voxel_size 186 | # pad an additional voxel underground to gurantee the reconstruction of ground 187 | voxel_origin[2] -= voxel_size 188 | voxel_num_xyz[2] += 1 189 | voxel_num_xyz[2] -= skip_top_voxel 190 | 191 | voxel_count_total = voxel_num_xyz[0] * voxel_num_xyz[1] * voxel_num_xyz[2] 192 | if voxel_count_total > 5e8: # this value is determined by your gpu memory 193 | print("too many query points, use smaller chunks") 194 | return None, None, None 195 | # self.cur_device = "cpu" # firstly save in cpu memory (which would be larger than gpu's) 196 | # print("too much query points, use cpu memory") 197 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device) 198 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device) 199 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device) 200 | 201 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ... 202 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 203 | # get the vector of all the grid point's 3D coordinates 204 | coord = ( 205 | torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float() 206 | ) 207 | # transform to world coordinate system 208 | coord *= voxel_size 209 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device) 210 | 211 | return coord, voxel_num_xyz, voxel_origin 212 | 213 | def get_query_from_hor_slice(self, bbx, slice_z, voxel_size): 214 | """ 215 | get grid query points inside a given bounding box (bbx) at slice height (slice_z) 216 | """ 217 | # bbx and voxel_size are all in the world coordinate system 218 | min_bound = bbx.get_min_bound() 219 | max_bound = bbx.get_max_bound() 220 | len_xyz = max_bound - min_bound 221 | voxel_num_xyz = (np.ceil(len_xyz / voxel_size)).astype(np.int_) 222 | voxel_num_xyz[2] = 1 223 | voxel_origin = min_bound 224 | voxel_origin[2] = slice_z 225 | 226 | query_count_total = voxel_num_xyz[0] * voxel_num_xyz[1] 227 | if query_count_total > 1e8: # avoid gpu memory issue, dirty fix 228 | self.cur_device = ( 229 | "cpu" # firstly save in cpu memory (which would be larger than gpu's) 230 | ) 231 | print("too much query points, use cpu memory") 232 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device) 233 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device) 234 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device) 235 | 236 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ... 237 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 238 | # get the vector of all the grid point's 3D coordinates 239 | coord = ( 240 | torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float() 241 | ) 242 | # transform to world coordinate system 243 | coord *= voxel_size 244 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device) 245 | 246 | return coord, voxel_num_xyz, voxel_origin 247 | 248 | def get_query_from_ver_slice(self, bbx, slice_x, voxel_size): 249 | """ 250 | get grid query points inside a given bounding box (bbx) at slice position (slice_x) 251 | """ 252 | # bbx and voxel_size are all in the world coordinate system 253 | min_bound = bbx.get_min_bound() 254 | max_bound = bbx.get_max_bound() 255 | len_xyz = max_bound - min_bound 256 | voxel_num_xyz = (np.ceil(len_xyz / voxel_size)).astype(np.int_) 257 | voxel_num_xyz[0] = 1 258 | voxel_origin = min_bound 259 | voxel_origin[0] = slice_x 260 | 261 | query_count_total = voxel_num_xyz[1] * voxel_num_xyz[2] 262 | if query_count_total > 1e8: # avoid gpu memory issue, dirty fix 263 | self.cur_device = ( 264 | "cpu" # firstly save in cpu memory (which would be larger than gpu's) 265 | ) 266 | print("too much query points, use cpu memory") 267 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device) 268 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device) 269 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device) 270 | 271 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ... 272 | x, y, z = torch.meshgrid(x, y, z, indexing="ij") 273 | # get the vector of all the grid point's 3D coordinates 274 | coord = ( 275 | torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float() 276 | ) 277 | # transform to world coordinate system 278 | coord *= voxel_size 279 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device) 280 | 281 | return coord, voxel_num_xyz, voxel_origin 282 | 283 | def generate_sdf_map(self, coord, sdf_pred, mc_mask): 284 | """ 285 | Generate the SDF map for saving 286 | """ 287 | device = o3d.core.Device("CPU:0") 288 | dtype = o3d.core.float32 289 | sdf_map_pc = o3d.t.geometry.PointCloud(device) 290 | 291 | coord_np = coord.detach().cpu().numpy() 292 | 293 | # the sdf (unit: m) would be saved in the intensity channel 294 | sdf_map_pc.point["positions"] = o3d.core.Tensor(coord_np, dtype, device) 295 | sdf_map_pc.point["intensities"] = o3d.core.Tensor( 296 | np.expand_dims(sdf_pred, axis=1), dtype, device 297 | ) # scaled sdf prediction 298 | if mc_mask is not None: 299 | # the marching cubes mask would be saved in the labels channel 300 | sdf_map_pc.point["labels"] = o3d.core.Tensor( 301 | np.expand_dims(mc_mask, axis=1), o3d.core.int32, device 302 | ) # mask 303 | 304 | # global transform (to world coordinate system) before output 305 | if not np.array_equal(self.global_transform, np.eye(4)): 306 | sdf_map_pc.transform(self.global_transform) 307 | 308 | return sdf_map_pc 309 | 310 | def generate_sdf_map_for_vis( 311 | self, coord, sdf_pred, mc_mask, min_sdf=-1.0, max_sdf=1.0, cmap="bwr" 312 | ): # 'jet','bwr','viridis' 313 | """ 314 | Generate the SDF map for visualization 315 | """ 316 | # do the masking or not 317 | if mc_mask is not None: 318 | coord = coord[mc_mask > 0] 319 | sdf_pred = sdf_pred[mc_mask > 0] 320 | 321 | coord_np = coord.detach().cpu().numpy().astype(np.float64) 322 | 323 | sdf_pred_show = np.clip((sdf_pred - min_sdf) / (max_sdf - min_sdf), 0.0, 1.0) 324 | 325 | color_map = cm.get_cmap(cmap) # or 'jet' 326 | colors = color_map(1.0 - sdf_pred_show)[:, :3].astype(np.float64) # change to blue (+) --> red (-) 327 | 328 | sdf_map_pc = o3d.geometry.PointCloud() 329 | sdf_map_pc.points = o3d.utility.Vector3dVector(coord_np) 330 | sdf_map_pc.colors = o3d.utility.Vector3dVector(colors) 331 | if not np.array_equal(self.global_transform, np.eye(4)): 332 | sdf_map_pc.transform(self.global_transform) 333 | 334 | return sdf_map_pc 335 | 336 | def assign_to_bbx(self, sdf_pred, sem_pred, color_pred, mc_mask, voxel_num_xyz): 337 | """assign the queried sdf, semantic label and marching cubes mask back to the 3D grids in the specified bounding box 338 | Args: 339 | sdf_pred: Ndim np.array/torch.tensor 340 | sem_pred: Ndim np.array/torch.tensor 341 | mc_mask: Ndim bool np.array/torch.tensor 342 | voxel_num_xyz: 3dim numpy array/torch.tensor, the number of voxels on each axis for the bbx 343 | Returns: 344 | sdf_pred: a*b*c np.array/torch.tensor, 3d grids of sign distance values 345 | sem_pred: a*b*c np.array/torch.tensor, 3d grids of semantic labels 346 | mc_mask: a*b*c np.array/torch.tensor, 3d grids of marching cube masks, marching cubes only on where 347 | the mask is true 348 | """ 349 | if sdf_pred is not None: 350 | sdf_pred = sdf_pred.reshape( 351 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2] 352 | ) 353 | 354 | if sem_pred is not None: 355 | sem_pred = sem_pred.reshape( 356 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2] 357 | ) 358 | 359 | if color_pred is not None: 360 | color_pred = color_pred.reshape( 361 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2] 362 | ) 363 | 364 | if mc_mask is not None: 365 | mc_mask = mc_mask.reshape( 366 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2] 367 | ) 368 | 369 | return sdf_pred, sem_pred, color_pred, mc_mask 370 | 371 | def mc_mesh(self, mc_sdf, mc_mask, voxel_size, mc_origin): 372 | """use the marching cubes algorithm to get mesh vertices and faces 373 | Args: 374 | mc_sdf: a*b*c np.array, 3d grids of sign distance values 375 | mc_mask: a*b*c np.array, 3d grids of marching cube masks, marching cubes only on where 376 | the mask is true 377 | voxel_size: scalar, marching cubes voxel size with unit m 378 | mc_origin: 3*1 np.array, the coordinate of the bottom-left corner of the 3d grids for 379 | marching cubes, in world coordinate system with unit m 380 | Returns: 381 | ([verts], [faces]), mesh vertices and triangle faces 382 | """ 383 | if not self.silence: 384 | print("Marching cubes ...") 385 | # the input are all already numpy arraies 386 | verts, faces = np.zeros((0, 3)), np.zeros((0, 3)) 387 | try: 388 | verts, faces, _, _ = skimage.measure.marching_cubes( 389 | mc_sdf, level=0.0, allow_degenerate=False, mask=mc_mask 390 | ) 391 | # Whether to allow degenerate (i.e. zero-area) triangles in the 392 | # end-result. Default True. If False, degenerate triangles are 393 | # removed, at the cost of making the algorithm slower. 394 | except: 395 | pass 396 | 397 | verts = mc_origin + verts * voxel_size 398 | 399 | return verts, faces 400 | 401 | def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True): 402 | """ 403 | Predict the semantic label of the vertices 404 | """ 405 | if len(verts) == 0: 406 | return mesh 407 | 408 | # print("predict semantic labels of the vertices") 409 | verts_torch = torch.tensor(verts, dtype=self.dtype, device=self.device) 410 | _, verts_sem, _, _ = self.query_points( 411 | verts_torch, self.config.infer_bs, False, True, False, False 412 | ) 413 | verts_sem_list = list(verts_sem) 414 | verts_sem_rgb = [sem_kitti_color_map[sem_label] for sem_label in verts_sem_list] 415 | verts_sem_rgb = np.asarray(verts_sem_rgb, dtype=np.float64) / 255.0 416 | mesh.vertex_colors = o3d.utility.Vector3dVector(verts_sem_rgb) 417 | 418 | # filter the freespace vertices 419 | if filter_free_space_vertices: 420 | non_freespace_idx = verts_sem <= 0 421 | mesh.remove_vertices_by_mask(non_freespace_idx) 422 | 423 | return mesh 424 | 425 | def estimate_vertices_color(self, mesh, verts): 426 | """ 427 | Predict the color of the vertices 428 | """ 429 | if len(verts) == 0: 430 | return mesh 431 | 432 | # print("predict color labels of the vertices") 433 | verts_torch = torch.tensor(verts, dtype=self.dtype, device=self.device) 434 | _, _, verts_color, _ = self.query_points( 435 | verts_torch, self.config.infer_bs, False, False, True, False 436 | ) 437 | 438 | if self.config.color_channel == 1: 439 | verts_color = np.repeat(verts_color * 2.0, 3, axis=1) 440 | 441 | mesh.vertex_colors = o3d.utility.Vector3dVector(verts_color) 442 | 443 | return mesh 444 | 445 | def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300): 446 | """ 447 | Cluster connected triangles and remove the small clusters 448 | """ 449 | triangle_clusters, cluster_n_triangles, _ = mesh.cluster_connected_triangles() 450 | triangle_clusters = np.asarray(triangle_clusters) 451 | cluster_n_triangles = np.asarray(cluster_n_triangles) 452 | # print("Remove the small clusters") 453 | triangles_to_remove = ( 454 | cluster_n_triangles[triangle_clusters] < filter_cluster_min_tri 455 | ) 456 | mesh.remove_triangles_by_mask(triangles_to_remove) 457 | 458 | return mesh 459 | 460 | def generate_bbx_sdf_hor_slice( 461 | self, bbx, slice_z, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0, mask_min_nn_count=5 462 | ): 463 | """ 464 | Generate the SDF slice at height (slice_z) 465 | """ 466 | # print("Generate the SDF slice at heright %.2f (m)" % (slice_z)) 467 | coord, _, _ = self.get_query_from_hor_slice(bbx, slice_z, voxel_size) 468 | sdf_pred, _, _, mc_mask = self.query_points( 469 | coord, 470 | self.config.infer_bs, 471 | True, 472 | False, 473 | False, 474 | self.config.mc_mask_on, 475 | query_locally=query_locally, 476 | mask_min_nn_count=mask_min_nn_count, 477 | ) 478 | sdf_map_pc = self.generate_sdf_map_for_vis( 479 | coord, sdf_pred, mc_mask, min_sdf, max_sdf 480 | ) 481 | 482 | return sdf_map_pc 483 | 484 | def generate_bbx_sdf_ver_slice( 485 | self, bbx, slice_x, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0, mask_min_nn_count=5 486 | ): 487 | """ 488 | Generate the SDF slice at x position (slice_x) 489 | """ 490 | # print("Generate the SDF slice at x position %.2f (m)" % (slice_x)) 491 | coord, _, _ = self.get_query_from_ver_slice(bbx, slice_x, voxel_size) 492 | sdf_pred, _, _, mc_mask = self.query_points( 493 | coord, 494 | self.config.infer_bs, 495 | True, 496 | False, 497 | False, 498 | self.config.mc_mask_on, 499 | query_locally=query_locally, 500 | mask_min_nn_count=mask_min_nn_count, 501 | ) 502 | sdf_map_pc = self.generate_sdf_map_for_vis( 503 | coord, sdf_pred, mc_mask, min_sdf, max_sdf 504 | ) 505 | 506 | return sdf_map_pc 507 | 508 | # reconstruct the mesh from a the map defined by a collection of bounding boxes 509 | def recon_aabb_collections_mesh( 510 | self, 511 | aabbs, 512 | voxel_size, 513 | mesh_path=None, 514 | query_locally=False, 515 | estimate_sem=False, 516 | estimate_color=False, 517 | mesh_normal=True, 518 | filter_isolated_mesh=False, 519 | filter_free_space_vertices=True, 520 | mesh_min_nn=10, 521 | use_torch_mc=False, 522 | ): 523 | """ 524 | Reconstruct the mesh from a collection of bounding boxes 525 | """ 526 | if not self.silence: 527 | print("# Chunk for meshing: ", len(aabbs)) 528 | 529 | mesh_merged = o3d.geometry.TriangleMesh() 530 | for bbx in tqdm(aabbs, disable=self.silence): 531 | cur_mesh = self.recon_aabb_mesh( 532 | bbx, 533 | voxel_size, 534 | None, 535 | query_locally, 536 | estimate_sem, 537 | estimate_color, 538 | mesh_normal, 539 | filter_isolated_mesh, 540 | filter_free_space_vertices, 541 | mesh_min_nn, 542 | use_torch_mc, 543 | ) 544 | mesh_merged += cur_mesh 545 | 546 | remove_gpu_cache() # deal with high GPU memory consumption when meshing (TODO) 547 | 548 | mesh_merged.remove_duplicated_vertices() 549 | 550 | if mesh_normal: 551 | mesh_merged.compute_vertex_normals() 552 | 553 | if mesh_path is not None: 554 | o3d.io.write_triangle_mesh(mesh_path, mesh_merged) 555 | if not self.silence: 556 | print("save the mesh to %s\n" % (mesh_path)) 557 | 558 | return mesh_merged 559 | 560 | def recon_aabb_mesh( 561 | self, 562 | bbx, 563 | voxel_size, 564 | mesh_path=None, 565 | query_locally=False, 566 | estimate_sem=False, 567 | estimate_color=False, 568 | mesh_normal=True, 569 | filter_isolated_mesh=False, 570 | filter_free_space_vertices=True, 571 | mesh_min_nn=10, 572 | use_torch_mc=False, 573 | ): 574 | """ 575 | Reconstruct the mesh from a given bounding box 576 | """ 577 | # reconstruct and save the (semantic) mesh from the feature octree the decoders within a 578 | # given bounding box. bbx and voxel_size all with unit m, in world coordinate system 579 | coord, voxel_num_xyz, voxel_origin = self.get_query_from_bbx( 580 | bbx, voxel_size, self.config.pad_voxel, self.config.skip_top_voxel 581 | ) 582 | if coord is None: # use chunks in this case 583 | return None 584 | 585 | sdf_pred, _, _, mc_mask = self.query_points( 586 | coord, 587 | self.config.infer_bs, 588 | True, 589 | False, 590 | False, 591 | self.config.mc_mask_on, 592 | query_locally, 593 | mesh_min_nn, 594 | out_torch=use_torch_mc, 595 | ) 596 | 597 | mc_sdf, _, _, mc_mask = self.assign_to_bbx( 598 | sdf_pred, None, None, mc_mask, voxel_num_xyz 599 | ) 600 | if use_torch_mc: 601 | # torch version 602 | verts, faces = self.mc_mesh_torch( 603 | mc_sdf, mc_mask, voxel_size, torch.tensor(voxel_origin).to(mc_sdf) 604 | ) # has some double faces problem 605 | mesh = o3d.t.geometry.TriangleMesh(device=o3d.core.Device("cuda:0")) 606 | mesh.vertex.positions = o3d.core.Tensor.from_dlpack( 607 | torch.utils.dlpack.to_dlpack(verts) 608 | ) 609 | mesh.triangle.indices = o3d.core.Tensor.from_dlpack( 610 | torch.utils.dlpack.to_dlpack(faces) 611 | ) 612 | mesh = mesh.to_legacy() 613 | mesh.remove_duplicated_vertices() 614 | mesh.compute_vertex_normals() 615 | else: 616 | # np cpu version 617 | verts, faces = self.mc_mesh( 618 | mc_sdf, mc_mask.astype(bool), voxel_size, voxel_origin 619 | ) # too slow ? (actually not, the slower part is the querying) 620 | # directly use open3d to get mesh 621 | mesh = o3d.geometry.TriangleMesh( 622 | o3d.utility.Vector3dVector(verts.astype(np.float64)), 623 | o3d.utility.Vector3iVector(faces), 624 | ) 625 | 626 | # if not self.silence: 627 | # print("Marching cubes done") 628 | 629 | if estimate_sem: 630 | mesh = self.estimate_vertices_sem(mesh, verts, filter_free_space_vertices) 631 | else: 632 | if estimate_color: 633 | mesh = self.estimate_vertices_color(mesh, verts) 634 | 635 | mesh.remove_duplicated_vertices() 636 | 637 | if mesh_normal: 638 | mesh.compute_vertex_normals() 639 | 640 | if filter_isolated_mesh: 641 | mesh = self.filter_isolated_vertices(mesh, self.config.min_cluster_vertices) 642 | 643 | # global transform (to world coordinate system) before output 644 | if not np.array_equal(self.global_transform, np.eye(4)): 645 | mesh.transform(self.global_transform) 646 | 647 | # write the mesh to ply file 648 | if mesh_path is not None: 649 | o3d.io.write_triangle_mesh(mesh_path, mesh) 650 | if not self.silence: 651 | print("save the mesh to %s\n" % (mesh_path)) 652 | 653 | return mesh 654 | -------------------------------------------------------------------------------- /utils/point_cloud2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2008 Willow Garage, Inc. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions are met: 5 | # 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # 13 | # * Neither the name of the Willow Garage, Inc. nor the names of its 14 | # contributors may be used to endorse or promote products derived from 15 | # this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 21 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | # POSSIBILITY OF SUCH DAMAGE. 28 | 29 | """ 30 | This file is based on https://github.com/ros2/common_interfaces/blob/4bac182a0a582b5e6b784d9fa9f0dabc1aca4d35/sensor_msgs_py/sensor_msgs_py/point_cloud2.py 31 | All rights reserved to the original authors: Tim Field and Florian Vahl. 32 | """ 33 | 34 | import sys 35 | from typing import Iterable, List, Optional, Tuple 36 | 37 | import numpy as np 38 | 39 | try: 40 | from rosbags.typesys.types import sensor_msgs__msg__PointCloud2 as PointCloud2 41 | from rosbags.typesys.types import sensor_msgs__msg__PointField as PointField 42 | except ImportError as e: 43 | raise ImportError('rosbags library not installed, run "pip install -U rosbags"') from e 44 | 45 | 46 | _DATATYPES = {} 47 | _DATATYPES[PointField.INT8] = np.dtype(np.int8) 48 | _DATATYPES[PointField.UINT8] = np.dtype(np.uint8) 49 | _DATATYPES[PointField.INT16] = np.dtype(np.int16) 50 | _DATATYPES[PointField.UINT16] = np.dtype(np.uint16) 51 | _DATATYPES[PointField.INT32] = np.dtype(np.int32) 52 | _DATATYPES[PointField.UINT32] = np.dtype(np.uint32) 53 | _DATATYPES[PointField.FLOAT32] = np.dtype(np.float32) 54 | _DATATYPES[PointField.FLOAT64] = np.dtype(np.float64) 55 | 56 | DUMMY_FIELD_PREFIX = "unnamed_field" 57 | 58 | 59 | def read_point_cloud(msg: PointCloud2) -> Tuple[np.ndarray, np.ndarray]: 60 | """ 61 | Extract poitns and timestamps from a PointCloud2 message. 62 | 63 | :return: Tuple of [points, timestamps] 64 | points: array of x, y z points, shape: (N, 3) 65 | timestamps: array of per-pixel timestamps, shape: (N,) 66 | """ 67 | field_names = ["x", "y", "z"] 68 | t_field = None 69 | for field in msg.fields: 70 | if field.name in ["t", "timestamp", "time", "ts"]: 71 | t_field = field.name 72 | field_names.append(t_field) 73 | break 74 | 75 | points_structured = read_points(msg, field_names=field_names) 76 | points = np.column_stack( 77 | [points_structured["x"], points_structured["y"], points_structured["z"]] 78 | ) 79 | 80 | # Remove nan if any 81 | points = points[~np.any(np.isnan(points), axis=1)] 82 | 83 | if t_field: 84 | timestamps = points_structured[t_field].astype(np.float64) 85 | min_timestamp = np.min(timestamps) 86 | max_timestamp = np.max(timestamps) 87 | timestamps = (timestamps - min_timestamp) / (max_timestamp - min_timestamp) 88 | else: 89 | timestamps = None 90 | return points.astype(np.float64), timestamps 91 | 92 | 93 | def read_points( 94 | cloud: PointCloud2, 95 | field_names: Optional[List[str]] = None, 96 | uvs: Optional[Iterable] = None, 97 | reshape_organized_cloud: bool = False, 98 | ) -> np.ndarray: 99 | """ 100 | Read points from a sensor_msgs.PointCloud2 message. 101 | :param cloud: The point cloud to read from sensor_msgs.PointCloud2. 102 | :param field_names: The names of fields to read. If None, read all fields. 103 | (Type: Iterable, Default: None) 104 | :param uvs: If specified, then only return the points at the given 105 | coordinates. (Type: Iterable, Default: None) 106 | :param reshape_organized_cloud: Returns the array as an 2D organized point cloud if set. 107 | :return: Structured NumPy array containing all points. 108 | """ 109 | # Cast bytes to numpy array 110 | points = np.ndarray( 111 | shape=(cloud.width * cloud.height,), 112 | dtype=dtype_from_fields(cloud.fields, point_step=cloud.point_step), 113 | buffer=cloud.data, 114 | ) 115 | 116 | # Keep only the requested fields 117 | if field_names is not None: 118 | assert all( 119 | field_name in points.dtype.names for field_name in field_names 120 | ), "Requests field is not in the fields of the PointCloud!" 121 | # Mask fields 122 | points = points[list(field_names)] 123 | 124 | # Swap array if byte order does not match 125 | if bool(sys.byteorder != "little") != bool(cloud.is_bigendian): 126 | points = points.byteswap(inplace=True) 127 | 128 | # Select points indexed by the uvs field 129 | if uvs is not None: 130 | # Don't convert to numpy array if it is already one 131 | if not isinstance(uvs, np.ndarray): 132 | uvs = np.fromiter(uvs, int) 133 | # Index requested points 134 | points = points[uvs] 135 | 136 | # Cast into 2d array if cloud is 'organized' 137 | if reshape_organized_cloud and cloud.height > 1: 138 | points = points.reshape(cloud.width, cloud.height) 139 | 140 | return points 141 | 142 | 143 | def dtype_from_fields(fields: Iterable[PointField], point_step: Optional[int] = None) -> np.dtype: 144 | """ 145 | Convert a Iterable of sensor_msgs.msg.PointField messages to a np.dtype. 146 | :param fields: The point cloud fields. 147 | (Type: iterable of sensor_msgs.msg.PointField) 148 | :param point_step: Point step size in bytes. Calculated from the given fields by default. 149 | (Type: optional of integer) 150 | :returns: NumPy datatype 151 | """ 152 | # Create a lists containing the names, offsets and datatypes of all fields 153 | field_names = [] 154 | field_offsets = [] 155 | field_datatypes = [] 156 | for i, field in enumerate(fields): 157 | # Datatype as numpy datatype 158 | datatype = _DATATYPES[field.datatype] 159 | # Name field 160 | if field.name == "": 161 | name = f"{DUMMY_FIELD_PREFIX}_{i}" 162 | else: 163 | name = field.name 164 | # Handle fields with count > 1 by creating subfields with a suffix consiting 165 | # of "_" followed by the subfield counter [0 -> (count - 1)] 166 | assert field.count > 0, "Can't process fields with count = 0." 167 | for a in range(field.count): 168 | # Add suffix if we have multiple subfields 169 | if field.count > 1: 170 | subfield_name = f"{name}_{a}" 171 | else: 172 | subfield_name = name 173 | assert subfield_name not in field_names, "Duplicate field names are not allowed!" 174 | field_names.append(subfield_name) 175 | # Create new offset that includes subfields 176 | field_offsets.append(field.offset + a * datatype.itemsize) 177 | field_datatypes.append(datatype.str) 178 | 179 | # Create dtype 180 | dtype_dict = {"names": field_names, "formats": field_datatypes, "offsets": field_offsets} 181 | if point_step is not None: 182 | dtype_dict["itemsize"] = point_step 183 | return np.dtype(dtype_dict) 184 | -------------------------------------------------------------------------------- /utils/semantic_kitti_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LabelDataConverter: 5 | """Convert .label binary data to instance id and rgb""" 6 | 7 | def __init__(self, labelscan): 8 | 9 | self.convertdata(labelscan) 10 | 11 | def convertdata(self, labelscan): 12 | 13 | self.semantic_id = [] 14 | self.rgb_id = [] 15 | self.instance_id = [] 16 | self.rgb_arr_id = [] 17 | 18 | for counting in range(len(labelscan)): 19 | 20 | sem_id = int(labelscan[counting]) & 0xFFFF # lower 16 bit 21 | rgb, rgb_arr = self.get_sem_rgb(sem_id) 22 | instance_id = int(labelscan[counting]) >> 16 # higher 16 bit 23 | # rgb = self.get_random_rgb(instance_id) 24 | 25 | # print("Sem label:", sem_id, "Ins label:", instance_id, "Color:", hex(rgb)) 26 | # print(hex(rgb)) 27 | # instance label is given in each semantic label 28 | 29 | self.semantic_id.append(sem_id) 30 | self.rgb_id.append(rgb) 31 | self.rgb_arr_id.append(rgb_arr) 32 | self.instance_id.append(instance_id) 33 | 34 | 35 | def get_random_rgb(n): 36 | n = ((n ^ n >> 15) * 2246822519) & 0xFFFFFFFF 37 | n = ((n ^ n >> 13) * 3266489917) & 0xFFFFFFFF 38 | n = (n ^ n >> 16) >> 8 39 | print(n) 40 | return tuple(n.to_bytes(3, "big")) 41 | 42 | 43 | sem_kitti_learning_map = { 44 | 0: 0, # "unlabeled" 45 | 1: 0, # "outlier" mapped to "unlabeled" --------------------------mapped 46 | 10: 1, # "car" 47 | 11: 2, # "bicycle" 48 | 13: 5, # "bus" mapped to "other-vehicle" --------------------------mapped 49 | 15: 3, # "motorcycle" 50 | 16: 5, # "on-rails" mapped to "other-vehicle" ---------------------mapped 51 | 18: 4, # "truck" 52 | 20: 5, # "other-vehicle" 53 | 30: 6, # "person" 54 | 31: 7, # "bicyclist" 55 | 32: 8, # "motorcyclist" 56 | 40: 9, # "road" 57 | 44: 10, # "parking" 58 | 48: 11, # "sidewalk" 59 | 49: 12, # "other-ground" 60 | 50: 13, # "building" 61 | 51: 14, # "fence" 62 | 52: 20, # "other-structure" mapped to "unlabeled" ------------------mapped 63 | 60: 9, # "lane-marking" to "road" ---------------------------------mapped 64 | 70: 15, # "vegetation" 65 | 71: 16, # "trunk" 66 | 72: 17, # "terrain" 67 | 80: 18, # "pole" 68 | 81: 19, # "traffic-sign" 69 | 99: 20, # "other-object" to "unlabeled" ----------------------------mapped 70 | 252: 1, # "moving-car" 71 | 253: 7, # "moving-bicyclist" 72 | 254: 6, # "moving-person" 73 | 255: 8, # "moving-motorcyclist" 74 | 256: 5, # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 75 | 257: 5, # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 76 | 258: 4, # "moving-truck" 77 | 259: 5, # "moving-other-vehicle" 78 | } 79 | 80 | 81 | def sem_map_function(value): 82 | return sem_kitti_learning_map.get(value, value) 83 | 84 | 85 | sem_kitti_labels = { 86 | 0: "unlabeled", 87 | 1: "car", 88 | 2: "bicycle", 89 | 3: "motorcycle", 90 | 4: "truck", 91 | 5: "other-vehicle", 92 | 6: "person", 93 | 7: "bicyclist", 94 | 8: "motorcyclist", 95 | 9: "road", 96 | 10: "parking", 97 | 11: "sidewalk", 98 | 12: "other-ground", 99 | 13: "building", 100 | 14: "fence", 101 | 15: "vegetation", 102 | 16: "trunk", 103 | 17: "terrain", 104 | 18: "pole", 105 | 19: "traffic-sign", 106 | 20: "others", 107 | } 108 | 109 | sem_kitti_color_map = { # rgb 110 | 0: [255, 255, 255], 111 | 1: [100, 150, 245], 112 | 2: [100, 230, 245], 113 | 3: [30, 60, 150], 114 | 4: [80, 30, 180], 115 | 5: [0, 0, 255], 116 | 6: [255, 30, 30], 117 | 7: [255, 40, 200], 118 | 8: [150, 30, 90], 119 | 9: [255, 0, 255], 120 | 10: [255, 150, 255], 121 | 11: [75, 0, 75], 122 | 12: [175, 0, 75], 123 | 13: [255, 200, 0], 124 | 14: [255, 120, 50], 125 | 15: [0, 175, 0], 126 | 16: [135, 60, 0], 127 | 17: [150, 240, 80], 128 | 18: [255, 240, 150], 129 | 19: [255, 0, 0], 130 | 20: [30, 30, 30], 131 | } 132 | -------------------------------------------------------------------------------- /utils/so3_math.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file so3_math.py 3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn] 4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved 5 | 6 | import torch 7 | import numpy as np 8 | 9 | def vec2skew(v: torch.Tensor): 10 | """创建一个3x3的斜对称矩阵,对应于向量v的叉积操作""" 11 | zero = torch.zeros_like(v[0]) 12 | return torch.tensor([ 13 | [zero, -v[2], v[1]], 14 | [v[2], zero, -v[0]], 15 | [-v[1], v[0], zero] 16 | ], device=v.device, dtype=v.dtype) 17 | 18 | def vectors_to_skew_symmetric(vectors: torch.Tensor): 19 | """ 20 | Convert a batch of vectors to a batch of skew-symmetric matrices. 21 | 22 | Parameters: 23 | vectors : torch.Tensor 24 | Input tensor containing vectors. Shape [m, 3] 25 | 26 | Returns: 27 | skew_matrices : torch.Tensor 28 | Output tensor containing skew-symmetric matrices. Shape [m, 3, 3] 29 | """ 30 | skew_matrices = torch.zeros((vectors.shape[0], 3, 3), dtype=vectors.dtype, device=vectors.device) 31 | skew_matrices[:, 0, 1] = -vectors[:, 2] 32 | skew_matrices[:, 0, 2] = vectors[:, 1] 33 | skew_matrices[:, 1, 0] = vectors[:, 2] 34 | skew_matrices[:, 1, 2] = -vectors[:, 0] 35 | skew_matrices[:, 2, 0] = -vectors[:, 1] 36 | skew_matrices[:, 2, 1] = vectors[:, 0] 37 | 38 | return skew_matrices 39 | 40 | 41 | def so3Exp(so3: torch.Tensor): 42 | """将 so3 向量转换为 SO3 旋转矩阵。 43 | 44 | 参数: 45 | so3 (torch.Tensor): 形状为(3,)的 so3 向量。 46 | 47 | 返回: 48 | torch.Tensor: 形状为(3, 3)的 SO3 旋转矩阵。 49 | """ 50 | so3_norm = torch.norm(so3) 51 | if so3_norm <= 1e-7: 52 | return torch.eye(3, device=so3.device, dtype=so3.dtype) 53 | 54 | so3_skew_sym = vec2skew(so3) 55 | I = torch.eye(3, device=so3.device, dtype=so3.dtype) 56 | 57 | SO3 = I + (so3_skew_sym / so3_norm) * torch.sin(so3_norm) + \ 58 | (so3_skew_sym @ so3_skew_sym / (so3_norm * so3_norm)) * (1 - torch.cos(so3_norm)) 59 | return SO3 60 | 61 | 62 | def SO3Log(SO3: torch.Tensor): 63 | '''李群转换为李代数 64 | 65 | 参数: 66 | SO3 (torch.Tensor): 形状为(3, 3)的 SO3 旋转矩阵。 67 | 68 | 返回: 69 | torch.Tensor: 形状为(3,)的 so3 向量。 70 | ''' 71 | # 计算旋转角度 theta 72 | trace = SO3.trace() 73 | theta = torch.acos((trace - 1) / 2) if trace <= 3 - 1e-6 else 0.0 74 | 75 | # 计算so3向量 76 | so3 = torch.tensor([ 77 | SO3[2, 1] - SO3[1, 2], 78 | SO3[0, 2] - SO3[2, 0], 79 | SO3[1, 0] - SO3[0, 1] 80 | ], device=SO3.device) 81 | 82 | # 调整so3向量的尺度 83 | if abs(theta) < 0.001: 84 | so3 = 0.5 * so3 85 | else: 86 | so3 = 0.5 * theta / torch.sin(theta) * so3 87 | 88 | return so3 89 | 90 | 91 | def A_T(v: torch.Tensor): 92 | """根据给定的三维向量v,计算相应的旋转矩阵""" 93 | squared_norm = torch.dot(v, v) # 计算向量的模的平方 94 | norm = torch.sqrt(squared_norm) # 向量的模 95 | identity = torch.eye(3, device=v.device, dtype=v.dtype) # 单位矩阵 96 | 97 | if norm < 1e-11: 98 | return identity 99 | else: 100 | S = vec2skew(v) # 计算斜对称矩阵 101 | term1 = (1 - torch.cos(norm)) / squared_norm 102 | term2 = (1 - torch.sin(norm) / norm) / squared_norm 103 | return identity + term1 * S + term2 * torch.matmul(S, S) -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file visualizer.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | # Adapted from Nacho's awesome lidar visualizer (https://github.com/PRBonn/lidar-visualizer) 7 | # This is deprecated, now we use the GUI in gui/slam_gui.py 8 | 9 | import os 10 | from functools import partial 11 | from typing import Callable, List 12 | 13 | import numpy as np 14 | import open3d as o3d 15 | 16 | from utils.config import Config 17 | 18 | YELLOW = np.array([1, 0.706, 0]) 19 | RED = np.array([255, 0, 0]) / 255.0 20 | PURPLE = np.array([238, 130, 238]) / 255.0 21 | BLACK = np.array([0, 0, 0]) / 255.0 22 | GOLDEN = np.array([1.0, 0.843, 0.0]) 23 | GREEN = np.array([0, 128, 0]) / 255.0 24 | BLUE = np.array([0, 0, 128]) / 255.0 25 | LIGHTBLUE = np.array([0.00, 0.65, 0.93]) 26 | 27 | 28 | class MapVisualizer: 29 | # Public Interaface ---------------------------------------------------------------------------- 30 | def __init__(self, config: Config = None): 31 | 32 | # Initialize GUI controls 33 | self.block_vis = True 34 | self.play_crun = True 35 | self.reset_bounding_box = True 36 | self.config = config 37 | 38 | self.cur_frame_id: int = 0 39 | 40 | # Create data 41 | self.scan = o3d.geometry.PointCloud() 42 | self.frame_axis = o3d.geometry.TriangleMesh() 43 | self.sensor_cad = o3d.geometry.TriangleMesh() 44 | self.mesh = o3d.geometry.TriangleMesh() 45 | self.sdf = o3d.geometry.PointCloud() 46 | self.neural_points = o3d.geometry.PointCloud() 47 | self.data_pool = o3d.geometry.PointCloud() 48 | 49 | self.odom_traj_pcd = o3d.geometry.PointCloud() 50 | self.pgo_traj_pcd = o3d.geometry.PointCloud() 51 | self.gt_traj_pcd = o3d.geometry.PointCloud() 52 | 53 | self.odom_traj = o3d.geometry.LineSet() 54 | self.pgo_traj = o3d.geometry.LineSet() 55 | self.gt_traj = o3d.geometry.LineSet() 56 | 57 | self.pgo_edges = o3d.geometry.LineSet() 58 | 59 | self.log_path = "./" 60 | self.sdf_slice_height = 0.0 61 | self.mc_res_m = 0.1 62 | self.mesh_min_nn = 10 63 | self.keep_local_mesh = True 64 | 65 | self.frame_axis_len = 0.5 66 | 67 | if config is not None: 68 | self.log_path = os.path.join(config.run_path, "log") 69 | self.frame_axis_len = config.vis_frame_axis_len 70 | self.sdf_slice_height = config.sdf_slice_height 71 | if self.config.sensor_cad_path is not None: 72 | self.sensor_cad = o3d.io.read_triangle_mesh(config.sensor_cad_path) 73 | self.sensor_cad.compute_vertex_normals() 74 | self.mc_res_m = config.mc_res_m 75 | self.mesh_min_nn = config.mesh_min_nn 76 | self.keep_local_mesh = config.keep_local_mesh 77 | 78 | self.before_pgo = True 79 | self.last_pose = np.eye(4) 80 | 81 | # Initialize visualizer 82 | self.vis = o3d.visualization.VisualizerWithKeyCallback() 83 | self._register_key_callbacks() 84 | self._initialize_visualizer() 85 | 86 | # Visualization options 87 | self.render_mesh: bool = True 88 | self.render_pointcloud: bool = True 89 | self.render_frame_axis: bool = True 90 | self.render_trajectory: bool = True 91 | self.render_gt_trajectory: bool = False 92 | self.render_odom_trajectory: bool = ( 93 | True # when pgo is on, visualize the odom or not 94 | ) 95 | self.render_neural_points: bool = False 96 | self.render_data_pool: bool = False 97 | self.render_sdf: bool = False 98 | self.render_pgo: bool = self.render_trajectory 99 | 100 | self.sdf_slice_height_step: float = 0.1 101 | 102 | self.vis_pc_color: bool = True 103 | self.pc_uniform_color: bool = False 104 | 105 | self.vis_only_cur_samples: bool = False 106 | 107 | self.mc_res_change_interval_m: float = 0.2 * self.mc_res_m 108 | 109 | self.vis_global: bool = False 110 | 111 | self.ego_view: bool = False 112 | self.ego_change_flag: bool = False 113 | 114 | self.debug_mode: int = 0 115 | 116 | self.neural_points_vis_mode: int = 0 117 | 118 | self.global_viewpoint: bool = False 119 | self.view_control = self.vis.get_view_control() 120 | self.camera_params = self.view_control.convert_to_pinhole_camera_parameters() 121 | 122 | def update_view(self): 123 | self.vis.poll_events() 124 | self.vis.update_renderer() 125 | 126 | def pause_view(self): 127 | while self.block_vis: 128 | self.update_view() 129 | if self.play_crun: 130 | break 131 | 132 | def update( 133 | self, 134 | scan=None, 135 | pose=None, 136 | sdf=None, 137 | mesh=None, 138 | neural_points=None, 139 | data_pool=None, 140 | pause_now=False, 141 | ): 142 | self._update_geometries(scan, pose, sdf, mesh, neural_points, data_pool) 143 | self.update_view() 144 | self.pause_view() 145 | if pause_now: 146 | self.stop() 147 | 148 | def update_traj( 149 | self, 150 | cur_pose=None, 151 | odom_poses=None, 152 | gt_poses=None, 153 | pgo_poses=None, 154 | pgo_edges=None, 155 | ): 156 | self._update_traj(cur_pose, odom_poses, gt_poses, pgo_poses, pgo_edges) 157 | self.update_view() 158 | self.pause_view() 159 | 160 | def update_pointcloud(self, scan): 161 | self._update_pointcloud(scan) 162 | self.update_view() 163 | self.pause_view() 164 | 165 | def update_mesh(self, mesh): 166 | self._update_mesh(mesh) 167 | self.update_view() 168 | self.pause_view() 169 | 170 | def destroy_window(self): 171 | self.vis.destroy_window() 172 | 173 | def stop(self): 174 | self.play_crun = not self.play_crun 175 | while self.block_vis: 176 | self.vis.poll_events() 177 | self.vis.update_renderer() 178 | if self.play_crun: 179 | break 180 | 181 | def _initialize_visualizer(self): 182 | w_name = "📍 PIN-SLAM Visualizer" 183 | self.vis.create_window( 184 | window_name=w_name, width=2560, height=1600 185 | ) # 1920, 1080 186 | self.vis.add_geometry(self.scan) 187 | self.vis.add_geometry(self.sdf) 188 | self.vis.add_geometry(self.frame_axis) 189 | self.vis.add_geometry(self.mesh) 190 | self.vis.add_geometry(self.neural_points) 191 | self.vis.add_geometry(self.data_pool) 192 | self.vis.add_geometry(self.odom_traj_pcd) 193 | self.vis.add_geometry(self.gt_traj_pcd) 194 | self.vis.add_geometry(self.pgo_traj_pcd) 195 | self.vis.add_geometry(self.pgo_edges) 196 | 197 | self.vis.get_render_option().line_width = 500 198 | self.vis.get_render_option().light_on = True 199 | self.vis.get_render_option().mesh_shade_option = ( 200 | o3d.visualization.MeshShadeOption.Color 201 | ) 202 | 203 | if self.config is not None: 204 | self.vis.get_render_option().point_size = self.config.vis_point_size 205 | if self.config.mesh_vis_normal: 206 | self.vis.get_render_option().mesh_color_option = ( 207 | o3d.visualization.MeshColorOption.Normal 208 | ) 209 | 210 | print( 211 | f"{w_name} initialized. Press:\n" 212 | "\t[SPACE] to pause/resume\n" 213 | "\t[ESC/Q] to exit\n" 214 | "\t [G] to toggle on/off the global/local map visualization\n" 215 | "\t [E] to toggle on/off the ego/map viewpoint\n" 216 | "\t [F] to toggle on/off the current point cloud\n" 217 | "\t [M] to toggle on/off the mesh\n" 218 | "\t [T] to toggle on/off PIN SLAM trajectory\n" 219 | "\t [Y] to toggle on/off the reference trajectory\n" 220 | "\t [U] to toggle on/off PIN odometry trajectory\n" 221 | "\t [A] to toggle on/off the current frame axis\n" 222 | "\t [P] to toggle on/off the neural points map\n" 223 | "\t [D] to toggle on/off the data pool\n" 224 | "\t [I] to toggle on/off the sdf map slice\n" 225 | "\t [R] to center the view point\n" 226 | "\t [Z] to save the currently visualized entities in the log folder\n" 227 | ) 228 | 229 | def _register_key_callback(self, keys: List, callback: Callable): 230 | for key in keys: 231 | self.vis.register_key_callback(ord(str(key)), partial(callback)) 232 | 233 | def _register_key_callbacks(self): 234 | self._register_key_callback(["Ā", "Q"], self._quit) 235 | self._register_key_callback([" "], self._start_stop) 236 | self._register_key_callback(["R"], self._center_viewpoint) 237 | self._register_key_callback(["E"], self._toggle_ego) 238 | self._register_key_callback(["F"], self._toggle_pointcloud) 239 | self._register_key_callback(["A"], self._toggle_frame_axis) 240 | self._register_key_callback(["I"], self._toggle_sdf) 241 | self._register_key_callback(["M"], self._toggle_mesh) 242 | self._register_key_callback(["P"], self._toggle_neural_points) 243 | self._register_key_callback(["D"], self._toggle_data_pool) 244 | self._register_key_callback(["T"], self._toggle_trajectory) 245 | self._register_key_callback(["Y"], self._toggle_gt_trajectory) 246 | self._register_key_callback(["U"], self._toggle_odom_trajectory) 247 | self._register_key_callback(["G"], self._toggle_global) 248 | self._register_key_callback(["Z"], self._save_cur_vis) 249 | self._register_key_callback([";"], self._toggle_loop_debug) 250 | self._register_key_callback( 251 | ["/"], self._toggle_neural_point_vis_mode 252 | ) # vis neural point color using feature, ts or certainty 253 | self._register_key_callback(["'"], self._toggle_vis_cur_sample) 254 | self._register_key_callback(["]"], self._toggle_increase_mesh_res) 255 | self._register_key_callback(["["], self._toggle_decrease_mesh_res) 256 | self._register_key_callback(["."], self._toggle_increase_mesh_nn) # '>' 257 | self._register_key_callback([","], self._toggle_decrease_mesh_nn) # '<' 258 | self._register_key_callback(["5"], self._toggle_point_color) 259 | self._register_key_callback(["6"], self._toggle_uniform_color) 260 | self._register_key_callback(["7"], self._switch_background) 261 | # self.vis.register_key_callback(262, partial(self._toggle_)) # right arrow # for future 262 | # self.vis.register_key_callback(263, partial(self._toggle_)) # left arrow 263 | self.vis.register_key_callback( 264 | 265, partial(self._toggle_increase_slice_height) 265 | ) # up arrow 266 | self.vis.register_key_callback( 267 | 264, partial(self._toggle_decrease_slice_height) 268 | ) # down arrow 269 | # leave C and V as the view copying, pasting function 270 | # use alt + prt sc for the window screenshot 271 | 272 | def _switch_background(self, vis): 273 | cur_background_color = vis.get_render_option().background_color 274 | vis.get_render_option().background_color = np.ones(3) - cur_background_color 275 | 276 | def _center_viewpoint(self, vis): 277 | self.reset_bounding_box = not self.reset_bounding_box 278 | vis.reset_view_point(True) 279 | 280 | def _toggle_point_color( 281 | self, vis 282 | ): # actually used to show the source point cloud weight for registration 283 | self.vis_pc_color = not self.vis_pc_color 284 | 285 | def _quit(self, vis): 286 | print("Destroying Visualizer") 287 | vis.destroy_window() 288 | os._exit(0) 289 | 290 | def _save_cur_vis(self, vis): 291 | if self.data_pool.has_points(): 292 | data_pool_pc_name = str(self.cur_frame_id) + "_training_sdf_pool.ply" 293 | data_pool_pc_path = os.path.join(self.log_path, data_pool_pc_name) 294 | o3d.io.write_point_cloud(data_pool_pc_path, self.data_pool) 295 | print("Output current training data pool to: ", data_pool_pc_path) 296 | if self.scan.has_points(): 297 | if self.vis_pc_color: 298 | scan_pc_name = str(self.cur_frame_id) + "_scan_map.ply" 299 | else: 300 | scan_pc_name = str(self.cur_frame_id) + "_scan_reg.ply" 301 | scan_pc_path = os.path.join(self.log_path, scan_pc_name) 302 | o3d.io.write_point_cloud(scan_pc_path, self.scan) 303 | print("Output current scan to: ", scan_pc_path) 304 | if self.neural_points.has_points(): 305 | neural_point_name = str(self.cur_frame_id) + "_neural_points.ply" 306 | neural_point_path = os.path.join(self.log_path, neural_point_name) 307 | o3d.io.write_point_cloud(neural_point_path, self.neural_points) 308 | print("Output current neural points to: ", neural_point_path) 309 | if self.sdf.has_points(): 310 | sdf_slice_name = str(self.cur_frame_id) + "_sdf_slice.ply" 311 | sdf_slice_path = os.path.join(self.log_path, sdf_slice_name) 312 | o3d.io.write_point_cloud(sdf_slice_path, self.sdf) 313 | print("Output current SDF slice to: ", sdf_slice_path) 314 | if self.mesh.has_triangles(): 315 | mesh_name = str(self.cur_frame_id) + "_mesh_vis.ply" 316 | mesh_path = os.path.join(self.log_path, mesh_name) 317 | o3d.io.write_triangle_mesh(mesh_path, self.mesh) 318 | print("Output current mesh to: ", mesh_path) 319 | if self.frame_axis.has_triangles(): 320 | ego_name = str(self.cur_frame_id) + "_sensor_vis.ply" 321 | ego_path = os.path.join(self.log_path, ego_name) 322 | o3d.io.write_triangle_mesh(ego_path, self.frame_axis) 323 | print("Output current sensor model to: ", ego_path) 324 | 325 | def _next_frame(self, vis): # FIXME 326 | self.block_vis = not self.block_vis 327 | 328 | def _start_stop(self, vis): 329 | self.play_crun = not self.play_crun 330 | 331 | def _toggle_pointcloud(self, vis): 332 | self.render_pointcloud = not self.render_pointcloud 333 | 334 | def _toggle_frame_axis(self, vis): 335 | self.render_frame_axis = not self.render_frame_axis 336 | 337 | def _toggle_trajectory(self, vis): 338 | self.render_trajectory = not self.render_trajectory 339 | 340 | def _toggle_gt_trajectory(self, vis): 341 | self.render_gt_trajectory = not self.render_gt_trajectory 342 | 343 | def _toggle_odom_trajectory(self, vis): 344 | self.render_odom_trajectory = not self.render_odom_trajectory 345 | 346 | def _toggle_pgo(self, vis): 347 | self.render_pgo = not self.render_pgo 348 | 349 | def _toggle_sdf(self, vis): 350 | self.render_sdf = not self.render_sdf 351 | 352 | def _toggle_mesh(self, vis): 353 | self.render_mesh = not self.render_mesh 354 | print("Show mesh: ", self.render_mesh) 355 | 356 | def _toggle_neural_points(self, vis): 357 | self.render_neural_points = not self.render_neural_points 358 | 359 | def _toggle_data_pool(self, vis): 360 | self.render_data_pool = not self.render_data_pool 361 | 362 | def _toggle_global(self, vis): 363 | self.vis_global = not self.vis_global 364 | 365 | def _toggle_ego(self, vis): 366 | self.ego_view = not self.ego_view 367 | self.ego_change_flag = True # ego->global or global->ego 368 | self.reset_bounding_box = not self.reset_bounding_box 369 | vis.reset_view_point(True) 370 | 371 | def _toggle_uniform_color(self, vis): 372 | self.pc_uniform_color = not self.pc_uniform_color 373 | 374 | def _toggle_loop_debug(self, vis): 375 | self.debug_mode = ( 376 | self.debug_mode + 1 377 | ) % 3 # 0,1,2 # switch between different debug mode 378 | print("Switch to debug mode:", self.debug_mode) 379 | 380 | def _toggle_vis_cur_sample(self, vis): 381 | self.vis_only_cur_samples = not self.vis_only_cur_samples 382 | 383 | def _toggle_neural_point_vis_mode(self, vis): 384 | self.neural_points_vis_mode = ( 385 | self.neural_points_vis_mode + 1 386 | ) % 5 # 0,1,2,3,4 # switch between different vis mode 387 | print("Switch to neural point visualization mode:", self.neural_points_vis_mode) 388 | 389 | def _toggle_increase_mesh_res(self, vis): 390 | self.mc_res_m += self.mc_res_change_interval_m 391 | print("Current marching cubes voxel size [m]:", f"{self.mc_res_m:.2f}") 392 | 393 | def _toggle_decrease_mesh_res(self, vis): 394 | self.mc_res_m = max( 395 | self.mc_res_change_interval_m, self.mc_res_m - self.mc_res_change_interval_m 396 | ) 397 | print("Current marching cubes voxel size [m]:", f"{self.mc_res_m:.2f}") 398 | 399 | def _toggle_increase_slice_height(self, vis): 400 | self.sdf_slice_height += self.sdf_slice_height_step 401 | print("Current SDF slice height [m]:", f"{self.sdf_slice_height:.2f}") 402 | 403 | def _toggle_decrease_slice_height(self, vis): 404 | self.sdf_slice_height -= self.sdf_slice_height_step 405 | print("Current SDF slice height [m]:", f"{self.sdf_slice_height:.2f}") 406 | 407 | def _toggle_increase_mesh_nn(self, vis): 408 | self.mesh_min_nn += 1 409 | print("Current marching cubes mask nn count:", self.mesh_min_nn) 410 | 411 | def _toggle_decrease_mesh_nn(self, vis): 412 | self.mesh_min_nn = max(5, self.mesh_min_nn - 1) 413 | print("Current marching cubes mask nn count:", self.mesh_min_nn) 414 | 415 | def _toggle_help(self, vis): 416 | print( 417 | f"Instructions. Press:\n" 418 | "\t[SPACE] to pause/resume\n" 419 | "\t[ESC/Q] to exit\n" 420 | "\t [G] to toggle on/off the global/local map visualization\n" 421 | "\t [E] to toggle on/off the ego/map viewpoint\n" 422 | "\t [F] to toggle on/off the current point cloud\n" 423 | "\t [M] to toggle on/off the mesh\n" 424 | "\t [T] to toggle on/off PIN SLAM trajectory\n" 425 | "\t [Y] to toggle on/off the reference trajectory\n" 426 | "\t [U] to toggle on/off PIN odometry trajectory\n" 427 | "\t [A] to toggle on/off the current frame axis\n" 428 | "\t [P] to toggle on/off the neural points map\n" 429 | "\t [D] to toggle on/off the data pool\n" 430 | "\t [I] to toggle on/off the sdf map slice\n" 431 | "\t [R] to center the view point\n" 432 | "\t [Z] to save the currently visualized entities in the log folder\n" 433 | ) 434 | self.play_crun = not self.play_crun 435 | return False 436 | 437 | def _update_mesh(self, mesh): 438 | if self.render_mesh: 439 | if mesh is not None: 440 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 441 | self.mesh = mesh 442 | self.vis.add_geometry(self.mesh, self.reset_bounding_box) 443 | else: 444 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 445 | 446 | def _update_pointcloud(self, scan): 447 | if scan is not None: 448 | self.vis.remove_geometry(self.scan, self.reset_bounding_box) 449 | self.scan = scan 450 | self.vis.add_geometry(self.scan, self.reset_bounding_box) 451 | 452 | if self.reset_bounding_box: 453 | self.vis.reset_view_point(True) 454 | self.reset_bounding_box = False 455 | 456 | def _update_geometries( 457 | self, 458 | scan=None, 459 | pose=None, 460 | sdf=None, 461 | mesh=None, 462 | neural_points=None, 463 | data_pool=None, 464 | ): 465 | 466 | # Scan (toggled by "F") 467 | if self.render_pointcloud: 468 | if scan is not None: 469 | self.scan.points = o3d.utility.Vector3dVector(scan.points) 470 | self.scan.colors = o3d.utility.Vector3dVector(scan.colors) 471 | self.scan.normals = o3d.utility.Vector3dVector(scan.normals) 472 | if self.pc_uniform_color or ( 473 | self.vis_pc_color 474 | and (self.config.color_channel == 0) 475 | and (not self.config.semantic_on) 476 | and (not self.config.dynamic_filter_on) 477 | ): 478 | self.scan.paint_uniform_color(GOLDEN) 479 | else: 480 | self.scan.points = o3d.utility.Vector3dVector() 481 | else: 482 | self.scan.points = o3d.utility.Vector3dVector() 483 | # self.scan.colors = o3d.utility.Vector3dVector() 484 | if self.ego_view and pose is not None: 485 | self.scan.transform(np.linalg.inv(pose)) 486 | self.vis.update_geometry(self.scan) 487 | 488 | # Mesh Map (toggled by "M") 489 | if self.render_mesh: 490 | if mesh is not None: 491 | if not self.keep_local_mesh: 492 | self.vis.remove_geometry( 493 | self.mesh, self.reset_bounding_box 494 | ) # if comment, then we keep the previous reconstructed mesh (for the case we use local map reconstruction) 495 | self.mesh = mesh 496 | if self.ego_view and pose is not None: 497 | self.mesh.transform(np.linalg.inv(pose)) 498 | self.vis.add_geometry(self.mesh, self.reset_bounding_box) 499 | else: # None, meshing for every frame can be time consuming, we just keep the mesh reconstructed from last frame for vis 500 | if self.ego_view and pose is not None: 501 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 502 | if self.ego_change_flag: # global -> ego view 503 | self.mesh.transform(np.linalg.inv(pose)) 504 | self.ego_change_flag = False 505 | else: 506 | self.mesh.transform(np.linalg.inv(pose) @ self.last_pose) 507 | self.vis.add_geometry(self.mesh, self.reset_bounding_box) 508 | elif self.ego_change_flag: # ego -> global view 509 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 510 | self.mesh.transform(self.last_pose) 511 | self.vis.add_geometry(self.mesh, self.reset_bounding_box) 512 | self.ego_change_flag = False 513 | else: 514 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box) 515 | 516 | # Neural Points Map (toggled by "P") 517 | if neural_points is not None: 518 | if self.render_neural_points: 519 | self.neural_points.points = o3d.utility.Vector3dVector( 520 | neural_points.points 521 | ) 522 | self.neural_points.colors = o3d.utility.Vector3dVector( 523 | neural_points.colors 524 | ) 525 | else: 526 | self.neural_points.points = o3d.utility.Vector3dVector() 527 | else: 528 | self.neural_points.points = o3d.utility.Vector3dVector() 529 | if self.ego_view and pose is not None: 530 | self.neural_points.transform(np.linalg.inv(pose)) 531 | self.vis.update_geometry(self.neural_points) 532 | 533 | # Data Pool (toggled by "D") 534 | if data_pool is not None: 535 | if self.render_data_pool: 536 | self.data_pool.points = o3d.utility.Vector3dVector(data_pool.points) 537 | self.data_pool.colors = o3d.utility.Vector3dVector(data_pool.colors) 538 | else: 539 | self.data_pool.points = o3d.utility.Vector3dVector() 540 | else: 541 | self.data_pool.points = o3d.utility.Vector3dVector() 542 | if self.ego_view and pose is not None: 543 | self.data_pool.transform(np.linalg.inv(pose)) 544 | self.vis.update_geometry(self.data_pool) 545 | 546 | # SDF map (toggled by "I") 547 | if sdf is not None: 548 | if self.render_sdf: 549 | self.sdf.points = o3d.utility.Vector3dVector(sdf.points) 550 | self.sdf.colors = o3d.utility.Vector3dVector(sdf.colors) 551 | else: 552 | self.sdf.points = o3d.utility.Vector3dVector() 553 | else: 554 | self.sdf.points = o3d.utility.Vector3dVector() 555 | if self.ego_view and pose is not None: 556 | self.sdf.transform(np.linalg.inv(pose)) 557 | self.vis.update_geometry(self.sdf) 558 | 559 | # Coordinate frame axis (toggled by "A") 560 | if self.render_frame_axis: 561 | if pose is not None: 562 | self.vis.remove_geometry(self.frame_axis, self.reset_bounding_box) 563 | self.frame_axis = o3d.geometry.TriangleMesh.create_coordinate_frame( 564 | size=self.frame_axis_len, origin=np.zeros(3) 565 | ) 566 | self.frame_axis += self.sensor_cad 567 | if not self.ego_view: 568 | self.frame_axis = self.frame_axis.transform(pose) 569 | self.vis.add_geometry(self.frame_axis, self.reset_bounding_box) 570 | else: 571 | self.vis.remove_geometry(self.frame_axis, self.reset_bounding_box) 572 | 573 | if pose is not None: 574 | self.last_pose = pose 575 | 576 | if self.reset_bounding_box: 577 | self.vis.reset_view_point(True) 578 | self.reset_bounding_box = False 579 | 580 | # show traj as lineset 581 | # long list to np conversion time 582 | def _update_traj( 583 | self, 584 | cur_pose=None, 585 | odom_poses_np=None, 586 | gt_poses_np=None, 587 | pgo_poses_np=None, 588 | loop_edges=None, 589 | ): 590 | 591 | self.vis.remove_geometry(self.odom_traj, self.reset_bounding_box) 592 | self.vis.remove_geometry(self.gt_traj, self.reset_bounding_box) 593 | self.vis.remove_geometry(self.pgo_traj, self.reset_bounding_box) 594 | self.vis.remove_geometry(self.pgo_edges, self.reset_bounding_box) 595 | 596 | if (self.render_trajectory and odom_poses_np is not None and odom_poses_np.shape[0] > 1): 597 | if pgo_poses_np is not None and (not self.render_odom_trajectory): 598 | self.odom_traj = o3d.geometry.LineSet() 599 | else: 600 | odom_position_np = odom_poses_np[:, :3, 3] 601 | self.odom_traj.points = o3d.utility.Vector3dVector(odom_position_np) 602 | odom_edges = np.array([[i, i + 1] for i in range(odom_poses_np.shape[0] - 1)]) 603 | self.odom_traj.lines = o3d.utility.Vector2iVector(odom_edges) 604 | 605 | if pgo_poses_np is None or self.before_pgo: 606 | self.odom_traj.paint_uniform_color(RED) 607 | else: 608 | self.odom_traj.paint_uniform_color(BLUE) 609 | 610 | if self.ego_view and cur_pose is not None: 611 | self.odom_traj.transform(np.linalg.inv(cur_pose)) 612 | else: 613 | self.odom_traj = o3d.geometry.LineSet() 614 | 615 | if ( 616 | self.render_trajectory 617 | and pgo_poses_np is not None 618 | and pgo_poses_np.shape[0] > 1 619 | and (not self.before_pgo) 620 | ): 621 | pgo_position_np = pgo_poses_np[:, :3, 3] 622 | 623 | self.pgo_traj.points = o3d.utility.Vector3dVector(pgo_position_np) 624 | pgo_traj_edges = np.array([[i, i + 1] for i in range(pgo_poses_np.shape[0] - 1)]) 625 | self.pgo_traj.lines = o3d.utility.Vector2iVector(pgo_traj_edges) 626 | self.pgo_traj.paint_uniform_color(RED) 627 | 628 | if self.ego_view and cur_pose is not None: 629 | self.pgo_traj.transform(np.linalg.inv(cur_pose)) 630 | 631 | if self.render_pgo and loop_edges is not None and len(loop_edges) > 0: 632 | edges = np.array(loop_edges) 633 | self.pgo_edges.points = o3d.utility.Vector3dVector(pgo_position_np) 634 | self.pgo_edges.lines = o3d.utility.Vector2iVector(edges) 635 | self.pgo_edges.paint_uniform_color(GREEN) 636 | 637 | if self.ego_view and cur_pose is not None: 638 | self.pgo_edges.transform(np.linalg.inv(cur_pose)) 639 | else: 640 | self.pgo_edges = o3d.geometry.LineSet() 641 | else: 642 | self.pgo_traj = o3d.geometry.LineSet() 643 | self.pgo_edges = o3d.geometry.LineSet() 644 | 645 | if ( 646 | self.render_trajectory 647 | and self.render_gt_trajectory 648 | and gt_poses_np is not None 649 | and gt_poses_np.shape[0] > 1 650 | ): 651 | gt_position_np = gt_poses_np[:, :3, 3] 652 | self.gt_traj.points = o3d.utility.Vector3dVector(gt_position_np) 653 | gt_edges = np.array([[i, i + 1] for i in range(gt_poses_np.shape[0] - 1)]) 654 | self.gt_traj.lines = o3d.utility.Vector2iVector(gt_edges) 655 | self.gt_traj.paint_uniform_color(BLACK) 656 | if odom_poses_np is None: 657 | self.gt_traj.paint_uniform_color(RED) 658 | if self.ego_view and cur_pose is not None: 659 | self.gt_traj.transform(np.linalg.inv(cur_pose)) 660 | else: 661 | self.gt_traj = o3d.geometry.LineSet() 662 | 663 | self.vis.add_geometry(self.odom_traj, self.reset_bounding_box) 664 | self.vis.add_geometry(self.gt_traj, self.reset_bounding_box) 665 | self.vis.add_geometry(self.pgo_traj, self.reset_bounding_box) 666 | self.vis.add_geometry(self.pgo_edges, self.reset_bounding_box) 667 | 668 | def _toggle_view(self, vis): 669 | self.global_viewpoint = not self.global_viewpoint 670 | vis.update_renderer() 671 | vis.reset_view_point(True) 672 | current_camera = self.view_control.convert_to_pinhole_camera_parameters() 673 | if self.camera_params and not self.global_viewpoint: 674 | self.view_control.convert_from_pinhole_camera_parameters(self.camera_params) 675 | self.camera_params = current_camera 676 | -------------------------------------------------------------------------------- /vis_pin_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # @file vis_pin_map.py 3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de] 4 | # Copyright (c) 2024 Yue Pan, all rights reserved 5 | 6 | import glob 7 | import os 8 | import sys 9 | import time 10 | 11 | import numpy as np 12 | import open3d as o3d 13 | import torch 14 | import torch.multiprocessing as mp 15 | import dtyper as typer 16 | from rich import print 17 | 18 | from model.decoder import Decoder 19 | from model.neural_points import NeuralPoints 20 | from utils.config import Config 21 | from utils.mesher import Mesher 22 | from utils.tools import setup_experiment, split_chunks, load_decoders, remove_gpu_cache 23 | 24 | from gui import slam_gui 25 | from gui.gui_utils import ParamsGUI, VisPacket 26 | 27 | 28 | ''' 29 | load the pin-map and do the reconstruction 30 | ''' 31 | 32 | app = typer.Typer(add_completion=False, rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]}) 33 | 34 | docstring = f""" 35 | :round_pushpin: Inspect the PIN Map \n 36 | 37 | [bold green]Examples: [/bold green] 38 | 39 | # Inspect the PIN Map stored in a mapping result folder, showing both the neural points and the mesh reconstructed with a certain marching cubes resolution 40 | $ python3 vis_pin_map.ply :open_file_folder: -m 41 | 42 | # Additionally, you can specify the cropped point cloud file and the output mesh file 43 | $ python3 vis_pin_map.ply :open_file_folder: -m -c :page_facing_up: -o :page_facing_up: 44 | 45 | """ 46 | 47 | @app.command(help=docstring) 48 | def vis_pin_map( 49 | result_folder: str = typer.Argument(..., help='Path to the result folder'), 50 | mesh_res_m: float = typer.Option(None, '--mesh_res_m', '-m', help='Resolution of the mesh in meters'), 51 | cropped_ply_filename: str = typer.Option("neural_points.ply", '--cropped_ply_filename', '-c', help='Path to the cropped point cloud file'), 52 | output_mesh_filename: str = typer.Option(None, '--output_mesh_filename', '-o', help='Path to the output mesh file'), 53 | mc_nn: int = typer.Option(9, '--mc_nn', '-n', help='Minimum number of neighbors for SDF querying for marching cubes'), 54 | o3d_vis_on: bool = typer.Option(True, '--visualize_on', '-v', help='Turn on the visualizer'), 55 | ): 56 | 57 | config = Config() 58 | 59 | yaml_files = glob.glob(f"{result_folder}/*.yaml") 60 | if len(yaml_files) > 1: # Check if there is exactly one YAML file 61 | sys.exit("There are multiple YAML files. Please handle accordingly.") 62 | elif len(yaml_files) == 0: # If no YAML files are found 63 | sys.exit("No YAML files found in the specified path.") 64 | config.load(yaml_files[0]) 65 | config.model_path = os.path.join(result_folder, "model", "pin_map.pth") 66 | 67 | print("[bold green]Load and inspect PIN Map[/bold green]","📍" ) 68 | 69 | run_path = setup_experiment(config, sys.argv, debug_mode=True) 70 | 71 | mp.set_start_method("spawn") # don't forget this 72 | 73 | # initialize the mlp decoder 74 | geo_mlp = Decoder(config, config.geo_mlp_hidden_dim, config.geo_mlp_level, 1) 75 | sem_mlp = Decoder(config, config.sem_mlp_hidden_dim, config.sem_mlp_level, config.sem_class_count + 1) if config.semantic_on else None 76 | color_mlp = Decoder(config, config.color_mlp_hidden_dim, config.color_mlp_level, config.color_channel) if config.color_on else None 77 | 78 | mlp_dict = {"sdf":geo_mlp, "semantic": sem_mlp,"color":color_mlp} 79 | 80 | # initialize the neural point features 81 | neural_points: NeuralPoints = NeuralPoints(config) 82 | 83 | # Load the map 84 | loaded_model = torch.load(config.model_path, weights_only=False) 85 | neural_points = loaded_model["neural_points"] 86 | load_decoders(loaded_model, mlp_dict) 87 | neural_points.temporal_local_map_on = False 88 | neural_points.recreate_hash(neural_points.neural_points[0], torch.eye(3).cuda(), False, False) 89 | neural_points.compute_feature_principle_components(down_rate = 59) 90 | print("PIN Map loaded") 91 | 92 | # mesh reconstructor 93 | mesher = Mesher(config, neural_points, mlp_dict) 94 | 95 | mesh_on = (mesh_res_m is not None) 96 | if mesh_on: 97 | config.mc_res_m = mesh_res_m 98 | config.mesh_min_nn = mc_nn 99 | 100 | q_main2vis = q_vis2main = None 101 | if o3d_vis_on: 102 | # communicator between the processes 103 | q_main2vis = mp.Queue() 104 | q_vis2main = mp.Queue() 105 | 106 | params_gui = ParamsGUI( 107 | q_main2vis=q_main2vis, 108 | q_vis2main=q_vis2main, 109 | config=config, 110 | local_map_default_on=False, 111 | mesh_default_on=mesh_on, 112 | neural_point_map_default_on=config.neural_point_map_default_on, 113 | ) 114 | gui_process = mp.Process(target=slam_gui.run, args=(params_gui,)) 115 | gui_process.start() 116 | time.sleep(3) # second 117 | 118 | cur_mesh = None 119 | out_mesh_path = None 120 | 121 | if mesh_on: 122 | cropped_ply_path = os.path.join(result_folder, "map", cropped_ply_filename) 123 | if os.path.exists(cropped_ply_path): 124 | cropped_pc = o3d.io.read_point_cloud(cropped_ply_path) 125 | print("Load region for meshing from {}".format(cropped_ply_path)) 126 | 127 | else: 128 | cropped_pc = neural_points.get_neural_points_o3d(query_global=True, random_down_ratio=23) 129 | 130 | mesh_aabb = cropped_pc.get_axis_aligned_bounding_box() 131 | chunks_aabb = split_chunks(cropped_pc, mesh_aabb, mesh_res_m*300) 132 | 133 | if output_mesh_filename is not None: 134 | out_mesh_path = os.path.join(result_folder, "mesh", output_mesh_filename) 135 | print("Output the mesh to: ", out_mesh_path) 136 | 137 | mc_cm_str = str(round(mesh_res_m*1e2)) 138 | print("Reconstructing the mesh with resolution {} cm".format(mc_cm_str)) 139 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, mesh_res_m, out_mesh_path, False, config.semantic_on, 140 | config.color_on, filter_isolated_mesh=True, mesh_min_nn=mc_nn) 141 | print("Reconstructing the global mesh done") 142 | 143 | remove_gpu_cache() 144 | 145 | if o3d_vis_on: 146 | 147 | while True: 148 | if not q_vis2main.empty(): 149 | q_vis2main.get() 150 | 151 | packet_to_vis: VisPacket = VisPacket(slam_finished=True) 152 | 153 | if not neural_points.is_empty(): 154 | packet_to_vis.add_neural_points_data(neural_points, only_local_map=False, pca_color_on=True) 155 | 156 | if cur_mesh is not None: 157 | packet_to_vis.add_mesh(np.array(cur_mesh.vertices, dtype=np.float64), np.array(cur_mesh.triangles), np.array(cur_mesh.vertex_colors, dtype=np.float64)) 158 | cur_mesh = None 159 | 160 | q_main2vis.put(packet_to_vis) 161 | time.sleep(1.0) 162 | 163 | 164 | if __name__ == "__main__": 165 | app() --------------------------------------------------------------------------------