├── .gitignore
├── .idea
├── .gitignore
├── CLID-SLAM.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── assets
├── GUI_Mesh.png
└── GUI_Neural_Points.png
├── cad
├── camera.ply
├── drone.ply
├── ipb_car.ply
└── kitti_car.ply
├── config
├── run_SubT_MRS.yaml
└── run_ncd128.yaml
├── dataset
└── converter
│ ├── config
│ └── rosbag2dataset.yaml
│ └── rosbag2dataset_parallel.py
├── experiment
└── .gitkeep
├── gui
├── __pycache__
│ ├── gui_utils.cpython-312.pyc
│ └── slam_gui.cpython-312.pyc
├── gui_utils.py
└── slam_gui.py
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-312.pyc
│ ├── decoder.cpython-312.pyc
│ ├── local_point_cloud_map.cpython-312.pyc
│ └── neural_points.cpython-312.pyc
├── decoder.py
├── local_point_cloud_map.py
└── neural_points.py
├── requirements.txt
├── slam.py
├── tools.ipynb
├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-312.pyc
│ ├── config.cpython-312.pyc
│ ├── data_sampler.cpython-312.pyc
│ ├── dataset_indexing.cpython-312.pyc
│ ├── error_state_iekf.cpython-312.pyc
│ ├── eval_traj_utils.cpython-312.pyc
│ ├── loss.cpython-312.pyc
│ ├── mapper.cpython-312.pyc
│ ├── mesher.cpython-312.pyc
│ ├── semantic_kitti_utils.cpython-312.pyc
│ ├── slam_dataset.cpython-312.pyc
│ ├── so3_math.cpython-312.pyc
│ └── tools.cpython-312.pyc
├── config.py
├── data_sampler.py
├── dataset_indexing.py
├── error_state_iekf.py
├── eval_traj_utils.py
├── loss.py
├── mapper.py
├── mesher.py
├── point_cloud2.py
├── semantic_kitti_utils.py
├── slam_dataset.py
├── so3_math.py
├── tools.py
└── visualizer.py
└── vis_pin_map.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/CLID-SLAM.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Photogrammetry & Robotics Bonn
4 | Copyright (c) 2025 Dalian University of Technology Intelligent Robotics Laboratory
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
⚔️CLID-SLAM: A Coupled LiDAR-Inertial Neural Implicit Dense SLAM with Region-Specific SDF Estimation
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | | Mesh | Neural Points |
16 | |-------------------------------|-----------------------------------|
17 | |  |  |
18 |
19 | ## TODO 📝
20 |
21 | - [x] Release the source code
22 | - [ ] Enhance the README.md
23 | - [ ] Include the theory derivations
24 |
25 | ## Installation
26 |
27 | ### Platform Requirements
28 | - Ubuntu 20.04
29 | - GPU (tested on RTX 4090)
30 |
31 | ### Steps
32 | 1. **Clone the repository**
33 | ```bash
34 | git clone git@github.com:DUTRobot/CLID-SLAM.git
35 | cd CLID-SLAM
36 | ```
37 |
38 | 2. **Create Conda Environment**
39 | ```bash
40 | conda create -n slam python=3.12
41 | conda activate slam
42 | ```
43 |
44 | 3. **Install PyTorch**
45 | ```bash
46 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
47 | ```
48 |
49 | 4. **Install ROS Dependencies**
50 | ```bash
51 | sudo apt install ros-noetic-rosbag ros-noetic-sensor-msgs
52 | ```
53 |
54 | 5. **Install Other Dependencies**
55 | ```bash
56 | pip3 install -r requirements.txt
57 | ```
58 |
59 | ## Data Preparation
60 |
61 | ### Download ROSbag Files
62 | Download these essential ROSbag datasets:
63 | - [**Newer College Dataset**](https://ori-drs.github.io/newer-college-dataset/)
64 | - [**SubT-MRS Dataset**](https://superodometry.com/iccv23_challenge_LiI)
65 |
66 | ### Convert to Sequences
67 | 1. Edit `./dataset/converter/config/rosbag2dataset.yaml`.
68 | 2. Run:
69 | ```bash
70 | python3 ./dataset/converter/rosbag2dataset_parallel.py
71 |
72 | ## Run CLID-SLAM
73 | ```bash
74 | python3 slam.py ./config/run_ncd128.yaml
75 | ```
76 | ## Acknowledgements 🙏
77 |
78 | This project is built upon the open-source project [**PIN-SLAM**](https://github.com/PRBonn/PIN_SLAM), developed by [**PRBonn/YuePanEdward**](https://github.com/YuePanEdward). A huge thanks to the contributors of **PIN-SLAM** for their outstanding work and dedication!
79 |
--------------------------------------------------------------------------------
/assets/GUI_Mesh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/assets/GUI_Mesh.png
--------------------------------------------------------------------------------
/assets/GUI_Neural_Points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/assets/GUI_Neural_Points.png
--------------------------------------------------------------------------------
/cad/camera.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/camera.ply
--------------------------------------------------------------------------------
/cad/drone.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/drone.ply
--------------------------------------------------------------------------------
/cad/ipb_car.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/ipb_car.ply
--------------------------------------------------------------------------------
/cad/kitti_car.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/cad/kitti_car.ply
--------------------------------------------------------------------------------
/config/run_SubT_MRS.yaml:
--------------------------------------------------------------------------------
1 | setting:
2 | name: "SubT_MRS"
3 | output_root: "./experiment"
4 | imu_path: "./dataset/SubT_MRS/Final_Challenge_UGV1/sequences/imu"
5 | pc_path: "./dataset/SubT_MRS/Final_Challenge_UGV1/sequences/lidar"
6 | pose_ts_path: "./dataset/SubT_MRS/Final_Challenge_UGV1/sequences/pose_ts.txt"
7 | deskew: True
8 |
9 | process:
10 | min_range_m: 1.0
11 | max_range_m: 60.0
12 | min_z_m: -10.0
13 | vox_down_m: 0.1
14 | sampler:
15 | local_voxel_size_m: 0.2
16 | surface_sample_range_m: 0.25
17 | surface_sample_n: 4
18 | free_sample_begin_ratio: 0.8
19 | free_front_sample_n: 2
20 | neuralpoints:
21 | voxel_size_m: 0.4
22 | num_nei_cells: 2
23 | search_alpha: 0.5
24 | weighted_first: True
25 | layer_norm_on: True
26 | loss:
27 | sigma_sigmoid_m: 0.1
28 | loss_weight_on: True
29 | dist_weight_scale: 0.8
30 | continual:
31 | batch_size_new_sample: 1000
32 | pool_capacity: 1e7
33 | tracker:
34 | measurement_noise_covariance: 0.01
35 | bias_noise_covariance: 0.0001
36 | source_vox_down_m: 0.6
37 | iter_n: 50
38 | T_imu_lidar:
39 | - [ 1.0, 0, 0, 0]
40 | - [ 0, 1.0, 0, 0]
41 | - [ 0, 0, 1.0, 0]
42 | - [ 0, 0, 0, 1.0]
43 | optimizer:
44 | iters: 10
45 | batch_size: 16384
46 | learning_rate: 0.01
47 | adaptive_iters: True
48 | eval:
49 | wandb_vis_on: False
50 | o3d_vis_on: False
51 | silence_log: True
52 | mesh_freq_frame: 50
53 | mesh_min_nn: 15
54 | save_map: True
--------------------------------------------------------------------------------
/config/run_ncd128.yaml:
--------------------------------------------------------------------------------
1 | setting:
2 | name: "ncd128"
3 | output_root: "./experiment"
4 | imu_path: "./dataset/ncd128/collection1/quad_easy/sequences/imu"
5 | pc_path: "./dataset/ncd128/collection1/quad_easy/sequences/lidar"
6 | pose_ts_path: "./dataset/ncd128/collection1/quad_easy/sequences/pose_ts.txt"
7 | deskew: True
8 |
9 | process:
10 | min_range_m: 1.0
11 | max_range_m: 60.0
12 | min_z_m: -10.0
13 | vox_down_m: 0.1
14 | sampler:
15 | local_voxel_size_m: 0.2
16 | surface_sample_range_m: 0.25
17 | surface_sample_n: 4
18 | free_sample_begin_ratio: 0.5
19 | free_sample_end_dist_m: 1.2
20 | free_front_sample_n: 2
21 | neuralpoints:
22 | voxel_size_m: 0.4
23 | num_nei_cells: 2
24 | search_alpha: 0.5
25 | weighted_first: True
26 | loss:
27 | sigma_sigmoid_m: 0.1
28 | loss_weight_on: True
29 | dist_weight_scale: 0.8
30 | continual:
31 | batch_size_new_sample: 1000
32 | pool_capacity: 1e7
33 | tracker:
34 | measurement_noise_covariance: 0.01
35 | bias_noise_covariance: 0.0001
36 | source_vox_down_m: 0.6
37 | iter_n: 50
38 | T_imu_lidar:
39 | - [ 1.0, 0, 0, 0.014 ]
40 | - [ 0, 1.0, 0, -0.012 ]
41 | - [ 0, 0, 1.0, -0.015 ]
42 | - [ 0, 0, 0, 1.0 ]
43 | optimizer:
44 | iters: 10
45 | batch_size: 16384
46 | learning_rate: 0.01
47 | adaptive_iters: True
48 | eval:
49 | wandb_vis_on: False
50 | o3d_vis_on: True
51 | silence_log: True
52 | mesh_freq_frame: 50
53 | mesh_min_nn: 15
54 | save_map: True
--------------------------------------------------------------------------------
/dataset/converter/config/rosbag2dataset.yaml:
--------------------------------------------------------------------------------
1 | input_bag: './dataset/ncd128/collection1/quad-easy.bag'
2 | output_folder: './dataset/ncd128/collection1/quad_easy'
3 | imu_topic: '/os_cloud_node/imu'
4 | lidar_topic: '/os_cloud_node/points'
5 | #imu_topic: '/imu/data'
6 | #lidar_topic: '/velodyne_points'
7 | batch_size: 100 # Number of messages per batch
8 | end_frame: -1 # -1 means process the entire bag file
--------------------------------------------------------------------------------
/dataset/converter/rosbag2dataset_parallel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file rosbag2dataset_parallel.py
3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn]
4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved
5 | import csv
6 | import os
7 | import yaml
8 | from multiprocessing import Process, Queue
9 | from typing import List, Tuple
10 |
11 | import numpy as np
12 | import rosbag
13 | import sensor_msgs.point_cloud2 as pc2
14 | from plyfile import PlyData, PlyElement
15 |
16 | G_M_S2 = 9.81 # Gravitational constant in m/s^2
17 |
18 | def load_config(path: str) -> dict:
19 | """Load configuration from a YAML file."""
20 | with open(path, 'r') as file:
21 | return yaml.safe_load(file)
22 |
23 |
24 | def write_ply(filename: str, data: tuple) -> bool:
25 | """
26 | Writes point cloud data along with timestamps to a PLY file.
27 |
28 | Args:
29 | filename (str): Path to the output .ply file.
30 | data (list): List containing two elements, the point data and the timestamps.
31 | data[0] should be a 2D NumPy array of shape (n, 4) where n is the number of points.
32 | data[1] should be a 1D NumPy array of length n representing timestamps.
33 | field_names (list): List of strings representing the names of the fields for PLY file.
34 |
35 | Returns:
36 | bool: True if the file was written successfully, False otherwise.
37 | """
38 | # Ensure timestamp data is a 2D array with one column
39 | points, timestamps = data
40 | combined_data = np.hstack([points, timestamps.reshape(-1, 1)])
41 | structured_array = np.core.records.fromarrays(combined_data.transpose(),
42 | names=['x', 'y', 'z', 'intensity', 'timestamp'])
43 | PlyData([PlyElement.describe(structured_array, 'vertex')], text=False).write(filename)
44 | return True
45 |
46 | # import pandas as pd
47 | # from pyntcloud import PyntCloud
48 | # def write_ply_pyntcloud(filename, data):
49 | # """
50 | # Writes point cloud data along with timestamps to a PLY file using PyntCloud.
51 | #
52 | # Args:
53 | # filename (str): Path to the output .ply file.
54 | # data (list): List containing two elements, the point data and the timestamps.
55 | # data[0] should be a 2D NumPy array of shape (n, 4) where n is the number of points.
56 | # data[1] should be a 1D NumPy array of length n representing timestamps.
57 | #
58 | # Returns:
59 | # bool: True if the file was written successfully, False otherwise.
60 | # """
61 | # # Create a DataFrame from the provided data
62 | # points, timestamps = data
63 | # df = pd.DataFrame(points, columns=['x', 'y', 'z', 'intensity'])
64 | # df['timestamp'] = timestamps
65 | #
66 | # # Convert DataFrame to a PyntCloud object
67 | # cloud = PyntCloud(df)
68 | #
69 | # # Save to a PLY file
70 | # try:
71 | # cloud.to_file(filename)
72 | # return True
73 | # except Exception as e:
74 | # print(f"Error writing PLY file with PyntCloud: {e}")
75 | # return False
76 |
77 | def write_csv(filename: str, imu_data_pool: List[Tuple[float, float, float, float, float, float, float]]) -> None:
78 | """Write IMU data to a CSV file."""
79 | with open(filename, 'w', newline='') as file:
80 | writer = csv.writer(file)
81 | writer.writerow(['timestamp', 'acc_x', 'acc_y', 'acc_z', 'gyro_x', 'gyro_y', 'gyro_z'])
82 | for imu_data in imu_data_pool:
83 | writer.writerow(imu_data)
84 |
85 |
86 | def extract_lidar_data(msg) -> Tuple[np.ndarray, np.ndarray]:
87 | """Extract point cloud data and timestamps from a LiDAR message."""
88 | pc_data = list(pc2.read_points(msg, skip_nans=True))
89 | pc_array = np.array(pc_data)
90 | timestamps = pc_array[:, 4] * 1e-9 # Convert to seconds
91 | return pc_array[:, :4], timestamps
92 |
93 |
94 | def process_lidar_data(batch_data: List[Tuple[str, Tuple[np.ndarray, np.ndarray]]]) -> None:
95 | """Process a batch of LiDAR data and save as PLY files."""
96 | for i, (ply_file_path, data) in enumerate(batch_data):
97 | if write_ply(ply_file_path, data):
98 | print(f"Exported LiDAR point cloud PLY file: {ply_file_path}")
99 |
100 |
101 | def sync_and_save(config: dict) -> None:
102 | """Synchronize and save LiDAR and IMU data from a ROS bag file."""
103 | os.makedirs(config["output_folder"], exist_ok=True)
104 | os.makedirs(os.path.join(config["output_folder"], "lidar"), exist_ok=True)
105 | os.makedirs(os.path.join(config["output_folder"], "imu"), exist_ok=True)
106 |
107 | in_bag = rosbag.Bag(config["input_bag"])
108 |
109 | frame_index = 0
110 | start_flag = False
111 | imu_last_timestamp = None
112 | imu_data_pool = []
113 | lidar_timestamp_queue = Queue()
114 |
115 | processes = []
116 | batch_size = config["batch_size"] # Number of messages per batch
117 | batch_lidar_data = []
118 |
119 | for topic, msg, t in in_bag.read_messages(topics=[config["imu_topic"], config["lidar_topic"]]):
120 | current_timestamp = t.to_sec()
121 |
122 | if topic == config["lidar_topic"]:
123 | if not start_flag:
124 | start_flag = True
125 | else:
126 | csv_file_path = os.path.join(config["output_folder"], "imu", f"{frame_index}.csv")
127 | write_csv(csv_file_path, imu_data_pool)
128 | imu_data_pool = []
129 | print(f"Exported IMU measurement CSV file: {csv_file_path}")
130 |
131 | if len(batch_lidar_data) >= batch_size:
132 | p = Process(target=process_lidar_data, args=(batch_lidar_data,))
133 | p.start()
134 | processes.append(p)
135 | batch_lidar_data = []
136 |
137 | lidar_frame_timestamp = msg.header.stamp.to_sec()
138 | lidar_timestamp_queue.put(lidar_frame_timestamp)
139 |
140 | ply_file_path = os.path.join(config["output_folder"], "lidar", f"{frame_index}.ply")
141 | point_cloud_data = extract_lidar_data(msg)
142 | batch_lidar_data.append((ply_file_path, point_cloud_data))
143 |
144 | imu_last_timestamp = current_timestamp
145 | frame_index += 1
146 |
147 | if 0 < config["end_frame"] <= frame_index:
148 | break
149 |
150 | elif topic == config["imu_topic"]:
151 | if start_flag:
152 | time_delta = current_timestamp - imu_last_timestamp
153 | imu_last_timestamp = current_timestamp
154 | imu_data = (
155 | time_delta,
156 | msg.linear_acceleration.x,
157 | msg.linear_acceleration.y,
158 | msg.linear_acceleration.z,
159 | msg.angular_velocity.x,
160 | msg.angular_velocity.y,
161 | msg.angular_velocity.z
162 | )
163 | imu_data_pool.append(imu_data)
164 |
165 | if batch_lidar_data:
166 | p = Process(target=process_lidar_data, args=(batch_lidar_data,))
167 | p.start()
168 | processes.append(p)
169 |
170 | for p in processes:
171 | p.join()
172 |
173 | with open(os.path.join(config["output_folder"], "pose_timestamps.txt"), 'w', newline='') as file:
174 | print("Writing pose timestamps...")
175 | writer = csv.writer(file)
176 | writer.writerow(['timestamp'])
177 | while not lidar_timestamp_queue.empty():
178 | lidar_timestamp = lidar_timestamp_queue.get()
179 | writer.writerow([lidar_timestamp])
180 | print("Pose timestamps written successfully.")
181 |
182 |
183 | if __name__ == "__main__":
184 | config = load_config('./dataset/converter/config/rosbag2dataset.yaml')
185 | sync_and_save(config)
--------------------------------------------------------------------------------
/experiment/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/experiment/.gitkeep
--------------------------------------------------------------------------------
/gui/__pycache__/gui_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/gui/__pycache__/gui_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/gui/__pycache__/slam_gui.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/gui/__pycache__/slam_gui.cpython-312.pyc
--------------------------------------------------------------------------------
/gui/gui_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file gui_utils.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 |
5 | # This GUI is built on top of the great work of MonoGS (https://github.com/muskie82/MonoGS/blob/main/gui/gui_utils.py)
6 |
7 | import queue
8 |
9 | from utils.tools import feature_pca_torch
10 | from model.neural_points import NeuralPoints
11 |
12 |
13 | class VisPacket:
14 | def __init__(
15 | self,
16 | frame_id = None,
17 | finish=False,
18 | current_pointcloud_xyz=None,
19 | current_pointcloud_rgb=None,
20 | mesh_verts=None,
21 | mesh_faces=None,
22 | mesh_verts_rgb=None,
23 | odom_poses=None,
24 | gt_poses=None,
25 | slam_poses=None,
26 | travel_dist=None,
27 | slam_finished=False,
28 | ):
29 | self.has_neural_points = False
30 |
31 | self.neural_points_data = None
32 |
33 | self.frame_id = frame_id
34 |
35 | self.add_scan(current_pointcloud_xyz, current_pointcloud_rgb)
36 |
37 | self.add_mesh(mesh_verts, mesh_faces, mesh_verts_rgb)
38 |
39 | self.add_traj(odom_poses, gt_poses, slam_poses)
40 |
41 | self.sdf_slice_xyz = None
42 | self.sdf_slice_rgb = None
43 |
44 | self.sdf_pool_xyz = None
45 | self.sdf_pool_rgb = None
46 |
47 | self.travel_dist = travel_dist
48 | self.slam_finished = slam_finished
49 |
50 | self.finish = finish
51 |
52 | # the sorrounding map is also added here
53 | def add_neural_points_data(self, neural_points: NeuralPoints, only_local_map: bool = True,
54 | pca_color_on: bool = True):
55 |
56 | if neural_points is not None:
57 | self.has_neural_points = True
58 | self.neural_points_data = {}
59 | self.neural_points_data["count"] = neural_points.count()
60 | self.neural_points_data["local_count"] = neural_points.local_count()
61 | self.neural_points_data["map_memory_mb"] = neural_points.cur_memory_mb
62 | self.neural_points_data["resolution"] = neural_points.resolution
63 |
64 | if only_local_map:
65 | self.neural_points_data["position"] = neural_points.local_neural_points
66 | self.neural_points_data["orientation"] = neural_points.local_point_orientations
67 | self.neural_points_data["geo_feature"] = neural_points.local_geo_features.detach()
68 | if neural_points.color_on:
69 | self.neural_points_data["color_feature"] = neural_points.local_color_features.detach()
70 | self.neural_points_data["ts"] = neural_points.local_point_ts_update
71 | self.neural_points_data["stability"] = neural_points.local_point_certainties
72 |
73 | if pca_color_on:
74 | local_geo_feature_3d, _ = feature_pca_torch((self.neural_points_data["geo_feature"])[:-1], principal_components=neural_points.geo_feature_pca, down_rate=17)
75 | self.neural_points_data["color_pca_geo"] = local_geo_feature_3d
76 |
77 | if neural_points.color_on:
78 | local_color_feature_3d, _ = feature_pca_torch((self.neural_points_data["color_feature"])[:-1], principal_components=neural_points.color_feature_pca, down_rate=17)
79 | self.neural_points_data["color_pca_color"] = local_color_feature_3d
80 |
81 | else:
82 | self.neural_points_data["position"] = neural_points.neural_points
83 | self.neural_points_data["orientation"] = neural_points.point_orientations
84 | self.neural_points_data["geo_feature"] = neural_points.geo_features
85 | if neural_points.color_on:
86 | self.neural_points_data["color_feature"] = neural_points.color_features
87 | self.neural_points_data["ts"] = neural_points.point_ts_update
88 | self.neural_points_data["stability"] = neural_points.point_certainties
89 | if neural_points.local_mask is not None:
90 | self.neural_points_data["local_mask"] = neural_points.local_mask[:-1]
91 |
92 | if pca_color_on:
93 | geo_feature_3d, _ = feature_pca_torch(neural_points.geo_features[:-1], principal_components=neural_points.geo_feature_pca, down_rate=97)
94 | self.neural_points_data["color_pca_geo"] = geo_feature_3d
95 |
96 | if neural_points.color_on:
97 | color_feature_3d, _ = feature_pca_torch(neural_points.color_features[:-1], principal_components=neural_points.color_feature_pca, down_rate=97)
98 | self.neural_points_data["color_pca_color"] = color_feature_3d
99 |
100 |
101 | def add_scan(self, current_pointcloud_xyz=None, current_pointcloud_rgb=None):
102 | self.current_pointcloud_xyz = current_pointcloud_xyz
103 | self.current_pointcloud_rgb = current_pointcloud_rgb
104 |
105 | # TODO: add normal later
106 |
107 | def add_sdf_slice(self, sdf_slice_xyz=None, sdf_slice_rgb=None):
108 | self.sdf_slice_xyz = sdf_slice_xyz
109 | self.sdf_slice_rgb = sdf_slice_rgb
110 |
111 | def add_sdf_training_pool(self, sdf_pool_xyz=None, sdf_pool_rgb=None):
112 | self.sdf_pool_xyz = sdf_pool_xyz
113 | self.sdf_pool_rgb = sdf_pool_rgb
114 |
115 | def add_mesh(self, mesh_verts=None, mesh_faces=None, mesh_verts_rgb=None):
116 | self.mesh_verts = mesh_verts
117 | self.mesh_faces = mesh_faces
118 | self.mesh_verts_rgb = mesh_verts_rgb
119 |
120 | def add_traj(self, odom_poses=None, gt_poses=None, slam_poses=None, loop_edges=None):
121 |
122 | self.odom_poses = odom_poses
123 | self.gt_poses = gt_poses
124 | self.slam_poses = slam_poses
125 |
126 | if slam_poses is None:
127 | self.slam_poses = odom_poses
128 |
129 | self.loop_edges = loop_edges
130 |
131 |
132 | def get_latest_queue(q):
133 | message = None
134 | while True:
135 | try:
136 | message_latest = q.get_nowait()
137 | if message is not None:
138 | del message
139 | message = message_latest
140 | except queue.Empty:
141 | if q.empty():
142 | break
143 | return message
144 |
145 |
146 | class ControlPacket:
147 | flag_pause = False
148 | flag_vis = True
149 | flag_mesh = False
150 | flag_sdf = False
151 | flag_global = False
152 | flag_source = False
153 | mc_res_m = 0.2
154 | mesh_min_nn = 10
155 | mesh_freq_frame = 50
156 | sdf_freq_frame = 50
157 | sdf_slice_height = 0.2
158 | sdf_res_m = 0.2
159 | cur_frame_id = 0
160 |
161 | class ParamsGUI:
162 | def __init__(
163 | self,
164 | q_main2vis=None,
165 | q_vis2main=None,
166 | config=None,
167 | local_map_default_on: bool = True,
168 | robot_default_on: bool = True,
169 | mesh_default_on: bool = False,
170 | sdf_default_on: bool = False,
171 | neural_point_map_default_on: bool = False,
172 | neural_point_color_default_mode: int = 1, # 1: geo feature pca, 2: photo feature pca, 3: time, 4: height
173 | neural_point_vis_down_rate: int = 1,
174 | ):
175 | self.q_main2vis = q_main2vis
176 | self.q_vis2main = q_vis2main
177 | self.config = config
178 |
179 | self.robot_default_on = robot_default_on
180 | self.neural_point_map_default_on = neural_point_map_default_on
181 | self.mesh_default_on = mesh_default_on
182 | self.sdf_default_on = sdf_default_on
183 | self.local_map_default_on = local_map_default_on
184 | self.neural_point_color_default_mode = neural_point_color_default_mode
185 | self.neural_point_vis_down_rate = neural_point_vis_down_rate
186 |
187 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/model/__pycache__/decoder.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/decoder.cpython-312.pyc
--------------------------------------------------------------------------------
/model/__pycache__/local_point_cloud_map.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/local_point_cloud_map.cpython-312.pyc
--------------------------------------------------------------------------------
/model/__pycache__/neural_points.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/model/__pycache__/neural_points.cpython-312.pyc
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file decoder.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from utils.config import Config
12 |
13 |
14 | class Decoder(nn.Module):
15 | def __init__(
16 | self,
17 | config: Config,
18 | hidden_dim,
19 | hidden_level,
20 | out_dim,
21 | is_time_conditioned=False,
22 | ):
23 |
24 | super().__init__()
25 |
26 | self.out_dim = out_dim
27 | self.use_leaky_relu = config.mlp_leaky_relu
28 | bias_on = config.mlp_bias_on
29 |
30 | # default not used
31 | if config.use_gaussian_pe:
32 | position_dim = config.pos_input_dim + 2 * config.pos_encoding_band
33 | else:
34 | position_dim = config.pos_input_dim * (2 * config.pos_encoding_band + 1)
35 |
36 | feature_dim = config.feature_dim
37 | input_dim = feature_dim + position_dim
38 |
39 | # default not used
40 | if is_time_conditioned:
41 | input_layer_count += 1
42 |
43 | # predict sdf (now it anyway only predict sdf without further sigmoid
44 | # Initializa the structure of shared MLP
45 | layers = []
46 | for i in range(hidden_level):
47 | if i == 0:
48 | layers.append(nn.Linear(input_dim, hidden_dim, bias_on))
49 | else:
50 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias_on))
51 | self.layers = nn.ModuleList(layers)
52 | self.lout = nn.Linear(hidden_dim, out_dim, bias_on)
53 |
54 | self.sdf_scale = 1.0
55 | if config.main_loss_type == "bce":
56 | self.sdf_scale = config.logistic_gaussian_ratio * config.sigma_sigmoid_m
57 |
58 | self.to(config.device)
59 | # torch.cuda.empty_cache()
60 |
61 | def mlp(self, features):
62 | # linear (feature_dim -> hidden_dim)
63 | # relu
64 | # linear (hidden_dim -> hidden_dim)
65 | # relu
66 | # linear (hidden_dim -> 1)
67 | for k, l in enumerate(self.layers):
68 | if k == 0:
69 | if self.use_leaky_relu:
70 | h = F.leaky_relu(l(features))
71 | else:
72 | h = F.relu(l(features))
73 | else:
74 | if self.use_leaky_relu:
75 | h = F.leaky_relu(l(h))
76 | else:
77 | h = F.relu(l(h))
78 | out = self.lout(h)
79 | return out
80 |
81 | # predict the sdf (opposite sign to the actual sdf)
82 | # unit is already m
83 | def sdf(self, features):
84 | out = self.mlp(features).squeeze(1) * self.sdf_scale
85 | return out
86 |
87 | def time_conditionded_sdf(self, features, ts):
88 | nn_k = features.shape[1]
89 | ts_nn_k = ts.repeat(nn_k).view(-1, nn_k, 1)
90 | time_conditioned_feature = torch.cat((features, ts_nn_k), dim=-1)
91 | out = self.sdf(time_conditioned_feature)
92 | return out
93 |
94 | # predict the occupancy probability
95 | def occupancy(self, features):
96 | out = torch.sigmoid(self.sdf(features) / -self.sdf_scale) # to [0, 1]
97 | return out
98 |
99 | # predict the probabilty of each semantic label
100 | def sem_label_prob(self, features):
101 | out = F.log_softmax(self.mlp(features), dim=-1)
102 | return out
103 |
104 | def sem_label(self, features):
105 | out = torch.argmax(self.sem_label_prob(features), dim=1)
106 | return out
107 |
108 | # def regress_color(self, features):
109 | # out = torch.clamp(self.mlp(features), 0.0, 1.0)
110 | # return out
111 |
112 | def regress_color(self, features):
113 | out = torch.sigmoid(self.mlp(features)) # sigmoid map to [0,1]
114 | return out
--------------------------------------------------------------------------------
/model/local_point_cloud_map.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file local_point_cloud_map.py
3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn]
4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved
5 | import math
6 | import torch
7 | from utils.config import Config
8 | from utils.tools import voxel_down_sample_torch
9 |
10 |
11 | class LocalPointCloudMap:
12 | def __init__(self, config: Config) -> None:
13 | self.config = config
14 | self.idx_dtype = torch.int64
15 | self.dtype = config.dtype
16 | self.device = config.device
17 | self.resolution = config.local_voxel_size_m
18 | self.buffer_size = config.local_buffer_size
19 |
20 | self.buffer_pt_index = -torch.ones(self.buffer_size, dtype=self.idx_dtype, device=self.device) # 哈希表
21 | self.local_point_cloud_map = torch.empty((0, 3), dtype=torch.float32, device=self.device)
22 |
23 | self.primes = torch.tensor([73856093, 19349663, 83492791], dtype=self.idx_dtype, device=self.device)
24 | self.neighbor_idx = None
25 | self.max_valid_range = None
26 | self.set_search_neighborhood()
27 | self.map_size = config.local_map_size
28 |
29 | def voxel_hash(self, points):
30 | grid_coords = (points / self.resolution).floor().to(self.primes)
31 | hash_values = torch.fmod((grid_coords * self.primes).sum(-1), self.buffer_size)
32 | return hash_values
33 |
34 | def insert_points(self, points):
35 | sample_idx = voxel_down_sample_torch(points, self.resolution)
36 | sample_points = points[sample_idx]
37 | hash_values = self.voxel_hash(sample_points)
38 | hash_idx = self.buffer_pt_index[hash_values]
39 |
40 | update_mask = (hash_idx == -1)
41 | new_points = sample_points[update_mask]
42 |
43 | cur_pt_count = self.local_point_cloud_map.shape[0]
44 | self.buffer_pt_index[hash_values[update_mask]] = torch.arange(new_points.shape[0], device=self.device) + cur_pt_count
45 |
46 | self.local_point_cloud_map = torch.cat(( self.local_point_cloud_map, new_points), 0)
47 |
48 | def update_map(self, sensor_position, points):
49 | self.insert_points(points)
50 | distances = torch.norm(self.local_point_cloud_map - sensor_position, dim=-1)
51 | keep_mask = distances < self.map_size
52 | self.local_point_cloud_map = self.local_point_cloud_map[keep_mask]
53 |
54 | new_buffer_pt_index = -torch.ones(self.buffer_size, dtype=self.idx_dtype, device=self.device)
55 | new_hash_values = self.voxel_hash(self.local_point_cloud_map)
56 | new_buffer_pt_index[new_hash_values] = torch.arange(self.local_point_cloud_map.shape[0], device=self.device)
57 |
58 | self.buffer_pt_index = new_buffer_pt_index
59 |
60 | def set_search_neighborhood(
61 | self, num_nei_cells: int = 1, search_alpha: float = 0.2
62 | ):
63 | dx = torch.arange(-num_nei_cells, num_nei_cells + 1, device=self.primes.device, dtype=self.primes.dtype,)
64 |
65 | coords = torch.meshgrid(dx, dx, dx, indexing="ij")
66 | dx = torch.stack(coords, dim=-1).reshape(-1, 3) # [K,3]
67 |
68 | dx2 = torch.sum(dx**2, dim=-1)
69 | self.neighbor_idx = dx[dx2 < (num_nei_cells + search_alpha) ** 2]
70 | self.max_valid_range = 1.732*(num_nei_cells + 1) * self.resolution
71 | # in the sphere --> smaller K --> faster training
72 | # when num_cells = 2 when num_cells = 3
73 | # alpha 0.2, K = 33 alpha 0.2, K = 147
74 | # alpha 0.3, K = 57 alpha 0.5, K = 179
75 | # alpha 0.5, K = 81 alpha 1.0, K = 251
76 | # alpha 1.0, K = 93
77 | # alpha 2.0, K = 125
78 |
79 | def region_specific_sdf_estimation(self, points: torch.Tensor):
80 | point_num = points.shape[0]
81 | sdf_abs = torch.ones(point_num, device=points.device)*self.max_valid_range
82 | surface_mask = torch.ones(point_num, dtype=torch.bool, device=self.config.device)
83 |
84 | bs = 262144 # 256 × 1024
85 | iter_n = math.ceil(point_num / bs)
86 | # 为了避免爆显存,采用分批处理的办法
87 | for n in range(iter_n):
88 | head, tail = n * bs, min((n + 1) * bs, point_num)
89 | batch_points = points[head:tail, :]
90 | batch_coords = (batch_points / self.resolution).floor().to(self.primes)
91 | batch_neighbord_cells = (batch_coords[..., None, :] + self.neighbor_idx)
92 | batch_hash = torch.fmod((batch_neighbord_cells * self.primes).sum(-1), self.buffer_size)
93 | batch_neighb_idx = self.buffer_pt_index[batch_hash]
94 | batch_neighb_pts = self.local_point_cloud_map[batch_neighb_idx]
95 | batch_dist = torch.norm(batch_neighb_pts - batch_points.view(-1, 1, 3), dim=-1)
96 | batch_dist = torch.where(batch_neighb_idx == -1, self.max_valid_range, batch_dist)
97 |
98 | # k nearst neighbors neural points
99 | batch_sdf_abs, batch_min_idx = torch.topk(batch_dist, 4, largest=False, dim=1)
100 | batch_min_idx_expanded = batch_min_idx.unsqueeze(-1).expand(-1, -1, 3)
101 | batch_knn_points = torch.gather(batch_neighb_pts, 1, batch_min_idx_expanded)
102 | valid_fit_mask = batch_sdf_abs[:, 3] < self.max_valid_range
103 | valid_batch_knn_points = batch_knn_points[valid_fit_mask]
104 | unit_normal_vector = torch.zeros_like(batch_points)
105 | plane_constant = torch.zeros(batch_points.size(0), device=batch_points.device)
106 | fit_success = torch.zeros(batch_points.size(0), dtype=torch.bool, device=batch_points.device)
107 |
108 | valid_unit_normal_vector, valid_plane_constant, valid_fit_success = estimate_plane(valid_batch_knn_points)
109 | unit_normal_vector[valid_fit_mask] = valid_unit_normal_vector
110 | plane_constant[valid_fit_mask] = valid_plane_constant
111 | fit_success[valid_fit_mask] = valid_fit_success
112 |
113 | fit_success &= batch_sdf_abs[:, 3] < self.max_valid_range # 平面拟合失败
114 | surface_mask[head:tail] &= (batch_sdf_abs[:, 0] < self.max_valid_range)
115 | distance = torch.abs(torch.sum(unit_normal_vector * batch_points, dim=1) + plane_constant)
116 | sdf_abs[head:tail][fit_success] = distance[fit_success]
117 | sdf_abs[head:tail][~fit_success] = batch_sdf_abs[:, 0][~fit_success]
118 |
119 | if not self.config.silence:
120 | print(surface_mask.sum().item() / surface_mask.numel())
121 | return sdf_abs, surface_mask
122 |
123 | def estimate_plane(points: torch.Tensor, eta_threshold: float = 0.2, threshold: float = 0.1):
124 | """Estimates planes from a given set of 3D points using Singular Value Decomposition (SVD)"""
125 | def fit_planes(points: torch.Tensor):
126 | """Fits multiple planes using SVD"""
127 | centroid = points.mean(dim=1, keepdim=True)
128 | centered_points = points - centroid
129 | U, S, Vh = torch.linalg.svd(centered_points, full_matrices=False) # Perform SVD.
130 |
131 | # The normal vector of the plane is the last row of Vh (since Vh is the transpose of V).
132 | normals = Vh[:, -1, :]
133 | return normals, centroid.squeeze(1), S
134 |
135 | def is_valid_planes(singular_values: torch.Tensor, eta_threshold: float):
136 | """Determines whether the fitted planes are valid based on the η value."""
137 | lambda_min = singular_values[:, -1] # The smallest singular value.
138 | lambda_mid = singular_values[:, 1] # The middle singular value.
139 |
140 | eta = lambda_min / (lambda_mid + 1e-6)
141 | return eta <= eta_threshold
142 |
143 | m, num_points, _ = points.shape
144 | normal_vector, centroids, singular_values = fit_planes(points) # Fit planes to the input points.
145 |
146 | # Initialize normal vectors with zeros.
147 | unit_normal_vector = torch.zeros((m, 3), dtype=points.dtype, device=points.device)
148 |
149 | valid_mask = is_valid_planes(singular_values, eta_threshold)
150 | normal_vector = normal_vector[valid_mask]
151 | unit_normal_vector[valid_mask] = normal_vector
152 | plane_constant = -1.0 * torch.sum(unit_normal_vector * centroids, dim=1)
153 |
154 | # Compute the distance of each point to its respective plane.
155 | # distances = torch.abs((points @ unit_normal_vector.unsqueeze(-1)).squeeze() + plane_constant.unsqueeze(-1))
156 | distances = torch.abs(torch.bmm(points, unit_normal_vector.unsqueeze(-1)).squeeze() + plane_constant.unsqueeze(-1))
157 | fit_success = torch.max(distances, dim=1).values <= threshold
158 | mask = fit_success & valid_mask
159 |
160 | return unit_normal_vector, plane_constant, mask
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.1
2 | rospkg==1.5.1
3 | catkin_pkg==1.0.0
4 | opencv-python==4.11.0.86
5 | scikit-image==0.25.2
6 | plyfile==1.1
7 | evo==1.28.0
8 | gnupg==2.3.1
9 | laspy==2.5.3
10 | natsort==8.1.0
11 | open3d==0.19.0
12 | pycryptodomex==3.20.0
13 | pypose==0.6.8
14 | pyquaternion==0.9.9
15 | rerun-sdk==0.17.0
16 | rich==12.5.1
17 | roma==1.5.0
18 | rospkg==1.5.1
19 | wandb==0.19.8
--------------------------------------------------------------------------------
/slam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file pin_slam.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 | # Modifications by:
6 | # Junlong Jiang [jiangjunlong@mail.dlut.edu.cn]
7 | # Copyright (c) 2025 Junlong Jiang, all rights reserved.
8 |
9 | import os
10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # by default 0, change it here if you want to use other GPU
11 | import sys
12 | import time
13 |
14 | import numpy as np
15 | import open3d as o3d
16 | import torch
17 | import torch.multiprocessing as mp
18 | import wandb
19 | from rich import print
20 | from tqdm import tqdm
21 |
22 | from gui import slam_gui
23 | from gui.gui_utils import ParamsGUI, VisPacket, ControlPacket, get_latest_queue
24 | from model.decoder import Decoder
25 | from model.local_point_cloud_map import LocalPointCloudMap
26 | from model.neural_points import NeuralPoints
27 | from utils.config import Config
28 | from utils.dataset_indexing import set_dataset_path
29 | from utils.error_state_iekf import IEKFOM
30 | from utils.mapper import Mapper
31 | from utils.mesher import Mesher
32 | from utils.slam_dataset import SLAMDataset
33 | from utils.tools import freeze_model, get_time, save_implicit_map, setup_experiment, split_chunks, remove_gpu_cache, \
34 | create_bbx_o3d
35 |
36 |
37 | def run_slam(config_path=None, dataset_name=None, sequence_name=None, seed=None):
38 | torch.set_num_threads(16) # 设置为16个线程,限制使用的线程数,使用太多的线程会导致电脑卡死
39 | config = Config()
40 | if config_path is not None:
41 | config.load(config_path)
42 | set_dataset_path(config, dataset_name, sequence_name)
43 | if seed is not None:
44 | config.seed = seed
45 | argv = ['slam.py', config_path, dataset_name, sequence_name, str(seed)]
46 | run_path = setup_experiment(config, argv)
47 | else:
48 | if len(sys.argv) > 1:
49 | config.load(sys.argv[1])
50 | else:
51 | sys.exit("Please provide the path to the config file.\nTry: \
52 | python3 slam.py path_to_config.yaml [dataset_name] [sequence_name] [random_seed]")
53 | # specific dataset [optional]
54 | if len(sys.argv) == 3:
55 | set_dataset_path(config, sys.argv[2])
56 | if len(sys.argv) > 3:
57 | set_dataset_path(config, sys.argv[2], sys.argv[3])
58 | if len(sys.argv) > 4: # random seed [optional]
59 | config.seed = int(sys.argv[4])
60 | run_path = setup_experiment(config, sys.argv)
61 | print("⚔️", "[bold green]CLID-SLAM starts[/bold green]")
62 |
63 | if config.o3d_vis_on:
64 | mp.set_start_method("spawn") # don't forget this
65 |
66 | # 初始化MLP解码器
67 | geo_mlp = Decoder(config, config.geo_mlp_hidden_dim, config.geo_mlp_level, 1)
68 | mlp_dict = {"sdf": geo_mlp, "semantic": None, "color": None}
69 |
70 | # 初始化神经点云地图
71 | neural_points = NeuralPoints(config)
72 | local_point_cloud_map = LocalPointCloudMap(config)
73 |
74 | # 初始化数据集
75 | dataset = SLAMDataset(config)
76 |
77 | # 里程计跟踪模块
78 | iekfom = IEKFOM(config, neural_points, geo_mlp)
79 | dataset.tracker = iekfom
80 |
81 | # 建图模块
82 | mapper = Mapper(config, dataset, neural_points, local_point_cloud_map, geo_mlp)
83 |
84 | # 网格重建
85 | mesher = Mesher(config, neural_points, mlp_dict)
86 |
87 | last_frame = dataset.total_pc_count - 1
88 |
89 | # 可视化
90 | q_main2vis = q_vis2main = None
91 | if config.o3d_vis_on:
92 | # communicator between the processes
93 | q_main2vis = mp.Queue()
94 | q_vis2main = mp.Queue()
95 |
96 | params_gui = ParamsGUI(
97 | q_main2vis=q_main2vis,
98 | q_vis2main=q_vis2main,
99 | config=config,
100 | local_map_default_on=config.local_map_default_on,
101 | mesh_default_on=config.mesh_default_on,
102 | sdf_default_on=config.sdf_default_on,
103 | neural_point_map_default_on=config.neural_point_map_default_on,
104 | )
105 | gui_process = mp.Process(target=slam_gui.run, args=(params_gui,))
106 | gui_process.start()
107 | time.sleep(3) # second
108 |
109 | # visualizer configs
110 | vis_visualize_on = True
111 | vis_source_pc_weight = False
112 | vis_global_on = not config.local_map_default_on
113 | vis_mesh_on = config.mesh_default_on
114 | vis_mesh_freq_frame = config.mesh_freq_frame
115 | vis_mesh_mc_res_m = config.mc_res_m
116 | vis_mesh_min_nn = config.mesh_min_nn
117 | vis_sdf_on = config.sdf_default_on
118 | vis_sdf_freq_frame = config.sdfslice_freq_frame
119 | vis_sdf_slice_height = config.sdf_slice_height
120 | vis_sdf_res_m = config.vis_sdf_res_m
121 |
122 | cur_mesh = None
123 | cur_sdf_slice = None
124 |
125 | for frame_id in tqdm(range(dataset.total_pc_count)):
126 |
127 | # I. 加载数据和预处理
128 | T0 = get_time()
129 | dataset.read_frame(frame_id)
130 |
131 | T1 = get_time()
132 | valid_frame = dataset.preprocess_frame()
133 | if not valid_frame:
134 | dataset.processed_frame += 1
135 | continue
136 |
137 | T2 = get_time()
138 |
139 | # II. 里程计定位
140 | if frame_id > 0:
141 | if config.track_on:
142 | cur_pose_torch, valid_flag = iekfom.update_iterated(dataset.cur_source_points)
143 | dataset.lose_track = not valid_flag
144 | dataset.update_odom_pose(cur_pose_torch) # update dataset.cur_pose_torch
145 |
146 | travel_dist = dataset.travel_dist[:frame_id+1]
147 | neural_points.travel_dist = torch.tensor(travel_dist, device=config.device, dtype=config.dtype) # always update this
148 | valid_mapping_flag = (not dataset.lose_track) and (not dataset.stop_status)
149 |
150 | T3 = get_time()
151 | # III: 建图和光束平差优化
152 | # if lose track, we will not update the map and data pool (don't let the wrong pose to corrupt the map)
153 | # if the robot stop, also don't process this frame, since there's no new oberservations
154 | if not dataset.lose_track and valid_mapping_flag:
155 | mapper.process_frame(dataset.cur_point_cloud_torch, dataset.cur_sem_labels_torch,
156 | dataset.cur_pose_torch, frame_id, (config.dynamic_filter_on and frame_id > 0))
157 | else:
158 | mapper.determine_used_pose()
159 | neural_points.reset_local_map(dataset.cur_pose_torch[:3, 3], None, frame_id) # not efficient for large map
160 |
161 | T4 = get_time()
162 |
163 | # for the first frame, we need more iterations to do the initialization (warm-up)
164 | # 计算当前帧建图的迭代轮数
165 | cur_iter_num = config.iters * config.init_iter_ratio if frame_id == 0 else config.iters
166 | if dataset.stop_status:
167 | cur_iter_num = max(1, cur_iter_num - 10)
168 | # 在某一帧后固定解码器的参数
169 | if frame_id == config.freeze_after_frame: # freeze the decoder after certain frame
170 | freeze_model(geo_mlp)
171 |
172 | # mapping with fixed poses (every frame)
173 | if frame_id % config.mapping_freq_frame == 0:
174 | mapper.mapping(cur_iter_num)
175 |
176 | T5 = get_time()
177 |
178 | # regular saving logs
179 | if config.log_freq_frame > 0 and (frame_id + 1) % config.log_freq_frame == 0:
180 | dataset.write_results_log()
181 |
182 | remove_gpu_cache()
183 |
184 | # IV: 网格重建和可视化
185 | if config.o3d_vis_on:
186 |
187 | if not q_vis2main.empty():
188 | control_packet: ControlPacket = get_latest_queue(q_vis2main)
189 |
190 | vis_visualize_on = control_packet.flag_vis
191 | vis_global_on = control_packet.flag_global
192 | vis_mesh_on = control_packet.flag_mesh
193 | vis_sdf_on = control_packet.flag_sdf
194 | vis_source_pc_weight = control_packet.flag_source
195 | vis_mesh_mc_res_m = control_packet.mc_res_m
196 | vis_mesh_min_nn = control_packet.mesh_min_nn
197 | vis_mesh_freq_frame = control_packet.mesh_freq_frame
198 | vis_sdf_slice_height = control_packet.sdf_slice_height
199 | vis_sdf_freq_frame = control_packet.sdf_freq_frame
200 | vis_sdf_res_m = control_packet.sdf_res_m
201 |
202 | while control_packet.flag_pause:
203 | time.sleep(0.1)
204 | if not q_vis2main.empty():
205 | control_packet = get_latest_queue(q_vis2main)
206 | if not control_packet.flag_pause:
207 | break
208 |
209 | if vis_visualize_on:
210 |
211 | dataset.update_o3d_map()
212 | # Only PIN-SLAM has
213 | # if config.track_on and frame_id > 0 and vis_source_pc_weight and (weight_pc_o3d is not None):
214 | # dataset.cur_frame_o3d = weight_pc_o3d
215 |
216 | # T7 = get_time()
217 | T6 = get_time()
218 |
219 | # reconstruction by marching cubes
220 | # Only PIN-SLAM has
221 | # if vis_mesh_on and (frame_id == 0 or frame_id == last_frame or (
222 | # frame_id + 1) % vis_mesh_freq_frame == 0 or pgm.last_loop_idx == frame_id):
223 | if vis_mesh_on and (frame_id == 0 or frame_id == last_frame or (
224 | frame_id + 1) % vis_mesh_freq_frame == 0):
225 | # update map bbx
226 | global_neural_pcd_down = neural_points.get_neural_points_o3d(query_global=True,
227 | random_down_ratio=37) # prime number
228 | dataset.map_bbx = global_neural_pcd_down.get_axis_aligned_bounding_box()
229 |
230 | if not vis_global_on: # only build the local mesh
231 | chunks_aabb = split_chunks(global_neural_pcd_down, dataset.cur_bbx,
232 | vis_mesh_mc_res_m * 100) # reconstruct in chunks
233 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, vis_mesh_mc_res_m, None, True,
234 | config.semantic_on, config.color_on,
235 | filter_isolated_mesh=True,
236 | mesh_min_nn=vis_mesh_min_nn)
237 | else:
238 | aabb = global_neural_pcd_down.get_axis_aligned_bounding_box()
239 | chunks_aabb = split_chunks(global_neural_pcd_down, aabb,
240 | vis_mesh_mc_res_m * 300) # reconstruct in chunks
241 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, vis_mesh_mc_res_m, None, False,
242 | config.semantic_on, config.color_on,
243 | filter_isolated_mesh=True,
244 | mesh_min_nn=vis_mesh_min_nn)
245 | # cur_sdf_slice = None
246 |
247 | if vis_sdf_on and (frame_id == 0 or frame_id == last_frame or (frame_id + 1) % vis_sdf_freq_frame == 0):
248 | sdf_bound = config.surface_sample_range_m * 4.0
249 | vis_sdf_bbx = create_bbx_o3d(dataset.cur_pose_ref[:3, 3], config.max_range / 2)
250 | cur_sdf_slice_h = mesher.generate_bbx_sdf_hor_slice(vis_sdf_bbx, dataset.cur_pose_ref[
251 | 2, 3] + vis_sdf_slice_height, vis_sdf_res_m, True, -sdf_bound,
252 | sdf_bound) # horizontal slice (local)
253 | if config.vis_sdf_slice_v:
254 | cur_sdf_slice_v = mesher.generate_bbx_sdf_ver_slice(dataset.cur_bbx, dataset.cur_pose_ref[0, 3],
255 | vis_sdf_res_m, True, -sdf_bound,
256 | sdf_bound) # vertical slice (local)
257 | cur_sdf_slice = cur_sdf_slice_h + cur_sdf_slice_v
258 | else:
259 | cur_sdf_slice = cur_sdf_slice_h
260 |
261 | pool_pcd = mapper.get_data_pool_o3d(down_rate=37)
262 | odom_poses, gt_poses, pgo_poses = dataset.get_poses_np_for_vis()
263 | loop_edges = None
264 | # Only PIN-SLAM has
265 | # loop_edges = pgm.loop_edges_vis if config.pgo_on else None
266 |
267 | packet_to_vis: VisPacket = VisPacket(frame_id=frame_id, travel_dist=travel_dist[-1])
268 |
269 | if not neural_points.is_empty():
270 | packet_to_vis.add_neural_points_data(neural_points, only_local_map=(not vis_global_on),
271 | pca_color_on=config.decoder_freezed)
272 |
273 | if dataset.cur_frame_o3d is not None:
274 | packet_to_vis.add_scan(np.array(dataset.cur_frame_o3d.points, dtype=np.float64),
275 | np.array(dataset.cur_frame_o3d.colors, dtype=np.float64))
276 |
277 | if cur_mesh is not None:
278 | packet_to_vis.add_mesh(np.array(cur_mesh.vertices, dtype=np.float64), np.array(cur_mesh.triangles),
279 | np.array(cur_mesh.vertex_colors, dtype=np.float64))
280 |
281 | if cur_sdf_slice is not None:
282 | packet_to_vis.add_sdf_slice(np.array(cur_sdf_slice.points, dtype=np.float64),
283 | np.array(cur_sdf_slice.colors, dtype=np.float64))
284 |
285 | if pool_pcd is not None:
286 | packet_to_vis.add_sdf_training_pool(np.array(pool_pcd.points, dtype=np.float64),
287 | np.array(pool_pcd.colors, dtype=np.float64))
288 |
289 | packet_to_vis.add_traj(odom_poses, gt_poses, pgo_poses, loop_edges)
290 |
291 | q_main2vis.put(packet_to_vis)
292 |
293 | T8 = get_time()
294 |
295 | # if not config.silence:
296 | # print("time for o3d update (ms): {:.2f}".format((T7 - T6) * 1e3))
297 | # print("time for visualization (ms): {:.2f}".format((T8 - T7) * 1e3))
298 |
299 | cur_frame_process_time = np.array([T2 - T1, T3 - T2, T4 - T3, T5 - T4, 0])
300 | # cur_frame_process_time = np.array([T2 - T1, T3 - T2, T5 - T4, T6 - T5, T4 - T3]) # loop & pgo in the end, visualization and I/O time excluded
301 | dataset.time_table.append(cur_frame_process_time) # in s
302 |
303 | if config.wandb_vis_on:
304 | wandb_log_content = {'frame': frame_id, 'timing(s)/preprocess': T2 - T1, 'timing(s)/tracking': T3 - T2,
305 | 'timing(s)/pgo': T4 - T3, 'timing(s)/mapping': T6 - T4}
306 | wandb.log(wandb_log_content)
307 |
308 | dataset.processed_frame += 1
309 |
310 | # V. 保存结果
311 | mapper.free_pool()
312 | pose_eval_results = dataset.write_results()
313 |
314 | neural_points.prune_map(config.max_prune_certainty, 0, True) # prune uncertain points for the final output
315 | neural_points.recreate_hash(None, None, False, False) # merge the final neural point map
316 | neural_pcd = neural_points.get_neural_points_o3d(query_global=True, color_mode=0)
317 | if config.save_map:
318 | o3d.io.write_point_cloud(os.path.join(run_path, "map", "neural_points.ply"), neural_pcd) # write the neural point cloud
319 | neural_points.clear_temp() # clear temp data for output
320 |
321 | output_mc_res_m = config.mc_res_m*0.6
322 | mc_cm_str = str(round(output_mc_res_m*1e2))
323 | if config.save_mesh:
324 | chunks_aabb = split_chunks(neural_pcd, neural_pcd.get_axis_aligned_bounding_box(), output_mc_res_m * 300) # reconstruct in chunks
325 | mesh_path = os.path.join(run_path, "mesh", "mesh_" + mc_cm_str + "cm.ply")
326 | print("Reconstructing the global mesh with resolution {} cm".format(mc_cm_str))
327 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, output_mc_res_m, mesh_path, False, config.semantic_on, config.color_on, filter_isolated_mesh=True, mesh_min_nn=config.mesh_min_nn)
328 | print("Reconstructing the global mesh done")
329 | neural_points.clear_temp() # clear temp data for output
330 | if config.save_map:
331 | save_implicit_map(run_path, neural_points, mlp_dict)
332 | # lcd_npmc.save_context_dict(mapper.used_poses, run_path)
333 | print("Use 'python vis_pin_map.py {} -m {} -o mesh_out_{}cm.ply' to inspect the map offline.".format(run_path, output_mc_res_m, mc_cm_str))
334 |
335 | if config.save_merged_pc:
336 | dataset.write_merged_point_cloud() # replay: save merged point cloud map
337 |
338 | remove_gpu_cache()
339 |
340 | if config.o3d_vis_on:
341 |
342 | while True:
343 | if not q_vis2main.empty():
344 | q_vis2main.get()
345 |
346 | packet_to_vis: VisPacket = VisPacket(frame_id=frame_id, travel_dist=travel_dist[-1], slam_finished=True)
347 |
348 | if not neural_points.is_empty():
349 | packet_to_vis.add_neural_points_data(neural_points, only_local_map=False,
350 | pca_color_on=config.decoder_freezed)
351 |
352 | if cur_mesh is not None:
353 | packet_to_vis.add_mesh(np.array(cur_mesh.vertices, dtype=np.float64), np.array(cur_mesh.triangles),
354 | np.array(cur_mesh.vertex_colors, dtype=np.float64))
355 | cur_mesh = None
356 |
357 | packet_to_vis.add_traj(odom_poses, gt_poses, pgo_poses, loop_edges)
358 |
359 | q_main2vis.put(packet_to_vis)
360 | time.sleep(1.0)
361 |
362 | return pose_eval_results
363 |
364 |
365 | if __name__ == "__main__":
366 | run_slam()
367 |
--------------------------------------------------------------------------------
/tools.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "metadata": {},
5 | "cell_type": "code",
6 | "outputs": [],
7 | "execution_count": null,
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "import csv\n",
12 | "import numpy as np\n",
13 | "import open3d as o3d\n",
14 | "from rosbag import Bag"
15 | ],
16 | "id": "2ea90e0351ac99dd"
17 | },
18 | {
19 | "metadata": {},
20 | "cell_type": "markdown",
21 | "source": "# Merge Multiple ROSbag Files",
22 | "id": "7bc6c26b767ca971"
23 | },
24 | {
25 | "metadata": {},
26 | "cell_type": "code",
27 | "outputs": [],
28 | "execution_count": null,
29 | "source": [
30 | "def merge():\n",
31 | " input_folder = \"xxxxxx\"\n",
32 | " output_folder = \"xxxxxxx\"\n",
33 | " outbag_name = \"xxxxx\"\n",
34 | " input_bags = os.listdir(input_folder)\n",
35 | " input_bags.sort() # 根据文件名进行排序\n",
36 | " print(\"Writing bag file: \" + outbag_name)\n",
37 | "\n",
38 | " with Bag(os.path.join(output_folder, outbag_name), 'w') as ob:\n",
39 | " for ifile in input_bags:\n",
40 | " print(\"> Reading bag file: \" + ifile)\n",
41 | " with Bag(os.path.join(input_folder, ifile), 'r') as ib:\n",
42 | " for topic, msg, t in ib:\n",
43 | " ob.write(topic, msg, t)\n",
44 | "merge()"
45 | ],
46 | "id": "7de5937f52f49bc"
47 | },
48 | {
49 | "metadata": {},
50 | "cell_type": "markdown",
51 | "source": [
52 | "\n",
53 | "# Convert Pose Format"
54 | ],
55 | "id": "dbb04396689dba20"
56 | },
57 | {
58 | "metadata": {},
59 | "cell_type": "code",
60 | "outputs": [],
61 | "execution_count": null,
62 | "source": [
63 | "# 该文件的作用是将csv格式转换为标准的tum位姿格式\n",
64 | "\n",
65 | "input_file = './dataset/SubT_MRS/SubT_MRS_Urban_Challenge_UGV2/poses/ground_truth_path.csv'\n",
66 | "output_file = './dataset/SubT_MRS/SubT_MRS_Urban_Challenge_UGV2/poses/gt_poses_tum.txt'\n",
67 | "\n",
68 | "with open(input_file, 'r') as file:\n",
69 | " reader = csv.reader(file)\n",
70 | " header = next(reader) # Skip the header\n",
71 | " with open(output_file, 'w') as outfile:\n",
72 | " for row in reader:\n",
73 | " nsec, x, y, z, qx, qy, qz, qw = map(float, row)\n",
74 | " sec = nsec * 1e-9 # Convert nanoseconds to seconds\n",
75 | " output_line = f\"{sec} {x} {y} {z} {qx} {qy} {qz} {qw}\\n\"\n",
76 | " outfile.write(output_line)\n",
77 | "\n",
78 | "print(\"Conversion completed, file saved as\", output_file)\n"
79 | ],
80 | "id": "38bfd4ef6b5b3eea"
81 | },
82 | {
83 | "metadata": {},
84 | "cell_type": "markdown",
85 | "source": [
86 | "\n",
87 | "# Mapping Performance Evaluation\n"
88 | ],
89 | "id": "d7bda514d936edca"
90 | },
91 | {
92 | "metadata": {},
93 | "cell_type": "code",
94 | "outputs": [],
95 | "execution_count": null,
96 | "source": [
97 | "def quaternion_to_rotation_matrix(qx: float, qy: float, qz: float, qw: float) -> np.ndarray:\n",
98 | " \"\"\"\n",
99 | " Converts a quaternion into a 3x3 rotation matrix.\n",
100 | "\n",
101 | " Parameters:\n",
102 | " - qx (float): X component of the quaternion.\n",
103 | " - qy (float): Y component of the quaternion.\n",
104 | " - qz (float): Z component of the quaternion.\n",
105 | " - qw (float): W (scalar) component of the quaternion.\n",
106 | "\n",
107 | " Returns:\n",
108 | " - np.ndarray: A 3x3 NumPy array representing the rotation matrix.\n",
109 | " \"\"\"\n",
110 | " # Normalize the quaternion to ensure a valid rotation\n",
111 | " norm = np.sqrt(qx**2 + qy**2 + qz**2 + qw**2)\n",
112 | " qx /= norm\n",
113 | " qy /= norm\n",
114 | " qz /= norm\n",
115 | " qw /= norm\n",
116 | "\n",
117 | " # Compute the rotation matrix using the normalized quaternion\n",
118 | " rotation_matrix = np.array([\n",
119 | " [1 - 2 * (qy**2 + qz**2), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],\n",
120 | " [2 * (qx * qy + qz * qw), 1 - 2 * (qx**2 + qz**2), 2 * (qy * qz - qx * qw)],\n",
121 | " [2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx**2 + qy**2)]\n",
122 | " ])\n",
123 | "\n",
124 | " return rotation_matrix"
125 | ],
126 | "id": "5431eea935e8ad5e"
127 | },
128 | {
129 | "metadata": {},
130 | "cell_type": "markdown",
131 | "source": "## Newer College Dataset Ncd Sequence",
132 | "id": "6d77fde250a16151"
133 | },
134 | {
135 | "metadata": {},
136 | "cell_type": "code",
137 | "outputs": [],
138 | "execution_count": null,
139 | "source": [
140 | "# ImMesh\n",
141 | "matrix_values = [\n",
142 | " 5.304626993818075675e-01, -8.474622417882305969e-01, 2.042261633276983360e-02, -8.377865843928848644e-02,\n",
143 | " 8.463450710843981595e-01, 5.308216667107832354e-01, 4.391332211528171242e-02, 3.370663104058911230e+00,\n",
144 | " -4.805564502133059107e-02, -6.009799083416286596e-03, 9.988265833638158009e-01, 7.037440120229881968e-01\n",
145 | "]\n",
146 | "T_lidar = np.vstack([np.array(matrix_values).reshape(3, 4), [0, 0, 0, 1]])\n",
147 | "\n",
148 | "T_lidar_imu = np.array([[-1.0, 0, 0, -0.006253],\n",
149 | " [0, -1.0, 0, 0.011775],\n",
150 | " [0, 0, 1.0, -0.028535],\n",
151 | " [0, 0, 0, 1]])"
152 | ],
153 | "id": "fb8b78b5c4feb986"
154 | },
155 | {
156 | "metadata": {},
157 | "cell_type": "markdown",
158 | "source": "## Newer College Dataset Extension Math Easy Sequence",
159 | "id": "c864e629b6013c7c"
160 | },
161 | {
162 | "metadata": {},
163 | "cell_type": "code",
164 | "outputs": [],
165 | "execution_count": null,
166 | "source": [
167 | "T_lidar = np.eye(4)\n",
168 | "\n",
169 | "# SLAMesh\n",
170 | "T_lidar[:3, :3] = quaternion_to_rotation_matrix(-0.00987445, 0.00774057, 0.842868, 0.537974)\n",
171 | "T_lidar[:3, 3] = np.array([-23.7176, -31.2646, 1.03258])\n",
172 | "\n",
173 | "# ImMesh\n",
174 | "# T_lidar[:3, :3] = quaternion_to_rotation_matrix(-0.00248205, 0.00444627, 0.842838, 0.538143)\n",
175 | "# T_lidar[:3, 3] = np.array([-23.7202, -31.2861, 1.04326])\n",
176 | "\n",
177 | "T_lidar_imu = np.array([\n",
178 | " [1.0, 0, 0, 0.014],\n",
179 | " [0, 1.0, 0, -0.012],\n",
180 | " [0, 0, 1.0, -0.015],\n",
181 | " [0, 0, 0, 1.0]])\n",
182 | "\n",
183 | "T = T_lidar @ T_lidar_imu"
184 | ],
185 | "id": "e9a22013e913de27"
186 | },
187 | {
188 | "metadata": {},
189 | "cell_type": "markdown",
190 | "source": "## Transform Mesh To Align The Ground Truth",
191 | "id": "61456f82338bb7e0"
192 | },
193 | {
194 | "metadata": {},
195 | "cell_type": "code",
196 | "outputs": [],
197 | "execution_count": null,
198 | "source": [
199 | "def mesh_transform():\n",
200 | " mesh_file = \"./math_easy.ply\"\n",
201 | " output_file = \"./math_easy_transformed.ply\"\n",
202 | "\n",
203 | " if not os.path.exists(mesh_file):\n",
204 | " sys.exit(f\"Mesh file {mesh_file} does not exist.\")\n",
205 | " print(\"Loading Mesh file: \", mesh_file)\n",
206 | "\n",
207 | " # Load Mesh\n",
208 | " mesh = o3d.io.read_triangle_mesh(mesh_file)\n",
209 | " mesh.compute_vertex_normals()\n",
210 | " transformation_matrix = np.array([[-4.20791735e-01, -9.07157072e-01, 6.01526210e-04, -2.37152142e+01],\n",
211 | " [ 9.07112929e-01, -4.20764518e-01, 1.01663692e-02, -3.12685037e+01],\n",
212 | " [-8.96939285e-03, 4.82357635e-03, 9.99948140e-01, 1.02807732e+00],\n",
213 | " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]])\n",
214 | "\n",
215 | " mesh.transform(transformation_matrix)\n",
216 | " o3d.io.write_triangle_mesh(output_file, mesh)\n",
217 | "\n",
218 | "mesh_transform()"
219 | ],
220 | "id": "532df31dd565ed8b"
221 | },
222 | {
223 | "metadata": {},
224 | "cell_type": "markdown",
225 | "source": "# Visualize Mesh\n",
226 | "id": "52c9c86e3ad5e34f"
227 | },
228 | {
229 | "metadata": {},
230 | "cell_type": "code",
231 | "outputs": [],
232 | "execution_count": null,
233 | "source": [
234 | "def vis_mesh():\n",
235 | " mesh_file = \"./ours_mesh_20cm.ply\" # the path of the mesh which you want to visualize\n",
236 | " if not os.path.exists(mesh_file):\n",
237 | " sys.exit(f\"Mesh file {mesh_file} does not exist.\")\n",
238 | " print(\"Loading Mesh file: \", mesh_file)\n",
239 | "\n",
240 | " # Load the mesh\n",
241 | " mesh = o3d.io.read_triangle_mesh(mesh_file)\n",
242 | " mesh.compute_vertex_normals()\n",
243 | " # Check if the mesh was loaded successfully\n",
244 | " if not mesh.has_vertices():\n",
245 | " sys.exit(\"Failed to load the mesh. No vertices found.\")\n",
246 | " vis = o3d.visualization.Visualizer()\n",
247 | " vis.create_window(window_name=\"Mesh Visualization\")\n",
248 | " vis.add_geometry(mesh)\n",
249 | " opt = vis.get_render_option()\n",
250 | " opt.light_on = True # Enable lighting to show the mesh color\n",
251 | " opt.mesh_show_back_face = True\n",
252 | " # Enable shortcuts in the console (e.g., Ctrl+9)\n",
253 | " vis.run()\n",
254 | " vis.destroy_window()\n",
255 | "\n",
256 | "vis_mesh()"
257 | ],
258 | "id": "abe2bf68d78b341e"
259 | }
260 | ],
261 | "metadata": {
262 | "kernelspec": {
263 | "display_name": "Python 3",
264 | "language": "python",
265 | "name": "python3"
266 | },
267 | "language_info": {
268 | "codemirror_mode": {
269 | "name": "ipython",
270 | "version": 2
271 | },
272 | "file_extension": ".py",
273 | "mimetype": "text/x-python",
274 | "name": "python",
275 | "nbconvert_exporter": "python",
276 | "pygments_lexer": "ipython2",
277 | "version": "2.7.6"
278 | }
279 | },
280 | "nbformat": 4,
281 | "nbformat_minor": 5
282 | }
283 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/config.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/data_sampler.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/data_sampler.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_indexing.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/dataset_indexing.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/error_state_iekf.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/error_state_iekf.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/eval_traj_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/eval_traj_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/loss.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/loss.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/mapper.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/mapper.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/mesher.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/mesher.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/semantic_kitti_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/semantic_kitti_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/slam_dataset.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/slam_dataset.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/so3_math.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/so3_math.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/tools.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DUTRobot/CLID-SLAM/54427e27e316f7ce1c4042c3d0fc287cdb65d1e9/utils/__pycache__/tools.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/data_sampler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file data_sampler.py
3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn]
4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved
5 |
6 | import torch
7 |
8 | from model.local_point_cloud_map import LocalPointCloudMap
9 | from utils.config import Config
10 | from utils.tools import transform_torch
11 |
12 |
13 | class DataSampler:
14 | def __init__(self, config: Config):
15 | self.config = config
16 | self.dev = config.device
17 |
18 | def sample(self, points_torch, local_point_cloud_map: LocalPointCloudMap, cur_pose_torch):
19 | dev = self.dev
20 | surface_sample_range = self.config.surface_sample_range_m
21 | surface_sample_n = self.config.surface_sample_n
22 | freespace_behind_sample_n = self.config.free_behind_n
23 | freespace_front_sample_n = self.config.free_front_n
24 | all_sample_n = (
25 | surface_sample_n + freespace_behind_sample_n + freespace_front_sample_n + 1
26 | ) # 1 as the exact measurement
27 | free_front_min_ratio = self.config.free_sample_begin_ratio
28 | free_sample_end_dist = self.config.free_sample_end_dist_m
29 |
30 | # get sample points
31 | point_num = points_torch.shape[0]
32 | distances = torch.linalg.norm(
33 | points_torch, dim=1, keepdim=True
34 | ) # ray distances (scaled)
35 |
36 | # Part 0. the exact measured point
37 | measured_sample_displacement = torch.zeros_like(distances)
38 | measured_sample_dist_ratio = torch.ones_like(distances)
39 |
40 | # Part 1. close-to-surface uniform sampling
41 | # uniform sample in the close-to-surface range (+- range)
42 | surface_sample_displacement = (
43 | torch.randn(point_num * surface_sample_n, 1, device=dev)
44 | * surface_sample_range
45 | )
46 |
47 | repeated_dist = distances.repeat(surface_sample_n, 1)
48 | surface_sample_dist_ratio = (
49 | surface_sample_displacement / repeated_dist + 1.0
50 | ) # 1.0 means on the surface
51 |
52 |
53 | # Part 2. free space (in front of surface) uniform sampling
54 | # if you want to reconstruct the thin objects (like poles, tree branches) well, you need more freespace samples to have
55 | # a space carving effect
56 |
57 | sigma_ratio = 2.0
58 | repeated_dist = distances.repeat(freespace_front_sample_n, 1)
59 | free_max_ratio = 1.0 - sigma_ratio * surface_sample_range / repeated_dist
60 | free_diff_ratio = free_max_ratio - free_front_min_ratio
61 | free_sample_front_dist_ratio = (
62 | torch.rand(point_num * freespace_front_sample_n, 1, device=dev)
63 | * free_diff_ratio
64 | + free_front_min_ratio
65 | )
66 | free_sample_front_displacement = (
67 | free_sample_front_dist_ratio - 1.0
68 | ) * repeated_dist
69 |
70 | # Part 3. free space (behind surface) uniform sampling
71 | repeated_dist = distances.repeat(freespace_behind_sample_n, 1)
72 | free_max_ratio = free_sample_end_dist / repeated_dist + 1.0
73 | free_behind_min_ratio = 1.0 + sigma_ratio * surface_sample_range / repeated_dist
74 | free_diff_ratio = free_max_ratio - free_behind_min_ratio
75 |
76 | free_sample_behind_dist_ratio = (
77 | torch.rand(point_num * freespace_behind_sample_n, 1, device=dev)
78 | * free_diff_ratio
79 | + free_behind_min_ratio
80 | )
81 |
82 | free_sample_behind_displacement = (
83 | free_sample_behind_dist_ratio - 1.0
84 | ) * repeated_dist
85 |
86 |
87 | # all together
88 | all_sample_displacement = torch.cat(
89 | (
90 | measured_sample_displacement,
91 | surface_sample_displacement,
92 | free_sample_front_displacement,
93 | free_sample_behind_displacement,
94 | ),
95 | 0,
96 | )
97 | all_sample_dist_ratio = torch.cat(
98 | (
99 | measured_sample_dist_ratio,
100 | surface_sample_dist_ratio,
101 | free_sample_front_dist_ratio,
102 | free_sample_behind_dist_ratio,
103 | ),
104 | 0,
105 | )
106 |
107 | repeated_points = points_torch.repeat(all_sample_n, 1)
108 | repeated_dist = distances.repeat(all_sample_n, 1)
109 | all_sample_points = repeated_points * all_sample_dist_ratio
110 | ####################################### Added By Jiang Junlong #################################################
111 | # 根据表面采样平移量计算符号
112 | sdf_sign = torch.where(surface_sample_displacement.squeeze(1) < 0, 1, -1)
113 | mask = torch.ones(point_num * all_sample_n, dtype=torch.bool, device=self.config.device)
114 |
115 | sdf_label_tensor = -1 * all_sample_displacement.squeeze(1)
116 | surface_sample_count = point_num * (surface_sample_n + 1)
117 | surface_sample_points = all_sample_points[point_num: surface_sample_count]
118 | surface_sample_points_G = transform_torch(surface_sample_points, cur_pose_torch)
119 | dist, valid_mask = local_point_cloud_map.region_specific_sdf_estimation(surface_sample_points_G)
120 | mask[point_num:surface_sample_count] = valid_mask
121 | sdf_label_tensor[point_num:surface_sample_count] = sdf_sign * dist
122 | # sdf_label = torch.clamp(sdf_label, -0.4, 0.4)
123 |
124 | # depth tensor of all the samples
125 | depths_tensor = repeated_dist * all_sample_dist_ratio
126 | # get the weight vector as the inverse of sigma
127 | weight_tensor = torch.ones_like(depths_tensor)
128 | if self.config.dist_weight_on: # far away surface samples would have lower weight
129 | weight_tensor[:surface_sample_count] = (
130 | 1
131 | + self.config.dist_weight_scale * 0.5
132 | - (repeated_dist[:surface_sample_count] / self.config.max_range)
133 | * self.config.dist_weight_scale
134 | ) # [0.6, 1.4]
135 |
136 | weight_tensor[surface_sample_count:] *= -1.0
137 |
138 | # Convert from the all ray surface + all ray free order to the ray-wise (surface + free) order
139 |
140 | all_sample_points = (
141 | all_sample_points.reshape(all_sample_n, -1, 3)
142 | .transpose(0, 1)
143 | .reshape(-1, 3)
144 | )
145 | sdf_label_tensor = (
146 | sdf_label_tensor.reshape(all_sample_n, -1).transpose(0, 1).reshape(-1)
147 | )
148 |
149 | weight_tensor = (
150 | weight_tensor.reshape(all_sample_n, -1).transpose(0, 1).reshape(-1)
151 | )
152 |
153 | mask = (
154 | mask.reshape(all_sample_n, -1).transpose(0, 1).reshape(-1)
155 | )
156 | return all_sample_points[mask], sdf_label_tensor[mask], weight_tensor[mask]
157 |
--------------------------------------------------------------------------------
/utils/dataset_indexing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file dataset_indexing.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 | import os
7 |
8 | from utils.config import Config
9 |
10 | def set_dataset_path(config: Config, dataset_name: str = '', seq: str = ''):
11 |
12 | config.name = config.name + '_' + dataset_name + '_' + seq.replace("/", "")
13 |
14 | if config.use_kiss_dataloader:
15 | config.data_loader_name = dataset_name
16 | config.data_loader_seq = seq
17 | print('KISS-ICP data loaders used')
18 | from kiss_icp.datasets import available_dataloaders
19 | print('Available dataloaders:', available_dataloaders())
20 |
21 | else:
22 | if dataset_name == "kitti":
23 | base_path = config.pc_path.rsplit('/', 3)[0]
24 | config.pc_path = os.path.join(base_path, 'sequences', seq, "velodyne") # input point cloud folder
25 | pose_file_name = seq + '.txt'
26 | config.pose_path = os.path.join(base_path, 'poses', pose_file_name) # input pose file
27 | config.calib_path = os.path.join(base_path, 'sequences', seq, "calib.txt") # input calib file (to sensor frame)
28 | config.label_path = os.path.join(base_path, 'sequences', seq, "labels") # input point-wise label path, for semantic mapping (optional)
29 | config.kitti_correction_on = True
30 | config.correction_deg = 0.195
31 | elif dataset_name == "mulran":
32 | config.name = config.name + "_mulran_" + seq
33 | base_path = config.pc_path.rsplit('/', 2)[0]
34 | config.pc_path = os.path.join(base_path, seq, "Ouster") # input point cloud folder
35 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
36 | elif dataset_name == "kitti_carla":
37 | config.name = config.name + "_kitti_carla_" + seq
38 | base_path = config.pc_path.rsplit('/', 3)[0]
39 | config.pc_path = os.path.join(base_path, seq, "generated", "frames") # input point cloud folder
40 | config.pose_path = os.path.join(base_path, seq, "generated", "poses.txt") # input pose file
41 | config.calib_path = os.path.join(base_path, seq, "generated", "calib.txt") # input calib file (to sensor frame)
42 | elif dataset_name == "ncd":
43 | config.name = config.name + "_ncd_" + seq
44 | base_path = config.pc_path.rsplit('/', 2)[0]
45 | config.pc_path = os.path.join(base_path, seq, "bin") # input point cloud folder
46 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
47 | config.calib_path = os.path.join(base_path, seq, "calib.txt") # input calib file (to sensor frame)
48 | elif dataset_name == "ncd128":
49 | config.name = config.name + "_ncd128_" + seq
50 | base_path = config.pc_path.rsplit('/', 2)[0]
51 | config.pc_path = os.path.join(base_path, seq, "ply") # input point cloud folder
52 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
53 | elif dataset_name == "ipbcar":
54 | config.name = config.name + "_ipbcar_" + seq
55 | base_path = config.pc_path.rsplit('/', 2)[0]
56 | config.pc_path = os.path.join(base_path, seq, "ouster") # input point cloud folder
57 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
58 | config.calib_path = os.path.join(base_path, seq, "calib.txt") # input calib file (to sensor frame)
59 | elif dataset_name == "hilti":
60 | config.name = config.name + "_hilti_" + seq
61 | base_path = config.pc_path.rsplit('/', 2)[0]
62 | config.pc_path = os.path.join(base_path, seq, "ply") # input point cloud folder
63 | # config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
64 | elif dataset_name == "m2dgr":
65 | config.name = config.name + "_m2dgr_" + seq
66 | base_path = config.pc_path.rsplit('/', 2)[0]
67 | config.pc_path = os.path.join(base_path, seq, "points") # input point cloud folder
68 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
69 | elif dataset_name == "replica":
70 | config.name = config.name + "_replica_" + seq
71 | base_path = config.pc_path.rsplit('/', 2)[0]
72 | config.pc_path = os.path.join(base_path, seq, "rgbd_down_ply") # input point cloud folder
73 | #config.pc_path = os.path.join(base_path, seq, "rgbd_ply") # input point cloud folder
74 | config.pose_path = os.path.join(base_path, seq, "poses.txt") # input pose file
75 |
--------------------------------------------------------------------------------
/utils/error_state_iekf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file error_state_iekf.py
3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn]
4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved
5 |
6 | import math
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from model.decoder import Decoder
12 | from model.neural_points import NeuralPoints
13 | from utils.config import Config
14 | from utils.so3_math import vec2skew, so3Exp, SO3Log, vectors_to_skew_symmetric
15 | from utils.tools import get_gradient, transform_torch
16 |
17 | G_m_s2 = 9.81 # 定义全局重力加速度
18 |
19 |
20 | class StateIkfom:
21 | """18维的状态量x定义: 对应顺序为旋转(3), 位置(3), 速度(3), 角速度偏置(3), 加速度偏置(3), 重力向量(3)"""
22 | def __init__(
23 | self,
24 | dtype,
25 | pos=None,
26 | rot=None,
27 | vel=None, bg=None,
28 | ba=None,
29 | grav=None
30 | ):
31 | self.dtype = dtype
32 | self.rot = torch.eye(3, dtype=self.dtype) if rot is None else rot
33 | self.pos = torch.zeros(3, dtype=self.dtype) if pos is None else pos
34 | self.vel = torch.zeros(3, dtype=self.dtype) if vel is None else vel
35 | self.bg = torch.zeros(3, dtype=self.dtype) if bg is None else bg
36 | self.ba = torch.zeros(3, dtype=self.dtype) if ba is None else ba
37 | self.grav = torch.tensor([0.0, 0.0, -G_m_s2], dtype=self.dtype) if grav is None else grav
38 |
39 | def cpu(self):
40 | """将所有张量转移到CPU"""
41 | self.rot = self.rot.cpu()
42 | self.pos = self.pos.cpu()
43 | self.vel = self.vel.cpu()
44 | self.bg = self.bg.cpu()
45 | self.ba = self.ba.cpu()
46 | self.grav = self.grav.cpu()
47 |
48 | def cuda(self):
49 | """将所有张量转移到GPU"""
50 | self.rot = self.rot.cuda()
51 | self.pos = self.pos.cuda()
52 | self.vel = self.vel.cuda()
53 | self.bg = self.bg.cuda()
54 | self.ba = self.ba.cuda()
55 | self.grav = self.grav.cuda()
56 |
57 |
58 | class InputIkfom:
59 | """输入向量类定义,用于表示陀螺仪和加速度计的测量值。"""
60 |
61 | def __init__(
62 | self,
63 | dtype,
64 | acc=np.array,
65 | gyro=np.array
66 | ):
67 | self.dtype = dtype
68 | self.acc = torch.tensor(acc, dtype=self.dtype)
69 | self.gyro = torch.tensor(gyro, dtype=self.dtype)
70 |
71 |
72 | def boxplus(state: StateIkfom, delta: torch.tensor):
73 | """广义加法操作"""
74 | new_state = StateIkfom(state.dtype)
75 | new_state.rot = state.rot @ so3Exp(delta[0:3])
76 | new_state.pos = state.pos + delta[3:6]
77 | new_state.vel = state.vel + delta[6:9]
78 | new_state.bg = state.bg + delta[9:12]
79 | new_state.ba = state.ba + delta[12:15]
80 | new_state.grav = state.grav + delta[15:18]
81 | return new_state
82 |
83 |
84 | def boxminus(x1: StateIkfom, x2: StateIkfom):
85 | """广义减法操作,计算两个状态之间的差"""
86 | delta_rot = SO3Log(x2.rot.T @ x1.rot)
87 | delta_pos = x1.pos - x2.pos
88 | delta_vel = x1.vel - x2.vel
89 | delta_bg = x1.bg - x2.bg
90 | delta_ba = x1.ba - x2.ba
91 | delta_grav = x1.grav - x2.grav
92 | delta = torch.concatenate([delta_rot, delta_pos, delta_vel, delta_bg, delta_ba, delta_grav])
93 | return delta
94 |
95 |
96 | class IEKFOM:
97 | """迭代扩展卡尔曼滤波器类"""
98 |
99 | def __init__(
100 | self,
101 | config: Config,
102 | neural_points: NeuralPoints,
103 | geo_decoder: Decoder,
104 | ):
105 | self.config = config
106 | self.silence = config.silence
107 | self.neural_points = neural_points
108 | self.geo_decoder = geo_decoder
109 | self.device = self.config.device
110 | self.dtype = config.dtype
111 | self.tran_dtype = config.tran_dtype
112 |
113 | self.x = StateIkfom(self.tran_dtype) # 初始化状态
114 | self.P = torch.eye(18, dtype=self.tran_dtype) # 初始化状态协方差矩阵
115 | self.P[9:12, 9:12] = self.P[9:12, 9:12] * 1e-4 # 初始陀螺仪偏置协方差
116 | self.P[12:15, 12:15] = self.P[12:15, 12:15] * 1e-3 # 初始加速度计协方差
117 | self.P[15:18, 15:18] = self.P[15:18, 15:18] * 1e-4 # 初始重力协方差
118 | self.Q = self.process_noise_covariance() # 前向传播白噪声协方差
119 | self.R_inv = None # 测量噪声协方差
120 | self.eps = 0.001 # 收敛阈值
121 | self.max_iteration = self.config.reg_iter_n # 最大迭代轮数
122 |
123 | def process_noise_covariance(self):
124 | """噪声协方差Q的初始化"""
125 | Q = torch.zeros((12, 12), dtype=self.config.tran_dtype)
126 | Q[:3, :3] = self.config.measurement_noise_covariance * torch.eye(3)
127 | Q[3:6, 3:6] = self.config.measurement_noise_covariance * torch.eye(3)
128 | Q[6:9, 6:9] = self.config.bias_noise_covariance * torch.eye(3)
129 | Q[9:12, 9:12] = self.config.bias_noise_covariance * torch.eye(3)
130 | return Q
131 |
132 | def df_dx(self, s: StateIkfom, in_: InputIkfom, dt: float):
133 | """计算状态转移函数的雅可比矩阵"""
134 | # omega_ = in_.gyro - s.bg
135 | acc_ = in_.acc - s.ba
136 |
137 | df_dx = torch.eye(18, dtype=self.tran_dtype)
138 | I_dt = torch.eye(3, dtype=self.tran_dtype) * dt
139 | # cov[0:3, 0:3] = so3Exp(-omega_ * dt)
140 | df_dx[0:3, 0:3] = torch.eye(3, dtype=self.tran_dtype) # so3Exp(-omega_ * dt) 可以近似为I
141 | df_dx[0:3, 9:12] = -I_dt
142 | df_dx[3:6, 6:9] = I_dt
143 | df_dx[6:9, 0:3] = -s.rot @ vec2skew(acc_) * dt
144 | df_dx[6:9, 12:15] = -s.rot * dt
145 | df_dx[6:9, 15:18] = I_dt
146 |
147 | return df_dx
148 |
149 | def df_dw(self, s: StateIkfom, in_: InputIkfom, dt: float):
150 | """计算过程噪声的雅可比矩阵"""
151 | # omega_ = in_.gyro - s.bg
152 | I = torch.eye(3, dtype=self.tran_dtype)
153 | cov = torch.zeros((18, 12), dtype=self.tran_dtype)
154 | # cov[0:3, 0:3] = -A_T(omega_ * dt)
155 | cov[0:3, 0:3] = -I # -A(w dt)可以简化为-I
156 | cov[6:9, 3:6] = -s.rot # -R
157 | cov[9:12, 6:9] = I # I
158 | cov[12:15, 9:12] = I # I
159 | cov = cov * dt
160 |
161 | return cov
162 |
163 | def predict(self, i_in: InputIkfom, dt: float):
164 | """前向传播,在cpu上执行前向传播效率高得多"""
165 | f = self.f_model(self.x, i_in)
166 | df_dx = self.df_dx(self.x, i_in, dt)
167 | df_dw = self.df_dw(self.x, i_in, dt)
168 |
169 | self.x = boxplus(self.x, f * dt)
170 | self.P = df_dx @ self.P @ df_dx.T + df_dw @ self.Q @ df_dw.T
171 |
172 | def f_model(self, s: StateIkfom, in_: InputIkfom):
173 | """获取运动方程,用于描述状态如何随时间演变"""
174 | res = torch.zeros(18, dtype=self.tran_dtype)
175 | a_inertial = s.rot @ (in_.acc - s.ba) + s.grav
176 | res[:3] = in_.gyro - s.bg
177 | res[3:6] = s.vel
178 | res[6:9] = a_inertial
179 | return res
180 |
181 | def h_model(self, point_cloud_imu: torch.tensor):
182 | bs = self.config.infer_bs
183 | mask_min_nn_count = self.config.track_mask_query_nn_k
184 | min_grad_norm = self.config.reg_min_grad_norm
185 | max_grad_norm = self.config.reg_max_grad_norm
186 |
187 | pose = torch.eye(4)
188 | pose[:3, :3] = self.x.rot
189 | pose[:3, 3] = self.x.pos
190 |
191 | point_cloud_global = transform_torch(point_cloud_imu, pose)
192 | sample_count = point_cloud_global.shape[0]
193 | iter_n = math.ceil(sample_count / bs)
194 |
195 | sdf_pred = torch.zeros(sample_count, device=point_cloud_global.device)
196 | sdf_std = torch.zeros(sample_count, device=point_cloud_global.device)
197 | mc_mask = torch.zeros(sample_count, device=point_cloud_global.device, dtype=torch.bool)
198 | sdf_grad = torch.zeros((sample_count, 3), device=point_cloud_global.device)
199 | certainty = torch.zeros(sample_count, device=point_cloud_global.device)
200 |
201 | # 为了避免爆显存,采用分批处理的办法
202 | # 计算点云的SDF预测值
203 | for n in range(iter_n):
204 | head = n * bs
205 | tail = min((n + 1) * bs, sample_count)
206 | batch_coord = point_cloud_global[head:tail, :]
207 | batch_coord.requires_grad_(True)
208 |
209 | (
210 | batch_geo_feature,
211 | _,
212 | weight_knn,
213 | nn_count,
214 | batch_certainty,
215 | ) = self.neural_points.query_feature(
216 | batch_coord,
217 | training_mode=False,
218 | query_locally=True,
219 | query_color_feature=False,
220 | ) # inference mode
221 |
222 | batch_sdf = self.geo_decoder.sdf(batch_geo_feature)
223 | if not self.config.weighted_first:
224 | batch_sdf_mean = torch.sum(batch_sdf * weight_knn, dim=1)
225 | batch_sdf_var = torch.sum((weight_knn * (batch_sdf - batch_sdf_mean.unsqueeze(-1)) ** 2), dim=1)
226 | batch_sdf_std = torch.sqrt(batch_sdf_var).squeeze(1)
227 | batch_sdf = batch_sdf_mean.squeeze(1)
228 | sdf_std[head:tail] = batch_sdf_std.detach()
229 |
230 | batch_sdf_grad = get_gradient(batch_coord, batch_sdf)
231 | sdf_grad[head:tail, :] = batch_sdf_grad.detach()
232 | sdf_pred[head:tail] = batch_sdf.detach()
233 | mc_mask[head:tail] = nn_count >= mask_min_nn_count
234 | certainty[head:tail] = batch_certainty.detach()
235 |
236 | # 剔除异常预测(也是滤波器的观测)
237 | grad_norm = sdf_grad.norm(dim=-1, keepdim=True).squeeze()
238 | max_sdf_std = self.config.surface_sample_range_m * self.config.max_sdf_std_ratio
239 | valid_idx = (
240 | mc_mask
241 | & (grad_norm < max_grad_norm)
242 | & (grad_norm > min_grad_norm)
243 | & (sdf_std < max_sdf_std)
244 | )
245 | valid_points = point_cloud_global[valid_idx]
246 | valid_point_count = valid_points.shape[0]
247 | point_cloud_imu = point_cloud_imu[valid_idx]
248 | grad_norm = grad_norm[valid_idx]
249 | sdf_pred = sdf_pred[valid_idx]
250 | sdf_grad = sdf_grad[valid_idx]
251 |
252 | # 计算雅可比矩阵
253 | H = torch.zeros((valid_point_count, 18), device=self.device, dtype=self.tran_dtype)
254 | pc_imu_hat = vectors_to_skew_symmetric(point_cloud_imu)
255 | rotation = self.x.rot.to(dtype=self.dtype).unsqueeze(0)
256 | A = torch.bmm(rotation.repeat(valid_point_count, 1, 1), pc_imu_hat)
257 | H[:, 0: 3] = -torch.bmm(sdf_grad.unsqueeze(1), A).squeeze(1)
258 | H[:, 3: 6] = sdf_grad
259 |
260 | # 计算不确定度(对精度有一个轻微的提升)
261 | sdf_residual = sdf_pred.to(dtype=self.tran_dtype)
262 | grad_anomaly = (grad_norm - 1.0).to(dtype=self.tran_dtype)
263 | w_grad = 1 / (1 + grad_anomaly ** 2)
264 | w_res = 0.4 / (0.4 + sdf_residual ** 2)
265 | self.R_inv = w_grad * w_res * 1000
266 |
267 | return sdf_residual, H, valid_points
268 |
269 | def update_iterated(self, point_cloud_imu: torch.tensor):
270 | """
271 | 使用迭代方法更新状态估计。
272 |
273 | Args:
274 | source_points (np.array): 测量点云,假定为 Nx3 矩阵。
275 | maximum_iter (int): 最大迭代次数。
276 | """
277 | # 将状态量和协方差矩阵转移到GPU
278 | self.x.cuda()
279 | self.P = self.P.cuda()
280 | valid_flag = True
281 | converged = False
282 |
283 | x_propagated = self.x
284 | P_inv = torch.linalg.inv(self.P)
285 | I = torch.eye(18, device=self.device, dtype=self.tran_dtype)
286 | term_thre_deg = self.config.reg_term_thre_deg
287 | term_thre_m = self.config.reg_term_thre_m
288 |
289 | for i in range(self.max_iteration):
290 | dx_new = boxminus(self.x, x_propagated)
291 | z, H, valid_points = self.h_model(point_cloud_imu)
292 | valid_point_count = valid_points.shape[0]
293 | source_point_count = point_cloud_imu.shape[0]
294 |
295 | if valid_point_count / source_point_count < 0.2 and i == self.max_iteration - 1:
296 | if not self.config.silence:
297 | print("[bold yellow](Warning) registration failed: not enough valid points[/bold yellow]")
298 | valid_flag = False
299 |
300 | H_T_R_inv = H.T * self.R_inv
301 | S = H_T_R_inv @ H
302 |
303 | K_front = torch.linalg.inv(S + P_inv)
304 | K = K_front @ H_T_R_inv
305 |
306 | dx_ = -K @ z + (K @ H - I) @ dx_new
307 | self.x = boxplus(self.x, dx_)
308 | tran_m = dx_[3:6].norm()
309 | rot_angle_deg = dx_[0:3].norm() * 180.0 / np.pi
310 |
311 | # 第一种迭代终止判定方式(有一定的物理含义)
312 | if rot_angle_deg < term_thre_deg and tran_m < term_thre_m and torch.all(torch.abs(dx_[6:]) < self.eps):
313 | if not self.config.silence:
314 | print("Converged after", i, "iterations")
315 | converged = True
316 |
317 | # 第二种迭代终止判定方式
318 | # if torch.all(torch.abs(dx_) < self.eps):
319 | # if not self.config.silence:
320 | # print("Converged after", i, "iterations")
321 | # converged = True
322 |
323 | if not valid_flag or converged:
324 | break
325 |
326 | self.P = (I - K @ H) @ self.P
327 | updated_pose = torch.eye(4, dtype=self.dtype, device=self.device)
328 | updated_pose[:3, :3] = self.x.rot.to(self.dtype)
329 | updated_pose[:3, 3] = self.x.pos.to(self.dtype)
330 |
331 | # 将状态量和协方差矩阵转移到CPU
332 | self.x.cpu()
333 | self.P = self.P.cpu()
334 | return updated_pose, valid_flag
--------------------------------------------------------------------------------
/utils/eval_traj_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file eval_traj_utils.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 | from collections import defaultdict
7 | from typing import Dict, List, Tuple
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 |
13 | def absolute_error(
14 | poses_gt: np.ndarray, poses_result: np.ndarray, align_on: bool = True
15 | ):
16 | assert poses_gt.shape[0] == poses_result.shape[0], "poses length should be identical"
17 | align_mat = np.eye(4)
18 | if align_on:
19 | align_rot, align_tran, _ = align_traj(poses_result, poses_gt)
20 | align_mat[:3, :3] = align_rot
21 | align_mat[:3, 3] = np.squeeze(align_tran)
22 |
23 | frame_count = poses_gt.shape[0]
24 |
25 | rot_errors = []
26 | tran_errors = []
27 |
28 | for i in range(frame_count):
29 | cur_results_pose_aligned = align_mat @ poses_result[i]
30 | cur_gt_pose = poses_gt[i]
31 | delta_rot = (
32 | np.linalg.inv(cur_gt_pose[:3, :3]) @ cur_results_pose_aligned[:3, :3]
33 | )
34 | delta_tran = cur_gt_pose[:3, 3] - cur_results_pose_aligned[:3, 3]
35 |
36 | # the one used for kiss-icp
37 | # delta_tran = cur_gt_pose[:3,3] - delta_rot @ cur_results_pose_aligned[:3,3]
38 |
39 | delta_rot_theta = rotation_error(delta_rot)
40 | delta_t = np.linalg.norm(delta_tran)
41 |
42 | rot_errors.append(delta_rot_theta)
43 | tran_errors.append(delta_t)
44 |
45 | rot_errors = np.array(rot_errors)
46 | tran_errors = np.array(tran_errors)
47 |
48 | rot_rmse = np.sqrt(np.dot(rot_errors, rot_errors) / frame_count) * 180.0 / np.pi # this seems to have some problem
49 | tran_rmse = np.sqrt(np.dot(tran_errors, tran_errors) / frame_count)
50 |
51 | # rot_mean = np.mean(rot_errors)
52 | # tran_mean = np.mean(tran_errors)
53 |
54 | # rot_median = np.median(rot_errors)
55 | # tran_median = np.median(tran_errors)
56 |
57 | # rot_std = np.std(rot_errors)
58 | # tran_std = np.std(tran_errors)
59 |
60 | return rot_rmse, tran_rmse, align_mat
61 |
62 |
63 | def align_traj(poses_np_1, poses_np_2):
64 |
65 | traj_1 = poses_np_1[:,:3,3].squeeze().T
66 | traj_2 = poses_np_2[:,:3,3].squeeze().T
67 |
68 | return align(traj_1, traj_2)
69 |
70 |
71 | def align(model, data):
72 | """Align two trajectories using the method of Horn (closed-form).
73 |
74 | Input:
75 | model -- first trajectory (3xn)
76 | data -- second trajectory (3xn)
77 |
78 | Output:
79 | rot -- rotation matrix (3x3)
80 | trans -- translation vector (3x1)
81 | trans_error -- translational error per point (1xn)
82 |
83 | Borrowed from NICE-SLAM
84 | """
85 | model_zerocentered = model - model.mean(1, keepdims=True)
86 | data_zerocentered = data - data.mean(1, keepdims=True)
87 |
88 | W = np.zeros((3, 3))
89 | for column in range(model.shape[1]):
90 | W += np.outer(model_zerocentered[:, column], data_zerocentered[:, column])
91 | U, d, Vh = np.linalg.linalg.svd(W.transpose())
92 | S = np.matrix(np.identity(3))
93 | if np.linalg.det(U) * np.linalg.det(Vh) < 0:
94 | S[2, 2] = -1
95 | rot = U * S * Vh
96 | trans = data.mean(1, keepdims=True) - rot * model.mean(1, keepdims=True)
97 |
98 | model_aligned = rot * model + trans
99 |
100 | alignment_error = model_aligned - data
101 |
102 | trans_error = np.sqrt(np.sum(np.multiply(alignment_error, alignment_error), 0)).A[
103 | 0
104 | ] # as RMSE
105 |
106 | return rot, trans, trans_error
107 |
108 |
109 | def relative_error(poses_gt, poses_result):
110 | """calculate sequence error (kitti metric, relative drifting error)
111 | Args:
112 | poses_gt, kx4x4 np.array, ground truth poses
113 | poses_result, kx4x4 np.array, predicted poses
114 | Returns:
115 | err (list list): [first_frame, rotation error, translation error, length, speed]
116 | - first_frame: frist frame index
117 | - rotation error: rotation error per length
118 | - translation error: translation error per length
119 | - length: evaluation trajectory length
120 | - speed: car speed (#FIXME: 10FPS is assumed)
121 | """
122 | assert poses_gt.shape[0] == poses_result.shape[0], "poses length should be identical"
123 | err = []
124 | dist = trajectory_distances(poses_gt)
125 | step_size = 10
126 |
127 | lengths = [100, 200, 300, 400, 500, 600, 700, 800] # unit: m
128 | num_lengths = len(lengths)
129 |
130 | for first_frame in range(0, poses_gt.shape[0], step_size):
131 | for i in range(num_lengths):
132 | len_ = lengths[i]
133 | last_frame = last_frame_from_segment_length(dist, first_frame, len_)
134 |
135 | # Continue if sequence not long enough
136 | if last_frame == -1:
137 | continue
138 |
139 | # compute rotational and translational errors
140 | pose_delta_gt = np.linalg.inv(poses_gt[first_frame]) @ poses_gt[last_frame]
141 | pose_delta_result = (
142 | np.linalg.inv(poses_result[first_frame]) @ poses_result[last_frame]
143 | )
144 |
145 | pose_error = np.linalg.inv(pose_delta_result) @ pose_delta_gt
146 |
147 | r_err = rotation_error(pose_error)
148 | t_err = translation_error(pose_error)
149 |
150 | # compute speed
151 | num_frames = last_frame - first_frame + 1.0
152 | speed = len_ / (0.1 * num_frames)
153 |
154 | err.append([first_frame, r_err / len_, t_err / len_, len_, speed])
155 |
156 | t_err = 0
157 | r_err = 0
158 |
159 | if len(err) == 0: # the case when the trajectory is not long enough
160 | return 0, 0
161 |
162 | for i in range(len(err)):
163 | r_err += err[i][1]
164 | t_err += err[i][2]
165 |
166 | r_err /= len(err)
167 | t_err /= len(err)
168 | drift_ate = t_err * 100.0
169 | drift_are = r_err / np.pi * 180.0
170 |
171 | return drift_ate, drift_are
172 |
173 |
174 | def trajectory_distances(poses_np):
175 | """Compute distance for each pose w.r.t frame-0
176 | Args:
177 | poses kx4x4 np.array
178 | Returns:
179 | dist (float list): distance of each pose w.r.t frame-0
180 | """
181 | dist = [0]
182 |
183 | for i in range(poses_np.shape[0] - 1):
184 | rela_dist = np.linalg.norm(poses_np[i+1] - poses_np[i])
185 | dist.append(dist[i] + rela_dist)
186 |
187 | return dist
188 |
189 |
190 | def rotation_error(pose_error):
191 | """Compute rotation error
192 | From a rotation matrix to the axis angle, use the angle as the result
193 | Args:
194 | pose_error (4x4 or 3x3 array): relative pose error
195 | Returns:
196 | rot_error (float): rotation error
197 | """
198 | a = pose_error[0, 0]
199 | b = pose_error[1, 1]
200 | c = pose_error[2, 2]
201 | # 0.5 * (trace - 1)
202 | d = 0.5 * (a + b + c - 1.0)
203 | # make sure the rot_mat is valid (trace < 3, det = 1)
204 | rot_error = np.arccos(max(min(d, 1.0), -1.0)) # in rad
205 | return rot_error
206 |
207 |
208 | def translation_error(pose_error):
209 | """Compute translation error
210 | Args:
211 | pose_error (4x4 array): relative pose error
212 | Returns:
213 | trans_error (float): translation error
214 | """
215 | dx = pose_error[0, 3]
216 | dy = pose_error[1, 3]
217 | dz = pose_error[2, 3]
218 | trans_error = np.sqrt(dx**2 + dy**2 + dz**2)
219 | return trans_error
220 |
221 |
222 | def last_frame_from_segment_length(dist, first_frame, length):
223 | """Find frame (index) that away from the first_frame with
224 | the required distance
225 | Args:
226 | dist (float list): distance of each pose w.r.t frame-0
227 | first_frame (int): start-frame index
228 | length (float): required distance
229 | Returns:
230 | i (int) / -1: end-frame index. if not found return -1
231 | """
232 | for i in range(first_frame, len(dist), 1):
233 | if dist[i] > (dist[first_frame] + length):
234 | return i
235 | return -1
236 |
237 |
238 | def plot_trajectories(
239 | traj_plot_path: str,
240 | poses_est,
241 | poses_ref,
242 | poses_est_2=None,
243 | plot_3d: bool = True,
244 | grid_on: bool = True,
245 | plot_start_end_markers: bool = True,
246 | vis_now: bool = False,
247 | close_all: bool = True,
248 | ) -> None:
249 | # positions_est, positions_ref, positions_est_2 as list of numpy array
250 |
251 | from evo.core.trajectory import PosePath3D
252 | from evo.tools import plot as evoplot
253 | from evo.tools.settings import SETTINGS
254 |
255 | # without alignment
256 |
257 | if close_all:
258 | plt.close("all")
259 |
260 | poses = PosePath3D(poses_se3=poses_est)
261 | gt_poses = PosePath3D(poses_se3=poses_ref)
262 | if poses_est_2 is not None:
263 | poses_2 = PosePath3D(poses_se3=poses_est_2)
264 |
265 | if plot_3d:
266 | plot_mode = evoplot.PlotMode.xyz
267 | else:
268 | plot_mode = evoplot.PlotMode.xy
269 |
270 | fig = plt.figure(f"Trajectory results")
271 | ax = evoplot.prepare_axis(fig, plot_mode)
272 | evoplot.traj(
273 | ax=ax,
274 | plot_mode=plot_mode,
275 | traj=gt_poses,
276 | label="ground truth",
277 | style=SETTINGS.plot_reference_linestyle,
278 | color=SETTINGS.plot_reference_color,
279 | alpha=SETTINGS.plot_reference_alpha,
280 | plot_start_end_markers=False,
281 | )
282 | evoplot.traj(
283 | ax=ax,
284 | plot_mode=plot_mode,
285 | traj=poses,
286 | label="PIN-SLAM",
287 | style=SETTINGS.plot_trajectory_linestyle,
288 | color="#4c72b0bf",
289 | alpha=SETTINGS.plot_trajectory_alpha,
290 | plot_start_end_markers=plot_start_end_markers,
291 | )
292 | if poses_est_2 is not None: # better to change color (or the alpha)
293 | evoplot.traj(
294 | ax=ax,
295 | plot_mode=plot_mode,
296 | traj=poses_2,
297 | label="PIN-Odom",
298 | style=SETTINGS.plot_trajectory_linestyle,
299 | color="#FF940E",
300 | alpha=SETTINGS.plot_trajectory_alpha / 3.0,
301 | plot_start_end_markers=False,
302 | )
303 |
304 | plt.tight_layout()
305 | ax.legend(frameon=grid_on)
306 |
307 | if traj_plot_path is not None:
308 | plt.savefig(traj_plot_path, dpi=600)
309 |
310 | if vis_now:
311 | plt.show()
312 |
313 |
314 | def read_kitti_format_calib(filename: str):
315 | """
316 | read calibration file (with the kitti format)
317 | returns -> dict calibration matrices as 4*4 numpy arrays
318 | """
319 | calib = {}
320 | calib_file = open(filename)
321 |
322 | for line in calib_file:
323 | key, content = line.strip().split(":")
324 | values = [float(v) for v in content.strip().split()]
325 | pose = np.zeros((4, 4))
326 |
327 | pose[0, 0:4] = values[0:4]
328 | pose[1, 0:4] = values[4:8]
329 | pose[2, 0:4] = values[8:12]
330 | pose[3, 3] = 1.0
331 |
332 | calib[key] = pose
333 |
334 | calib_file.close()
335 | return calib
336 |
337 |
338 | def read_kitti_format_poses(filename: str) -> List[np.ndarray]:
339 | """
340 | read pose file (with the kitti format)
341 | returns -> list, transformation before calibration transformation
342 | """
343 | pose_file = open(filename)
344 |
345 | poses = []
346 |
347 | for line in pose_file:
348 | values = [float(v) for v in line.strip().split()]
349 |
350 | pose = np.zeros((4, 4))
351 | pose[0, 0:4] = values[0:4]
352 | pose[1, 0:4] = values[4:8]
353 | pose[2, 0:4] = values[8:12]
354 | pose[3, 3] = 1.0
355 | poses.append(pose)
356 |
357 | pose_file.close()
358 | return poses
359 |
360 |
361 | # copyright: Nacho et al. KISS-ICP
362 | def apply_kitti_format_calib(poses: List[np.ndarray], calib_T_cl) -> List[np.ndarray]:
363 | """Converts from Velodyne to Camera Frame (# T_camera<-lidar)"""
364 | poses_calib = []
365 | for pose in poses:
366 | poses_calib.append(calib_T_cl @ pose @ np.linalg.inv(calib_T_cl))
367 | return poses_calib
368 |
369 |
370 | # copyright: Nacho et al. KISS-ICP
371 | def write_kitti_format_poses(filename: str, poses: List[np.ndarray]):
372 | def _to_kitti_format(poses: np.ndarray) -> np.ndarray:
373 | return np.array([np.concatenate((pose[0], pose[1], pose[2])) for pose in poses])
374 |
375 | np.savetxt(fname=f"{filename}_kitti.txt", X=_to_kitti_format(poses))
376 |
377 |
378 | # for LiDAR dataset
379 | def get_metrics(seq_result: List[Dict]):
380 | odom_ate = (seq_result[0])["Average Translation Error [%]"]
381 | odom_are = (seq_result[0])["Average Rotational Error [deg/m]"] * 100.0
382 | slam_rmse = (seq_result[1])["Absoulte Trajectory Error [m]"]
383 | metrics_dict = {
384 | "Odometry ATE [%]": odom_ate,
385 | "Odometry ARE [deg/100m]": odom_are,
386 | "SLAM RMSE [m]": slam_rmse,
387 | }
388 | return metrics_dict
389 |
390 |
391 | def mean_metrics(seq_metrics: List[Dict]):
392 | sums = defaultdict(float)
393 | counts = defaultdict(int)
394 |
395 | for seq_metric in seq_metrics:
396 | for key, value in seq_metric.items():
397 | sums[key] += value
398 | counts[key] += 1
399 |
400 | mean_metrics = {key: sums[key] / counts[key] for key in sums}
401 | return mean_metrics
402 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file loss.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def sdf_diff_loss(pred, label, weight, scale=1.0, l2_loss=True):
11 | count = pred.shape[0]
12 | diff = pred - label
13 | diff_m = diff / scale # so it's still in m unit
14 | if l2_loss:
15 | loss = (weight * (diff_m**2)).sum() / count # l2 loss
16 | else:
17 | loss = (weight * torch.abs(diff_m)).sum() / count # l1 loss
18 | return loss
19 |
20 |
21 | def sdf_l1_loss(pred, label):
22 | loss = torch.abs(pred - label)
23 | return loss.mean()
24 |
25 |
26 | def sdf_l2_loss(pred, label):
27 | loss = (pred - label) ** 2
28 | return loss.mean()
29 |
30 |
31 | def color_diff_loss(pred, label, weight, weighted=False, l2_loss=False):
32 | diff = pred - label
33 | if not weighted:
34 | weight = 1.0
35 | else:
36 | weight = weight.unsqueeze(1)
37 | if l2_loss:
38 | loss = (weight * (diff**2)).mean()
39 | else:
40 | loss = (weight * torch.abs(diff)).mean()
41 | return loss
42 |
43 |
44 | # used by our approach
45 | def sdf_bce_loss(pred, label, sigma, weight, weighted=False, bce_reduction="mean"):
46 | """ Calculate the binary cross entropy (BCE) loss for SDF supervision
47 | Args:
48 | pred (torch.tenosr): batch of predicted SDF values
49 | label (torch.tensor): batch of the target SDF values
50 | sigma (float): scale factor for the sigmoid function
51 | weight (torch.tenosr): batch of the per-sample weight
52 | weighted (bool, optional): apply the weight or not
53 | bce_reduction (string, optional): specifies the reduction to apply to the output
54 | Returns:
55 | loss (torch.tensor): BCE loss for the batch
56 | """
57 | if weighted:
58 | loss_bce = nn.BCEWithLogitsLoss(reduction=bce_reduction, weight=weight)
59 | else:
60 | loss_bce = nn.BCEWithLogitsLoss(reduction=bce_reduction)
61 | label_op = torch.sigmoid(label / sigma) # occupancy prob
62 | loss = loss_bce(pred / sigma, label_op)
63 | return loss
64 |
65 |
66 | # the loss divised by Starry Zhong
67 | def sdf_zhong_loss(pred, label, trunc_dist=None, weight=None, weighted=False):
68 | if not weighted:
69 | weight = 1.0
70 | else:
71 | weight = weight
72 | loss = torch.zeros_like(label, dtype=label.dtype, device=label.device)
73 | middle_point = label / 2.0
74 | middle_point_abs = torch.abs(middle_point)
75 | shift_difference_abs = torch.abs(pred - middle_point)
76 | mask = shift_difference_abs > middle_point_abs
77 | loss[mask] = (shift_difference_abs - middle_point_abs)[
78 | mask
79 | ] # not masked region simply has a loss of zero, masked region L1 loss
80 | if trunc_dist is not None:
81 | surface_mask = torch.abs(label) < trunc_dist
82 | loss[surface_mask] = torch.abs(pred - label)[surface_mask]
83 | loss *= weight
84 | return loss.mean()
85 |
86 |
87 | # not used
88 | def smooth_sdf_loss(pred, label, delta=20.0, weight=None, weighted=False):
89 | if not weighted:
90 | weight = 1.0
91 | else:
92 | weight = weight
93 | sign_factors = torch.ones_like(label, dtype=label.dtype, device=label.device)
94 | sign_factors[label < 0.0] = -1.0
95 | sign_loss = -sign_factors * delta * pred / 2.0
96 | no_loss = torch.zeros_like(pred, dtype=pred.dtype, device=pred.device)
97 | truncated_loss = sign_factors * delta * (pred / 2.0 - label)
98 | losses = torch.stack((sign_loss, no_loss, truncated_loss), dim=0)
99 | final_loss = torch.logsumexp(losses, dim=0)
100 | final_loss = ((2.0 / delta) * final_loss * weight).mean()
101 | return final_loss
102 |
103 |
104 | def ray_estimation_loss(x, y, d_meas): # for each ray
105 | # x as depth
106 | # y as sdf prediction
107 | # d_meas as measured depth
108 |
109 | # regard each sample as a ray
110 | mat_A = torch.vstack((x, torch.ones_like(x))).transpose(0, 1)
111 | vec_b = y.view(-1, 1)
112 | least_square_estimate = torch.linalg.lstsq(mat_A, vec_b).solution
113 |
114 | a = least_square_estimate[0] # -> -1 (added in ekional loss term)
115 | b = least_square_estimate[1]
116 |
117 | d_estimate = torch.clamp(-b / a, min=1.0, max=40.0) # -> d
118 |
119 | d_error = torch.abs(d_estimate - d_meas)
120 |
121 | return d_error
122 |
123 |
124 | def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch]
125 | # x as depth
126 | # y as occ.prob. prediction
127 | x = x.squeeze(1)
128 | sort_x, indices = torch.sort(x)
129 | sort_y = y[indices]
130 |
131 | w = torch.ones_like(y)
132 | for i in range(sort_x.shape[0]):
133 | w[i] = sort_y[i]
134 | for j in range(i):
135 | w[i] *= 1.0 - sort_y[j]
136 |
137 | d_render = (w * x).sum()
138 |
139 | d_error = torch.abs(d_render - d_meas)
140 |
141 | return d_error
142 |
143 |
144 | def batch_ray_rendering_loss(x, y, d_meas, neus_on=True): # for all rays in a batch
145 | # x as depth [ray number * sample number]
146 | # y as prediction (the alpha in volume rendering) [ray number * sample number]
147 | # d_meas as measured depth [ray number]
148 | # w as the raywise weight [ray number]
149 | # neus_on determine if using the loss defined in NEUS
150 |
151 | sort_x, indices = torch.sort(x, 1) # for each row
152 | sort_y = torch.gather(y, 1, indices) # for each row
153 |
154 | if neus_on:
155 | neus_alpha = (sort_y[:, 1:] - sort_y[:, 0:-1]) / (1.0 - sort_y[:, 0:-1] + 1e-10)
156 | # avoid dividing by 0 (nan)
157 | # print(neus_alpha)
158 | alpha = torch.clamp(neus_alpha, min=0.0, max=1.0)
159 | else:
160 | alpha = sort_y
161 |
162 | one_minus_alpha = torch.ones_like(alpha) - alpha + 1e-10
163 |
164 | cum_mat = torch.cumprod(one_minus_alpha, 1)
165 |
166 | weights = cum_mat / one_minus_alpha * alpha
167 |
168 | weights_x = weights * sort_x[:, 0 : alpha.shape[1]]
169 |
170 | d_render = torch.sum(weights_x, 1)
171 |
172 | d_error = torch.abs(d_render - d_meas)
173 |
174 | d_error_mean = torch.mean(d_error)
175 |
176 | return d_error_mean
177 |
--------------------------------------------------------------------------------
/utils/mesher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file mesher.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 | import math
7 |
8 | import matplotlib.cm as cm
9 | import numpy as np
10 | import open3d as o3d
11 | import skimage.measure
12 | import torch
13 | from tqdm import tqdm
14 |
15 | from model.decoder import Decoder
16 | from model.neural_points import NeuralPoints
17 | from utils.config import Config
18 | from utils.semantic_kitti_utils import sem_kitti_color_map
19 | from utils.tools import remove_gpu_cache
20 |
21 | class Mesher:
22 | def __init__(
23 | self,
24 | config: Config,
25 | neural_points: NeuralPoints,
26 | decoders: dict,
27 | ):
28 |
29 | self.config = config
30 | self.silence = config.silence
31 | self.neural_points = neural_points
32 | self.sdf_mlp = decoders["sdf"]
33 | self.sem_mlp = decoders["semantic"]
34 | self.color_mlp = decoders["color"]
35 | self.device = config.device
36 | self.cur_device = self.device
37 | self.dtype = config.dtype
38 | self.global_transform = np.eye(4)
39 |
40 | def query_points(
41 | self,
42 | coord,
43 | bs,
44 | query_sdf=True,
45 | query_sem=False,
46 | query_color=False,
47 | query_mask=True,
48 | query_locally=False,
49 | mask_min_nn_count: int = 4,
50 | out_torch: bool = False,
51 | ):
52 | """query the sdf value, semantic label and marching cubes mask for points
53 | Args:
54 | coord: Nx3 torch tensor, the coordinates of all N (axbxc) query points in the scaled
55 | kaolin coordinate system [-1,1]
56 | bs: batch size for the inference
57 | Returns:
58 | sdf_pred: Ndim numpy array or torch tensor, signed distance value (scaled) at each query point
59 | sem_pred: Ndim numpy array or torch tenosr, semantic label prediction at each query point
60 | mc_mask: Ndim bool numpy array or torch tensor, marching cubes mask at each query point
61 | """
62 | # the coord torch tensor is already scaled in the [-1,1] coordinate system
63 | sample_count = coord.shape[0]
64 | iter_n = math.ceil(sample_count / bs)
65 | if query_sdf:
66 | if out_torch:
67 | sdf_pred = torch.zeros(sample_count)
68 | else:
69 | sdf_pred = np.zeros(sample_count)
70 | else:
71 | sdf_pred = None
72 | if query_sem:
73 | if out_torch:
74 | sem_pred = torch.zeros(sample_count)
75 | else:
76 | sem_pred = np.zeros(sample_count)
77 | else:
78 | sem_pred = None
79 | if query_color:
80 | if out_torch:
81 | color_pred = torch.zeros((sample_count, self.config.color_channel))
82 | else:
83 | color_pred = np.zeros((sample_count, self.config.color_channel))
84 | else:
85 | color_pred = None
86 | if query_mask:
87 | if out_torch:
88 | mc_mask = torch.zeros(sample_count)
89 | else:
90 | mc_mask = np.zeros(sample_count)
91 | else:
92 | mc_mask = None
93 |
94 | with torch.no_grad(): # eval step
95 | for n in tqdm(range(iter_n), disable=self.silence):
96 | head = n * bs
97 | tail = min((n + 1) * bs, sample_count)
98 | batch_coord = coord[head:tail, :]
99 | batch_size = batch_coord.shape[0]
100 | if self.cur_device == "cpu" and self.device == "cuda":
101 | batch_coord = batch_coord.cuda()
102 | (
103 | batch_geo_feature,
104 | batch_color_feature,
105 | weight_knn,
106 | nn_count,
107 | _,
108 | ) = self.neural_points.query_feature(
109 | batch_coord,
110 | training_mode=False,
111 | query_locally=query_locally, # inference mode, query globally
112 | query_geo_feature=query_sdf or query_sem,
113 | query_color_feature=query_color,
114 | )
115 |
116 | pred_mask = nn_count >= 1 # only query sdf here
117 | if query_sdf:
118 | if self.config.weighted_first:
119 | batch_sdf = torch.zeros(batch_size, device=self.device)
120 | else:
121 | batch_sdf = torch.zeros(
122 | batch_size,
123 | batch_geo_feature.shape[1],
124 | 1,
125 | device=self.device,
126 | )
127 | # predict the sdf with the feature, only do for the unmasked part (not in the unknown freespace)
128 | batch_sdf[pred_mask] = self.sdf_mlp.sdf(
129 | batch_geo_feature[pred_mask]
130 | )
131 |
132 | if not self.config.weighted_first:
133 | batch_sdf = torch.sum(batch_sdf * weight_knn, dim=1).squeeze(1)
134 | if out_torch:
135 | sdf_pred[head:tail] = batch_sdf.detach()
136 | else:
137 | sdf_pred[head:tail] = batch_sdf.detach().cpu().numpy()
138 | if query_sem:
139 | batch_sem_prob = self.sem_mlp.sem_label_prob(batch_geo_feature)
140 | if not self.config.weighted_first:
141 | batch_sem_prob = torch.sum(batch_sem_prob * weight_knn, dim=1)
142 | batch_sem = torch.argmax(batch_sem_prob, dim=1)
143 | if out_torch:
144 | sem_pred[head:tail] = batch_sem.detach()
145 | else:
146 | sem_pred[head:tail] = batch_sem.detach().cpu().numpy()
147 | if query_color:
148 | batch_color = self.color_mlp.regress_color(batch_color_feature)
149 | if not self.config.weighted_first:
150 | batch_color = torch.sum(batch_color * weight_knn, dim=1) # N, C
151 | if out_torch:
152 | color_pred[head:tail] = batch_color.detach()
153 | else:
154 | color_pred[head:tail] = (
155 | batch_color.detach().cpu().numpy().astype(dtype=np.float64)
156 | )
157 | if query_mask:
158 | # do marching cubes only when there are at least K near neural points
159 | mask_mc = nn_count >= mask_min_nn_count
160 | if out_torch:
161 | mc_mask[head:tail] = mask_mc.detach()
162 | else:
163 | mc_mask[head:tail] = mask_mc.detach().cpu().numpy()
164 |
165 | return sdf_pred, sem_pred, color_pred, mc_mask
166 |
167 | def get_query_from_bbx(self, bbx, voxel_size, pad_voxel=0, skip_top_voxel=0):
168 | """
169 | get grid query points inside a given bounding box (bbx)
170 | Args:
171 | bbx: open3d bounding box, in world coordinate system, with unit m
172 | voxel_size: scalar, marching cubes voxel size with unit m
173 | Returns:
174 | coord: Nx3 torch tensor, the coordinates of all N (axbxc) query points in the scaled
175 | kaolin coordinate system [-1,1]
176 | voxel_num_xyz: 3dim numpy array, the number of voxels on each axis for the bbx
177 | voxel_origin: 3dim numpy array the coordinate of the bottom-left corner of the 3d grids
178 | for marching cubes, in world coordinate system with unit m
179 | """
180 | # bbx and voxel_size are all in the world coordinate system
181 | min_bound = bbx.get_min_bound()
182 | max_bound = bbx.get_max_bound()
183 | len_xyz = max_bound - min_bound
184 | voxel_num_xyz = (np.ceil(len_xyz / voxel_size) + pad_voxel * 2).astype(np.int_)
185 | voxel_origin = min_bound - pad_voxel * voxel_size
186 | # pad an additional voxel underground to gurantee the reconstruction of ground
187 | voxel_origin[2] -= voxel_size
188 | voxel_num_xyz[2] += 1
189 | voxel_num_xyz[2] -= skip_top_voxel
190 |
191 | voxel_count_total = voxel_num_xyz[0] * voxel_num_xyz[1] * voxel_num_xyz[2]
192 | if voxel_count_total > 5e8: # this value is determined by your gpu memory
193 | print("too many query points, use smaller chunks")
194 | return None, None, None
195 | # self.cur_device = "cpu" # firstly save in cpu memory (which would be larger than gpu's)
196 | # print("too much query points, use cpu memory")
197 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device)
198 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device)
199 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device)
200 |
201 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ...
202 | x, y, z = torch.meshgrid(x, y, z, indexing="ij")
203 | # get the vector of all the grid point's 3D coordinates
204 | coord = (
205 | torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float()
206 | )
207 | # transform to world coordinate system
208 | coord *= voxel_size
209 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device)
210 |
211 | return coord, voxel_num_xyz, voxel_origin
212 |
213 | def get_query_from_hor_slice(self, bbx, slice_z, voxel_size):
214 | """
215 | get grid query points inside a given bounding box (bbx) at slice height (slice_z)
216 | """
217 | # bbx and voxel_size are all in the world coordinate system
218 | min_bound = bbx.get_min_bound()
219 | max_bound = bbx.get_max_bound()
220 | len_xyz = max_bound - min_bound
221 | voxel_num_xyz = (np.ceil(len_xyz / voxel_size)).astype(np.int_)
222 | voxel_num_xyz[2] = 1
223 | voxel_origin = min_bound
224 | voxel_origin[2] = slice_z
225 |
226 | query_count_total = voxel_num_xyz[0] * voxel_num_xyz[1]
227 | if query_count_total > 1e8: # avoid gpu memory issue, dirty fix
228 | self.cur_device = (
229 | "cpu" # firstly save in cpu memory (which would be larger than gpu's)
230 | )
231 | print("too much query points, use cpu memory")
232 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device)
233 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device)
234 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device)
235 |
236 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ...
237 | x, y, z = torch.meshgrid(x, y, z, indexing="ij")
238 | # get the vector of all the grid point's 3D coordinates
239 | coord = (
240 | torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float()
241 | )
242 | # transform to world coordinate system
243 | coord *= voxel_size
244 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device)
245 |
246 | return coord, voxel_num_xyz, voxel_origin
247 |
248 | def get_query_from_ver_slice(self, bbx, slice_x, voxel_size):
249 | """
250 | get grid query points inside a given bounding box (bbx) at slice position (slice_x)
251 | """
252 | # bbx and voxel_size are all in the world coordinate system
253 | min_bound = bbx.get_min_bound()
254 | max_bound = bbx.get_max_bound()
255 | len_xyz = max_bound - min_bound
256 | voxel_num_xyz = (np.ceil(len_xyz / voxel_size)).astype(np.int_)
257 | voxel_num_xyz[0] = 1
258 | voxel_origin = min_bound
259 | voxel_origin[0] = slice_x
260 |
261 | query_count_total = voxel_num_xyz[1] * voxel_num_xyz[2]
262 | if query_count_total > 1e8: # avoid gpu memory issue, dirty fix
263 | self.cur_device = (
264 | "cpu" # firstly save in cpu memory (which would be larger than gpu's)
265 | )
266 | print("too much query points, use cpu memory")
267 | x = torch.arange(voxel_num_xyz[0], dtype=torch.int16, device=self.cur_device)
268 | y = torch.arange(voxel_num_xyz[1], dtype=torch.int16, device=self.cur_device)
269 | z = torch.arange(voxel_num_xyz[2], dtype=torch.int16, device=self.cur_device)
270 |
271 | # order: [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], [0,1,2] ...
272 | x, y, z = torch.meshgrid(x, y, z, indexing="ij")
273 | # get the vector of all the grid point's 3D coordinates
274 | coord = (
275 | torch.stack((x.flatten(), y.flatten(), z.flatten())).transpose(0, 1).float()
276 | )
277 | # transform to world coordinate system
278 | coord *= voxel_size
279 | coord += torch.tensor(voxel_origin, dtype=self.dtype, device=self.cur_device)
280 |
281 | return coord, voxel_num_xyz, voxel_origin
282 |
283 | def generate_sdf_map(self, coord, sdf_pred, mc_mask):
284 | """
285 | Generate the SDF map for saving
286 | """
287 | device = o3d.core.Device("CPU:0")
288 | dtype = o3d.core.float32
289 | sdf_map_pc = o3d.t.geometry.PointCloud(device)
290 |
291 | coord_np = coord.detach().cpu().numpy()
292 |
293 | # the sdf (unit: m) would be saved in the intensity channel
294 | sdf_map_pc.point["positions"] = o3d.core.Tensor(coord_np, dtype, device)
295 | sdf_map_pc.point["intensities"] = o3d.core.Tensor(
296 | np.expand_dims(sdf_pred, axis=1), dtype, device
297 | ) # scaled sdf prediction
298 | if mc_mask is not None:
299 | # the marching cubes mask would be saved in the labels channel
300 | sdf_map_pc.point["labels"] = o3d.core.Tensor(
301 | np.expand_dims(mc_mask, axis=1), o3d.core.int32, device
302 | ) # mask
303 |
304 | # global transform (to world coordinate system) before output
305 | if not np.array_equal(self.global_transform, np.eye(4)):
306 | sdf_map_pc.transform(self.global_transform)
307 |
308 | return sdf_map_pc
309 |
310 | def generate_sdf_map_for_vis(
311 | self, coord, sdf_pred, mc_mask, min_sdf=-1.0, max_sdf=1.0, cmap="bwr"
312 | ): # 'jet','bwr','viridis'
313 | """
314 | Generate the SDF map for visualization
315 | """
316 | # do the masking or not
317 | if mc_mask is not None:
318 | coord = coord[mc_mask > 0]
319 | sdf_pred = sdf_pred[mc_mask > 0]
320 |
321 | coord_np = coord.detach().cpu().numpy().astype(np.float64)
322 |
323 | sdf_pred_show = np.clip((sdf_pred - min_sdf) / (max_sdf - min_sdf), 0.0, 1.0)
324 |
325 | color_map = cm.get_cmap(cmap) # or 'jet'
326 | colors = color_map(1.0 - sdf_pred_show)[:, :3].astype(np.float64) # change to blue (+) --> red (-)
327 |
328 | sdf_map_pc = o3d.geometry.PointCloud()
329 | sdf_map_pc.points = o3d.utility.Vector3dVector(coord_np)
330 | sdf_map_pc.colors = o3d.utility.Vector3dVector(colors)
331 | if not np.array_equal(self.global_transform, np.eye(4)):
332 | sdf_map_pc.transform(self.global_transform)
333 |
334 | return sdf_map_pc
335 |
336 | def assign_to_bbx(self, sdf_pred, sem_pred, color_pred, mc_mask, voxel_num_xyz):
337 | """assign the queried sdf, semantic label and marching cubes mask back to the 3D grids in the specified bounding box
338 | Args:
339 | sdf_pred: Ndim np.array/torch.tensor
340 | sem_pred: Ndim np.array/torch.tensor
341 | mc_mask: Ndim bool np.array/torch.tensor
342 | voxel_num_xyz: 3dim numpy array/torch.tensor, the number of voxels on each axis for the bbx
343 | Returns:
344 | sdf_pred: a*b*c np.array/torch.tensor, 3d grids of sign distance values
345 | sem_pred: a*b*c np.array/torch.tensor, 3d grids of semantic labels
346 | mc_mask: a*b*c np.array/torch.tensor, 3d grids of marching cube masks, marching cubes only on where
347 | the mask is true
348 | """
349 | if sdf_pred is not None:
350 | sdf_pred = sdf_pred.reshape(
351 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]
352 | )
353 |
354 | if sem_pred is not None:
355 | sem_pred = sem_pred.reshape(
356 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]
357 | )
358 |
359 | if color_pred is not None:
360 | color_pred = color_pred.reshape(
361 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]
362 | )
363 |
364 | if mc_mask is not None:
365 | mc_mask = mc_mask.reshape(
366 | voxel_num_xyz[0], voxel_num_xyz[1], voxel_num_xyz[2]
367 | )
368 |
369 | return sdf_pred, sem_pred, color_pred, mc_mask
370 |
371 | def mc_mesh(self, mc_sdf, mc_mask, voxel_size, mc_origin):
372 | """use the marching cubes algorithm to get mesh vertices and faces
373 | Args:
374 | mc_sdf: a*b*c np.array, 3d grids of sign distance values
375 | mc_mask: a*b*c np.array, 3d grids of marching cube masks, marching cubes only on where
376 | the mask is true
377 | voxel_size: scalar, marching cubes voxel size with unit m
378 | mc_origin: 3*1 np.array, the coordinate of the bottom-left corner of the 3d grids for
379 | marching cubes, in world coordinate system with unit m
380 | Returns:
381 | ([verts], [faces]), mesh vertices and triangle faces
382 | """
383 | if not self.silence:
384 | print("Marching cubes ...")
385 | # the input are all already numpy arraies
386 | verts, faces = np.zeros((0, 3)), np.zeros((0, 3))
387 | try:
388 | verts, faces, _, _ = skimage.measure.marching_cubes(
389 | mc_sdf, level=0.0, allow_degenerate=False, mask=mc_mask
390 | )
391 | # Whether to allow degenerate (i.e. zero-area) triangles in the
392 | # end-result. Default True. If False, degenerate triangles are
393 | # removed, at the cost of making the algorithm slower.
394 | except:
395 | pass
396 |
397 | verts = mc_origin + verts * voxel_size
398 |
399 | return verts, faces
400 |
401 | def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True):
402 | """
403 | Predict the semantic label of the vertices
404 | """
405 | if len(verts) == 0:
406 | return mesh
407 |
408 | # print("predict semantic labels of the vertices")
409 | verts_torch = torch.tensor(verts, dtype=self.dtype, device=self.device)
410 | _, verts_sem, _, _ = self.query_points(
411 | verts_torch, self.config.infer_bs, False, True, False, False
412 | )
413 | verts_sem_list = list(verts_sem)
414 | verts_sem_rgb = [sem_kitti_color_map[sem_label] for sem_label in verts_sem_list]
415 | verts_sem_rgb = np.asarray(verts_sem_rgb, dtype=np.float64) / 255.0
416 | mesh.vertex_colors = o3d.utility.Vector3dVector(verts_sem_rgb)
417 |
418 | # filter the freespace vertices
419 | if filter_free_space_vertices:
420 | non_freespace_idx = verts_sem <= 0
421 | mesh.remove_vertices_by_mask(non_freespace_idx)
422 |
423 | return mesh
424 |
425 | def estimate_vertices_color(self, mesh, verts):
426 | """
427 | Predict the color of the vertices
428 | """
429 | if len(verts) == 0:
430 | return mesh
431 |
432 | # print("predict color labels of the vertices")
433 | verts_torch = torch.tensor(verts, dtype=self.dtype, device=self.device)
434 | _, _, verts_color, _ = self.query_points(
435 | verts_torch, self.config.infer_bs, False, False, True, False
436 | )
437 |
438 | if self.config.color_channel == 1:
439 | verts_color = np.repeat(verts_color * 2.0, 3, axis=1)
440 |
441 | mesh.vertex_colors = o3d.utility.Vector3dVector(verts_color)
442 |
443 | return mesh
444 |
445 | def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300):
446 | """
447 | Cluster connected triangles and remove the small clusters
448 | """
449 | triangle_clusters, cluster_n_triangles, _ = mesh.cluster_connected_triangles()
450 | triangle_clusters = np.asarray(triangle_clusters)
451 | cluster_n_triangles = np.asarray(cluster_n_triangles)
452 | # print("Remove the small clusters")
453 | triangles_to_remove = (
454 | cluster_n_triangles[triangle_clusters] < filter_cluster_min_tri
455 | )
456 | mesh.remove_triangles_by_mask(triangles_to_remove)
457 |
458 | return mesh
459 |
460 | def generate_bbx_sdf_hor_slice(
461 | self, bbx, slice_z, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0, mask_min_nn_count=5
462 | ):
463 | """
464 | Generate the SDF slice at height (slice_z)
465 | """
466 | # print("Generate the SDF slice at heright %.2f (m)" % (slice_z))
467 | coord, _, _ = self.get_query_from_hor_slice(bbx, slice_z, voxel_size)
468 | sdf_pred, _, _, mc_mask = self.query_points(
469 | coord,
470 | self.config.infer_bs,
471 | True,
472 | False,
473 | False,
474 | self.config.mc_mask_on,
475 | query_locally=query_locally,
476 | mask_min_nn_count=mask_min_nn_count,
477 | )
478 | sdf_map_pc = self.generate_sdf_map_for_vis(
479 | coord, sdf_pred, mc_mask, min_sdf, max_sdf
480 | )
481 |
482 | return sdf_map_pc
483 |
484 | def generate_bbx_sdf_ver_slice(
485 | self, bbx, slice_x, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0, mask_min_nn_count=5
486 | ):
487 | """
488 | Generate the SDF slice at x position (slice_x)
489 | """
490 | # print("Generate the SDF slice at x position %.2f (m)" % (slice_x))
491 | coord, _, _ = self.get_query_from_ver_slice(bbx, slice_x, voxel_size)
492 | sdf_pred, _, _, mc_mask = self.query_points(
493 | coord,
494 | self.config.infer_bs,
495 | True,
496 | False,
497 | False,
498 | self.config.mc_mask_on,
499 | query_locally=query_locally,
500 | mask_min_nn_count=mask_min_nn_count,
501 | )
502 | sdf_map_pc = self.generate_sdf_map_for_vis(
503 | coord, sdf_pred, mc_mask, min_sdf, max_sdf
504 | )
505 |
506 | return sdf_map_pc
507 |
508 | # reconstruct the mesh from a the map defined by a collection of bounding boxes
509 | def recon_aabb_collections_mesh(
510 | self,
511 | aabbs,
512 | voxel_size,
513 | mesh_path=None,
514 | query_locally=False,
515 | estimate_sem=False,
516 | estimate_color=False,
517 | mesh_normal=True,
518 | filter_isolated_mesh=False,
519 | filter_free_space_vertices=True,
520 | mesh_min_nn=10,
521 | use_torch_mc=False,
522 | ):
523 | """
524 | Reconstruct the mesh from a collection of bounding boxes
525 | """
526 | if not self.silence:
527 | print("# Chunk for meshing: ", len(aabbs))
528 |
529 | mesh_merged = o3d.geometry.TriangleMesh()
530 | for bbx in tqdm(aabbs, disable=self.silence):
531 | cur_mesh = self.recon_aabb_mesh(
532 | bbx,
533 | voxel_size,
534 | None,
535 | query_locally,
536 | estimate_sem,
537 | estimate_color,
538 | mesh_normal,
539 | filter_isolated_mesh,
540 | filter_free_space_vertices,
541 | mesh_min_nn,
542 | use_torch_mc,
543 | )
544 | mesh_merged += cur_mesh
545 |
546 | remove_gpu_cache() # deal with high GPU memory consumption when meshing (TODO)
547 |
548 | mesh_merged.remove_duplicated_vertices()
549 |
550 | if mesh_normal:
551 | mesh_merged.compute_vertex_normals()
552 |
553 | if mesh_path is not None:
554 | o3d.io.write_triangle_mesh(mesh_path, mesh_merged)
555 | if not self.silence:
556 | print("save the mesh to %s\n" % (mesh_path))
557 |
558 | return mesh_merged
559 |
560 | def recon_aabb_mesh(
561 | self,
562 | bbx,
563 | voxel_size,
564 | mesh_path=None,
565 | query_locally=False,
566 | estimate_sem=False,
567 | estimate_color=False,
568 | mesh_normal=True,
569 | filter_isolated_mesh=False,
570 | filter_free_space_vertices=True,
571 | mesh_min_nn=10,
572 | use_torch_mc=False,
573 | ):
574 | """
575 | Reconstruct the mesh from a given bounding box
576 | """
577 | # reconstruct and save the (semantic) mesh from the feature octree the decoders within a
578 | # given bounding box. bbx and voxel_size all with unit m, in world coordinate system
579 | coord, voxel_num_xyz, voxel_origin = self.get_query_from_bbx(
580 | bbx, voxel_size, self.config.pad_voxel, self.config.skip_top_voxel
581 | )
582 | if coord is None: # use chunks in this case
583 | return None
584 |
585 | sdf_pred, _, _, mc_mask = self.query_points(
586 | coord,
587 | self.config.infer_bs,
588 | True,
589 | False,
590 | False,
591 | self.config.mc_mask_on,
592 | query_locally,
593 | mesh_min_nn,
594 | out_torch=use_torch_mc,
595 | )
596 |
597 | mc_sdf, _, _, mc_mask = self.assign_to_bbx(
598 | sdf_pred, None, None, mc_mask, voxel_num_xyz
599 | )
600 | if use_torch_mc:
601 | # torch version
602 | verts, faces = self.mc_mesh_torch(
603 | mc_sdf, mc_mask, voxel_size, torch.tensor(voxel_origin).to(mc_sdf)
604 | ) # has some double faces problem
605 | mesh = o3d.t.geometry.TriangleMesh(device=o3d.core.Device("cuda:0"))
606 | mesh.vertex.positions = o3d.core.Tensor.from_dlpack(
607 | torch.utils.dlpack.to_dlpack(verts)
608 | )
609 | mesh.triangle.indices = o3d.core.Tensor.from_dlpack(
610 | torch.utils.dlpack.to_dlpack(faces)
611 | )
612 | mesh = mesh.to_legacy()
613 | mesh.remove_duplicated_vertices()
614 | mesh.compute_vertex_normals()
615 | else:
616 | # np cpu version
617 | verts, faces = self.mc_mesh(
618 | mc_sdf, mc_mask.astype(bool), voxel_size, voxel_origin
619 | ) # too slow ? (actually not, the slower part is the querying)
620 | # directly use open3d to get mesh
621 | mesh = o3d.geometry.TriangleMesh(
622 | o3d.utility.Vector3dVector(verts.astype(np.float64)),
623 | o3d.utility.Vector3iVector(faces),
624 | )
625 |
626 | # if not self.silence:
627 | # print("Marching cubes done")
628 |
629 | if estimate_sem:
630 | mesh = self.estimate_vertices_sem(mesh, verts, filter_free_space_vertices)
631 | else:
632 | if estimate_color:
633 | mesh = self.estimate_vertices_color(mesh, verts)
634 |
635 | mesh.remove_duplicated_vertices()
636 |
637 | if mesh_normal:
638 | mesh.compute_vertex_normals()
639 |
640 | if filter_isolated_mesh:
641 | mesh = self.filter_isolated_vertices(mesh, self.config.min_cluster_vertices)
642 |
643 | # global transform (to world coordinate system) before output
644 | if not np.array_equal(self.global_transform, np.eye(4)):
645 | mesh.transform(self.global_transform)
646 |
647 | # write the mesh to ply file
648 | if mesh_path is not None:
649 | o3d.io.write_triangle_mesh(mesh_path, mesh)
650 | if not self.silence:
651 | print("save the mesh to %s\n" % (mesh_path))
652 |
653 | return mesh
654 |
--------------------------------------------------------------------------------
/utils/point_cloud2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2008 Willow Garage, Inc.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions are met:
5 | #
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | #
9 | # * Redistributions in binary form must reproduce the above copyright
10 | # notice, this list of conditions and the following disclaimer in the
11 | # documentation and/or other materials provided with the distribution.
12 | #
13 | # * Neither the name of the Willow Garage, Inc. nor the names of its
14 | # contributors may be used to endorse or promote products derived from
15 | # this software without specific prior written permission.
16 | #
17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 | # POSSIBILITY OF SUCH DAMAGE.
28 |
29 | """
30 | This file is based on https://github.com/ros2/common_interfaces/blob/4bac182a0a582b5e6b784d9fa9f0dabc1aca4d35/sensor_msgs_py/sensor_msgs_py/point_cloud2.py
31 | All rights reserved to the original authors: Tim Field and Florian Vahl.
32 | """
33 |
34 | import sys
35 | from typing import Iterable, List, Optional, Tuple
36 |
37 | import numpy as np
38 |
39 | try:
40 | from rosbags.typesys.types import sensor_msgs__msg__PointCloud2 as PointCloud2
41 | from rosbags.typesys.types import sensor_msgs__msg__PointField as PointField
42 | except ImportError as e:
43 | raise ImportError('rosbags library not installed, run "pip install -U rosbags"') from e
44 |
45 |
46 | _DATATYPES = {}
47 | _DATATYPES[PointField.INT8] = np.dtype(np.int8)
48 | _DATATYPES[PointField.UINT8] = np.dtype(np.uint8)
49 | _DATATYPES[PointField.INT16] = np.dtype(np.int16)
50 | _DATATYPES[PointField.UINT16] = np.dtype(np.uint16)
51 | _DATATYPES[PointField.INT32] = np.dtype(np.int32)
52 | _DATATYPES[PointField.UINT32] = np.dtype(np.uint32)
53 | _DATATYPES[PointField.FLOAT32] = np.dtype(np.float32)
54 | _DATATYPES[PointField.FLOAT64] = np.dtype(np.float64)
55 |
56 | DUMMY_FIELD_PREFIX = "unnamed_field"
57 |
58 |
59 | def read_point_cloud(msg: PointCloud2) -> Tuple[np.ndarray, np.ndarray]:
60 | """
61 | Extract poitns and timestamps from a PointCloud2 message.
62 |
63 | :return: Tuple of [points, timestamps]
64 | points: array of x, y z points, shape: (N, 3)
65 | timestamps: array of per-pixel timestamps, shape: (N,)
66 | """
67 | field_names = ["x", "y", "z"]
68 | t_field = None
69 | for field in msg.fields:
70 | if field.name in ["t", "timestamp", "time", "ts"]:
71 | t_field = field.name
72 | field_names.append(t_field)
73 | break
74 |
75 | points_structured = read_points(msg, field_names=field_names)
76 | points = np.column_stack(
77 | [points_structured["x"], points_structured["y"], points_structured["z"]]
78 | )
79 |
80 | # Remove nan if any
81 | points = points[~np.any(np.isnan(points), axis=1)]
82 |
83 | if t_field:
84 | timestamps = points_structured[t_field].astype(np.float64)
85 | min_timestamp = np.min(timestamps)
86 | max_timestamp = np.max(timestamps)
87 | timestamps = (timestamps - min_timestamp) / (max_timestamp - min_timestamp)
88 | else:
89 | timestamps = None
90 | return points.astype(np.float64), timestamps
91 |
92 |
93 | def read_points(
94 | cloud: PointCloud2,
95 | field_names: Optional[List[str]] = None,
96 | uvs: Optional[Iterable] = None,
97 | reshape_organized_cloud: bool = False,
98 | ) -> np.ndarray:
99 | """
100 | Read points from a sensor_msgs.PointCloud2 message.
101 | :param cloud: The point cloud to read from sensor_msgs.PointCloud2.
102 | :param field_names: The names of fields to read. If None, read all fields.
103 | (Type: Iterable, Default: None)
104 | :param uvs: If specified, then only return the points at the given
105 | coordinates. (Type: Iterable, Default: None)
106 | :param reshape_organized_cloud: Returns the array as an 2D organized point cloud if set.
107 | :return: Structured NumPy array containing all points.
108 | """
109 | # Cast bytes to numpy array
110 | points = np.ndarray(
111 | shape=(cloud.width * cloud.height,),
112 | dtype=dtype_from_fields(cloud.fields, point_step=cloud.point_step),
113 | buffer=cloud.data,
114 | )
115 |
116 | # Keep only the requested fields
117 | if field_names is not None:
118 | assert all(
119 | field_name in points.dtype.names for field_name in field_names
120 | ), "Requests field is not in the fields of the PointCloud!"
121 | # Mask fields
122 | points = points[list(field_names)]
123 |
124 | # Swap array if byte order does not match
125 | if bool(sys.byteorder != "little") != bool(cloud.is_bigendian):
126 | points = points.byteswap(inplace=True)
127 |
128 | # Select points indexed by the uvs field
129 | if uvs is not None:
130 | # Don't convert to numpy array if it is already one
131 | if not isinstance(uvs, np.ndarray):
132 | uvs = np.fromiter(uvs, int)
133 | # Index requested points
134 | points = points[uvs]
135 |
136 | # Cast into 2d array if cloud is 'organized'
137 | if reshape_organized_cloud and cloud.height > 1:
138 | points = points.reshape(cloud.width, cloud.height)
139 |
140 | return points
141 |
142 |
143 | def dtype_from_fields(fields: Iterable[PointField], point_step: Optional[int] = None) -> np.dtype:
144 | """
145 | Convert a Iterable of sensor_msgs.msg.PointField messages to a np.dtype.
146 | :param fields: The point cloud fields.
147 | (Type: iterable of sensor_msgs.msg.PointField)
148 | :param point_step: Point step size in bytes. Calculated from the given fields by default.
149 | (Type: optional of integer)
150 | :returns: NumPy datatype
151 | """
152 | # Create a lists containing the names, offsets and datatypes of all fields
153 | field_names = []
154 | field_offsets = []
155 | field_datatypes = []
156 | for i, field in enumerate(fields):
157 | # Datatype as numpy datatype
158 | datatype = _DATATYPES[field.datatype]
159 | # Name field
160 | if field.name == "":
161 | name = f"{DUMMY_FIELD_PREFIX}_{i}"
162 | else:
163 | name = field.name
164 | # Handle fields with count > 1 by creating subfields with a suffix consiting
165 | # of "_" followed by the subfield counter [0 -> (count - 1)]
166 | assert field.count > 0, "Can't process fields with count = 0."
167 | for a in range(field.count):
168 | # Add suffix if we have multiple subfields
169 | if field.count > 1:
170 | subfield_name = f"{name}_{a}"
171 | else:
172 | subfield_name = name
173 | assert subfield_name not in field_names, "Duplicate field names are not allowed!"
174 | field_names.append(subfield_name)
175 | # Create new offset that includes subfields
176 | field_offsets.append(field.offset + a * datatype.itemsize)
177 | field_datatypes.append(datatype.str)
178 |
179 | # Create dtype
180 | dtype_dict = {"names": field_names, "formats": field_datatypes, "offsets": field_offsets}
181 | if point_step is not None:
182 | dtype_dict["itemsize"] = point_step
183 | return np.dtype(dtype_dict)
184 |
--------------------------------------------------------------------------------
/utils/semantic_kitti_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LabelDataConverter:
5 | """Convert .label binary data to instance id and rgb"""
6 |
7 | def __init__(self, labelscan):
8 |
9 | self.convertdata(labelscan)
10 |
11 | def convertdata(self, labelscan):
12 |
13 | self.semantic_id = []
14 | self.rgb_id = []
15 | self.instance_id = []
16 | self.rgb_arr_id = []
17 |
18 | for counting in range(len(labelscan)):
19 |
20 | sem_id = int(labelscan[counting]) & 0xFFFF # lower 16 bit
21 | rgb, rgb_arr = self.get_sem_rgb(sem_id)
22 | instance_id = int(labelscan[counting]) >> 16 # higher 16 bit
23 | # rgb = self.get_random_rgb(instance_id)
24 |
25 | # print("Sem label:", sem_id, "Ins label:", instance_id, "Color:", hex(rgb))
26 | # print(hex(rgb))
27 | # instance label is given in each semantic label
28 |
29 | self.semantic_id.append(sem_id)
30 | self.rgb_id.append(rgb)
31 | self.rgb_arr_id.append(rgb_arr)
32 | self.instance_id.append(instance_id)
33 |
34 |
35 | def get_random_rgb(n):
36 | n = ((n ^ n >> 15) * 2246822519) & 0xFFFFFFFF
37 | n = ((n ^ n >> 13) * 3266489917) & 0xFFFFFFFF
38 | n = (n ^ n >> 16) >> 8
39 | print(n)
40 | return tuple(n.to_bytes(3, "big"))
41 |
42 |
43 | sem_kitti_learning_map = {
44 | 0: 0, # "unlabeled"
45 | 1: 0, # "outlier" mapped to "unlabeled" --------------------------mapped
46 | 10: 1, # "car"
47 | 11: 2, # "bicycle"
48 | 13: 5, # "bus" mapped to "other-vehicle" --------------------------mapped
49 | 15: 3, # "motorcycle"
50 | 16: 5, # "on-rails" mapped to "other-vehicle" ---------------------mapped
51 | 18: 4, # "truck"
52 | 20: 5, # "other-vehicle"
53 | 30: 6, # "person"
54 | 31: 7, # "bicyclist"
55 | 32: 8, # "motorcyclist"
56 | 40: 9, # "road"
57 | 44: 10, # "parking"
58 | 48: 11, # "sidewalk"
59 | 49: 12, # "other-ground"
60 | 50: 13, # "building"
61 | 51: 14, # "fence"
62 | 52: 20, # "other-structure" mapped to "unlabeled" ------------------mapped
63 | 60: 9, # "lane-marking" to "road" ---------------------------------mapped
64 | 70: 15, # "vegetation"
65 | 71: 16, # "trunk"
66 | 72: 17, # "terrain"
67 | 80: 18, # "pole"
68 | 81: 19, # "traffic-sign"
69 | 99: 20, # "other-object" to "unlabeled" ----------------------------mapped
70 | 252: 1, # "moving-car"
71 | 253: 7, # "moving-bicyclist"
72 | 254: 6, # "moving-person"
73 | 255: 8, # "moving-motorcyclist"
74 | 256: 5, # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped
75 | 257: 5, # "moving-bus" mapped to "moving-other-vehicle" -----------mapped
76 | 258: 4, # "moving-truck"
77 | 259: 5, # "moving-other-vehicle"
78 | }
79 |
80 |
81 | def sem_map_function(value):
82 | return sem_kitti_learning_map.get(value, value)
83 |
84 |
85 | sem_kitti_labels = {
86 | 0: "unlabeled",
87 | 1: "car",
88 | 2: "bicycle",
89 | 3: "motorcycle",
90 | 4: "truck",
91 | 5: "other-vehicle",
92 | 6: "person",
93 | 7: "bicyclist",
94 | 8: "motorcyclist",
95 | 9: "road",
96 | 10: "parking",
97 | 11: "sidewalk",
98 | 12: "other-ground",
99 | 13: "building",
100 | 14: "fence",
101 | 15: "vegetation",
102 | 16: "trunk",
103 | 17: "terrain",
104 | 18: "pole",
105 | 19: "traffic-sign",
106 | 20: "others",
107 | }
108 |
109 | sem_kitti_color_map = { # rgb
110 | 0: [255, 255, 255],
111 | 1: [100, 150, 245],
112 | 2: [100, 230, 245],
113 | 3: [30, 60, 150],
114 | 4: [80, 30, 180],
115 | 5: [0, 0, 255],
116 | 6: [255, 30, 30],
117 | 7: [255, 40, 200],
118 | 8: [150, 30, 90],
119 | 9: [255, 0, 255],
120 | 10: [255, 150, 255],
121 | 11: [75, 0, 75],
122 | 12: [175, 0, 75],
123 | 13: [255, 200, 0],
124 | 14: [255, 120, 50],
125 | 15: [0, 175, 0],
126 | 16: [135, 60, 0],
127 | 17: [150, 240, 80],
128 | 18: [255, 240, 150],
129 | 19: [255, 0, 0],
130 | 20: [30, 30, 30],
131 | }
132 |
--------------------------------------------------------------------------------
/utils/so3_math.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file so3_math.py
3 | # @author Junlong Jiang [jiangjunlong@mail.dlut.edu.cn]
4 | # Copyright (c) 2025 Junlong Jiang, all rights reserved
5 |
6 | import torch
7 | import numpy as np
8 |
9 | def vec2skew(v: torch.Tensor):
10 | """创建一个3x3的斜对称矩阵,对应于向量v的叉积操作"""
11 | zero = torch.zeros_like(v[0])
12 | return torch.tensor([
13 | [zero, -v[2], v[1]],
14 | [v[2], zero, -v[0]],
15 | [-v[1], v[0], zero]
16 | ], device=v.device, dtype=v.dtype)
17 |
18 | def vectors_to_skew_symmetric(vectors: torch.Tensor):
19 | """
20 | Convert a batch of vectors to a batch of skew-symmetric matrices.
21 |
22 | Parameters:
23 | vectors : torch.Tensor
24 | Input tensor containing vectors. Shape [m, 3]
25 |
26 | Returns:
27 | skew_matrices : torch.Tensor
28 | Output tensor containing skew-symmetric matrices. Shape [m, 3, 3]
29 | """
30 | skew_matrices = torch.zeros((vectors.shape[0], 3, 3), dtype=vectors.dtype, device=vectors.device)
31 | skew_matrices[:, 0, 1] = -vectors[:, 2]
32 | skew_matrices[:, 0, 2] = vectors[:, 1]
33 | skew_matrices[:, 1, 0] = vectors[:, 2]
34 | skew_matrices[:, 1, 2] = -vectors[:, 0]
35 | skew_matrices[:, 2, 0] = -vectors[:, 1]
36 | skew_matrices[:, 2, 1] = vectors[:, 0]
37 |
38 | return skew_matrices
39 |
40 |
41 | def so3Exp(so3: torch.Tensor):
42 | """将 so3 向量转换为 SO3 旋转矩阵。
43 |
44 | 参数:
45 | so3 (torch.Tensor): 形状为(3,)的 so3 向量。
46 |
47 | 返回:
48 | torch.Tensor: 形状为(3, 3)的 SO3 旋转矩阵。
49 | """
50 | so3_norm = torch.norm(so3)
51 | if so3_norm <= 1e-7:
52 | return torch.eye(3, device=so3.device, dtype=so3.dtype)
53 |
54 | so3_skew_sym = vec2skew(so3)
55 | I = torch.eye(3, device=so3.device, dtype=so3.dtype)
56 |
57 | SO3 = I + (so3_skew_sym / so3_norm) * torch.sin(so3_norm) + \
58 | (so3_skew_sym @ so3_skew_sym / (so3_norm * so3_norm)) * (1 - torch.cos(so3_norm))
59 | return SO3
60 |
61 |
62 | def SO3Log(SO3: torch.Tensor):
63 | '''李群转换为李代数
64 |
65 | 参数:
66 | SO3 (torch.Tensor): 形状为(3, 3)的 SO3 旋转矩阵。
67 |
68 | 返回:
69 | torch.Tensor: 形状为(3,)的 so3 向量。
70 | '''
71 | # 计算旋转角度 theta
72 | trace = SO3.trace()
73 | theta = torch.acos((trace - 1) / 2) if trace <= 3 - 1e-6 else 0.0
74 |
75 | # 计算so3向量
76 | so3 = torch.tensor([
77 | SO3[2, 1] - SO3[1, 2],
78 | SO3[0, 2] - SO3[2, 0],
79 | SO3[1, 0] - SO3[0, 1]
80 | ], device=SO3.device)
81 |
82 | # 调整so3向量的尺度
83 | if abs(theta) < 0.001:
84 | so3 = 0.5 * so3
85 | else:
86 | so3 = 0.5 * theta / torch.sin(theta) * so3
87 |
88 | return so3
89 |
90 |
91 | def A_T(v: torch.Tensor):
92 | """根据给定的三维向量v,计算相应的旋转矩阵"""
93 | squared_norm = torch.dot(v, v) # 计算向量的模的平方
94 | norm = torch.sqrt(squared_norm) # 向量的模
95 | identity = torch.eye(3, device=v.device, dtype=v.dtype) # 单位矩阵
96 |
97 | if norm < 1e-11:
98 | return identity
99 | else:
100 | S = vec2skew(v) # 计算斜对称矩阵
101 | term1 = (1 - torch.cos(norm)) / squared_norm
102 | term2 = (1 - torch.sin(norm) / norm) / squared_norm
103 | return identity + term1 * S + term2 * torch.matmul(S, S)
--------------------------------------------------------------------------------
/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file visualizer.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 | # Adapted from Nacho's awesome lidar visualizer (https://github.com/PRBonn/lidar-visualizer)
7 | # This is deprecated, now we use the GUI in gui/slam_gui.py
8 |
9 | import os
10 | from functools import partial
11 | from typing import Callable, List
12 |
13 | import numpy as np
14 | import open3d as o3d
15 |
16 | from utils.config import Config
17 |
18 | YELLOW = np.array([1, 0.706, 0])
19 | RED = np.array([255, 0, 0]) / 255.0
20 | PURPLE = np.array([238, 130, 238]) / 255.0
21 | BLACK = np.array([0, 0, 0]) / 255.0
22 | GOLDEN = np.array([1.0, 0.843, 0.0])
23 | GREEN = np.array([0, 128, 0]) / 255.0
24 | BLUE = np.array([0, 0, 128]) / 255.0
25 | LIGHTBLUE = np.array([0.00, 0.65, 0.93])
26 |
27 |
28 | class MapVisualizer:
29 | # Public Interaface ----------------------------------------------------------------------------
30 | def __init__(self, config: Config = None):
31 |
32 | # Initialize GUI controls
33 | self.block_vis = True
34 | self.play_crun = True
35 | self.reset_bounding_box = True
36 | self.config = config
37 |
38 | self.cur_frame_id: int = 0
39 |
40 | # Create data
41 | self.scan = o3d.geometry.PointCloud()
42 | self.frame_axis = o3d.geometry.TriangleMesh()
43 | self.sensor_cad = o3d.geometry.TriangleMesh()
44 | self.mesh = o3d.geometry.TriangleMesh()
45 | self.sdf = o3d.geometry.PointCloud()
46 | self.neural_points = o3d.geometry.PointCloud()
47 | self.data_pool = o3d.geometry.PointCloud()
48 |
49 | self.odom_traj_pcd = o3d.geometry.PointCloud()
50 | self.pgo_traj_pcd = o3d.geometry.PointCloud()
51 | self.gt_traj_pcd = o3d.geometry.PointCloud()
52 |
53 | self.odom_traj = o3d.geometry.LineSet()
54 | self.pgo_traj = o3d.geometry.LineSet()
55 | self.gt_traj = o3d.geometry.LineSet()
56 |
57 | self.pgo_edges = o3d.geometry.LineSet()
58 |
59 | self.log_path = "./"
60 | self.sdf_slice_height = 0.0
61 | self.mc_res_m = 0.1
62 | self.mesh_min_nn = 10
63 | self.keep_local_mesh = True
64 |
65 | self.frame_axis_len = 0.5
66 |
67 | if config is not None:
68 | self.log_path = os.path.join(config.run_path, "log")
69 | self.frame_axis_len = config.vis_frame_axis_len
70 | self.sdf_slice_height = config.sdf_slice_height
71 | if self.config.sensor_cad_path is not None:
72 | self.sensor_cad = o3d.io.read_triangle_mesh(config.sensor_cad_path)
73 | self.sensor_cad.compute_vertex_normals()
74 | self.mc_res_m = config.mc_res_m
75 | self.mesh_min_nn = config.mesh_min_nn
76 | self.keep_local_mesh = config.keep_local_mesh
77 |
78 | self.before_pgo = True
79 | self.last_pose = np.eye(4)
80 |
81 | # Initialize visualizer
82 | self.vis = o3d.visualization.VisualizerWithKeyCallback()
83 | self._register_key_callbacks()
84 | self._initialize_visualizer()
85 |
86 | # Visualization options
87 | self.render_mesh: bool = True
88 | self.render_pointcloud: bool = True
89 | self.render_frame_axis: bool = True
90 | self.render_trajectory: bool = True
91 | self.render_gt_trajectory: bool = False
92 | self.render_odom_trajectory: bool = (
93 | True # when pgo is on, visualize the odom or not
94 | )
95 | self.render_neural_points: bool = False
96 | self.render_data_pool: bool = False
97 | self.render_sdf: bool = False
98 | self.render_pgo: bool = self.render_trajectory
99 |
100 | self.sdf_slice_height_step: float = 0.1
101 |
102 | self.vis_pc_color: bool = True
103 | self.pc_uniform_color: bool = False
104 |
105 | self.vis_only_cur_samples: bool = False
106 |
107 | self.mc_res_change_interval_m: float = 0.2 * self.mc_res_m
108 |
109 | self.vis_global: bool = False
110 |
111 | self.ego_view: bool = False
112 | self.ego_change_flag: bool = False
113 |
114 | self.debug_mode: int = 0
115 |
116 | self.neural_points_vis_mode: int = 0
117 |
118 | self.global_viewpoint: bool = False
119 | self.view_control = self.vis.get_view_control()
120 | self.camera_params = self.view_control.convert_to_pinhole_camera_parameters()
121 |
122 | def update_view(self):
123 | self.vis.poll_events()
124 | self.vis.update_renderer()
125 |
126 | def pause_view(self):
127 | while self.block_vis:
128 | self.update_view()
129 | if self.play_crun:
130 | break
131 |
132 | def update(
133 | self,
134 | scan=None,
135 | pose=None,
136 | sdf=None,
137 | mesh=None,
138 | neural_points=None,
139 | data_pool=None,
140 | pause_now=False,
141 | ):
142 | self._update_geometries(scan, pose, sdf, mesh, neural_points, data_pool)
143 | self.update_view()
144 | self.pause_view()
145 | if pause_now:
146 | self.stop()
147 |
148 | def update_traj(
149 | self,
150 | cur_pose=None,
151 | odom_poses=None,
152 | gt_poses=None,
153 | pgo_poses=None,
154 | pgo_edges=None,
155 | ):
156 | self._update_traj(cur_pose, odom_poses, gt_poses, pgo_poses, pgo_edges)
157 | self.update_view()
158 | self.pause_view()
159 |
160 | def update_pointcloud(self, scan):
161 | self._update_pointcloud(scan)
162 | self.update_view()
163 | self.pause_view()
164 |
165 | def update_mesh(self, mesh):
166 | self._update_mesh(mesh)
167 | self.update_view()
168 | self.pause_view()
169 |
170 | def destroy_window(self):
171 | self.vis.destroy_window()
172 |
173 | def stop(self):
174 | self.play_crun = not self.play_crun
175 | while self.block_vis:
176 | self.vis.poll_events()
177 | self.vis.update_renderer()
178 | if self.play_crun:
179 | break
180 |
181 | def _initialize_visualizer(self):
182 | w_name = "📍 PIN-SLAM Visualizer"
183 | self.vis.create_window(
184 | window_name=w_name, width=2560, height=1600
185 | ) # 1920, 1080
186 | self.vis.add_geometry(self.scan)
187 | self.vis.add_geometry(self.sdf)
188 | self.vis.add_geometry(self.frame_axis)
189 | self.vis.add_geometry(self.mesh)
190 | self.vis.add_geometry(self.neural_points)
191 | self.vis.add_geometry(self.data_pool)
192 | self.vis.add_geometry(self.odom_traj_pcd)
193 | self.vis.add_geometry(self.gt_traj_pcd)
194 | self.vis.add_geometry(self.pgo_traj_pcd)
195 | self.vis.add_geometry(self.pgo_edges)
196 |
197 | self.vis.get_render_option().line_width = 500
198 | self.vis.get_render_option().light_on = True
199 | self.vis.get_render_option().mesh_shade_option = (
200 | o3d.visualization.MeshShadeOption.Color
201 | )
202 |
203 | if self.config is not None:
204 | self.vis.get_render_option().point_size = self.config.vis_point_size
205 | if self.config.mesh_vis_normal:
206 | self.vis.get_render_option().mesh_color_option = (
207 | o3d.visualization.MeshColorOption.Normal
208 | )
209 |
210 | print(
211 | f"{w_name} initialized. Press:\n"
212 | "\t[SPACE] to pause/resume\n"
213 | "\t[ESC/Q] to exit\n"
214 | "\t [G] to toggle on/off the global/local map visualization\n"
215 | "\t [E] to toggle on/off the ego/map viewpoint\n"
216 | "\t [F] to toggle on/off the current point cloud\n"
217 | "\t [M] to toggle on/off the mesh\n"
218 | "\t [T] to toggle on/off PIN SLAM trajectory\n"
219 | "\t [Y] to toggle on/off the reference trajectory\n"
220 | "\t [U] to toggle on/off PIN odometry trajectory\n"
221 | "\t [A] to toggle on/off the current frame axis\n"
222 | "\t [P] to toggle on/off the neural points map\n"
223 | "\t [D] to toggle on/off the data pool\n"
224 | "\t [I] to toggle on/off the sdf map slice\n"
225 | "\t [R] to center the view point\n"
226 | "\t [Z] to save the currently visualized entities in the log folder\n"
227 | )
228 |
229 | def _register_key_callback(self, keys: List, callback: Callable):
230 | for key in keys:
231 | self.vis.register_key_callback(ord(str(key)), partial(callback))
232 |
233 | def _register_key_callbacks(self):
234 | self._register_key_callback(["Ā", "Q"], self._quit)
235 | self._register_key_callback([" "], self._start_stop)
236 | self._register_key_callback(["R"], self._center_viewpoint)
237 | self._register_key_callback(["E"], self._toggle_ego)
238 | self._register_key_callback(["F"], self._toggle_pointcloud)
239 | self._register_key_callback(["A"], self._toggle_frame_axis)
240 | self._register_key_callback(["I"], self._toggle_sdf)
241 | self._register_key_callback(["M"], self._toggle_mesh)
242 | self._register_key_callback(["P"], self._toggle_neural_points)
243 | self._register_key_callback(["D"], self._toggle_data_pool)
244 | self._register_key_callback(["T"], self._toggle_trajectory)
245 | self._register_key_callback(["Y"], self._toggle_gt_trajectory)
246 | self._register_key_callback(["U"], self._toggle_odom_trajectory)
247 | self._register_key_callback(["G"], self._toggle_global)
248 | self._register_key_callback(["Z"], self._save_cur_vis)
249 | self._register_key_callback([";"], self._toggle_loop_debug)
250 | self._register_key_callback(
251 | ["/"], self._toggle_neural_point_vis_mode
252 | ) # vis neural point color using feature, ts or certainty
253 | self._register_key_callback(["'"], self._toggle_vis_cur_sample)
254 | self._register_key_callback(["]"], self._toggle_increase_mesh_res)
255 | self._register_key_callback(["["], self._toggle_decrease_mesh_res)
256 | self._register_key_callback(["."], self._toggle_increase_mesh_nn) # '>'
257 | self._register_key_callback([","], self._toggle_decrease_mesh_nn) # '<'
258 | self._register_key_callback(["5"], self._toggle_point_color)
259 | self._register_key_callback(["6"], self._toggle_uniform_color)
260 | self._register_key_callback(["7"], self._switch_background)
261 | # self.vis.register_key_callback(262, partial(self._toggle_)) # right arrow # for future
262 | # self.vis.register_key_callback(263, partial(self._toggle_)) # left arrow
263 | self.vis.register_key_callback(
264 | 265, partial(self._toggle_increase_slice_height)
265 | ) # up arrow
266 | self.vis.register_key_callback(
267 | 264, partial(self._toggle_decrease_slice_height)
268 | ) # down arrow
269 | # leave C and V as the view copying, pasting function
270 | # use alt + prt sc for the window screenshot
271 |
272 | def _switch_background(self, vis):
273 | cur_background_color = vis.get_render_option().background_color
274 | vis.get_render_option().background_color = np.ones(3) - cur_background_color
275 |
276 | def _center_viewpoint(self, vis):
277 | self.reset_bounding_box = not self.reset_bounding_box
278 | vis.reset_view_point(True)
279 |
280 | def _toggle_point_color(
281 | self, vis
282 | ): # actually used to show the source point cloud weight for registration
283 | self.vis_pc_color = not self.vis_pc_color
284 |
285 | def _quit(self, vis):
286 | print("Destroying Visualizer")
287 | vis.destroy_window()
288 | os._exit(0)
289 |
290 | def _save_cur_vis(self, vis):
291 | if self.data_pool.has_points():
292 | data_pool_pc_name = str(self.cur_frame_id) + "_training_sdf_pool.ply"
293 | data_pool_pc_path = os.path.join(self.log_path, data_pool_pc_name)
294 | o3d.io.write_point_cloud(data_pool_pc_path, self.data_pool)
295 | print("Output current training data pool to: ", data_pool_pc_path)
296 | if self.scan.has_points():
297 | if self.vis_pc_color:
298 | scan_pc_name = str(self.cur_frame_id) + "_scan_map.ply"
299 | else:
300 | scan_pc_name = str(self.cur_frame_id) + "_scan_reg.ply"
301 | scan_pc_path = os.path.join(self.log_path, scan_pc_name)
302 | o3d.io.write_point_cloud(scan_pc_path, self.scan)
303 | print("Output current scan to: ", scan_pc_path)
304 | if self.neural_points.has_points():
305 | neural_point_name = str(self.cur_frame_id) + "_neural_points.ply"
306 | neural_point_path = os.path.join(self.log_path, neural_point_name)
307 | o3d.io.write_point_cloud(neural_point_path, self.neural_points)
308 | print("Output current neural points to: ", neural_point_path)
309 | if self.sdf.has_points():
310 | sdf_slice_name = str(self.cur_frame_id) + "_sdf_slice.ply"
311 | sdf_slice_path = os.path.join(self.log_path, sdf_slice_name)
312 | o3d.io.write_point_cloud(sdf_slice_path, self.sdf)
313 | print("Output current SDF slice to: ", sdf_slice_path)
314 | if self.mesh.has_triangles():
315 | mesh_name = str(self.cur_frame_id) + "_mesh_vis.ply"
316 | mesh_path = os.path.join(self.log_path, mesh_name)
317 | o3d.io.write_triangle_mesh(mesh_path, self.mesh)
318 | print("Output current mesh to: ", mesh_path)
319 | if self.frame_axis.has_triangles():
320 | ego_name = str(self.cur_frame_id) + "_sensor_vis.ply"
321 | ego_path = os.path.join(self.log_path, ego_name)
322 | o3d.io.write_triangle_mesh(ego_path, self.frame_axis)
323 | print("Output current sensor model to: ", ego_path)
324 |
325 | def _next_frame(self, vis): # FIXME
326 | self.block_vis = not self.block_vis
327 |
328 | def _start_stop(self, vis):
329 | self.play_crun = not self.play_crun
330 |
331 | def _toggle_pointcloud(self, vis):
332 | self.render_pointcloud = not self.render_pointcloud
333 |
334 | def _toggle_frame_axis(self, vis):
335 | self.render_frame_axis = not self.render_frame_axis
336 |
337 | def _toggle_trajectory(self, vis):
338 | self.render_trajectory = not self.render_trajectory
339 |
340 | def _toggle_gt_trajectory(self, vis):
341 | self.render_gt_trajectory = not self.render_gt_trajectory
342 |
343 | def _toggle_odom_trajectory(self, vis):
344 | self.render_odom_trajectory = not self.render_odom_trajectory
345 |
346 | def _toggle_pgo(self, vis):
347 | self.render_pgo = not self.render_pgo
348 |
349 | def _toggle_sdf(self, vis):
350 | self.render_sdf = not self.render_sdf
351 |
352 | def _toggle_mesh(self, vis):
353 | self.render_mesh = not self.render_mesh
354 | print("Show mesh: ", self.render_mesh)
355 |
356 | def _toggle_neural_points(self, vis):
357 | self.render_neural_points = not self.render_neural_points
358 |
359 | def _toggle_data_pool(self, vis):
360 | self.render_data_pool = not self.render_data_pool
361 |
362 | def _toggle_global(self, vis):
363 | self.vis_global = not self.vis_global
364 |
365 | def _toggle_ego(self, vis):
366 | self.ego_view = not self.ego_view
367 | self.ego_change_flag = True # ego->global or global->ego
368 | self.reset_bounding_box = not self.reset_bounding_box
369 | vis.reset_view_point(True)
370 |
371 | def _toggle_uniform_color(self, vis):
372 | self.pc_uniform_color = not self.pc_uniform_color
373 |
374 | def _toggle_loop_debug(self, vis):
375 | self.debug_mode = (
376 | self.debug_mode + 1
377 | ) % 3 # 0,1,2 # switch between different debug mode
378 | print("Switch to debug mode:", self.debug_mode)
379 |
380 | def _toggle_vis_cur_sample(self, vis):
381 | self.vis_only_cur_samples = not self.vis_only_cur_samples
382 |
383 | def _toggle_neural_point_vis_mode(self, vis):
384 | self.neural_points_vis_mode = (
385 | self.neural_points_vis_mode + 1
386 | ) % 5 # 0,1,2,3,4 # switch between different vis mode
387 | print("Switch to neural point visualization mode:", self.neural_points_vis_mode)
388 |
389 | def _toggle_increase_mesh_res(self, vis):
390 | self.mc_res_m += self.mc_res_change_interval_m
391 | print("Current marching cubes voxel size [m]:", f"{self.mc_res_m:.2f}")
392 |
393 | def _toggle_decrease_mesh_res(self, vis):
394 | self.mc_res_m = max(
395 | self.mc_res_change_interval_m, self.mc_res_m - self.mc_res_change_interval_m
396 | )
397 | print("Current marching cubes voxel size [m]:", f"{self.mc_res_m:.2f}")
398 |
399 | def _toggle_increase_slice_height(self, vis):
400 | self.sdf_slice_height += self.sdf_slice_height_step
401 | print("Current SDF slice height [m]:", f"{self.sdf_slice_height:.2f}")
402 |
403 | def _toggle_decrease_slice_height(self, vis):
404 | self.sdf_slice_height -= self.sdf_slice_height_step
405 | print("Current SDF slice height [m]:", f"{self.sdf_slice_height:.2f}")
406 |
407 | def _toggle_increase_mesh_nn(self, vis):
408 | self.mesh_min_nn += 1
409 | print("Current marching cubes mask nn count:", self.mesh_min_nn)
410 |
411 | def _toggle_decrease_mesh_nn(self, vis):
412 | self.mesh_min_nn = max(5, self.mesh_min_nn - 1)
413 | print("Current marching cubes mask nn count:", self.mesh_min_nn)
414 |
415 | def _toggle_help(self, vis):
416 | print(
417 | f"Instructions. Press:\n"
418 | "\t[SPACE] to pause/resume\n"
419 | "\t[ESC/Q] to exit\n"
420 | "\t [G] to toggle on/off the global/local map visualization\n"
421 | "\t [E] to toggle on/off the ego/map viewpoint\n"
422 | "\t [F] to toggle on/off the current point cloud\n"
423 | "\t [M] to toggle on/off the mesh\n"
424 | "\t [T] to toggle on/off PIN SLAM trajectory\n"
425 | "\t [Y] to toggle on/off the reference trajectory\n"
426 | "\t [U] to toggle on/off PIN odometry trajectory\n"
427 | "\t [A] to toggle on/off the current frame axis\n"
428 | "\t [P] to toggle on/off the neural points map\n"
429 | "\t [D] to toggle on/off the data pool\n"
430 | "\t [I] to toggle on/off the sdf map slice\n"
431 | "\t [R] to center the view point\n"
432 | "\t [Z] to save the currently visualized entities in the log folder\n"
433 | )
434 | self.play_crun = not self.play_crun
435 | return False
436 |
437 | def _update_mesh(self, mesh):
438 | if self.render_mesh:
439 | if mesh is not None:
440 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box)
441 | self.mesh = mesh
442 | self.vis.add_geometry(self.mesh, self.reset_bounding_box)
443 | else:
444 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box)
445 |
446 | def _update_pointcloud(self, scan):
447 | if scan is not None:
448 | self.vis.remove_geometry(self.scan, self.reset_bounding_box)
449 | self.scan = scan
450 | self.vis.add_geometry(self.scan, self.reset_bounding_box)
451 |
452 | if self.reset_bounding_box:
453 | self.vis.reset_view_point(True)
454 | self.reset_bounding_box = False
455 |
456 | def _update_geometries(
457 | self,
458 | scan=None,
459 | pose=None,
460 | sdf=None,
461 | mesh=None,
462 | neural_points=None,
463 | data_pool=None,
464 | ):
465 |
466 | # Scan (toggled by "F")
467 | if self.render_pointcloud:
468 | if scan is not None:
469 | self.scan.points = o3d.utility.Vector3dVector(scan.points)
470 | self.scan.colors = o3d.utility.Vector3dVector(scan.colors)
471 | self.scan.normals = o3d.utility.Vector3dVector(scan.normals)
472 | if self.pc_uniform_color or (
473 | self.vis_pc_color
474 | and (self.config.color_channel == 0)
475 | and (not self.config.semantic_on)
476 | and (not self.config.dynamic_filter_on)
477 | ):
478 | self.scan.paint_uniform_color(GOLDEN)
479 | else:
480 | self.scan.points = o3d.utility.Vector3dVector()
481 | else:
482 | self.scan.points = o3d.utility.Vector3dVector()
483 | # self.scan.colors = o3d.utility.Vector3dVector()
484 | if self.ego_view and pose is not None:
485 | self.scan.transform(np.linalg.inv(pose))
486 | self.vis.update_geometry(self.scan)
487 |
488 | # Mesh Map (toggled by "M")
489 | if self.render_mesh:
490 | if mesh is not None:
491 | if not self.keep_local_mesh:
492 | self.vis.remove_geometry(
493 | self.mesh, self.reset_bounding_box
494 | ) # if comment, then we keep the previous reconstructed mesh (for the case we use local map reconstruction)
495 | self.mesh = mesh
496 | if self.ego_view and pose is not None:
497 | self.mesh.transform(np.linalg.inv(pose))
498 | self.vis.add_geometry(self.mesh, self.reset_bounding_box)
499 | else: # None, meshing for every frame can be time consuming, we just keep the mesh reconstructed from last frame for vis
500 | if self.ego_view and pose is not None:
501 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box)
502 | if self.ego_change_flag: # global -> ego view
503 | self.mesh.transform(np.linalg.inv(pose))
504 | self.ego_change_flag = False
505 | else:
506 | self.mesh.transform(np.linalg.inv(pose) @ self.last_pose)
507 | self.vis.add_geometry(self.mesh, self.reset_bounding_box)
508 | elif self.ego_change_flag: # ego -> global view
509 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box)
510 | self.mesh.transform(self.last_pose)
511 | self.vis.add_geometry(self.mesh, self.reset_bounding_box)
512 | self.ego_change_flag = False
513 | else:
514 | self.vis.remove_geometry(self.mesh, self.reset_bounding_box)
515 |
516 | # Neural Points Map (toggled by "P")
517 | if neural_points is not None:
518 | if self.render_neural_points:
519 | self.neural_points.points = o3d.utility.Vector3dVector(
520 | neural_points.points
521 | )
522 | self.neural_points.colors = o3d.utility.Vector3dVector(
523 | neural_points.colors
524 | )
525 | else:
526 | self.neural_points.points = o3d.utility.Vector3dVector()
527 | else:
528 | self.neural_points.points = o3d.utility.Vector3dVector()
529 | if self.ego_view and pose is not None:
530 | self.neural_points.transform(np.linalg.inv(pose))
531 | self.vis.update_geometry(self.neural_points)
532 |
533 | # Data Pool (toggled by "D")
534 | if data_pool is not None:
535 | if self.render_data_pool:
536 | self.data_pool.points = o3d.utility.Vector3dVector(data_pool.points)
537 | self.data_pool.colors = o3d.utility.Vector3dVector(data_pool.colors)
538 | else:
539 | self.data_pool.points = o3d.utility.Vector3dVector()
540 | else:
541 | self.data_pool.points = o3d.utility.Vector3dVector()
542 | if self.ego_view and pose is not None:
543 | self.data_pool.transform(np.linalg.inv(pose))
544 | self.vis.update_geometry(self.data_pool)
545 |
546 | # SDF map (toggled by "I")
547 | if sdf is not None:
548 | if self.render_sdf:
549 | self.sdf.points = o3d.utility.Vector3dVector(sdf.points)
550 | self.sdf.colors = o3d.utility.Vector3dVector(sdf.colors)
551 | else:
552 | self.sdf.points = o3d.utility.Vector3dVector()
553 | else:
554 | self.sdf.points = o3d.utility.Vector3dVector()
555 | if self.ego_view and pose is not None:
556 | self.sdf.transform(np.linalg.inv(pose))
557 | self.vis.update_geometry(self.sdf)
558 |
559 | # Coordinate frame axis (toggled by "A")
560 | if self.render_frame_axis:
561 | if pose is not None:
562 | self.vis.remove_geometry(self.frame_axis, self.reset_bounding_box)
563 | self.frame_axis = o3d.geometry.TriangleMesh.create_coordinate_frame(
564 | size=self.frame_axis_len, origin=np.zeros(3)
565 | )
566 | self.frame_axis += self.sensor_cad
567 | if not self.ego_view:
568 | self.frame_axis = self.frame_axis.transform(pose)
569 | self.vis.add_geometry(self.frame_axis, self.reset_bounding_box)
570 | else:
571 | self.vis.remove_geometry(self.frame_axis, self.reset_bounding_box)
572 |
573 | if pose is not None:
574 | self.last_pose = pose
575 |
576 | if self.reset_bounding_box:
577 | self.vis.reset_view_point(True)
578 | self.reset_bounding_box = False
579 |
580 | # show traj as lineset
581 | # long list to np conversion time
582 | def _update_traj(
583 | self,
584 | cur_pose=None,
585 | odom_poses_np=None,
586 | gt_poses_np=None,
587 | pgo_poses_np=None,
588 | loop_edges=None,
589 | ):
590 |
591 | self.vis.remove_geometry(self.odom_traj, self.reset_bounding_box)
592 | self.vis.remove_geometry(self.gt_traj, self.reset_bounding_box)
593 | self.vis.remove_geometry(self.pgo_traj, self.reset_bounding_box)
594 | self.vis.remove_geometry(self.pgo_edges, self.reset_bounding_box)
595 |
596 | if (self.render_trajectory and odom_poses_np is not None and odom_poses_np.shape[0] > 1):
597 | if pgo_poses_np is not None and (not self.render_odom_trajectory):
598 | self.odom_traj = o3d.geometry.LineSet()
599 | else:
600 | odom_position_np = odom_poses_np[:, :3, 3]
601 | self.odom_traj.points = o3d.utility.Vector3dVector(odom_position_np)
602 | odom_edges = np.array([[i, i + 1] for i in range(odom_poses_np.shape[0] - 1)])
603 | self.odom_traj.lines = o3d.utility.Vector2iVector(odom_edges)
604 |
605 | if pgo_poses_np is None or self.before_pgo:
606 | self.odom_traj.paint_uniform_color(RED)
607 | else:
608 | self.odom_traj.paint_uniform_color(BLUE)
609 |
610 | if self.ego_view and cur_pose is not None:
611 | self.odom_traj.transform(np.linalg.inv(cur_pose))
612 | else:
613 | self.odom_traj = o3d.geometry.LineSet()
614 |
615 | if (
616 | self.render_trajectory
617 | and pgo_poses_np is not None
618 | and pgo_poses_np.shape[0] > 1
619 | and (not self.before_pgo)
620 | ):
621 | pgo_position_np = pgo_poses_np[:, :3, 3]
622 |
623 | self.pgo_traj.points = o3d.utility.Vector3dVector(pgo_position_np)
624 | pgo_traj_edges = np.array([[i, i + 1] for i in range(pgo_poses_np.shape[0] - 1)])
625 | self.pgo_traj.lines = o3d.utility.Vector2iVector(pgo_traj_edges)
626 | self.pgo_traj.paint_uniform_color(RED)
627 |
628 | if self.ego_view and cur_pose is not None:
629 | self.pgo_traj.transform(np.linalg.inv(cur_pose))
630 |
631 | if self.render_pgo and loop_edges is not None and len(loop_edges) > 0:
632 | edges = np.array(loop_edges)
633 | self.pgo_edges.points = o3d.utility.Vector3dVector(pgo_position_np)
634 | self.pgo_edges.lines = o3d.utility.Vector2iVector(edges)
635 | self.pgo_edges.paint_uniform_color(GREEN)
636 |
637 | if self.ego_view and cur_pose is not None:
638 | self.pgo_edges.transform(np.linalg.inv(cur_pose))
639 | else:
640 | self.pgo_edges = o3d.geometry.LineSet()
641 | else:
642 | self.pgo_traj = o3d.geometry.LineSet()
643 | self.pgo_edges = o3d.geometry.LineSet()
644 |
645 | if (
646 | self.render_trajectory
647 | and self.render_gt_trajectory
648 | and gt_poses_np is not None
649 | and gt_poses_np.shape[0] > 1
650 | ):
651 | gt_position_np = gt_poses_np[:, :3, 3]
652 | self.gt_traj.points = o3d.utility.Vector3dVector(gt_position_np)
653 | gt_edges = np.array([[i, i + 1] for i in range(gt_poses_np.shape[0] - 1)])
654 | self.gt_traj.lines = o3d.utility.Vector2iVector(gt_edges)
655 | self.gt_traj.paint_uniform_color(BLACK)
656 | if odom_poses_np is None:
657 | self.gt_traj.paint_uniform_color(RED)
658 | if self.ego_view and cur_pose is not None:
659 | self.gt_traj.transform(np.linalg.inv(cur_pose))
660 | else:
661 | self.gt_traj = o3d.geometry.LineSet()
662 |
663 | self.vis.add_geometry(self.odom_traj, self.reset_bounding_box)
664 | self.vis.add_geometry(self.gt_traj, self.reset_bounding_box)
665 | self.vis.add_geometry(self.pgo_traj, self.reset_bounding_box)
666 | self.vis.add_geometry(self.pgo_edges, self.reset_bounding_box)
667 |
668 | def _toggle_view(self, vis):
669 | self.global_viewpoint = not self.global_viewpoint
670 | vis.update_renderer()
671 | vis.reset_view_point(True)
672 | current_camera = self.view_control.convert_to_pinhole_camera_parameters()
673 | if self.camera_params and not self.global_viewpoint:
674 | self.view_control.convert_from_pinhole_camera_parameters(self.camera_params)
675 | self.camera_params = current_camera
676 |
--------------------------------------------------------------------------------
/vis_pin_map.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # @file vis_pin_map.py
3 | # @author Yue Pan [yue.pan@igg.uni-bonn.de]
4 | # Copyright (c) 2024 Yue Pan, all rights reserved
5 |
6 | import glob
7 | import os
8 | import sys
9 | import time
10 |
11 | import numpy as np
12 | import open3d as o3d
13 | import torch
14 | import torch.multiprocessing as mp
15 | import dtyper as typer
16 | from rich import print
17 |
18 | from model.decoder import Decoder
19 | from model.neural_points import NeuralPoints
20 | from utils.config import Config
21 | from utils.mesher import Mesher
22 | from utils.tools import setup_experiment, split_chunks, load_decoders, remove_gpu_cache
23 |
24 | from gui import slam_gui
25 | from gui.gui_utils import ParamsGUI, VisPacket
26 |
27 |
28 | '''
29 | load the pin-map and do the reconstruction
30 | '''
31 |
32 | app = typer.Typer(add_completion=False, rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]})
33 |
34 | docstring = f"""
35 | :round_pushpin: Inspect the PIN Map \n
36 |
37 | [bold green]Examples: [/bold green]
38 |
39 | # Inspect the PIN Map stored in a mapping result folder, showing both the neural points and the mesh reconstructed with a certain marching cubes resolution
40 | $ python3 vis_pin_map.ply :open_file_folder: -m
41 |
42 | # Additionally, you can specify the cropped point cloud file and the output mesh file
43 | $ python3 vis_pin_map.ply :open_file_folder: -m -c :page_facing_up: -o :page_facing_up:
44 |
45 | """
46 |
47 | @app.command(help=docstring)
48 | def vis_pin_map(
49 | result_folder: str = typer.Argument(..., help='Path to the result folder'),
50 | mesh_res_m: float = typer.Option(None, '--mesh_res_m', '-m', help='Resolution of the mesh in meters'),
51 | cropped_ply_filename: str = typer.Option("neural_points.ply", '--cropped_ply_filename', '-c', help='Path to the cropped point cloud file'),
52 | output_mesh_filename: str = typer.Option(None, '--output_mesh_filename', '-o', help='Path to the output mesh file'),
53 | mc_nn: int = typer.Option(9, '--mc_nn', '-n', help='Minimum number of neighbors for SDF querying for marching cubes'),
54 | o3d_vis_on: bool = typer.Option(True, '--visualize_on', '-v', help='Turn on the visualizer'),
55 | ):
56 |
57 | config = Config()
58 |
59 | yaml_files = glob.glob(f"{result_folder}/*.yaml")
60 | if len(yaml_files) > 1: # Check if there is exactly one YAML file
61 | sys.exit("There are multiple YAML files. Please handle accordingly.")
62 | elif len(yaml_files) == 0: # If no YAML files are found
63 | sys.exit("No YAML files found in the specified path.")
64 | config.load(yaml_files[0])
65 | config.model_path = os.path.join(result_folder, "model", "pin_map.pth")
66 |
67 | print("[bold green]Load and inspect PIN Map[/bold green]","📍" )
68 |
69 | run_path = setup_experiment(config, sys.argv, debug_mode=True)
70 |
71 | mp.set_start_method("spawn") # don't forget this
72 |
73 | # initialize the mlp decoder
74 | geo_mlp = Decoder(config, config.geo_mlp_hidden_dim, config.geo_mlp_level, 1)
75 | sem_mlp = Decoder(config, config.sem_mlp_hidden_dim, config.sem_mlp_level, config.sem_class_count + 1) if config.semantic_on else None
76 | color_mlp = Decoder(config, config.color_mlp_hidden_dim, config.color_mlp_level, config.color_channel) if config.color_on else None
77 |
78 | mlp_dict = {"sdf":geo_mlp, "semantic": sem_mlp,"color":color_mlp}
79 |
80 | # initialize the neural point features
81 | neural_points: NeuralPoints = NeuralPoints(config)
82 |
83 | # Load the map
84 | loaded_model = torch.load(config.model_path, weights_only=False)
85 | neural_points = loaded_model["neural_points"]
86 | load_decoders(loaded_model, mlp_dict)
87 | neural_points.temporal_local_map_on = False
88 | neural_points.recreate_hash(neural_points.neural_points[0], torch.eye(3).cuda(), False, False)
89 | neural_points.compute_feature_principle_components(down_rate = 59)
90 | print("PIN Map loaded")
91 |
92 | # mesh reconstructor
93 | mesher = Mesher(config, neural_points, mlp_dict)
94 |
95 | mesh_on = (mesh_res_m is not None)
96 | if mesh_on:
97 | config.mc_res_m = mesh_res_m
98 | config.mesh_min_nn = mc_nn
99 |
100 | q_main2vis = q_vis2main = None
101 | if o3d_vis_on:
102 | # communicator between the processes
103 | q_main2vis = mp.Queue()
104 | q_vis2main = mp.Queue()
105 |
106 | params_gui = ParamsGUI(
107 | q_main2vis=q_main2vis,
108 | q_vis2main=q_vis2main,
109 | config=config,
110 | local_map_default_on=False,
111 | mesh_default_on=mesh_on,
112 | neural_point_map_default_on=config.neural_point_map_default_on,
113 | )
114 | gui_process = mp.Process(target=slam_gui.run, args=(params_gui,))
115 | gui_process.start()
116 | time.sleep(3) # second
117 |
118 | cur_mesh = None
119 | out_mesh_path = None
120 |
121 | if mesh_on:
122 | cropped_ply_path = os.path.join(result_folder, "map", cropped_ply_filename)
123 | if os.path.exists(cropped_ply_path):
124 | cropped_pc = o3d.io.read_point_cloud(cropped_ply_path)
125 | print("Load region for meshing from {}".format(cropped_ply_path))
126 |
127 | else:
128 | cropped_pc = neural_points.get_neural_points_o3d(query_global=True, random_down_ratio=23)
129 |
130 | mesh_aabb = cropped_pc.get_axis_aligned_bounding_box()
131 | chunks_aabb = split_chunks(cropped_pc, mesh_aabb, mesh_res_m*300)
132 |
133 | if output_mesh_filename is not None:
134 | out_mesh_path = os.path.join(result_folder, "mesh", output_mesh_filename)
135 | print("Output the mesh to: ", out_mesh_path)
136 |
137 | mc_cm_str = str(round(mesh_res_m*1e2))
138 | print("Reconstructing the mesh with resolution {} cm".format(mc_cm_str))
139 | cur_mesh = mesher.recon_aabb_collections_mesh(chunks_aabb, mesh_res_m, out_mesh_path, False, config.semantic_on,
140 | config.color_on, filter_isolated_mesh=True, mesh_min_nn=mc_nn)
141 | print("Reconstructing the global mesh done")
142 |
143 | remove_gpu_cache()
144 |
145 | if o3d_vis_on:
146 |
147 | while True:
148 | if not q_vis2main.empty():
149 | q_vis2main.get()
150 |
151 | packet_to_vis: VisPacket = VisPacket(slam_finished=True)
152 |
153 | if not neural_points.is_empty():
154 | packet_to_vis.add_neural_points_data(neural_points, only_local_map=False, pca_color_on=True)
155 |
156 | if cur_mesh is not None:
157 | packet_to_vis.add_mesh(np.array(cur_mesh.vertices, dtype=np.float64), np.array(cur_mesh.triangles), np.array(cur_mesh.vertex_colors, dtype=np.float64))
158 | cur_mesh = None
159 |
160 | q_main2vis.put(packet_to_vis)
161 | time.sleep(1.0)
162 |
163 |
164 | if __name__ == "__main__":
165 | app()
--------------------------------------------------------------------------------