├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
├── architecture.png
└── loopsplat.gif
├── configs
├── Replica
│ ├── office0.yaml
│ ├── office1.yaml
│ ├── office2.yaml
│ ├── office3.yaml
│ ├── office4.yaml
│ ├── replica.yaml
│ ├── room0.yaml
│ ├── room1.yaml
│ └── room2.yaml
├── ScanNet
│ ├── scannet.yaml
│ ├── scene0000_00.yaml
│ ├── scene0054_00.yaml
│ ├── scene0059_00.yaml
│ ├── scene0106_00.yaml
│ ├── scene0169_00.yaml
│ ├── scene0181_00.yaml
│ ├── scene0207_00.yaml
│ ├── scene0233_00.yaml
│ └── scene0465_00.yaml
├── TUM_RGBD
│ ├── rgbd_dataset_freiburg1_desk.yaml
│ ├── rgbd_dataset_freiburg1_desk2.yaml
│ ├── rgbd_dataset_freiburg1_room.yaml
│ ├── rgbd_dataset_freiburg2_xyz.yaml
│ ├── rgbd_dataset_freiburg3_long_office_household.yaml
│ └── tum_rgbd.yaml
└── scannetpp
│ ├── 281bc17764.yaml
│ ├── 2e74812d00.yaml
│ ├── 8b5caf3398.yaml
│ ├── b20a261fdf.yaml
│ ├── fb05e13ad1.yaml
│ └── scannetpp.yaml
├── environment.yml
├── requirements.txt
├── run_evaluation.py
├── run_slam.py
├── scripts
├── download_replica.sh
├── download_tum.sh
├── reproduce_sbatch.sh
└── scannet_preprocess.ipynb
└── src
├── entities
├── __init__.py
├── arguments.py
├── datasets.py
├── gaussian_model.py
├── gaussian_slam.py
├── lc.py
├── logger.py
├── losses.py
├── mapper.py
├── tracker.py
└── visual_odometer.py
├── evaluation
├── __init__.py
├── evaluate_merged_map.py
├── evaluate_reconstruction.py
├── evaluate_trajectory.py
└── evaluator.py
├── gsr
├── camera.py
├── descriptor.py
├── loss.py
├── overlap.py
├── pcr.py
├── renderer.py
├── se3
│ ├── numpy_se3.py
│ └── torch_se3.py
├── solver.py
└── utils.py
└── utils
├── __init__.py
├── eval_utils.py
├── gaussian_model_utils.py
├── graphics_utils.py
├── io_utils.py
├── mapper_utils.py
├── pose_utils.py
├── tracker_utils.py
├── utils.py
└── vis_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | output
4 | build
5 | diff_rasterization/diff_rast.egg-info
6 | diff_rasterization/dist
7 | tensorboard_3d
8 | screenshots
9 | debug
10 | wandb
11 | data
12 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "thirdparty/Hierarchical-Localization"]
2 | path = thirdparty/Hierarchical-Localization
3 | url = https://github.com/cvg/Hierarchical-Localization.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
LoopSplat: Loop Closure by Registering 3D Gaussian Splats
4 |
5 | Liyuan Zhu1
6 | ·
7 | Yue Li2
8 | ·
9 | Erik Sandström3
10 | ·
11 | Shengyu Huang3
12 | ·
13 | Konrad Schindler3
14 | ·
15 | Iro Armeni1
16 |
17 |
18 | International Conference on 3D Vision (3DV) 2025
19 |
20 | 1Stanford University · 2University of Amsterdam · 3ETH Zurich
21 |
22 |
23 |
24 | [](https://arxiv.org/abs/2408.10154) [](https://loopsplat.github.io/) [](https://opensource.org/licenses/MIT)
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | ## 📃 Description
37 |
38 |
39 |
40 |
41 |
42 |
43 | **LoopSplat** is a coupled RGB-D SLAM system that uses Gaussian splats as a unified scene representation for tracking, mapping, and maintaining global consistency. In the front-end, it continuously estimates the camera position while constructing the scene using Gaussian splats submaps. When the camera traverses beyond a predefined threshold, the current submap is finalized, and a new one is initiated. Concurrently, the back-end loop closure module monitors for location revisits. Upon detecting a loop, the system generates a pose graph, incorporating loop edge constraints derived from our proposed 3DGS registration. Subsequently, pose graph optimization (PGO) is executed to refine both camera poses and submaps, ensuring overall spatial coherence.
44 |
45 | # 🛠️ Setup
46 | The code has been tested on:
47 |
48 | - Ubuntu 22.04 LTS, Python 3.10.14, CUDA 12.2, GeForce RTX 4090/RTX 3090
49 | - CentOS Linux 7, Python 3.12.1, CUDA 12.4, A100/A6000
50 |
51 | ## 📦 Repository
52 |
53 | Clone the repo with `--recursive` because we have submodules:
54 |
55 | ```
56 | git clone --recursive git@github.com:GradientSpaces/LoopSplat.git
57 | cd LoopSplat
58 | ```
59 |
60 | ## 💻 Installation
61 | Make sure that gcc and g++ paths on your system are exported:
62 |
63 | ```
64 | export CC=
65 | export CXX=
66 | ```
67 |
68 | To find the gcc path and g++ path on your machine you can use which gcc.
69 |
70 |
71 | Then setup environment from the provided conda environment file,
72 |
73 | ```
74 | conda create -n loop_splat -c nvidia/label/cuda-12.1.0 cuda=12.1 cuda-toolkit=12.1 cuda-nvcc=12.1
75 | conda env update --file environment.yml --prune
76 | conda activate loop_splat
77 | pip install -r requirements.txt
78 | ```
79 |
80 | You will also need to install hloc for loop detection and 3DGS registration.
81 | ```
82 | cd thirdparty/Hierarchical-Localization
83 | python -m pip install -e .
84 | cd ../..
85 | ```
86 |
87 | We tested our code on RTX4090 and RTX A6000 GPUs respectively and Ubuntu22 and CentOS7.5.
88 |
89 | ## 🚀 Usage
90 |
91 | Here we elaborate on how to load the necessary data, configure Gaussian-SLAM for your use-case, debug it, and how to reproduce the results mentioned in the paper.
92 |
93 |
95 | ### Downloading the Datasets
96 | We tested our code on Replica, TUM_RGBD, ScanNet, and ScanNet++ datasets. We also provide scripts for downloading Replica and TUM_RGBD in `scripts` folder. Install git lfs before using the scripts by running ```git lfs install```.
97 |
98 | For reconstruction evaluation on Replica, we follow [Co-SLAM](https://github.com/JingwenWang95/neural_slam_eval?tab=readme-ov-file#datasets) mesh culling protocal, please use their code to process the mesh first.
99 |
100 | For downloading ScanNet, follow the procedure described on here.
101 | Pay attention! There are some frames in ScanNet with `inf` poses, we filter them out using the jupyter notebook `scripts/scannet_preprocess.ipynb`. Please change the path to your ScanNet data and run the cells.
102 |
103 | For downloading ScanNet++, follow the procedure described on here.
104 |
105 | The config files are named after the sequences that we used for our method.
106 |
107 |
108 | ### Running the code
109 | Start the system with the command:
110 |
111 | ```
112 | python run_slam.py configs// --input_path --output_path
113 | ```
114 |
115 | You can also configure input and output paths in the config yaml file.
116 |
117 |
118 | ### Reproducing Results
119 |
120 | You can reproduce the results for a single scene by running:
121 |
122 | ```
123 | python run_slam.py configs// --input_path --output_path
124 | ```
125 |
126 | If you are running on a SLURM cluster, you can reproduce the results for all scenes in a dataset by running the script:
127 | ```
128 | ./scripts/reproduce_sbatch.sh
129 | ```
130 | Please note the evaluation of ```depth_L1``` metric requires reconstruction of the mesh, which in turns requires headless installation of open3d if you are running on a cluster.
131 |
132 | ## 📧 Contact
133 | If you have any questions regarding this project, please contact Liyuan Zhu (liyzhu@stanford.edu). If you want to use our intermediate results for qualitative comparisons, please reach out to the same email.
134 |
135 | # ✏️ Acknowledgement
136 | Our implementation is heavily based on Gaussian-SLAM and MonoGS. We thank the authors for their open-source contributions. If you use the code that is based on their contribution, please cite them as well. We thank [Jianhao Zheng](https://jianhao-zheng.github.io/) for the help with datasets and [Yue Pan](https://github.com/YuePanEdward) for the fruitful discussion.
137 |
138 | # 🎓 Citation
139 |
140 | If you find our paper and code useful, please cite us:
141 |
142 | ```bib
143 | @inproceedings{zhu2025_loopsplat,
144 | title={LoopSplat: Loop Closure by Registering 3D Gaussian Splats},
145 | author={Liyuan Zhu and Yue Li and Erik Sandström and Shengyu Huang and Konrad Schindler and Iro Armeni},
146 | year={2025},
147 | booktitle = {International Conference on 3D Vision (3DV)},
148 | }
149 |
--------------------------------------------------------------------------------
/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/assets/architecture.png
--------------------------------------------------------------------------------
/assets/loopsplat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/assets/loopsplat.gif
--------------------------------------------------------------------------------
/configs/Replica/office0.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office0
4 | input_path: data/Replica-SLAM/Replica/office0/
5 | output_path: output/Replica/office0/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office1
4 | input_path: data/Replica-SLAM/Replica/office1/
5 | output_path: output/Replica/office1/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office2
4 | input_path: data/Replica-SLAM/Replica/office2/
5 | output_path: output/Replica/office2/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office3.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office3
4 | input_path: data/Replica-SLAM/Replica/office3/
5 | output_path: output/Replica/office3/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office4.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office4
4 | input_path: data/Replica-SLAM/Replica/office4/
5 | output_path: output/Replica/office4/
6 |
--------------------------------------------------------------------------------
/configs/Replica/replica.yaml:
--------------------------------------------------------------------------------
1 | project_name: "LoopSplat_SLAM_replica"
2 | dataset_name: "replica"
3 | checkpoint_path: null
4 | use_wandb: True
5 | frame_limit: -1 # for debugging, set to -1 to disable
6 | seed: 1
7 | mapping:
8 | new_submap_every: 50
9 | map_every: 5
10 | iterations: 100
11 | new_submap_iterations: 1000
12 | new_submap_points_num: 600000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: -1
15 | new_points_radius: 0.0000001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.1
19 | submap_using_motion_heuristic: True
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.95
23 | iterations: 60
24 | cam_rot_lr: 0.0002
25 | cam_trans_lr: 0.002
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: False # temp option to help const_init
28 | init_err_ratio: 5
29 | odometer_method: "point_to_plane" # hybrid or point_to_plane
30 | filter_alpha: False
31 | filter_outlier_depth: True
32 | alpha_thre: 0.98
33 | soft_alpha: True
34 | mask_invalid_depth: False
35 | enable_exposure: False
36 | cam:
37 | H: 680
38 | W: 1200
39 | fx: 600.0
40 | fy: 600.0
41 | cx: 599.5
42 | cy: 339.5
43 | depth_scale: 6553.5
44 | lc:
45 | min_similarity: 0.5
46 | pgo_edge_prune_thres: 0.25
47 | voxel_size: 0.02
48 | pgo_max_iterations: 500
49 | registration:
50 | method: "gs_reg"
51 | base_lr: 1.e-3
52 | min_overlap_ratio: 0.1
53 | use_render: False
54 | min_interval: 2
55 | final: False
--------------------------------------------------------------------------------
/configs/Replica/room0.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: room0
4 | input_path: data/Replica-SLAM/room0/
5 | output_path: output/Replica/room0/
6 |
--------------------------------------------------------------------------------
/configs/Replica/room1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: room1
4 | input_path: data/Replica-SLAM/Replica/room1/
5 | output_path: output/Replica/room1/
6 |
--------------------------------------------------------------------------------
/configs/Replica/room2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: room2
4 | input_path: data/Replica-SLAM/Replica/room2/
5 | output_path: output/Replica/room2/
6 |
--------------------------------------------------------------------------------
/configs/ScanNet/scannet.yaml:
--------------------------------------------------------------------------------
1 | project_name: "LoopSplat_SLAM_scannet"
2 | dataset_name: "scan_net"
3 | checkpoint_path: null
4 | use_wandb: True
5 | frame_limit: -1 # for debugging, set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 50
9 | map_every: 1
10 | iterations: 100
11 | new_submap_iterations: 100
12 | new_submap_points_num: 100000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: 30000
15 | new_points_radius: 0.0001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.5
19 | submap_using_motion_heuristic: False
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.6
23 | iterations: 200
24 | cam_rot_lr: 0.002
25 | cam_trans_lr: 0.01
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: False # temp option to help const_init
28 | init_err_ratio: 5
29 | odometer_method: "hybrid" # hybrid or point_to_plane
30 | filter_alpha: True
31 | filter_outlier_depth: True
32 | alpha_thre: 0.98
33 | soft_alpha: True
34 | mask_invalid_depth: True
35 | enable_exposure: True
36 | cam:
37 | H: 480
38 | W: 640
39 | fx: 577.590698
40 | fy: 578.729797
41 | cx: 318.905426
42 | cy: 242.683609
43 | depth_scale: 1.
44 | crop_edge: 12
45 | lc:
46 | min_similarity: 0.5
47 | pgo_edge_prune_thres: 0.25
48 | voxel_size: 0.02
49 | pgo_max_iterations: 500
50 | registration:
51 | method: "gs_reg"
52 | base_lr: 5.e-3
53 | min_overlap_ratio: 0.2
54 | use_render: True # use rendered image as target for registration
55 | min_interval: 4
56 | final: False
57 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0000_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0000_00
4 | output_path: output/ScanNet/scene0000
5 | scene_name: scene0000_00
6 | cam:
7 | H: 480
8 | W: 640
9 | fx: 577.6
10 | fy: 578.7
11 | cx: 318.9
12 | cy: 242.7
13 | depth_scale: 1.
14 | crop_edge: 12
--------------------------------------------------------------------------------
/configs/ScanNet/scene0054_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0054_00
4 | output_path: output/ScanNet/scene0054
5 | scene_name: scene0054_00
6 | cam:
7 | H: 480
8 | W: 640
9 | fx: 578.0
10 | fy: 578.0
11 | cx: 319.5
12 | cy: 239.5
13 | depth_scale: 1.
14 | crop_edge: 12
--------------------------------------------------------------------------------
/configs/ScanNet/scene0059_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0059_00
4 | output_path: output/ScanNet/scene0059
5 | scene_name: scene0059_00
6 | cam:
7 | H: 480
8 | W: 640
9 | fx: 577.6
10 | fy: 578.7
11 | cx: 318.9
12 | cy: 242.7
13 | depth_scale: 1.
14 | crop_edge: 12
--------------------------------------------------------------------------------
/configs/ScanNet/scene0106_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0106_00
4 | output_path: output/ScanNet/scene0106
5 | scene_name: scene0106_00
6 | cam:
7 | H: 480
8 | W: 640
9 | fx: 577.6
10 | fy: 578.7
11 | cx: 318.9
12 | cy: 242.7
13 | depth_scale: 1.
14 | crop_edge: 12
--------------------------------------------------------------------------------
/configs/ScanNet/scene0169_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0169_00
4 | output_path: output/ScanNet/scene0169
5 | scene_name: scene0169_00
6 | cam:
7 | H: 480
8 | W: 640
9 | fx: 574.5
10 | fy: 577.6
11 | cx: 322.5
12 | cy: 238.6
13 | depth_scale: 1.
14 | crop_edge: 12
--------------------------------------------------------------------------------
/configs/ScanNet/scene0181_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0181_00
4 | output_path: output/ScanNet/scene0181
5 | scene_name: scene0181_00
6 | cam:
7 | fx: 575.547668
8 | fy: 577.459778
9 | cx: 323.171967
10 | cy: 236.417465
11 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0207_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0207_00
4 | output_path: output/ScanNet/scene0207
5 | scene_name: scene0207_00
6 | cam:
7 | fx: 577.6
8 | fy: 578.7
9 | cx: 318.9
10 | cy: 242.7
--------------------------------------------------------------------------------
/configs/ScanNet/scene0233_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0233_00
4 | output_path: output/ScanNet/scene0233
5 | scene_name: scene0233_00
6 | cam:
7 | fx: 577.9
8 | fy: 577.9
9 | cx: 319.5
10 | cy: 239.5
--------------------------------------------------------------------------------
/configs/ScanNet/scene0465_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0465_00
4 | output_path: output/ScanNet/scene0465
5 | scene_name: scene0465_00
6 | cam:
7 | fx: 577.9
8 | fy: 577.9
9 | cx: 319.5
10 | cy: 239.5
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk/
5 | scene_name: rgbd_dataset_freiburg1_desk
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 517.3
10 | fy: 516.5
11 | cx: 318.6
12 | cy: 255.3
13 | crop_edge: 50
14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg1_desk2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk2
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk2/
5 | scene_name: rgbd_dataset_freiburg1_desk2
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 517.3
10 | fy: 516.5
11 | cx: 318.6
12 | cy: 255.3
13 | crop_edge: 50
14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg1_room.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_room
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_room/
5 | scene_name: rgbd_dataset_freiburg1_room
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 517.3
10 | fy: 516.5
11 | cx: 318.6
12 | cy: 255.3
13 | crop_edge: 50
14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg2_xyz
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg2_xyz/
5 | scene_name: rgbd_dataset_freiburg2_xyz
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 520.9
10 | fy: 521.0
11 | cx: 325.1
12 | cy: 249.7
13 | crop_edge: 8
14 | distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg3_long_office_household/
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household/
5 | scene_name: rgbd_dataset_freiburg3_long_office_household
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 517.3
10 | fy: 516.5
11 | cx: 318.6
12 | cy: 255.3
13 | crop_edge: 50
14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/tum_rgbd.yaml:
--------------------------------------------------------------------------------
1 | project_name: "LoopSplat_SLAM_tumrgbd"
2 | dataset_name: "tum_rgbd"
3 | checkpoint_path: null
4 | use_wandb: True
5 | frame_limit: -1 # for debugging, set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 50
9 | map_every: 1
10 | iterations: 100
11 | new_submap_iterations: 100
12 | new_submap_points_num: 100000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: 30000
15 | new_points_radius: 0.0001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.5
19 | submap_using_motion_heuristic: True
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.6
23 | iterations: 200
24 | cam_rot_lr: 0.002
25 | cam_trans_lr: 0.01
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: False # temp option to help const_init
28 | init_err_ratio: 5
29 | odometer_method: "hybrid" # hybrid or point_to_plane
30 | filter_alpha: False
31 | filter_outlier_depth: False
32 | alpha_thre: 0.98
33 | soft_alpha: True
34 | mask_invalid_depth: True
35 | enable_exposure: False
36 | cam:
37 | crop_edge: 16
38 | depth_scale: 5000.0
39 | lc:
40 | min_similarity: 0.5
41 | pgo_edge_prune_thres: 0.25
42 | voxel_size: 0.02
43 | pgo_max_iterations: 500
44 | registration:
45 | method: "gs_reg"
46 | base_lr: 5.e-3
47 | min_overlap_ratio: 0.2
48 | use_render: False
49 | min_interval: 3
50 | final: False
--------------------------------------------------------------------------------
/configs/scannetpp/281bc17764.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/281bc17764
4 | output_path: output/ScanNetPP/281bc17764
5 | scene_name: "281bc17764"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 312.79197434640764
12 | fy: 313.48022477591036
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/2e74812d00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/2e74812d00
4 | output_path: output/ScanNetPP/2e74812d00
5 | scene_name: "2e74812d00"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 312.0984049779051
12 | fy: 312.4823067146056
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/8b5caf3398.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/8b5caf3398
4 | output_path: output/ScanNetPP/8b5caf3398
5 | scene_name: "8b5caf3398"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 316.3837659917395
12 | fy: 319.18649362678593
13 | cx: 438.0
14 | cy: 292.0
15 |
--------------------------------------------------------------------------------
/configs/scannetpp/b20a261fdf.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/b20a261fdf
4 | output_path: output/ScanNetPP/b20a261fdf
5 | scene_name: "b20a261fdf"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 312.7099188244687
12 | fy: 313.5121746848229
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/fb05e13ad1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/fb05e13ad1
4 | output_path: output/ScanNetPP/fb05e13ad1
5 | scene_name: "fb05e13ad1"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 231.8197441948914
12 | fy: 231.9980523882361
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/scannetpp.yaml:
--------------------------------------------------------------------------------
1 | project_name: "LoopSplat_SLAM_scannetpp"
2 | dataset_name: "scannetpp"
3 | checkpoint_path: null
4 | use_wandb: True
5 | frame_limit: -1 # set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 100
9 | map_every: 2
10 | iterations: 500
11 | new_submap_iterations: 500
12 | new_submap_points_num: 400000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: 100000
15 | new_points_radius: 0.00000001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.5
19 | submap_using_motion_heuristic: False
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.5
23 | iterations: 300
24 | cam_rot_lr: 0.002
25 | cam_trans_lr: 0.01
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: True
28 | init_err_ratio: 50
29 | odometer_method: "point_to_plane" # hybrid or point_to_plane
30 | filter_alpha: True
31 | filter_outlier_depth: True
32 | alpha_thre: 0.98
33 | soft_alpha: False
34 | mask_invalid_depth: True
35 | enable_exposure: False
36 | cam:
37 | crop_edge: 0
38 | depth_scale: 1000.0
39 | lc:
40 | min_similarity: 0.34
41 | pgo_edge_prune_thres: 0.25
42 | voxel_size: 0.02
43 | pgo_max_iterations: 500
44 | registration:
45 | method: "gs_reg"
46 | base_lr: 5.e-3
47 | min_overlap_ratio: 0.2
48 | min_interval: 0
49 | final: False
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: loop_splat
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - nvidia/label/cuda-12.1.0
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - python=3.10
10 | - faiss-gpu=1.8.0
11 | - cuda-toolkit=12.1
12 | - pytorch=2.1.2
13 | - pytorch-cuda=12.1
14 | - torchvision=0.16.2
15 | - pip
16 | - pip:
17 | - open3d==0.18.0
18 | - wandb
19 | - trimesh==4.0.10
20 | - pytorch_msssim
21 | - torchmetrics
22 | - tqdm
23 | - imageio
24 | - opencv-python
25 | - plyfile
26 | - roma
27 | - einops==0.8.0
28 | - numpy==1.26.4
29 | - PyQt5==5.15.11
30 | - matplotlib==3.5.1
31 | - evo==1.11.0
32 | - python-pycg
33 | - einops
34 | - git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7
35 | - git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d
36 | - git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129
37 | - git+https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose.git@43e21bff91cd24986ee3dd52fe0bb06952e50ec7
38 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | open3d==0.18.0
2 | wandb
3 | trimesh==4.0.10
4 | pytorch_msssim
5 | torchmetrics
6 | tqdm
7 | imageio
8 | opencv-python
9 | plyfile
10 | roma
11 | einops==0.8.0
12 | numpy==1.26.4
13 | PyQt5==5.15.11
14 | matplotlib==3.5.1
15 | evo==1.11.0
16 | python-pycg
17 | einops
18 | git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d
19 | git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7
20 | git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129
21 | git+https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose.git@43e21bff91cd24986ee3dd52fe0bb06952e50ec7
--------------------------------------------------------------------------------
/run_evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from src.evaluation.evaluator import Evaluator
4 |
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser(description='Arguments to compute the mesh')
8 | parser.add_argument('--checkpoint_path', type=str, help='SLAM checkpoint path', default="output/slam/full_experiment/")
9 | parser.add_argument('--config_path', type=str, help='Config path', default="")
10 | return parser.parse_args()
11 |
12 |
13 | if __name__ == "__main__":
14 | args = get_args()
15 | if args.config_path == "":
16 | args.config_path = Path(args.checkpoint_path) / "config.yaml"
17 |
18 | evaluator = Evaluator(Path(args.checkpoint_path), Path(args.config_path))
19 | evaluator.run()
20 |
--------------------------------------------------------------------------------
/run_slam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import uuid
5 |
6 | import wandb
7 |
8 | from src.entities.gaussian_slam import GaussianSLAM
9 | from src.evaluation.evaluator import Evaluator
10 | from src.utils.io_utils import load_config
11 | from src.utils.utils import setup_seed
12 |
13 |
14 | def get_args():
15 | parser = argparse.ArgumentParser(
16 | description='Arguments to compute the mesh')
17 | parser.add_argument('config_path', type=str,
18 | help='Path to the configuration yaml file')
19 | parser.add_argument('--input_path', default="")
20 | parser.add_argument('--output_path', default="")
21 | parser.add_argument('--track_w_color_loss', type=float)
22 | parser.add_argument('--track_alpha_thre', type=float)
23 | parser.add_argument('--track_iters', type=int)
24 | parser.add_argument('--track_filter_alpha', action='store_true')
25 | parser.add_argument('--track_filter_outlier', action='store_true')
26 | parser.add_argument('--track_wo_filter_alpha', action='store_true')
27 | parser.add_argument("--track_wo_filter_outlier", action="store_true")
28 | parser.add_argument("--track_cam_trans_lr", type=float)
29 | parser.add_argument('--alpha_seeding_thre', type=float)
30 | parser.add_argument('--map_every', type=int)
31 | parser.add_argument("--map_iters", type=int)
32 | parser.add_argument('--new_submap_every', type=int)
33 | parser.add_argument('--project_name', type=str)
34 | parser.add_argument('--group_name', type=str)
35 | parser.add_argument('--gt_camera', action='store_true')
36 | parser.add_argument('--help_camera_initialization', action='store_true')
37 | parser.add_argument('--soft_alpha', action='store_true')
38 | parser.add_argument('--seed', type=int)
39 | parser.add_argument('--submap_using_motion_heuristic', action='store_true')
40 | parser.add_argument('--new_submap_points_num', type=int)
41 | return parser.parse_args()
42 |
43 |
44 | def update_config_with_args(config, args):
45 | if args.input_path:
46 | config["data"]["input_path"] = args.input_path
47 | if args.output_path:
48 | config["data"]["output_path"] = args.output_path
49 | if args.track_w_color_loss is not None:
50 | config["tracking"]["w_color_loss"] = args.track_w_color_loss
51 | if args.track_iters is not None:
52 | config["tracking"]["iterations"] = args.track_iters
53 | if args.track_filter_alpha:
54 | config["tracking"]["filter_alpha"] = True
55 | if args.track_wo_filter_alpha:
56 | config["tracking"]["filter_alpha"] = False
57 | if args.track_filter_outlier:
58 | config["tracking"]["filter_outlier_depth"] = True
59 | if args.track_wo_filter_outlier:
60 | config["tracking"]["filter_outlier_depth"] = False
61 | if args.track_alpha_thre is not None:
62 | config["tracking"]["alpha_thre"] = args.track_alpha_thre
63 | if args.map_every:
64 | config["mapping"]["map_every"] = args.map_every
65 | if args.map_iters:
66 | config["mapping"]["iterations"] = args.map_iters
67 | if args.new_submap_every:
68 | config["mapping"]["new_submap_every"] = args.new_submap_every
69 | if args.project_name:
70 | config["project_name"] = args.project_name
71 | if args.alpha_seeding_thre is not None:
72 | config["mapping"]["alpha_thre"] = args.alpha_seeding_thre
73 | if args.seed:
74 | config["seed"] = args.seed
75 | if args.help_camera_initialization:
76 | config["tracking"]["help_camera_initialization"] = True
77 | if args.soft_alpha:
78 | config["tracking"]["soft_alpha"] = True
79 | if args.submap_using_motion_heuristic:
80 | config["mapping"]["submap_using_motion_heuristic"] = True
81 | if args.new_submap_points_num:
82 | config["mapping"]["new_submap_points_num"] = args.new_submap_points_num
83 | if args.track_cam_trans_lr:
84 | config["tracking"]["cam_trans_lr"] = args.track_cam_trans_lr
85 | return config
86 |
87 |
88 | if __name__ == "__main__":
89 | args = get_args()
90 | config = load_config(args.config_path)
91 | config = update_config_with_args(config, args)
92 |
93 | if os.getenv('DISABLE_WANDB') == 'true':
94 | config["use_wandb"] = False
95 | if config["use_wandb"]:
96 | wandb.init(
97 | project=config["project_name"],
98 | config=config,
99 | dir="/home/yli3/scratch/outputs/slam/wandb",
100 | group=config["data"]["scene_name"]
101 | if not args.group_name
102 | else args.group_name,
103 | name=f'{config["data"]["scene_name"]}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}_{str(uuid.uuid4())[:5]}',
104 | )
105 | wandb.run.log_code(".", include_fn=lambda path: path.endswith(".py"))
106 |
107 | setup_seed(config["seed"])
108 | gslam = GaussianSLAM(config)
109 | gslam.run()
110 |
111 | evaluator = Evaluator(gslam.output_path, gslam.output_path / "config.yaml")
112 | evaluator.run()
113 | if config["use_wandb"]:
114 | wandb.finish()
115 | print("All done.✨")
116 |
--------------------------------------------------------------------------------
/scripts/download_replica.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data
2 | cd data
3 | git clone https://huggingface.co/datasets/voviktyl/Replica-SLAM
--------------------------------------------------------------------------------
/scripts/download_tum.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data
2 | cd data
3 | git clone https://huggingface.co/datasets/voviktyl/TUM_RGBD-SLAM
--------------------------------------------------------------------------------
/scripts/reproduce_sbatch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=output/logs/%A_%a.log # please change accordingly
3 | #SBATCH --error=output/logs/%A_%a.log # please change accordingly
4 | #SBATCH -N 1
5 | #SBATCH -n 1
6 | #SBATCH --gpus-per-node=1
7 | #SBATCH --partition=gpu
8 | #SBATCH --cpus-per-task=12
9 | #SBATCH --time=24:00:00
10 | #SBATCH --array=0-4 # number of scenes, 0-7 for Replica, 0-2 for TUM_RGBD, 0-5 for ScanNet, 0-4 for ScanNet++
11 |
12 | dataset="Replica" # set dataset
13 | if [ "$dataset" == "Replica" ]; then
14 | scenes=("room0" "room1" "room2" "office0" "office1" "office2" "office3" "office4")
15 | INPUT_PATH="data/Replica-SLAM"
16 | elif [ "$dataset" == "TUM_RGBD" ]; then
17 | scenes=("rgbd_dataset_freiburg1_desk" "rgbd_dataset_freiburg2_xyz" "rgbd_dataset_freiburg3_long_office_household")
18 | INPUT_PATH="data/TUM_RGBD-SLAM"
19 | elif [ "$dataset" == "ScanNet" ]; then
20 | scenes=("scene0000_00" "scene0059_00" "scene0106_00" "scene0169_00" "scene0181_00" "scene0207_00")
21 | INPUT_PATH="data/scannet/scans"
22 | elif [ "$dataset" == "ScanNetPP" ]; then
23 | scenes=("b20a261fdf" "8b5caf3398" "fb05e13ad1" "2e74812d00" "281bc17764")
24 | INPUT_PATH="data/scannetpp/data"
25 | else
26 | echo "Dataset not recognized!"
27 | exit 1
28 | fi
29 |
30 | OUTPUT_PATH="output"
31 | CONFIG_PATH="configs/${dataset}"
32 | EXPERIMENT_NAME="reproduce"
33 | SCENE_NAME=${scenes[$SLURM_ARRAY_TASK_ID]}
34 |
35 | source # please change accordingly
36 | conda activate gslam
37 |
38 | echo "Job for dataset: $dataset, scene: $SCENE_NAME"
39 | echo "Starting on: $(date)"
40 | echo "Running on node: $(hostname)"
41 |
42 | # Your command to run the experiment
43 | python run_slam.py "${CONFIG_PATH}/${SCENE_NAME}.yaml" \
44 | --input_path "${INPUT_PATH}/${SCENE_NAME}" \
45 | --output_path "${OUTPUT_PATH}/${dataset}/${EXPERIMENT_NAME}/${SCENE_NAME}" \
46 | --group_name "${EXPERIMENT_NAME}" \
47 |
48 | echo "Job for scene $SCENE_NAME completed."
49 | echo "Started at: $START_TIME"
50 | echo "Finished at: $(date)"
--------------------------------------------------------------------------------
/scripts/scannet_preprocess.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "8e3f10a3-fca1-4e2a-b39e-7cd8b44c724a",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stdout",
11 | "output_type": "stream",
12 | "text": [
13 | "Jupyter environment detected. Enabling Open3D WebVisualizer.\n",
14 | "[Open3D INFO] WebRTC GUI backend enabled.\n",
15 | "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n"
16 | ]
17 | }
18 | ],
19 | "source": [
20 | "import numpy as np\n",
21 | "import open3d as o3d\n",
22 | "import os\n",
23 | "import time\n",
24 | "import json\n",
25 | "import cv2\n",
26 | "from tqdm import tqdm\n",
27 | "import math\n",
28 | "from scipy.spatial.transform import Rotation as R\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import scipy\n",
31 | "import shutil\n",
32 | "\n",
33 | "from plyfile import PlyData, PlyElement\n",
34 | "import pandas as pd"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 3,
40 | "id": "3db0619f-a9bb-4aea-890d-d8c8cb3f7e18",
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "def read_intrinsic(data_folder):\n",
45 | " with open(os.path.join(data_folder,'data/intrinsic/intrinsic_depth.txt'), 'r') as f:\n",
46 | " lines = f.readlines()\n",
47 | " intrinsic = np.zeros((4,4))\n",
48 | " for i, line in enumerate(lines):\n",
49 | " for j, content in enumerate(line.split(' ')):\n",
50 | " intrinsic[i][j] = float(content)\n",
51 | " return intrinsic"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "id": "9574fb19-fe0c-45ae-9e59-3125a17e3566",
58 | "metadata": {},
59 | "outputs": [
60 | {
61 | "name": "stderr",
62 | "output_type": "stream",
63 | "text": [
64 | "100%|███████████████████████████████████████| 5578/5578 [01:33<00:00, 59.58it/s]\n",
65 | "100%|███████████████████████████████████████| 1807/1807 [00:29<00:00, 61.95it/s]\n",
66 | "100%|███████████████████████████████████████| 2324/2324 [00:35<00:00, 65.18it/s]\n",
67 | "100%|███████████████████████████████████████| 2034/2034 [00:31<00:00, 63.85it/s]\n",
68 | "100%|███████████████████████████████████████| 2349/2349 [00:35<00:00, 65.97it/s]\n",
69 | "100%|███████████████████████████████████████| 1988/1988 [00:31<00:00, 63.08it/s]\n",
70 | "100%|███████████████████████████████████████| 7643/7643 [01:52<00:00, 68.11it/s]\n",
71 | "100%|███████████████████████████████████████| 6306/6306 [01:31<00:00, 69.10it/s]\n"
72 | ]
73 | }
74 | ],
75 | "source": [
76 | "fps_fake = 20 # fake camera frequency for offline ORB-SLAM\n",
77 | "raw_folder = \"./scans\"\n",
78 | "processed_folder = \"./processed\"\n",
79 | "\n",
80 | "scenes = ['scene0000_00', 'scene0054_00', 'scene0059_00', 'scene0106_00', 'scene0169_00', 'scene0181_00', 'scene0207_00', 'scene0233_00']\n",
81 | "\n",
82 | "for scene_idx, scene in enumerate(scenes):\n",
83 | " save_folder = os.path.join(processed_folder,scene)\n",
84 | " data_folder = os.path.join(raw_folder,scene)\n",
85 | " \n",
86 | " os.makedirs(save_folder)\n",
87 | " os.makedirs(os.path.join(save_folder,\"rgb\"))\n",
88 | " os.makedirs(os.path.join(save_folder,\"depth\"))\n",
89 | " \n",
90 | " shutil.copy(os.path.join(data_folder, 'data/intrinsic/intrinsic_depth.txt'), \n",
91 | " os.path.join(save_folder, 'intrinsic.txt'))\n",
92 | " \n",
93 | " with open(os.path.join(save_folder,'gt_pose.txt'), 'w') as f:\n",
94 | " f.write('# timestamp tx ty tz qx qy qz qw\\n')\n",
95 | " \n",
96 | " initial_time_stamp = time.time() \n",
97 | " \n",
98 | " color_folder = os.path.join(data_folder,\"data/color\")\n",
99 | " depth_folder = os.path.join(data_folder,\"data/depth\")\n",
100 | " pose_folder = os.path.join(data_folder,\"data/pose\")\n",
101 | " \n",
102 | " num_frames = len(os.listdir(color_folder))\n",
103 | " \n",
104 | " frame_idx = 0\n",
105 | " for raw_idx in tqdm(range(num_frames)):\n",
106 | " with open(os.path.join(pose_folder,\"{}.txt\".format(raw_idx)), \"r\") as f:\n",
107 | " lines = f.readlines()\n",
108 | " M_w_c = np.zeros((4,4))\n",
109 | " for i in range(4):\n",
110 | " content = lines[i].split(\" \")\n",
111 | " for j in range(4):\n",
112 | " M_w_c[i,j] = float(content[j])\n",
113 | " \n",
114 | " if \"inf\" in lines[0]:\n",
115 | " # invalid gt poses, skip this frame\n",
116 | " continue\n",
117 | "\n",
118 | " ######## convert depth to [m] and float type #########\n",
119 | " depth = cv2.imread(os.path.join(depth_folder,\"{}.png\".format(raw_idx)),cv2.IMREAD_UNCHANGED)\n",
120 | " depth = depth.astype(\"float32\")/1000.0\n",
121 | "\n",
122 | " ######## resize rgb to the same size of depth #########\n",
123 | " rgb = cv2.imread(os.path.join(color_folder,\"{}.jpg\".format(raw_idx)))\n",
124 | " rgb = cv2.resize(rgb,(depth.shape[1],depth.shape[0]),interpolation=cv2.INTER_CUBIC)\n",
125 | "\n",
126 | " cv2.imwrite(os.path.join(save_folder,\"rgb/frame_{}.png\".format(str(frame_idx).zfill(5))),rgb)\n",
127 | " cv2.imwrite(os.path.join(save_folder,\"depth/frame_{}.TIFF\".format(str(frame_idx).zfill(5))),depth)\n",
128 | "\n",
129 | " content = \"{:.4f}\".format(initial_time_stamp + frame_idx*1.0/fps_fake)\n",
130 | " for t in M_w_c[:3,3]:\n",
131 | " content += \" {:.9f}\".format(t)\n",
132 | " for q in R.from_matrix(M_w_c[:3,:3]).as_quat():\n",
133 | " content += \" {:.9f}\".format(q)\n",
134 | " \n",
135 | " with open(os.path.join(save_folder,'gt_pose.txt'), 'a') as f:\n",
136 | " f.write(content + '\\n')\n",
137 | " \n",
138 | " frame_idx += 1"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "1d1cc4cd-9868-4830-a849-d6e5702c5f20",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": []
148 | }
149 | ],
150 | "metadata": {
151 | "kernelspec": {
152 | "display_name": "Python 3 (ipykernel)",
153 | "language": "python",
154 | "name": "python3"
155 | },
156 | "language_info": {
157 | "codemirror_mode": {
158 | "name": "ipython",
159 | "version": 3
160 | },
161 | "file_extension": ".py",
162 | "mimetype": "text/x-python",
163 | "name": "python",
164 | "nbconvert_exporter": "python",
165 | "pygments_lexer": "ipython3",
166 | "version": "3.11.4"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 5
171 | }
172 |
--------------------------------------------------------------------------------
/src/entities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/src/entities/__init__.py
--------------------------------------------------------------------------------
/src/entities/arguments.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from argparse import ArgumentParser, Namespace
15 |
16 |
17 | class GroupParams:
18 | pass
19 |
20 |
21 | class ParamGroup:
22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
23 | group = parser.add_argument_group(name)
24 | for key, value in vars(self).items():
25 | shorthand = False
26 | if key.startswith("_"):
27 | shorthand = True
28 | key = key[1:]
29 | t = type(value)
30 | value = value if not fill_none else None
31 | if shorthand:
32 | if t == bool:
33 | group.add_argument(
34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true")
35 | else:
36 | group.add_argument(
37 | "--" + key, ("-" + key[0:1]), default=value, type=t)
38 | else:
39 | if t == bool:
40 | group.add_argument(
41 | "--" + key, default=value, action="store_true")
42 | else:
43 | group.add_argument("--" + key, default=value, type=t)
44 |
45 | def extract(self, args):
46 | group = GroupParams()
47 | for arg in vars(args).items():
48 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
49 | setattr(group, arg[0], arg[1])
50 | return group
51 |
52 |
53 | class OptimizationParams(ParamGroup):
54 | def __init__(self, parser):
55 | self.iterations = 30_000
56 | self.position_lr_init = 0.0001
57 | self.position_lr_final = 0.0000016
58 | self.position_lr_delay_mult = 0.01
59 | self.position_lr_max_steps = 30_000
60 | self.feature_lr = 0.0025
61 | self.opacity_lr = 0.05
62 | self.scaling_lr = 0.005 # before 0.005
63 | self.rotation_lr = 0.001
64 | self.percent_dense = 0.01
65 | self.lambda_dssim = 0.2
66 | self.densification_interval = 100
67 | self.opacity_reset_interval = 3000
68 | self.densify_from_iter = 500
69 | self.densify_until_iter = 15_000
70 | self.densify_grad_threshold = 0.0002
71 | super().__init__(parser, "Optimization Parameters")
72 |
73 |
74 | def get_combined_args(parser: ArgumentParser):
75 | cmdlne_string = sys.argv[1:]
76 | cfgfile_string = "Namespace()"
77 | args_cmdline = parser.parse_args(cmdlne_string)
78 |
79 | try:
80 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
81 | print("Looking for config file in", cfgfilepath)
82 | with open(cfgfilepath) as cfg_file:
83 | print("Config file found: {}".format(cfgfilepath))
84 | cfgfile_string = cfg_file.read()
85 | except TypeError:
86 | print("Config file not found at")
87 | pass
88 | args_cfgfile = eval(cfgfile_string)
89 |
90 | merged_dict = vars(args_cfgfile).copy()
91 | for k, v in vars(args_cmdline).items():
92 | if v is not None:
93 | merged_dict[k] = v
94 | return Namespace(**merged_dict)
95 |
--------------------------------------------------------------------------------
/src/entities/datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from pathlib import Path
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import json
9 | import imageio
10 | import trimesh
11 |
12 |
13 | class BaseDataset(torch.utils.data.Dataset):
14 |
15 | def __init__(self, dataset_config: dict):
16 | self.dataset_path = Path(dataset_config["input_path"])
17 | self.frame_limit = dataset_config.get("frame_limit", -1)
18 | self.dataset_config = dataset_config
19 | self.height = dataset_config["H"]
20 | self.width = dataset_config["W"]
21 | self.fx = dataset_config["fx"]
22 | self.fy = dataset_config["fy"]
23 | self.cx = dataset_config["cx"]
24 | self.cy = dataset_config["cy"]
25 |
26 | self.depth_scale = dataset_config["depth_scale"]
27 | self.distortion = np.array(
28 | dataset_config['distortion']) if 'distortion' in dataset_config else None
29 | self.crop_edge = dataset_config['crop_edge'] if 'crop_edge' in dataset_config else 0
30 | if self.crop_edge:
31 | self.height -= 2 * self.crop_edge
32 | self.width -= 2 * self.crop_edge
33 | self.cx -= self.crop_edge
34 | self.cy -= self.crop_edge
35 |
36 | self.fovx = 2 * math.atan(self.width / (2 * self.fx))
37 | self.fovy = 2 * math.atan(self.height / (2 * self.fy))
38 | self.intrinsics = np.array(
39 | [[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1]])
40 |
41 | self.color_paths = []
42 | self.depth_paths = []
43 |
44 | def __len__(self):
45 | return len(self.color_paths) if self.frame_limit < 0 else int(self.frame_limit)
46 |
47 |
48 | class Replica(BaseDataset):
49 |
50 | def __init__(self, dataset_config: dict):
51 | super().__init__(dataset_config)
52 | self.color_paths = sorted(
53 | list((self.dataset_path / "results").glob("frame*.jpg")))
54 | self.depth_paths = sorted(
55 | list((self.dataset_path / "results").glob("depth*.png")))
56 | self.load_poses(self.dataset_path / "traj.txt")
57 | print(f"Loaded {len(self.color_paths)} frames")
58 |
59 | def load_poses(self, path):
60 | self.poses = []
61 | with open(path, "r") as f:
62 | lines = f.readlines()
63 | for line in lines:
64 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
65 | self.poses.append(c2w.astype(np.float32))
66 |
67 | def __getitem__(self, index):
68 | color_data = cv2.imread(str(self.color_paths[index]))
69 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
70 | depth_data = cv2.imread(
71 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
72 | depth_data = depth_data.astype(np.float32) / self.depth_scale
73 | return index, color_data, depth_data, self.poses[index]
74 |
75 |
76 | class TUM_RGBD(BaseDataset):
77 | def __init__(self, dataset_config: dict):
78 | super().__init__(dataset_config)
79 | self.color_paths, self.depth_paths, self.poses = self.loadtum(
80 | self.dataset_path, frame_rate=32)
81 |
82 | def parse_list(self, filepath, skiprows=0):
83 | """ read list data """
84 | return np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
85 |
86 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08):
87 | """ pair images, depths, and poses """
88 | associations = []
89 | for i, t in enumerate(tstamp_image):
90 | if tstamp_pose is None:
91 | j = np.argmin(np.abs(tstamp_depth - t))
92 | if (np.abs(tstamp_depth[j] - t) < max_dt):
93 | associations.append((i, j))
94 | else:
95 | j = np.argmin(np.abs(tstamp_depth - t))
96 | k = np.argmin(np.abs(tstamp_pose - t))
97 | if (np.abs(tstamp_depth[j] - t) < max_dt) and (np.abs(tstamp_pose[k] - t) < max_dt):
98 | associations.append((i, j, k))
99 | return associations
100 |
101 | def loadtum(self, datapath, frame_rate=-1):
102 | """ read video data in tum-rgbd format """
103 | if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')):
104 | pose_list = os.path.join(datapath, 'groundtruth.txt')
105 | elif os.path.isfile(os.path.join(datapath, 'pose.txt')):
106 | pose_list = os.path.join(datapath, 'pose.txt')
107 |
108 | image_list = os.path.join(datapath, 'rgb.txt')
109 | depth_list = os.path.join(datapath, 'depth.txt')
110 |
111 | image_data = self.parse_list(image_list)
112 | depth_data = self.parse_list(depth_list)
113 | pose_data = self.parse_list(pose_list, skiprows=1)
114 | pose_vecs = pose_data[:, 1:].astype(np.float64)
115 |
116 | tstamp_image = image_data[:, 0].astype(np.float64)
117 | tstamp_depth = depth_data[:, 0].astype(np.float64)
118 | tstamp_pose = pose_data[:, 0].astype(np.float64)
119 | associations = self.associate_frames(
120 | tstamp_image, tstamp_depth, tstamp_pose)
121 |
122 | indicies = [0]
123 | for i in range(1, len(associations)):
124 | t0 = tstamp_image[associations[indicies[-1]][0]]
125 | t1 = tstamp_image[associations[i][0]]
126 | if t1 - t0 > 1.0 / frame_rate:
127 | indicies += [i]
128 |
129 | images, poses, depths = [], [], []
130 | inv_pose = None
131 | for ix in indicies:
132 | (i, j, k) = associations[ix]
133 | images += [os.path.join(datapath, image_data[i, 1])]
134 | depths += [os.path.join(datapath, depth_data[j, 1])]
135 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k])
136 | if inv_pose is None:
137 | inv_pose = np.linalg.inv(c2w)
138 | c2w = np.eye(4)
139 | else:
140 | c2w = inv_pose@c2w
141 | poses += [c2w.astype(np.float32)]
142 |
143 | return images, depths, poses
144 |
145 | def pose_matrix_from_quaternion(self, pvec):
146 | """ convert 4x4 pose matrix to (t, q) """
147 | from scipy.spatial.transform import Rotation
148 |
149 | pose = np.eye(4)
150 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix()
151 | pose[:3, 3] = pvec[:3]
152 | return pose
153 |
154 | def __getitem__(self, index):
155 | color_data = cv2.imread(str(self.color_paths[index]))
156 | if self.distortion is not None:
157 | color_data = cv2.undistort(
158 | color_data, self.intrinsics, self.distortion)
159 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
160 |
161 | depth_data = cv2.imread(
162 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
163 | depth_data = depth_data.astype(np.float32) / self.depth_scale
164 | edge = self.crop_edge
165 | if edge > 0:
166 | color_data = color_data[edge:-edge, edge:-edge]
167 | depth_data = depth_data[edge:-edge, edge:-edge]
168 | # Interpolate depth values for splatting
169 | return index, color_data, depth_data, self.poses[index]
170 |
171 |
172 | class ScanNet(BaseDataset):
173 | def __init__(self, dataset_config: dict):
174 | super().__init__(dataset_config)
175 | self.color_paths = sorted(list(
176 | (self.dataset_path / "rgb").glob("*.png")), key=lambda x: int(os.path.basename(x)[-9:-4]))
177 | self.depth_paths = sorted(list(
178 | (self.dataset_path / "depth").glob("*.TIFF")), key=lambda x: int(os.path.basename(x)[-10:-5]))
179 | self.n_img = len(self.color_paths)
180 | self.load_poses(self.dataset_path / "gt_pose.txt")
181 |
182 | def load_poses(self, path):
183 | self.poses = []
184 | pose_data = np.loadtxt(path, delimiter=" ", dtype=np.unicode_, skiprows=1)
185 | pose_vecs = pose_data[:, 0:].astype(np.float64)
186 | for i in range(self.n_img):
187 | quat = pose_vecs[i][4:]
188 | trans = pose_vecs[i][1:4]
189 | T = trimesh.transformations.quaternion_matrix(np.roll(quat, 1))
190 | T[:3, 3] = trans
191 | pose = T
192 | self.poses.append(pose)
193 |
194 | def __getitem__(self, index):
195 | color_data = cv2.imread(str(self.color_paths[index]))
196 | if self.distortion is not None:
197 | color_data = cv2.undistort(
198 | color_data, self.intrinsics, self.distortion)
199 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
200 | color_data = cv2.resize(color_data, (self.dataset_config["W"], self.dataset_config["H"]))
201 |
202 | depth_data = cv2.imread(
203 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
204 | depth_data = depth_data.astype(np.float32) / self.depth_scale
205 | edge = self.crop_edge
206 | if edge > 0:
207 | color_data = color_data[edge:-edge, edge:-edge]
208 | depth_data = depth_data[edge:-edge, edge:-edge]
209 | # Interpolate depth values for splatting
210 | return index, color_data, depth_data, self.poses[index]
211 |
212 |
213 | class ScanNetPP(BaseDataset):
214 | def __init__(self, dataset_config: dict):
215 | super().__init__(dataset_config)
216 | self.use_train_split = dataset_config["use_train_split"]
217 | self.train_test_split = json.load(open(f"{self.dataset_path}/dslr/train_test_lists.json", "r"))
218 | if self.use_train_split:
219 | self.image_names = self.train_test_split["train"]
220 | else:
221 | self.image_names = self.train_test_split["test"]
222 | self.load_data()
223 |
224 | def load_data(self):
225 | self.poses = []
226 | cams_path = self.dataset_path / "dslr" / "nerfstudio" / "transforms_undistorted.json"
227 | cams_metadata = json.load(open(str(cams_path), "r"))
228 | frames_key = "frames" if self.use_train_split else "test_frames"
229 | frames_metadata = cams_metadata[frames_key]
230 | frame2idx = {frame["file_path"]: index for index, frame in enumerate(frames_metadata)}
231 | P = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).astype(np.float32)
232 | for image_name in self.image_names:
233 | frame_metadata = frames_metadata[frame2idx[image_name]]
234 | # if self.ignore_bad and frame_metadata['is_bad']:
235 | # continue
236 | color_path = str(self.dataset_path / "dslr" / "undistorted_images" / image_name)
237 | depth_path = str(self.dataset_path / "dslr" / "undistorted_depths" / image_name.replace('.JPG', '.png'))
238 | self.color_paths.append(color_path)
239 | self.depth_paths.append(depth_path)
240 | c2w = np.array(frame_metadata["transform_matrix"]).astype(np.float32)
241 | c2w = P @ c2w @ P.T
242 | self.poses.append(c2w)
243 |
244 | def __len__(self):
245 | if self.use_train_split:
246 | return len(self.image_names) if self.frame_limit < 0 else int(self.frame_limit)
247 | else:
248 | return len(self.image_names)
249 |
250 | def __getitem__(self, index):
251 |
252 | color_data = np.asarray(imageio.imread(self.color_paths[index]), dtype=float)
253 | color_data = cv2.resize(color_data, (self.width, self.height), interpolation=cv2.INTER_LINEAR)
254 | color_data = color_data.astype(np.uint8)
255 |
256 | depth_data = np.asarray(imageio.imread(self.depth_paths[index]), dtype=np.int64)
257 | depth_data = cv2.resize(depth_data.astype(float), (self.width, self.height), interpolation=cv2.INTER_NEAREST)
258 | depth_data = depth_data.astype(np.float32) / self.depth_scale
259 | return index, color_data, depth_data, self.poses[index]
260 |
261 |
262 | def get_dataset(dataset_name: str):
263 | if dataset_name == "replica":
264 | return Replica
265 | elif dataset_name == "tum_rgbd":
266 | return TUM_RGBD
267 | elif dataset_name == "scan_net":
268 | return ScanNet
269 | elif dataset_name == "scannetpp":
270 | return ScanNetPP
271 | raise NotImplementedError(f"Dataset {dataset_name} not implemented")
272 |
--------------------------------------------------------------------------------
/src/entities/gaussian_slam.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Gaussian-SLAM class, which is responsible for controlling Mapper and Tracker
2 | It also decides when to start a new submap and when to update the estimated camera poses.
3 | """
4 | import os
5 | import pprint
6 | from argparse import ArgumentParser
7 | from datetime import datetime
8 | from pathlib import Path
9 |
10 | import numpy as np
11 | import torch
12 | import roma
13 |
14 | from src.entities.arguments import OptimizationParams
15 | from src.entities.datasets import get_dataset
16 | from src.entities.gaussian_model import GaussianModel
17 | from src.entities.mapper import Mapper
18 | from src.entities.tracker import Tracker
19 | from src.entities.lc import Loop_closure
20 | from src.entities.logger import Logger
21 | from src.utils.io_utils import save_dict_to_ckpt, save_dict_to_yaml
22 | from src.utils.mapper_utils import exceeds_motion_thresholds
23 | from src.utils.utils import np2torch, setup_seed, torch2np
24 | from src.utils.vis_utils import * # noqa - needed for debugging
25 |
26 |
27 | class GaussianSLAM(object):
28 |
29 | def __init__(self, config: dict) -> None:
30 |
31 | self._setup_output_path(config)
32 | self.device = "cuda"
33 | self.config = config
34 |
35 | self.scene_name = config["data"]["scene_name"]
36 | self.dataset_name = config["dataset_name"]
37 | self.dataset = get_dataset(config["dataset_name"])({**config["data"], **config["cam"]})
38 |
39 | n_frames = len(self.dataset)
40 | frame_ids = list(range(n_frames))
41 | self.mapping_frame_ids = frame_ids[::config["mapping"]["map_every"]] + [n_frames - 1]
42 |
43 | self.estimated_c2ws = torch.empty(len(self.dataset), 4, 4)
44 | self.estimated_c2ws[0] = torch.from_numpy(self.dataset[0][3])
45 | self.exposures_ab = torch.zeros(len(self.dataset), 2)
46 |
47 | save_dict_to_yaml(config, "config.yaml", directory=self.output_path)
48 |
49 | self.submap_using_motion_heuristic = config["mapping"]["submap_using_motion_heuristic"]
50 |
51 | self.keyframes_info = {}
52 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
53 |
54 | if self.submap_using_motion_heuristic:
55 | self.new_submap_frame_ids = [0]
56 | else:
57 | self.new_submap_frame_ids = frame_ids[::config["mapping"]["new_submap_every"]] + [n_frames - 1]
58 | self.new_submap_frame_ids.pop(0)
59 |
60 | self.logger = Logger(self.output_path, config["use_wandb"])
61 | self.mapper = Mapper(config["mapping"], self.dataset, self.logger)
62 | self.tracker = Tracker(config["tracking"], self.dataset, self.logger)
63 | self.enable_exposure = self.tracker.enable_exposure
64 | self.loop_closer = Loop_closure(config, self.dataset, self.logger)
65 | self.loop_closer.submap_path = self.output_path / "submaps"
66 |
67 | print('Tracking config')
68 | pprint.PrettyPrinter().pprint(config["tracking"])
69 | print('Mapping config')
70 | pprint.PrettyPrinter().pprint(config["mapping"])
71 | print('Loop closure config')
72 | pprint.PrettyPrinter().pprint(config["lc"])
73 |
74 |
75 | def _setup_output_path(self, config: dict) -> None:
76 | """ Sets up the output path for saving results based on the provided configuration. If the output path is not
77 | specified in the configuration, it creates a new directory with a timestamp.
78 | Args:
79 | config: A dictionary containing the experiment configuration including data and output path information.
80 | """
81 | if "output_path" not in config["data"]:
82 | output_path = Path(config["data"]["output_path"])
83 | self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
84 | self.output_path = output_path / self.timestamp
85 | else:
86 | self.output_path = Path(config["data"]["output_path"])
87 | self.output_path.mkdir(exist_ok=True, parents=True)
88 |
89 | os.makedirs(self.output_path / "mapping_vis", exist_ok=True)
90 | os.makedirs(self.output_path / "tracking_vis", exist_ok=True)
91 |
92 | def should_start_new_submap(self, frame_id: int) -> bool:
93 | """ Determines whether a new submap should be started based on the motion heuristic or specific frame IDs.
94 | Args:
95 | frame_id: The ID of the current frame being processed.
96 | Returns:
97 | A boolean indicating whether to start a new submap.
98 | """
99 | if self.submap_using_motion_heuristic:
100 | if exceeds_motion_thresholds(
101 | self.estimated_c2ws[frame_id], self.estimated_c2ws[self.new_submap_frame_ids[-1]],
102 | rot_thre=50, trans_thre=0.5):
103 | print(f"\nNew submap at {frame_id}")
104 | return True
105 | elif frame_id in self.new_submap_frame_ids:
106 | return True
107 | return False
108 |
109 | def save_current_submap(self, gaussian_model: GaussianModel):
110 | """Saving the current submap's checkpoint and resetting the Gaussian model
111 |
112 | Args:
113 | gaussian_model (GaussianModel): The current GaussianModel instance to capture and reset for the new submap.
114 | """
115 |
116 | gaussian_params = gaussian_model.capture_dict()
117 | submap_ckpt_name = str(self.submap_id).zfill(6)
118 | submap_ckpt = {
119 | "gaussian_params": gaussian_params,
120 | "submap_keyframes": sorted(list(self.keyframes_info.keys()))
121 | }
122 | save_dict_to_ckpt(
123 | submap_ckpt, f"{submap_ckpt_name}.ckpt", directory=self.output_path / "submaps")
124 |
125 | def start_new_submap(self, frame_id: int, gaussian_model: GaussianModel) -> None:
126 | """ Initializes a new submap.
127 | This function updates the submap count and optionally marks the current frame ID for new submap initiation.
128 | Args:
129 | frame_id: The ID of the current frame at which the new submap is started.
130 | gaussian_model: The current GaussianModel instance to capture and reset for the new submap.
131 | Returns:
132 | A new, reset GaussianModel instance for the new submap.
133 | """
134 |
135 | gaussian_model = GaussianModel(0)
136 | gaussian_model.training_setup(self.opt)
137 | self.mapper.keyframes = []
138 | self.keyframes_info = {}
139 | if self.submap_using_motion_heuristic:
140 | self.new_submap_frame_ids.append(frame_id)
141 | self.mapping_frame_ids.append(frame_id) if frame_id not in self.mapping_frame_ids else self.mapping_frame_ids
142 | self.submap_id += 1
143 | self.loop_closer.submap_id += 1
144 | return gaussian_model
145 |
146 | def rigid_transform_gaussians(self, gaussian_params, tsfm_matrix):
147 | '''
148 | Apply a rigid transformation to the Gaussian parameters.
149 |
150 | Args:
151 | gaussian_params (dict): Dictionary containing Gaussian parameters.
152 | tsfm_matrix (torch.Tensor): 4x4 rigid transformation matrix.
153 |
154 | Returns:
155 | dict: Updated Gaussian parameters after applying the transformation.
156 | '''
157 | # Transform Gaussian centers (xyz)
158 | tsfm_matrix = torch.from_numpy(tsfm_matrix).float()
159 | xyz = gaussian_params['xyz']
160 | pts_ones = torch.ones((xyz.shape[0], 1))
161 | pts_homo = torch.cat([xyz, pts_ones], dim=1)
162 | transformed_xyz = (tsfm_matrix @ pts_homo.T).T[:, :3]
163 | gaussian_params['xyz'] = transformed_xyz
164 |
165 | # Rotate covariance matrix (rotation)
166 | rotation = gaussian_params['rotation']
167 | cur_rot = roma.unitquat_to_rotmat(rotation)
168 | rot_mat = tsfm_matrix[:3, :3].unsqueeze(0) # Adding batch dimension
169 | new_rot = rot_mat @ cur_rot
170 | new_quat = roma.rotmat_to_unitquat(new_rot)
171 | gaussian_params['rotation'] = new_quat.squeeze()
172 |
173 | return gaussian_params
174 |
175 | def update_keyframe_poses(self, lc_output, submaps_kf_ids, cur_frame_id):
176 | '''
177 | Update the keyframe poses using the correction from pgo, currently update the frame range that covered by the keyframes.
178 |
179 | '''
180 | for correction in lc_output:
181 | submap_id = correction['submap_id']
182 | correct_tsfm = correction['correct_tsfm']
183 | submap_kf_ids = submaps_kf_ids[submap_id]
184 | min_id, max_id = min(submap_kf_ids), max(submap_kf_ids)
185 | self.estimated_c2ws[min_id:max_id + 1] = torch.from_numpy(correct_tsfm).float() @ self.estimated_c2ws[min_id:max_id + 1]
186 |
187 | # last tracked frame is based on last submap, update it as well
188 | self.estimated_c2ws[cur_frame_id] = torch.from_numpy(lc_output[-1]['correct_tsfm']).float() @ self.estimated_c2ws[cur_frame_id]
189 |
190 |
191 | def apply_correction_to_submaps(self, correction_list):
192 | submaps_kf_ids= {}
193 | for correction in correction_list:
194 | submap_id = correction['submap_id']
195 | correct_tsfm = correction['correct_tsfm']
196 |
197 | submap_ckpt_name = str(submap_id).zfill(6) + ".ckpt"
198 | submap_ckpt = torch.load(self.output_path / "submaps" / submap_ckpt_name)
199 | submaps_kf_ids[submap_id] = submap_ckpt["submap_keyframes"]
200 |
201 | gaussian_params = submap_ckpt["gaussian_params"]
202 | updated_gaussian_params = self.rigid_transform_gaussians(
203 | gaussian_params, correct_tsfm)
204 |
205 | submap_ckpt["gaussian_params"] = updated_gaussian_params
206 | torch.save(submap_ckpt, self.output_path / "submaps" / submap_ckpt_name)
207 | return submaps_kf_ids
208 |
209 | def run(self) -> None:
210 | """ Starts the main program flow for Gaussian-SLAM, including tracking and mapping. """
211 | setup_seed(self.config["seed"])
212 | gaussian_model = GaussianModel(0)
213 | gaussian_model.training_setup(self.opt)
214 | self.submap_id = 0
215 |
216 | for frame_id in range(len(self.dataset)):
217 |
218 | if frame_id in [0, 1]:
219 | estimated_c2w = self.dataset[frame_id][-1]
220 | exposure_ab = torch.nn.Parameter(torch.tensor(
221 | 0.0, device="cuda")), torch.nn.Parameter(torch.tensor(0.0, device="cuda"))
222 | else:
223 | estimated_c2w, exposure_ab = self.tracker.track(
224 | frame_id, gaussian_model,
225 | torch2np(self.estimated_c2ws[torch.tensor([0, frame_id - 2, frame_id - 1])]))
226 | exposure_ab = exposure_ab if self.enable_exposure else None
227 | self.estimated_c2ws[frame_id] = np2torch(estimated_c2w)
228 |
229 | # Reinitialize gaussian model for new segment
230 | if self.should_start_new_submap(frame_id):
231 | # first save current submap and its keyframe info
232 | self.save_current_submap(gaussian_model)
233 |
234 | # update submap infomation for loop closer
235 | self.loop_closer.update_submaps_info(self.keyframes_info)
236 |
237 | # apply loop closure
238 | lc_output = self.loop_closer.loop_closure(self.estimated_c2ws)
239 |
240 | if len(lc_output) > 0:
241 | submaps_kf_ids = self.apply_correction_to_submaps(lc_output)
242 | self.update_keyframe_poses(lc_output, submaps_kf_ids, frame_id)
243 |
244 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path)
245 |
246 | gaussian_model = self.start_new_submap(frame_id, gaussian_model)
247 |
248 | if frame_id in self.mapping_frame_ids:
249 | print("\nMapping frame", frame_id)
250 | gaussian_model.training_setup(self.opt, exposure_ab)
251 | estimate_c2w = torch2np(self.estimated_c2ws[frame_id])
252 | new_submap = not bool(self.keyframes_info)
253 | opt_dict = self.mapper.map(
254 | frame_id, estimate_c2w, gaussian_model, new_submap, exposure_ab)
255 |
256 | # Keyframes info update
257 | self.keyframes_info[frame_id] = {
258 | "keyframe_id": frame_id,
259 | "opt_dict": opt_dict,
260 | }
261 | if self.enable_exposure:
262 | self.keyframes_info[frame_id]["exposure_a"] = exposure_ab[0].item()
263 | self.keyframes_info[frame_id]["exposure_b"] = exposure_ab[1].item()
264 |
265 | if frame_id == len(self.dataset) - 1 and self.config['lc']['final']:
266 | print("\n Final loop closure ...")
267 | self.loop_closer.update_submaps_info(self.keyframes_info)
268 | lc_output = self.loop_closer.loop_closure(self.estimated_c2ws, final=True)
269 | if len(lc_output) > 0:
270 | submaps_kf_ids = self.apply_correction_to_submaps(lc_output)
271 | self.update_keyframe_poses(lc_output, submaps_kf_ids, frame_id)
272 | if self.enable_exposure:
273 | self.exposures_ab[frame_id] = torch.tensor([exposure_ab[0].item(), exposure_ab[1].item()])
274 |
275 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path)
276 | if self.enable_exposure:
277 | save_dict_to_ckpt(self.exposures_ab, "exposures_ab.ckpt", directory=self.output_path)
278 |
--------------------------------------------------------------------------------
/src/entities/logger.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Logger class, which is responsible for logging for both Mapper and the Tracker """
2 | from pathlib import Path
3 | from typing import Union
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import wandb
9 |
10 |
11 | class Logger(object):
12 |
13 | def __init__(self, output_path: Union[Path, str], use_wandb=False) -> None:
14 | self.output_path = Path(output_path)
15 | (self.output_path / "mapping_vis").mkdir(exist_ok=True, parents=True)
16 | self.use_wandb = use_wandb
17 |
18 | def log_tracking_iteration(self, frame_id, cur_pose, gt_quat, gt_trans, total_loss,
19 | color_loss, depth_loss, iter, num_iters,
20 | wandb_output=False, print_output=False) -> None:
21 | """ Logs tracking iteration metrics including pose error, losses, and optionally reports to Weights & Biases.
22 | Logs the error between the current pose estimate and ground truth quaternion and translation,
23 | as well as various loss metrics. Can output to wandb if enabled and specified, and print to console.
24 | Args:
25 | frame_id: Identifier for the current frame.
26 | cur_pose: The current estimated pose as a tensor (quaternion + translation).
27 | gt_quat: Ground truth quaternion.
28 | gt_trans: Ground truth translation.
29 | total_loss: Total computed loss for the current iteration.
30 | color_loss: Computed color loss for the current iteration.
31 | depth_loss: Computed depth loss for the current iteration.
32 | iter: The current iteration number.
33 | num_iters: The total number of iterations planned.
34 | wandb_output: Whether to output the log to wandb.
35 | print_output: Whether to print the log output.
36 | """
37 |
38 | quad_err = torch.abs(cur_pose[:4] - gt_quat).mean().item()
39 | trans_err = torch.abs(cur_pose[4:] - gt_trans).mean().item()
40 | if self.use_wandb and wandb_output:
41 | wandb.log(
42 | {
43 | "Tracking/idx": frame_id,
44 | "Tracking/cam_quad_err": quad_err,
45 | "Tracking/cam_position_err": trans_err,
46 | "Tracking/total_loss": total_loss.item(),
47 | "Tracking/color_loss": color_loss.item(),
48 | "Tracking/depth_loss": depth_loss.item(),
49 | "Tracking/num_iters": num_iters,
50 | })
51 | if iter == num_iters - 1:
52 | msg = f"frame_id: {frame_id}, cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f} "
53 | else:
54 | msg = f"iter: {iter}, color_loss: {color_loss.item():.5f}, depth_loss: {depth_loss.item():.5f} "
55 | msg = msg + f", cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f}"
56 | if print_output:
57 | print(msg, flush=True)
58 |
59 | def log_mapping_iteration(self, frame_id, new_pts_num, model_size, iter_opt_time, opt_dict: dict) -> None:
60 | """ Logs mapping iteration metrics including the number of new points, model size, and optimization times,
61 | and optionally reports to Weights & Biases (wandb).
62 | Args:
63 | frame_id: Identifier for the current frame.
64 | new_pts_num: The number of new points added in the current mapping iteration.
65 | model_size: The total size of the model after the current mapping iteration.
66 | iter_opt_time: Time taken per optimization iteration.
67 | opt_dict: A dictionary containing optimization metrics such as PSNR, color loss, and depth loss.
68 | """
69 | if self.use_wandb:
70 | wandb.log({"Mapping/idx": frame_id,
71 | "Mapping/num_total_gs": model_size,
72 | "Mapping/num_new_gs": new_pts_num,
73 | "Mapping/per_iteration_time": iter_opt_time,
74 | "Mapping/psnr_render": opt_dict["psnr_render"],
75 | "Mapping/color_loss": opt_dict[frame_id]["color_loss"],
76 | "Mapping/depth_loss": opt_dict[frame_id]["depth_loss"]})
77 |
78 | def vis_mapping_iteration(self, frame_id, iter, color, depth, gt_color, gt_depth, seeding_mask=None, interval=10) -> None:
79 | """
80 | Visualization of depth, color images and save to file.
81 |
82 | Args:
83 | frame_id (int): current frame index.
84 | iter (int): the iteration number.
85 | save_rendered_image (bool): whether to save the rgb image in separate folder
86 | img_dir (str): the directory to save the visualization.
87 | seeding_mask: used in mapper when adding gaussians, if not none.
88 | """
89 | if frame_id % interval != 0:
90 | return
91 | gt_depth_np = gt_depth.cpu().numpy()
92 | gt_color_np = gt_color.cpu().numpy()
93 |
94 | depth_np = depth.detach().cpu().numpy()
95 | color = torch.round(color * 255.0) / 255.0
96 | color_np = color.detach().cpu().numpy()
97 | depth_residual = np.abs(gt_depth_np - depth_np)
98 | depth_residual[gt_depth_np == 0.0] = 0.0
99 | # make errors >=5cm noticeable
100 | depth_residual = np.clip(depth_residual, 0.0, 0.05)
101 |
102 | color_residual = np.abs(gt_color_np - color_np)
103 | color_residual[np.squeeze(gt_depth_np == 0.0)] = 0.0
104 |
105 | # Determine Aspect Ratio and Figure Size
106 | aspect_ratio = color.shape[1] / color.shape[0]
107 | fig_height = 8
108 | # Adjust the multiplier as needed for better spacing
109 | fig_width = fig_height * aspect_ratio * 1.2
110 |
111 | fig, axs = plt.subplots(2, 3, figsize=(fig_width, fig_height))
112 | axs[0, 0].imshow(gt_depth_np, cmap="jet", vmin=0, vmax=6)
113 | axs[0, 0].set_title('Input Depth', fontsize=16)
114 | axs[0, 0].set_xticks([])
115 | axs[0, 0].set_yticks([])
116 | axs[0, 1].imshow(depth_np, cmap="jet", vmin=0, vmax=6)
117 | axs[0, 1].set_title('Rendered Depth', fontsize=16)
118 | axs[0, 1].set_xticks([])
119 | axs[0, 1].set_yticks([])
120 | axs[0, 2].imshow(depth_residual, cmap="plasma")
121 | axs[0, 2].set_title('Depth Residual', fontsize=16)
122 | axs[0, 2].set_xticks([])
123 | axs[0, 2].set_yticks([])
124 | gt_color_np = np.clip(gt_color_np, 0, 1)
125 | color_np = np.clip(color_np, 0, 1)
126 | color_residual = np.clip(color_residual, 0, 1)
127 | axs[1, 0].imshow(gt_color_np, cmap="plasma")
128 | axs[1, 0].set_title('Input RGB', fontsize=16)
129 | axs[1, 0].set_xticks([])
130 | axs[1, 0].set_yticks([])
131 | axs[1, 1].imshow(color_np, cmap="plasma")
132 | axs[1, 1].set_title('Rendered RGB', fontsize=16)
133 | axs[1, 1].set_xticks([])
134 | axs[1, 1].set_yticks([])
135 | if seeding_mask is not None:
136 | axs[1, 2].imshow(seeding_mask, cmap="gray")
137 | axs[1, 2].set_title('Densification Mask', fontsize=16)
138 | axs[1, 2].set_xticks([])
139 | axs[1, 2].set_yticks([])
140 | else:
141 | axs[1, 2].imshow(color_residual, cmap="plasma")
142 | axs[1, 2].set_title('RGB Residual', fontsize=16)
143 | axs[1, 2].set_xticks([])
144 | axs[1, 2].set_yticks([])
145 |
146 | for ax in axs.flatten():
147 | ax.axis('off')
148 | fig.tight_layout()
149 | plt.subplots_adjust(top=0.90) # Adjust top margin
150 | fig_name = str(self.output_path / "mapping_vis" / f'{frame_id:04d}_{iter:04d}.jpg')
151 | fig_title = f"Mapper Color/Depth at frame {frame_id:04d} iters {iter:04d}"
152 | plt.suptitle(fig_title, y=0.98, fontsize=20)
153 | plt.savefig(fig_name, dpi=250, bbox_inches='tight')
154 | plt.clf()
155 | plt.close()
156 | if self.use_wandb:
157 | log_title = "Mapping_vis/" + f'{frame_id:04d}_{iter:04d}'
158 | wandb.log({log_title: [wandb.Image(fig_name)]})
159 | print(f"Saved rendering vis of color/depth at {frame_id:04d}_{iter:04d}.jpg")
160 |
--------------------------------------------------------------------------------
/src/entities/losses.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 |
8 | def l1_loss(network_output: torch.Tensor, gt: torch.Tensor, agg="mean") -> torch.Tensor:
9 | """
10 | Computes the L1 loss, which is the mean absolute error between the network output and the ground truth.
11 |
12 | Args:
13 | network_output: The output from the network.
14 | gt: The ground truth tensor.
15 | agg: The aggregation method to be used. Defaults to "mean".
16 | Returns:
17 | The computed L1 loss.
18 | """
19 | l1_loss = torch.abs(network_output - gt)
20 | if agg == "mean":
21 | return l1_loss.mean()
22 | elif agg == "sum":
23 | return l1_loss.sum()
24 | elif agg == "none":
25 | return l1_loss
26 | else:
27 | raise ValueError("Invalid aggregation method.")
28 |
29 |
30 | def gaussian(window_size: int, sigma: float) -> torch.Tensor:
31 | """
32 | Creates a 1D Gaussian kernel.
33 |
34 | Args:
35 | window_size: The size of the window for the Gaussian kernel.
36 | sigma: The standard deviation of the Gaussian kernel.
37 |
38 | Returns:
39 | The 1D Gaussian kernel.
40 | """
41 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 /
42 | float(2 * sigma ** 2)) for x in range(window_size)])
43 | return gauss / gauss.sum()
44 |
45 |
46 | def create_window(window_size: int, channel: int) -> Variable:
47 | """
48 | Creates a 2D Gaussian window/kernel for SSIM computation.
49 |
50 | Args:
51 | window_size: The size of the window to be created.
52 | channel: The number of channels in the image.
53 |
54 | Returns:
55 | A 2D Gaussian window expanded to match the number of channels.
56 | """
57 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
58 | _2D_window = _1D_window.mm(
59 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
60 | window = Variable(_2D_window.expand(
61 | channel, 1, window_size, window_size).contiguous())
62 | return window
63 |
64 |
65 | def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int = 11, size_average: bool = True) -> torch.Tensor:
66 | """
67 | Computes the Structural Similarity Index (SSIM) between two images.
68 |
69 | Args:
70 | img1: The first image.
71 | img2: The second image.
72 | window_size: The size of the window to be used in SSIM computation. Defaults to 11.
73 | size_average: If True, averages the SSIM over all pixels. Defaults to True.
74 |
75 | Returns:
76 | The computed SSIM value.
77 | """
78 | channel = img1.size(-3)
79 | window = create_window(window_size, channel)
80 |
81 | if img1.is_cuda:
82 | window = window.cuda(img1.get_device())
83 | window = window.type_as(img1)
84 |
85 | return _ssim(img1, img2, window, window_size, channel, size_average)
86 |
87 |
88 | def _ssim(img1: torch.Tensor, img2: torch.Tensor, window: Variable, window_size: int,
89 | channel: int, size_average: bool = True) -> torch.Tensor:
90 | """
91 | Internal function to compute the Structural Similarity Index (SSIM) between two images.
92 |
93 | Args:
94 | img1: The first image.
95 | img2: The second image.
96 | window: The Gaussian window/kernel for SSIM computation.
97 | window_size: The size of the window to be used in SSIM computation.
98 | channel: The number of channels in the image.
99 | size_average: If True, averages the SSIM over all pixels.
100 |
101 | Returns:
102 | The computed SSIM value.
103 | """
104 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
105 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
106 |
107 | mu1_sq = mu1.pow(2)
108 | mu2_sq = mu2.pow(2)
109 | mu1_mu2 = mu1 * mu2
110 |
111 | sigma1_sq = F.conv2d(img1 * img1, window,
112 | padding=window_size // 2, groups=channel) - mu1_sq
113 | sigma2_sq = F.conv2d(img2 * img2, window,
114 | padding=window_size // 2, groups=channel) - mu2_sq
115 | sigma12 = F.conv2d(img1 * img2, window,
116 | padding=window_size // 2, groups=channel) - mu1_mu2
117 |
118 | C1 = 0.01 ** 2
119 | C2 = 0.03 ** 2
120 |
121 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
122 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
123 |
124 | if size_average:
125 | return ssim_map.mean()
126 | else:
127 | return ssim_map.mean(1).mean(1).mean(1)
128 |
129 |
130 | def isotropic_loss(scaling: torch.Tensor) -> torch.Tensor:
131 | """
132 | Computes loss enforcing isotropic scaling for the 3D Gaussians
133 | Args:
134 | scaling: scaling tensor of 3D Gaussians of shape (n, 3)
135 | Returns:
136 | The computed isotropic loss
137 | """
138 | mean_scaling = scaling.mean(dim=1, keepdim=True)
139 | isotropic_diff = torch.abs(scaling - mean_scaling * torch.ones_like(scaling))
140 | return isotropic_diff.mean()
141 |
--------------------------------------------------------------------------------
/src/entities/tracker.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Mapper class, which is responsible scene mapping: Paper Section 3.4 """
2 | from argparse import ArgumentParser
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 | from scipy.spatial.transform import Rotation as R
9 |
10 | from src.entities.arguments import OptimizationParams
11 | from src.entities.losses import l1_loss
12 | from src.entities.gaussian_model import GaussianModel
13 | from src.entities.logger import Logger
14 | from src.entities.datasets import BaseDataset
15 | from src.entities.visual_odometer import VisualOdometer
16 | from src.utils.gaussian_model_utils import build_rotation
17 | from src.utils.tracker_utils import (compute_camera_opt_params,
18 | extrapolate_poses, multiply_quaternions,
19 | transformation_to_quaternion)
20 | from src.utils.utils import (get_render_settings, np2torch,
21 | render_gaussian_model, torch2np)
22 |
23 |
24 | class Tracker(object):
25 | def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None:
26 | """ Initializes the Tracker with a given configuration, dataset, and logger.
27 | Args:
28 | config: Configuration dictionary specifying hyperparameters and operational settings.
29 | dataset: The dataset object providing access to the sequence of frames.
30 | logger: Logger object for logging the tracking process.
31 | """
32 | self.dataset = dataset
33 | self.logger = logger
34 | self.config = config
35 | self.filter_alpha = self.config["filter_alpha"]
36 | self.filter_outlier_depth = self.config["filter_outlier_depth"]
37 | self.alpha_thre = self.config["alpha_thre"]
38 | self.soft_alpha = self.config["soft_alpha"]
39 | self.mask_invalid_depth_in_color_loss = self.config["mask_invalid_depth"]
40 | self.w_color_loss = self.config["w_color_loss"]
41 | self.transform = torchvision.transforms.ToTensor()
42 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
43 | self.frame_depth_loss = []
44 | self.frame_color_loss = []
45 | self.odometry_type = self.config["odometry_type"]
46 | self.help_camera_initialization = self.config["help_camera_initialization"]
47 | self.init_err_ratio = self.config["init_err_ratio"]
48 | self.enable_exposure = self.config["enable_exposure"]
49 | self.odometer = VisualOdometer(self.dataset.intrinsics, self.config["odometer_method"])
50 |
51 | def compute_losses(self, gaussian_model: GaussianModel, render_settings: dict,
52 | opt_cam_rot: torch.Tensor, opt_cam_trans: torch.Tensor,
53 | gt_color: torch.Tensor, gt_depth: torch.Tensor, depth_mask: torch.Tensor,
54 | exposure_ab=None) -> tuple:
55 | """ Computes the tracking losses with respect to ground truth color and depth.
56 | Args:
57 | gaussian_model: The current state of the Gaussian model of the scene.
58 | render_settings: Dictionary containing rendering settings such as image dimensions and camera intrinsics.
59 | opt_cam_rot: Optimizable tensor representing the camera's rotation.
60 | opt_cam_trans: Optimizable tensor representing the camera's translation.
61 | gt_color: Ground truth color image tensor.
62 | gt_depth: Ground truth depth image tensor.
63 | depth_mask: Binary mask indicating valid depth values in the ground truth depth image.
64 | Returns:
65 | A tuple containing losses and renders
66 | """
67 | rel_transform = torch.eye(4).cuda().float()
68 | rel_transform[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None]))[0]
69 | rel_transform[:3, 3] = opt_cam_trans
70 |
71 | pts = gaussian_model.get_xyz()
72 | pts_ones = torch.ones(pts.shape[0], 1).cuda().float()
73 | pts4 = torch.cat((pts, pts_ones), dim=1)
74 | transformed_pts = (rel_transform @ pts4.T).T[:, :3]
75 |
76 | quat = F.normalize(opt_cam_rot[None])
77 | _rotations = multiply_quaternions(gaussian_model.get_rotation(), quat.unsqueeze(0)).squeeze(0)
78 |
79 | render_dict = render_gaussian_model(gaussian_model, render_settings,
80 | override_means_3d=transformed_pts, override_rotations=_rotations)
81 | rendered_color, rendered_depth = render_dict["color"], render_dict["depth"]
82 | if self.enable_exposure:
83 | rendered_color = torch.clamp(torch.exp(exposure_ab[0]) * rendered_color + exposure_ab[1], 0, 1.)
84 | alpha_mask = render_dict["alpha"] > self.alpha_thre
85 |
86 | tracking_mask = torch.ones_like(alpha_mask).bool()
87 | tracking_mask &= depth_mask
88 | depth_err = torch.abs(rendered_depth - gt_depth) * depth_mask
89 |
90 | if self.filter_alpha:
91 | tracking_mask &= alpha_mask
92 | if self.filter_outlier_depth and torch.median(depth_err) > 0:
93 | tracking_mask &= depth_err < 50 * torch.median(depth_err)
94 |
95 | color_loss = l1_loss(rendered_color, gt_color, agg="none")
96 | depth_loss = l1_loss(rendered_depth, gt_depth, agg="none") * tracking_mask
97 |
98 | if self.soft_alpha:
99 | alpha = render_dict["alpha"] ** 3
100 | color_loss *= alpha
101 | depth_loss *= alpha
102 | if self.mask_invalid_depth_in_color_loss:
103 | color_loss *= tracking_mask
104 | else:
105 | color_loss *= tracking_mask
106 |
107 | color_loss = color_loss.sum()
108 | depth_loss = depth_loss.sum()
109 |
110 | return color_loss, depth_loss, rendered_color, rendered_depth, alpha_mask
111 |
112 | def track(self, frame_id: int, gaussian_model: GaussianModel, prev_c2ws: np.ndarray) -> np.ndarray:
113 | """
114 | Updates the camera pose estimation for the current frame based on the provided image and depth, using either ground truth poses,
115 | constant speed assumption, or visual odometry.
116 | Args:
117 | frame_id: Index of the current frame being processed.
118 | gaussian_model: The current Gaussian model of the scene.
119 | prev_c2ws: Array containing the camera-to-world transformation matrices for the frames (0, i - 2, i - 1)
120 | Returns:
121 | The updated camera-to-world transformation matrix for the current frame.
122 | """
123 | _, image, depth, gt_c2w = self.dataset[frame_id]
124 |
125 | if (self.help_camera_initialization or self.odometry_type == "odometer") and self.odometer.last_rgbd is None:
126 | _, last_image, last_depth, _ = self.dataset[frame_id - 1]
127 | self.odometer.update_last_rgbd(last_image, last_depth)
128 |
129 | if self.odometry_type == "gt":
130 | return gt_c2w
131 | elif self.odometry_type == "const_speed":
132 | init_c2w = extrapolate_poses(prev_c2ws[1:])
133 | elif self.odometry_type == "odometer":
134 | odometer_rel = self.odometer.estimate_rel_pose(image, depth)
135 | init_c2w = prev_c2ws[-1] @ odometer_rel
136 | elif self.odometry_type == "previous":
137 | init_c2w = prev_c2ws[-1]
138 |
139 | last_c2w = prev_c2ws[-1]
140 | last_w2c = np.linalg.inv(last_c2w)
141 | init_rel = init_c2w @ np.linalg.inv(last_c2w)
142 | init_rel_w2c = np.linalg.inv(init_rel)
143 | reference_w2c = last_w2c
144 | render_settings = get_render_settings(
145 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, reference_w2c)
146 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c)
147 | if self.enable_exposure:
148 | exposure_ab = torch.nn.Parameter(torch.tensor(
149 | 0.0, device="cuda")), torch.nn.Parameter(torch.tensor(0.0, device="cuda"))
150 | else:
151 | exposure_ab = None
152 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config, exposure_ab)
153 |
154 | gt_color = self.transform(image).cuda()
155 | gt_depth = np2torch(depth, "cuda")
156 | depth_mask = gt_depth > 0.0
157 | gt_trans = np2torch(gt_c2w[:3, 3])
158 | gt_quat = np2torch(R.from_matrix(gt_c2w[:3, :3]).as_quat(canonical=True)[[3, 0, 1, 2]])
159 | num_iters = self.config["iterations"]
160 | current_min_loss = float("inf")
161 |
162 | print(f"\nTracking frame {frame_id}")
163 | # Initial loss check
164 | color_loss, depth_loss, _, _, _ = self.compute_losses(gaussian_model, render_settings, opt_cam_rot,
165 | opt_cam_trans, gt_color, gt_depth, depth_mask,
166 | exposure_ab)
167 | if len(self.frame_color_loss) > 0 and (
168 | color_loss.item() > self.init_err_ratio * np.median(self.frame_color_loss)
169 | or depth_loss.item() > self.init_err_ratio * np.median(self.frame_depth_loss)
170 | ):
171 | num_iters *= 2
172 | print(f"Higher initial loss, increasing num_iters to {num_iters}")
173 | if self.help_camera_initialization and self.odometry_type != "odometer":
174 | _, last_image, last_depth, _ = self.dataset[frame_id - 1]
175 | self.odometer.update_last_rgbd(last_image, last_depth)
176 | odometer_rel = self.odometer.estimate_rel_pose(image, depth)
177 | init_c2w = last_c2w @ odometer_rel
178 | init_rel = init_c2w @ np.linalg.inv(last_c2w)
179 | init_rel_w2c = np.linalg.inv(init_rel)
180 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c)
181 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config, exposure_ab)
182 | render_settings = get_render_settings(
183 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, last_w2c)
184 | print(f"re-init with odometer for frame {frame_id}")
185 |
186 | for iter in range(num_iters):
187 | color_loss, depth_loss, _, _, _, = self.compute_losses(
188 | gaussian_model, render_settings, opt_cam_rot, opt_cam_trans, gt_color, gt_depth, depth_mask, exposure_ab)
189 |
190 | total_loss = (self.w_color_loss * color_loss + (1 - self.w_color_loss) * depth_loss)
191 | total_loss.backward()
192 | gaussian_model.optimizer.step()
193 | # gaussian_model.scheduler.step(total_loss, epoch=iter)
194 | gaussian_model.optimizer.zero_grad(set_to_none=True)
195 |
196 | with torch.no_grad():
197 | if total_loss.item() < current_min_loss:
198 | current_min_loss = total_loss.item()
199 | best_w2c = torch.eye(4)
200 | best_w2c[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None].clone().detach().cpu()))[0]
201 | best_w2c[:3, 3] = opt_cam_trans.clone().detach().cpu()
202 |
203 | cur_quat, cur_trans = F.normalize(opt_cam_rot[None].clone().detach()), opt_cam_trans.clone().detach()
204 | cur_rel_w2c = torch.eye(4)
205 | cur_rel_w2c[:3, :3] = build_rotation(cur_quat)[0]
206 | cur_rel_w2c[:3, 3] = cur_trans
207 | if iter == num_iters - 1:
208 | cur_w2c = torch.from_numpy(reference_w2c) @ best_w2c
209 | else:
210 | cur_w2c = torch.from_numpy(reference_w2c) @ cur_rel_w2c
211 | cur_c2w = torch.inverse(cur_w2c)
212 | cur_cam = transformation_to_quaternion(cur_c2w)
213 | if (gt_quat * cur_cam[:4]).sum() < 0: # for logging purpose
214 | gt_quat *= -1
215 | if iter == num_iters - 1:
216 | self.frame_color_loss.append(color_loss.item())
217 | self.frame_depth_loss.append(depth_loss.item())
218 | self.logger.log_tracking_iteration(
219 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters,
220 | wandb_output=True, print_output=True)
221 | elif iter % 20 == 0:
222 | self.logger.log_tracking_iteration(
223 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters,
224 | wandb_output=False, print_output=True)
225 |
226 | final_c2w = torch.inverse(torch.from_numpy(reference_w2c) @ best_w2c)
227 | final_c2w[-1, :] = torch.tensor([0., 0., 0., 1.], dtype=final_c2w.dtype, device=final_c2w.device)
228 | return torch2np(final_c2w), exposure_ab
229 |
--------------------------------------------------------------------------------
/src/entities/visual_odometer.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Odometer class, which is allows for fast pose estimation from RGBD neighbor frames """
2 | import numpy as np
3 | import open3d as o3d
4 | import open3d.core as o3c
5 |
6 |
7 | class VisualOdometer(object):
8 |
9 | def __init__(self, intrinsics: np.ndarray, method_name="hybrid", device="cuda"):
10 | """ Initializes the visual odometry system with specified intrinsics, method, and device.
11 | Args:
12 | intrinsics: Camera intrinsic parameters.
13 | method_name: The name of the odometry computation method to use ('hybrid' or 'point_to_plane').
14 | device: The computation device ('cuda' or 'cpu').
15 | """
16 | device = "CUDA:0" if device == "cuda" else "CPU:0"
17 | self.device = o3c.Device(device)
18 | self.intrinsics = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64)
19 | self.last_abs_pose = None
20 | self.last_frame = None
21 | self.criteria_list = [
22 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500),
23 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500),
24 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500)]
25 | self.setup_method(method_name)
26 | self.max_depth = 10.0
27 | self.depth_scale = 1.0
28 | self.last_rgbd = None
29 |
30 | def setup_method(self, method_name: str) -> None:
31 | """ Sets up the odometry computation method based on the provided method name.
32 | Args:
33 | method_name: The name of the odometry method to use ('hybrid' or 'point_to_plane').
34 | """
35 | if method_name == "hybrid":
36 | self.method = o3d.t.pipelines.odometry.Method.Hybrid
37 | elif method_name == "point_to_plane":
38 | self.method = o3d.t.pipelines.odometry.Method.PointToPlane
39 | else:
40 | raise ValueError("Odometry method does not exist!")
41 |
42 | def update_last_rgbd(self, image: np.ndarray, depth: np.ndarray) -> None:
43 | """ Updates the last RGB-D frame stored in the system with a new RGB-D frame constructed from provided image and depth.
44 | Args:
45 | image: The new RGB image as a numpy ndarray.
46 | depth: The new depth image as a numpy ndarray.
47 | """
48 | self.last_rgbd = o3d.t.geometry.RGBDImage(
49 | o3d.t.geometry.Image(np.ascontiguousarray(
50 | image).astype(np.float32)).to(self.device),
51 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device))
52 |
53 | def estimate_rel_pose(self, image: np.ndarray, depth: np.ndarray, init_transform=np.eye(4)):
54 | """ Estimates the relative pose of the current frame with respect to the last frame using RGB-D odometry.
55 | Args:
56 | image: The current RGB image as a numpy ndarray.
57 | depth: The current depth image as a numpy ndarray.
58 | init_transform: An initial transformation guess as a numpy ndarray. Defaults to the identity matrix.
59 | Returns:
60 | The relative transformation matrix as a numpy ndarray.
61 | """
62 | rgbd = o3d.t.geometry.RGBDImage(
63 | o3d.t.geometry.Image(np.ascontiguousarray(image).astype(np.float32)).to(self.device),
64 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device))
65 | rel_transform = o3d.t.pipelines.odometry.rgbd_odometry_multi_scale(
66 | self.last_rgbd, rgbd, self.intrinsics, o3c.Tensor(init_transform),
67 | self.depth_scale, self.max_depth, self.criteria_list, self.method)
68 | self.last_rgbd = rgbd.clone()
69 |
70 | # Adjust for the coordinate system difference
71 | rel_transform = rel_transform.transformation.cpu().numpy()
72 | rel_transform[0, [1, 2, 3]] *= -1
73 | rel_transform[1, [0, 2, 3]] *= -1
74 | rel_transform[2, [0, 1, 3]] *= -1
75 |
76 | return rel_transform
77 |
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/src/evaluation/__init__.py
--------------------------------------------------------------------------------
/src/evaluation/evaluate_merged_map.py:
--------------------------------------------------------------------------------
1 | """ This module is responsible for merging submaps. """
2 | from argparse import ArgumentParser
3 |
4 | import faiss
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | from torch.utils.data import Dataset
9 | from tqdm import tqdm
10 |
11 | from src.entities.arguments import OptimizationParams
12 | from src.entities.gaussian_model import GaussianModel
13 | from src.entities.losses import isotropic_loss, l1_loss, ssim
14 | from src.utils.utils import (batch_search_faiss, get_render_settings,
15 | np2ptcloud, render_gaussian_model, torch2np)
16 | from src.utils.gaussian_model_utils import BasicPointCloud
17 |
18 |
19 | class RenderFrames(Dataset):
20 | """A dataset class for loading keyframes along with their estimated camera poses and render settings."""
21 | def __init__(self, dataset, render_poses: np.ndarray, height: int, width: int, fx: float, fy: float, exposures_ab=None):
22 | self.dataset = dataset
23 | self.render_poses = render_poses
24 | self.height = height
25 | self.width = width
26 | self.fx = fx
27 | self.fy = fy
28 | self.device = "cuda"
29 | self.stride = 1
30 | self.exposures_ab = exposures_ab
31 | if len(dataset) > 1000:
32 | self.stride = len(dataset) // 1000
33 |
34 | def __len__(self) -> int:
35 | return len(self.dataset) // self.stride
36 |
37 | def __getitem__(self, idx):
38 | idx = idx * self.stride
39 | color = (torch.from_numpy(
40 | self.dataset[idx][1]) / 255.0).float().to(self.device)
41 | depth = torch.from_numpy(self.dataset[idx][2]).float().to(self.device)
42 | estimate_c2w = self.render_poses[idx]
43 | estimate_w2c = np.linalg.inv(estimate_c2w)
44 | frame = {
45 | "frame_id": idx,
46 | "color": color,
47 | "depth": depth,
48 | "render_settings": get_render_settings(
49 | self.width, self.height, self.dataset.intrinsics, estimate_w2c)
50 | }
51 | if self.exposures_ab is not None:
52 | frame["exposure_ab"] = self.exposures_ab[idx]
53 | return frame
54 |
55 |
56 | def merge_submaps(submaps_paths: list, radius: float = 0.0001, device: str = "cuda") -> o3d.geometry.PointCloud:
57 | """ Merge submaps into a single point cloud, which is then used for global map refinement.
58 | Args:
59 | segments_paths (list): Folder path of the submaps.
60 | radius (float, optional): Nearest neighbor distance threshold for adding a point. Defaults to 0.0001.
61 | device (str, optional): Defaults to "cuda".
62 |
63 | Returns:
64 | o3d.geometry.PointCloud: merged point cloud
65 | """
66 | pts_index = faiss.IndexFlatL2(3)
67 | if device == "cuda":
68 | pts_index = faiss.index_cpu_to_gpu(
69 | faiss.StandardGpuResources(),
70 | 0,
71 | faiss.IndexIVFFlat(faiss.IndexFlatL2(3), 3, 500, faiss.METRIC_L2))
72 | pts_index.nprobe = 5
73 | merged_pts = []
74 | for submap_path in tqdm(submaps_paths, desc="Merging submaps"):
75 | gaussian_params = torch.load(submap_path)["gaussian_params"]
76 | current_pts = gaussian_params["xyz"].to(device).float().contiguous()
77 | pts_index.train(current_pts)
78 | distances, _ = batch_search_faiss(pts_index, current_pts, 8)
79 | neighbor_num = (distances < radius).sum(axis=1).int()
80 | ids_to_include = torch.where(neighbor_num == 0)[0]
81 | pts_index.add(current_pts[ids_to_include])
82 | merged_pts.append(current_pts[ids_to_include])
83 | pts = torch2np(torch.vstack(merged_pts))
84 | pt_cloud = np2ptcloud(pts, np.zeros_like(pts))
85 |
86 | # Downsampling if the total number of points is too large
87 | if len(pt_cloud.points) > 1_000_000:
88 | voxel_size = 0.02
89 | pt_cloud = pt_cloud.voxel_down_sample(voxel_size)
90 | print(f"Downsampled point cloud to {len(pt_cloud.points)} points")
91 | filtered_pt_cloud, _ = pt_cloud.remove_statistical_outlier(nb_neighbors=40, std_ratio=3.0)
92 | del pts_index
93 | return filtered_pt_cloud
94 |
95 |
96 | def refine_global_map(pt_cloud: o3d.geometry.PointCloud, training_frames: list, max_iterations: int,
97 | export_refine_mesh=False, output_dir=".",
98 | len_frames=None, o3d_intrinsic=None, enable_sh=True, enable_exposure=False) -> GaussianModel:
99 | """Refines a global map based on the merged point cloud and training keyframes frames.
100 | Args:
101 | pt_cloud (o3d.geometry.PointCloud): The merged point cloud used for refinement.
102 | training_frames (list): A list of training frames for map refinement.
103 | max_iterations (int): The maximum number of iterations to perform for refinement.
104 | Returns:
105 | GaussianModel: The refined global map as a Gaussian model.
106 | """
107 | opt_params = OptimizationParams(ArgumentParser(description="Training script parameters"))
108 |
109 | gaussian_model = GaussianModel(3)
110 | gaussian_model.active_sh_degree = 0
111 | if pt_cloud is None:
112 | output_mesh = output_dir / "mesh" / "cleaned_mesh.ply"
113 | output_mesh = o3d.io.read_triangle_mesh(str(output_mesh))
114 | pcd = o3d.geometry.PointCloud()
115 | pcd.points = output_mesh.vertices
116 | pcd.colors = output_mesh.vertex_colors
117 | pcd = pcd.voxel_down_sample(voxel_size=0.02)
118 | pcd = BasicPointCloud(points=np.asarray(pcd.points),
119 | colors=np.asarray(pcd.colors))
120 | gaussian_model.create_from_pcd(pcd, 1.0)
121 | gaussian_model.training_setup(opt_params)
122 | else:
123 | gaussian_model.training_setup(opt_params)
124 | gaussian_model.add_points(pt_cloud)
125 |
126 | iteration = 0
127 | for iteration in tqdm(range(max_iterations), desc="Refinement"):
128 | training_frame = next(training_frames)
129 | gaussian_model.update_learning_rate(iteration)
130 | if enable_sh and iteration > 0 and iteration % 1000 == 0:
131 | gaussian_model.oneupSHdegree()
132 | gt_color, gt_depth, render_settings = (
133 | training_frame["color"].squeeze(0),
134 | training_frame["depth"].squeeze(0),
135 | training_frame["render_settings"])
136 |
137 | render_dict = render_gaussian_model(gaussian_model, render_settings)
138 | rendered_color, rendered_depth = (render_dict["color"].permute(1, 2, 0), render_dict["depth"])
139 | if enable_exposure and training_frame.get("exposure_ab") is not None:
140 | rendered_color = torch.clamp(
141 | rendered_color * torch.exp(training_frame["exposure_ab"][0,0]) + training_frame["exposure_ab"][0,1], 0, 1.)
142 |
143 | reg_loss = isotropic_loss(gaussian_model.get_scaling())
144 | depth_mask = (gt_depth > 0)
145 | color_loss = (1.0 - opt_params.lambda_dssim) * l1_loss(
146 | rendered_color[depth_mask, :], gt_color[depth_mask, :]
147 | ) + opt_params.lambda_dssim * (1.0 - ssim(rendered_color, gt_color))
148 | depth_loss = l1_loss(
149 | rendered_depth[:, depth_mask], gt_depth[depth_mask])
150 |
151 | total_loss = color_loss + depth_loss + reg_loss
152 | total_loss.backward()
153 |
154 | with torch.no_grad():
155 | if iteration % 500 == 0:
156 | prune_mask = (gaussian_model.get_opacity() < 0.005).squeeze()
157 | gaussian_model.prune_points(prune_mask)
158 |
159 | # Optimizer step
160 | gaussian_model.optimizer.step()
161 | gaussian_model.optimizer.zero_grad(set_to_none=True)
162 | iteration += 1
163 |
164 | try:
165 | if export_refine_mesh:
166 | output_dir = output_dir / "mesh" / "refined_mesh.ply"
167 | scale = 1.0
168 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
169 | voxel_length=5.0 * scale / 512.0,
170 | sdf_trunc=0.04 * scale,
171 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8)
172 | for i in tqdm(range(len_frames), desc="Integrating mesh"): # one cycle
173 | training_frame = next(training_frames)
174 | gt_color, gt_depth, render_settings, estimate_w2c = (
175 | training_frame["color"].squeeze(0),
176 | training_frame["depth"].squeeze(0),
177 | training_frame["render_settings"],
178 | training_frame["estimate_w2c"])
179 |
180 | render_dict = render_gaussian_model(gaussian_model, render_settings)
181 | rendered_color, rendered_depth = (
182 | render_dict["color"].permute(1, 2, 0), render_dict["depth"])
183 | rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0)
184 |
185 | rendered_color = (
186 | torch2np(rendered_color) * 255).astype(np.uint8)
187 | rendered_depth = torch2np(rendered_depth.squeeze())
188 | # rendered_depth = filter_depth_outliers(
189 | # rendered_depth, kernel_size=20, threshold=0.1)
190 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
191 | o3d.geometry.Image(np.ascontiguousarray(rendered_color)),
192 | o3d.geometry.Image(rendered_depth),
193 | depth_scale=scale,
194 | depth_trunc=30,
195 | convert_rgb_to_intensity=False)
196 | volume.integrate(
197 | rgbd, o3d_intrinsic, estimate_w2c.squeeze().cpu().numpy().astype(np.float64))
198 |
199 | o3d_mesh = volume.extract_triangle_mesh()
200 | o3d.io.write_triangle_mesh(str(output_dir), o3d_mesh)
201 | print(f"Refined mesh saved to {output_dir}")
202 |
203 | except Exception as e:
204 | print(f"Error export_refine_mesh in refine_global_map:\n {e}")
205 |
206 | return gaussian_model
207 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_reconstruction.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | import trimesh
9 | from evaluate_3d_reconstruction import run_evaluation
10 | from tqdm import tqdm
11 |
12 |
13 | def normalize(x):
14 | return x / np.linalg.norm(x)
15 |
16 |
17 | def get_align_transformation(rec_meshfile, gt_meshfile):
18 | """
19 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh.
20 | """
21 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
22 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
23 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices)
24 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices)
25 | trans_init = np.eye(4)
26 | threshold = 0.1
27 | reg_p2p = o3d.pipelines.registration.registration_icp(
28 | o3d_rec_pc,
29 | o3d_gt_pc,
30 | threshold,
31 | trans_init,
32 | o3d.pipelines.registration.TransformationEstimationPointToPoint(),
33 | )
34 | transformation = reg_p2p.transformation
35 | return transformation
36 |
37 |
38 | def check_proj(points, W, H, fx, fy, cx, cy, c2w):
39 | """
40 | Check if points can be projected into the camera view.
41 |
42 | Returns:
43 | bool: True if there are points can be projected
44 |
45 | """
46 | c2w = c2w.copy()
47 | c2w[:3, 1] *= -1.0
48 | c2w[:3, 2] *= -1.0
49 | points = torch.from_numpy(points).cuda().clone()
50 | w2c = np.linalg.inv(c2w)
51 | w2c = torch.from_numpy(w2c).cuda().float()
52 | K = torch.from_numpy(
53 | np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]).reshape(3, 3)
54 | ).cuda()
55 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
56 | homo_points = (
57 | torch.cat([points, ones], dim=1).reshape(-1, 4, 1).cuda().float()
58 | ) # (N, 4)
59 | cam_cord_homo = w2c @ homo_points # (N, 4, 1)=(4,4)*(N, 4, 1)
60 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1)
61 | cam_cord[:, 0] *= -1
62 | uv = K.float() @ cam_cord.float()
63 | z = uv[:, -1:] + 1e-5
64 | uv = uv[:, :2] / z
65 | uv = uv.float().squeeze(-1).cpu().numpy()
66 | edge = 0
67 | mask = (
68 | (0 <= -z[:, 0, 0].cpu().numpy())
69 | & (uv[:, 0] < W - edge)
70 | & (uv[:, 0] > edge)
71 | & (uv[:, 1] < H - edge)
72 | & (uv[:, 1] > edge)
73 | )
74 | return mask.sum() > 0
75 |
76 |
77 | def get_cam_position(gt_meshfile):
78 | mesh_gt = trimesh.load(gt_meshfile)
79 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt)
80 | extents[2] *= 0.7
81 | extents[1] *= 0.7
82 | extents[0] *= 0.3
83 | transform = np.linalg.inv(to_origin)
84 | transform[2, 3] += 0.4
85 | return extents, transform
86 |
87 |
88 | def viewmatrix(z, up, pos):
89 | vec2 = normalize(z)
90 | vec1_avg = up
91 | vec0 = normalize(np.cross(vec1_avg, vec2))
92 | vec1 = normalize(np.cross(vec2, vec0))
93 | m = np.stack([vec0, vec1, vec2, pos], 1)
94 | return m
95 |
96 |
97 | def calc_2d_metric(
98 | rec_meshfile, gt_meshfile, unseen_gt_pointcloud_file, align=True, n_imgs=1000
99 | ):
100 | """
101 | 2D reconstruction metric, depth L1 loss.
102 |
103 | """
104 | H = 500
105 | W = 500
106 | focal = 300
107 | fx = focal
108 | fy = focal
109 | cx = H / 2.0 - 0.5
110 | cy = W / 2.0 - 0.5
111 |
112 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
113 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
114 | pc_unseen = np.load(unseen_gt_pointcloud_file)
115 | if align:
116 | transformation = get_align_transformation(rec_meshfile, gt_meshfile)
117 | rec_mesh = rec_mesh.transform(transformation)
118 |
119 | # get vacant area inside the room
120 | extents, transform = get_cam_position(gt_meshfile)
121 |
122 | vis = o3d.visualization.Visualizer()
123 | vis.create_window(width=W, height=H, visible=False)
124 | vis.get_render_option().mesh_show_back_face = True
125 | errors = []
126 | for i in tqdm(range(n_imgs)):
127 | while True:
128 | # sample view, and check if unseen region is not inside the camera view
129 | # if inside, then needs to resample
130 | up = [0, 0, -1]
131 | origin = trimesh.sample.volume_rectangular(
132 | extents, 1, transform=transform)
133 | origin = origin.reshape(-1)
134 | tx = round(random.uniform(-10000, +10000), 2)
135 | ty = round(random.uniform(-10000, +10000), 2)
136 | tz = round(random.uniform(-10000, +10000), 2)
137 | # will be normalized, so sample from range [0.0,1.0]
138 | target = [tx, ty, tz]
139 | target = np.array(target) - np.array(origin)
140 | c2w = viewmatrix(target, up, origin)
141 | tmp = np.eye(4)
142 | tmp[:3, :] = c2w # sample translations
143 | c2w = tmp
144 | # if unseen points are projected into current view (c2w)
145 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w)
146 | if ~seen:
147 | break
148 |
149 | param = o3d.camera.PinholeCameraParameters()
150 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array
151 |
152 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic(
153 | W, H, fx, fy, cx, cy)
154 |
155 | ctr = vis.get_view_control()
156 | ctr.set_constant_z_far(20)
157 | ctr.convert_from_pinhole_camera_parameters(param, True)
158 |
159 | vis.add_geometry(
160 | gt_mesh,
161 | reset_bounding_box=True,
162 | )
163 | ctr.convert_from_pinhole_camera_parameters(param, True)
164 | vis.poll_events()
165 | vis.update_renderer()
166 | gt_depth = vis.capture_depth_float_buffer(True)
167 | gt_depth = np.asarray(gt_depth)
168 | vis.remove_geometry(
169 | gt_mesh,
170 | reset_bounding_box=True,
171 | )
172 |
173 | vis.add_geometry(
174 | rec_mesh,
175 | reset_bounding_box=True,
176 | )
177 | ctr.convert_from_pinhole_camera_parameters(param, True)
178 | vis.poll_events()
179 | vis.update_renderer()
180 | ours_depth = vis.capture_depth_float_buffer(True)
181 | ours_depth = np.asarray(ours_depth)
182 | vis.remove_geometry(
183 | rec_mesh,
184 | reset_bounding_box=True,
185 | )
186 |
187 | # filter missing surfaces where depth is 0
188 | if (ours_depth > 0).sum() > 0:
189 | errors += [
190 | np.abs(gt_depth[ours_depth > 0] -
191 | ours_depth[ours_depth > 0]).mean()
192 | ]
193 | else:
194 | continue
195 |
196 | errors = np.array(errors)
197 | return {"depth_l1_sample_view": errors.mean() * 100}
198 |
199 |
200 | def clean_mesh(mesh):
201 | mesh_tri = trimesh.Trimesh(
202 | vertices=np.asarray(mesh.vertices),
203 | faces=np.asarray(mesh.triangles),
204 | vertex_colors=np.asarray(mesh.vertex_colors),
205 | )
206 | components = trimesh.graph.connected_components(
207 | edges=mesh_tri.edges_sorted)
208 |
209 | min_len = 200
210 | components_to_keep = [c for c in components if len(c) >= min_len]
211 |
212 | new_vertices = []
213 | new_faces = []
214 | new_colors = []
215 | vertex_count = 0
216 | for component in components_to_keep:
217 | vertices = mesh_tri.vertices[component]
218 | colors = mesh_tri.visual.vertex_colors[component]
219 |
220 | # Create a mapping from old vertex indices to new vertex indices
221 | index_mapping = {
222 | old_idx: vertex_count + new_idx for new_idx, old_idx in enumerate(component)
223 | }
224 | vertex_count += len(vertices)
225 |
226 | # Select faces that are part of the current connected component and update vertex indices
227 | faces_in_component = mesh_tri.faces[
228 | np.any(np.isin(mesh_tri.faces, component), axis=1)
229 | ]
230 | reindexed_faces = np.vectorize(index_mapping.get)(faces_in_component)
231 |
232 | new_vertices.extend(vertices)
233 | new_faces.extend(reindexed_faces)
234 | new_colors.extend(colors)
235 |
236 | cleaned_mesh_tri = trimesh.Trimesh(vertices=new_vertices, faces=new_faces)
237 | cleaned_mesh_tri.visual.vertex_colors = np.array(new_colors)
238 |
239 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.nondegenerate_faces())
240 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.unique_faces())
241 | print(
242 | f"Mesh cleaning (before/after), vertices: {len(mesh_tri.vertices)}/{len(cleaned_mesh_tri.vertices)}, faces: {len(mesh_tri.faces)}/{len(cleaned_mesh_tri.faces)}")
243 |
244 | cleaned_mesh = o3d.geometry.TriangleMesh(
245 | o3d.utility.Vector3dVector(cleaned_mesh_tri.vertices),
246 | o3d.utility.Vector3iVector(cleaned_mesh_tri.faces),
247 | )
248 | vertex_colors = np.asarray(cleaned_mesh_tri.visual.vertex_colors)[
249 | :, :3] / 255.0
250 | cleaned_mesh.vertex_colors = o3d.utility.Vector3dVector(
251 | vertex_colors.astype(np.float64)
252 | )
253 |
254 | return cleaned_mesh
255 |
256 |
257 | def evaluate_reconstruction(
258 | mesh_path: Path,
259 | gt_mesh_path: Path,
260 | unseen_pc_path: Path,
261 | output_path: Path,
262 | to_clean=True,
263 | ):
264 | if to_clean:
265 | mesh = o3d.io.read_triangle_mesh(str(mesh_path))
266 | print(mesh)
267 | cleaned_mesh = clean_mesh(mesh)
268 | cleaned_mesh_path = output_path / "mesh" / "cleaned_mesh.ply"
269 | o3d.io.write_triangle_mesh(str(cleaned_mesh_path), cleaned_mesh)
270 | mesh_path = cleaned_mesh_path
271 |
272 | result_3d = run_evaluation(
273 | str(mesh_path.parts[-1]),
274 | str(mesh_path.parent),
275 | str(gt_mesh_path).split("/")[-1].split(".")[0],
276 | distance_thresh=0.01,
277 | full_path_to_gt_ply=gt_mesh_path,
278 | icp_align=True,
279 | )
280 |
281 | try:
282 | result_2d = calc_2d_metric(str(mesh_path), str(gt_mesh_path), str(unseen_pc_path), align=True, n_imgs=1000)
283 | except Exception as e:
284 | print(e)
285 | result_2d = {"depth_l1_sample_view": None}
286 |
287 | result = {**result_3d, **result_2d}
288 | with open(str(output_path / "reconstruction_metrics.json"), "w") as f:
289 | json.dump(result, f)
290 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_trajectory.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 |
8 | class NumpyFloatValuesEncoder(json.JSONEncoder):
9 | def default(self, obj):
10 | if isinstance(obj, np.float32):
11 | return float(obj)
12 | return JSONEncoder.default(self, obj)
13 |
14 |
15 | def align(model, data):
16 | """Align two trajectories using the method of Horn (closed-form).
17 |
18 | Input:
19 | model -- first trajectory (3xn)
20 | data -- second trajectory (3xn)
21 |
22 | Output:
23 | rot -- rotation matrix (3x3)
24 | trans -- translation vector (3x1)
25 | trans_error -- translational error per point (1xn)
26 |
27 | """
28 | np.set_printoptions(precision=3, suppress=True)
29 | model_zerocentered = model - model.mean(1)
30 | data_zerocentered = data - data.mean(1)
31 |
32 | W = np.zeros((3, 3))
33 | for column in range(model.shape[1]):
34 | W += np.outer(model_zerocentered[:,
35 | column], data_zerocentered[:, column])
36 | U, d, Vh = np.linalg.linalg.svd(W.transpose())
37 | S = np.matrix(np.identity(3))
38 | if (np.linalg.det(U) * np.linalg.det(Vh) < 0):
39 | S[2, 2] = -1
40 | rot = U * S * Vh
41 | trans = data.mean(1) - rot * model.mean(1)
42 |
43 | model_aligned = rot * model + trans
44 | alignment_error = model_aligned - data
45 |
46 | trans_error = np.sqrt(
47 | np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0]
48 |
49 | return rot, trans, trans_error
50 |
51 |
52 | def align_trajectories(t_pred: np.ndarray, t_gt: np.ndarray):
53 | """
54 | Args:
55 | t_pred: (n, 3) translations
56 | t_gt: (n, 3) translations
57 | Returns:
58 | t_align: (n, 3) aligned translations
59 | """
60 | t_align = np.matrix(t_pred).transpose()
61 | R, t, _ = align(t_align, np.matrix(t_gt).transpose())
62 | t_align = R * t_align + t
63 | t_align = np.asarray(t_align).T
64 | return t_align
65 |
66 |
67 | def pose_error(t_pred: np.ndarray, t_gt: np.ndarray, align=False):
68 | """
69 | Args:
70 | t_pred: (n, 3) translations
71 | t_gt: (n, 3) translations
72 | Returns:
73 | dict: error dict
74 | """
75 | n = t_pred.shape[0]
76 | trans_error = np.linalg.norm(t_pred - t_gt, axis=1)
77 | return {
78 | "compared_pose_pairs": n,
79 | "rmse": np.sqrt(np.dot(trans_error, trans_error) / n),
80 | "mean": np.mean(trans_error),
81 | "median": np.median(trans_error),
82 | "std": np.std(trans_error),
83 | "min": np.min(trans_error),
84 | "max": np.max(trans_error)
85 | }
86 |
87 |
88 | def plot_2d(pts, ax=None, color="green", label="None", title="3D Trajectory in 2D"):
89 | if ax is None:
90 | _, ax = plt.subplots()
91 | ax.scatter(pts[:, 0], pts[:, 1], color=color, label=label, s=0.7)
92 | ax.set_xlabel('X')
93 | ax.set_ylabel('Y')
94 | ax.set_title(title)
95 | return ax
96 |
97 |
98 | def evaluate_trajectory(estimated_poses: np.ndarray, gt_poses: np.ndarray, output_path: Path):
99 | output_path.mkdir(exist_ok=True, parents=True)
100 | # Truncate the ground truth trajectory if needed
101 | if gt_poses.shape[0] > estimated_poses.shape[0]:
102 | gt_poses = gt_poses[:estimated_poses.shape[0]]
103 | valid = ~np.any(np.isnan(gt_poses) |
104 | np.isinf(gt_poses), axis=(1, 2))
105 | gt_poses = gt_poses[valid]
106 | estimated_poses = estimated_poses[valid]
107 |
108 | gt_t = gt_poses[:, :3, 3]
109 | estimated_t = estimated_poses[:, :3, 3]
110 | estimated_t_aligned = align_trajectories(estimated_t, gt_t)
111 | ate = pose_error(estimated_t, gt_t)
112 | ate_aligned = pose_error(estimated_t_aligned, gt_t)
113 |
114 | with open(str(output_path / "ate.json"), "w") as f:
115 | f.write(json.dumps(ate, cls=NumpyFloatValuesEncoder))
116 |
117 | with open(str(output_path / "ate_aligned.json"), "w") as f:
118 | f.write(json.dumps(ate_aligned, cls=NumpyFloatValuesEncoder))
119 |
120 | ate_rmse, ate_rmse_aligned = ate["rmse"], ate_aligned["rmse"]
121 | ax = plot_2d(
122 | estimated_t, label=f"ate-rmse: {round(ate_rmse * 100, 2)} cm", color="orange")
123 | ax = plot_2d(estimated_t_aligned, ax,
124 | label=f"ate-rsme (aligned): {round(ate_rmse_aligned * 100, 2)} cm", color="lightskyblue")
125 | ax = plot_2d(gt_t, ax, label="GT", color="green")
126 | ax.legend()
127 | plt.savefig(str(output_path / "eval_trajectory.png"), dpi=300)
128 | plt.close()
129 | print(
130 | f"ATE-RMSE: {ate_rmse * 100:.2f} cm, ATE-RMSE (aligned): {ate_rmse_aligned * 100:.2f} cm")
131 |
--------------------------------------------------------------------------------
/src/gsr/camera.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | from PIL import Image
5 |
6 | from src.gsr.loss import image_gradient, image_gradient_mask
7 | from src.utils.graphics_utils import getProjectionMatrix2, getWorld2View2
8 |
9 |
10 |
11 | class Camera(nn.Module):
12 | def __init__(
13 | self,
14 | uid,
15 | color,
16 | depth,
17 | gt_T,
18 | projection_matrix,
19 | fx,
20 | fy,
21 | cx,
22 | cy,
23 | fovx,
24 | fovy,
25 | image_height,
26 | image_width,
27 | device="cuda:0",
28 | ):
29 | super(Camera, self).__init__()
30 | self.uid = uid
31 | self.device = device
32 |
33 | T = torch.eye(4, device=device)
34 | self.R = T[:3, :3]
35 | self.T = T[:3, 3]
36 | self.R_gt = gt_T[:3, :3]
37 | self.T_gt = gt_T[:3, 3]
38 |
39 | self.original_image = color
40 | self.depth = depth
41 | self.grad_mask = None
42 |
43 | self.fx = fx
44 | self.fy = fy
45 | self.cx = cx
46 | self.cy = cy
47 | self.FoVx = fovx
48 | self.FoVy = fovy
49 | self.image_height = image_height
50 | self.image_width = image_width
51 |
52 | self.cam_rot_delta = nn.Parameter(
53 | torch.zeros(3, requires_grad=True, device=device)
54 | )
55 | self.cam_trans_delta = nn.Parameter(
56 | torch.zeros(3, requires_grad=True, device=device)
57 | )
58 |
59 | self.exposure_a = nn.Parameter(
60 | torch.tensor([0.0], requires_grad=True, device=device)
61 | )
62 | self.exposure_b = nn.Parameter(
63 | torch.tensor([0.0], requires_grad=True, device=device)
64 | )
65 |
66 | self.projection_matrix = projection_matrix.to(device=device)
67 |
68 | @staticmethod
69 | def init_from_dataset(dataset, idx, projection_matrix):
70 | gt_color, gt_depth, gt_pose = dataset[idx]
71 | return Camera(
72 | idx,
73 | gt_color,
74 | gt_depth,
75 | gt_pose,
76 | projection_matrix,
77 | dataset.fx,
78 | dataset.fy,
79 | dataset.cx,
80 | dataset.cy,
81 | dataset.fovx,
82 | dataset.fovy,
83 | dataset.height,
84 | dataset.width,
85 | device=dataset.device,
86 | )
87 |
88 | @staticmethod
89 | def init_from_gui(uid, T, FoVx, FoVy, fx, fy, cx, cy, H, W):
90 | projection_matrix = getProjectionMatrix2(
91 | znear=0.01, zfar=100.0, fx=fx, fy=fy, cx=cx, cy=cy, W=W, H=H
92 | ).transpose(0, 1)
93 | return Camera(
94 | uid, None, None, T, projection_matrix, fx, fy, cx, cy, FoVx, FoVy, H, W
95 | )
96 |
97 | @property
98 | def world_view_transform(self):
99 | return getWorld2View2(self.R, self.T).transpose(0, 1)
100 |
101 | @property
102 | def full_proj_transform(self):
103 | return (
104 | self.world_view_transform.unsqueeze(0).bmm(
105 | self.projection_matrix.unsqueeze(0)
106 | )
107 | ).squeeze(0)
108 |
109 | @property
110 | def camera_center(self):
111 | return self.world_view_transform.inverse()[3, :3]
112 |
113 | def update_RT(self, R, t):
114 | self.R = R.to(device=self.device)
115 | self.T = t.to(device=self.device)
116 |
117 | def compute_grad_mask(self, config):
118 | edge_threshold = config["Training"]["edge_threshold"]
119 |
120 | gray_img = self.original_image.mean(dim=0, keepdim=True)
121 | gray_grad_v, gray_grad_h = image_gradient(gray_img)
122 | mask_v, mask_h = image_gradient_mask(gray_img)
123 | gray_grad_v = gray_grad_v * mask_v
124 | gray_grad_h = gray_grad_h * mask_h
125 | img_grad_intensity = torch.sqrt(gray_grad_v**2 + gray_grad_h**2)
126 |
127 | if config["Dataset"]["type"] == "replica":
128 | row, col = 32, 32
129 | multiplier = edge_threshold
130 | _, h, w = self.original_image.shape
131 | for r in range(row):
132 | for c in range(col):
133 | block = img_grad_intensity[
134 | :,
135 | r * int(h / row) : (r + 1) * int(h / row),
136 | c * int(w / col) : (c + 1) * int(w / col),
137 | ]
138 | th_median = block.median()
139 | block[block > (th_median * multiplier)] = 1
140 | block[block <= (th_median * multiplier)] = 0
141 | self.grad_mask = img_grad_intensity
142 | else:
143 | median_img_grad_intensity = img_grad_intensity.median()
144 | self.grad_mask = (
145 | img_grad_intensity > median_img_grad_intensity * edge_threshold
146 | )
147 |
148 | def clean(self):
149 | self.original_image = None
150 | self.depth = None
151 | self.grad_mask = None
152 |
153 | self.cam_rot_delta = None
154 | self.cam_trans_delta = None
155 |
156 | self.exposure_a = None
157 | self.exposure_b = None
158 |
159 | @property
160 | def get_T(self):
161 | T = torch.eye(4, device=self.device).float()
162 | T[:3, :3] = self.R
163 | T[:3, 3] = self.T
164 | return T
165 |
166 | @property
167 | def get_T_gt(self):
168 | T = torch.eye(4, device=self.device).float()
169 | T[:3, :3] = self.R_gt
170 | T[:3, 3] = self.T_gt
171 | return T
172 |
173 | def load_rgb(self, image=None):
174 |
175 | if image==None and hasattr(self, "rgb_path"):
176 | self.original_image = torch.from_numpy(np.array(Image.open(self.rgb_path))).permute(2, 0, 1).cuda().float() / 255.0
177 | self.compute_grad_mask(self.config)
178 |
179 | if image is not None:
180 | self.original_image = image
181 | self.compute_grad_mask(self.config)
--------------------------------------------------------------------------------
/src/gsr/descriptor.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import io
3 | import sys
4 |
5 | import torch
6 | from einops import *
7 |
8 | sys.path.append("thirdparty/Hierarchical-Localization")
9 |
10 | with contextlib.redirect_stderr(io.StringIO()):
11 | from hloc.extractors.netvlad import NetVLAD
12 |
13 |
14 | class GlobalDesc:
15 |
16 | def __init__(self):
17 | conf = {
18 | 'output': 'global-feats-netvlad',
19 | 'model': {'name': 'netvlad'},
20 | 'preprocessing': {'resize_max': 1024},
21 | }
22 | self.netvlad = NetVLAD(conf).to('cuda').eval()
23 |
24 | @torch.no_grad()
25 | def __call__(self, images):
26 | assert parse_shape(images, '_ rgb _ _') == dict(rgb=3)
27 | assert (images.dtype == torch.float) and (images.max() <= 1.0001), images.max()
28 | return self.netvlad({'image': images})['global_descriptor'] # B 4096
--------------------------------------------------------------------------------
/src/gsr/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | """
3 | https://github.com/muskie82/MonoGS/blob/main/utils/slam_utils.py
4 | """
5 |
6 | def image_gradient(image):
7 | # Compute image gradient using Scharr Filter
8 | c = image.shape[0]
9 | conv_y = torch.tensor(
10 | [[3, 0, -3], [10, 0, -10], [3, 0, -3]], dtype=torch.float32, device="cuda"
11 | )
12 | conv_x = torch.tensor(
13 | [[3, 10, 3], [0, 0, 0], [-3, -10, -3]], dtype=torch.float32, device="cuda"
14 | )
15 | normalizer = 1.0 / torch.abs(conv_y).sum()
16 | p_img = torch.nn.functional.pad(image, (1, 1, 1, 1), mode="reflect")[None]
17 | img_grad_v = normalizer * torch.nn.functional.conv2d(
18 | p_img, conv_x.view(1, 1, 3, 3).repeat(c, 1, 1, 1), groups=c
19 | )
20 | img_grad_h = normalizer * torch.nn.functional.conv2d(
21 | p_img, conv_y.view(1, 1, 3, 3).repeat(c, 1, 1, 1), groups=c
22 | )
23 | return img_grad_v[0], img_grad_h[0]
24 |
25 |
26 | def image_gradient_mask(image, eps=0.01):
27 | # Compute image gradient mask
28 | c = image.shape[0]
29 | conv_y = torch.ones((1, 1, 3, 3), dtype=torch.float32, device="cuda")
30 | conv_x = torch.ones((1, 1, 3, 3), dtype=torch.float32, device="cuda")
31 | p_img = torch.nn.functional.pad(image, (1, 1, 1, 1), mode="reflect")[None]
32 | p_img = torch.abs(p_img) > eps
33 | img_grad_v = torch.nn.functional.conv2d(
34 | p_img.float(), conv_x.repeat(c, 1, 1, 1), groups=c
35 | )
36 | img_grad_h = torch.nn.functional.conv2d(
37 | p_img.float(), conv_y.repeat(c, 1, 1, 1), groups=c
38 | )
39 |
40 | return img_grad_v[0] == torch.sum(conv_x), img_grad_h[0] == torch.sum(conv_y)
41 |
42 |
43 | def depth_reg(depth, gt_image, huber_eps=0.1, mask=None):
44 | mask_v, mask_h = image_gradient_mask(depth)
45 | gray_grad_v, gray_grad_h = image_gradient(gt_image.mean(dim=0, keepdim=True))
46 | depth_grad_v, depth_grad_h = image_gradient(depth)
47 | gray_grad_v, gray_grad_h = gray_grad_v[mask_v], gray_grad_h[mask_h]
48 | depth_grad_v, depth_grad_h = depth_grad_v[mask_v], depth_grad_h[mask_h]
49 |
50 | w_h = torch.exp(-10 * gray_grad_h**2)
51 | w_v = torch.exp(-10 * gray_grad_v**2)
52 | err = (w_h * torch.abs(depth_grad_h)).mean() + (
53 | w_v * torch.abs(depth_grad_v)
54 | ).mean()
55 | return err
56 |
57 |
58 | def get_loss_tracking(config, image, depth, opacity, viewpoint, initialization=False):
59 | image_ab = (torch.exp(viewpoint.exposure_a)) * image + viewpoint.exposure_b
60 | if config["Training"]["monocular"]:
61 | return get_loss_tracking_rgb(config, image_ab, depth, opacity, viewpoint)
62 | return get_loss_tracking_rgbd(config, image_ab, depth, opacity, viewpoint)
63 |
64 |
65 | def get_loss_tracking_rgb(config, image, depth, opacity, viewpoint):
66 | gt_image = viewpoint.original_image.cuda()
67 | _, h, w = gt_image.shape
68 | mask_shape = (1, h, w)
69 | rgb_boundary_threshold = config["Training"]["rgb_boundary_threshold"]
70 | rgb_pixel_mask = (gt_image.sum(dim=0) > rgb_boundary_threshold).view(*mask_shape)
71 | rgb_pixel_mask = rgb_pixel_mask * viewpoint.grad_mask
72 | l1 = opacity * torch.abs(image * rgb_pixel_mask - gt_image * rgb_pixel_mask)
73 | return l1.mean()
74 |
75 |
76 | def get_loss_tracking_rgbd(
77 | config, image, depth, opacity, viewpoint, initialization=False
78 | ):
79 | alpha = config["Training"]["alpha"] if "alpha" in config["Training"] else 0.95
80 |
81 | gt_depth = torch.from_numpy(viewpoint.depth).to(
82 | dtype=torch.float32, device=image.device
83 | )[None]
84 | depth_pixel_mask = (gt_depth > 0.01).view(*depth.shape)
85 | opacity_mask = (opacity > 0.95).view(*depth.shape)
86 |
87 | l1_rgb = get_loss_tracking_rgb(config, image, depth, opacity, viewpoint)
88 | depth_mask = depth_pixel_mask * opacity_mask
89 | l1_depth = torch.abs(depth * depth_mask - gt_depth * depth_mask)
90 | return alpha * l1_rgb + (1 - alpha) * l1_depth.mean()
--------------------------------------------------------------------------------
/src/gsr/overlap.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import numpy as np
3 | import torch
4 | import faiss
5 | import faiss.contrib.torch_utils
6 | from pycg import vis
7 |
8 | from src.utils.utils import batch_search_faiss
9 |
10 | def get_correspondences(src_pcd, tgt_pcd, trans, search_voxel_size, K=None):
11 | src_pcd.transform(trans)
12 | pcd_tree = o3d.geometry.KDTreeFlann(tgt_pcd)
13 |
14 | correspondences = []
15 | for i, point in enumerate(src_pcd.points):
16 | [count, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size)
17 | if K is not None:
18 | idx = idx[:K]
19 | for j in idx:
20 | correspondences.append([i, j])
21 |
22 | correspondences = np.array(correspondences)
23 | correspondences = torch.from_numpy(correspondences)
24 | return correspondences
25 |
26 | def get_overlap_ratio(source,target,threshold=0.03):
27 | """
28 | We compute overlap ratio from source point cloud to target point cloud
29 | """
30 | pcd_tree = o3d.geometry.KDTreeFlann(target)
31 |
32 | match_count=0
33 | for i, point in enumerate(source.points):
34 | [count, _, _] = pcd_tree.search_radius_vector_3d(point, threshold)
35 | if(count!=0):
36 | match_count+=1
37 |
38 | overlap_ratio = match_count / min(len(source.points), len(target.points))
39 | return overlap_ratio
40 |
41 | def compute_overlap_gaussians(src_gs, tgt_gs, threshold=0.03):
42 | """compute the overlap ratio and correspondences between two gaussians
43 |
44 | Args:
45 | src_gs: _description_
46 | tgt_ts: _description_
47 | threshold (float, optional): _description_. Defaults to 0.03.
48 | """
49 | src_tensor = src_gs.get_xyz().detach()
50 | tgt_tensor = tgt_gs.get_xyz().detach()
51 | cpu_index = faiss.IndexFlatL2(3)
52 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index)
53 | gpu_index.add(tgt_tensor)
54 |
55 | distances, _ = batch_search_faiss(gpu_index, src_tensor, 1)
56 | mask_src = distances < threshold
57 |
58 | cpu_index = faiss.IndexFlatL2(3)
59 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index)
60 | gpu_index.add(src_tensor)
61 |
62 | distances, _ = batch_search_faiss(gpu_index, tgt_tensor, 1)
63 | mask_tgt = distances < threshold
64 |
65 | faiss_overlap_ratio = min(mask_src.sum()/len(mask_src), mask_tgt.sum()/len(mask_tgt))
66 |
67 | return faiss_overlap_ratio
68 |
69 | def visualize_overlap(pc1, pc2, corr):
70 | import matplotlib.cm as cm
71 | src_pcd = o3d.geometry.PointCloud()
72 | src_pcd.points = o3d.utility.Vector3dVector(pc1.cpu().numpy())
73 | tgt_pcd = o3d.geometry.PointCloud()
74 | tgt_pcd.points = o3d.utility.Vector3dVector(pc2.cpu().numpy())
75 | # corr = get_correspondences(src_pcd, tgt_pcd, np.eye(4), 0.05)
76 |
77 | color_1 = cm.tab10(0)
78 | color_2 = cm.tab10(1)
79 | overlap_color = cm.tab10(2)
80 |
81 | color_src = np.ones_like(pc1.cpu().numpy())
82 | color_tgt = np.ones_like(pc2.cpu().numpy())
83 |
84 | if len(corr)>0:
85 | color_src[corr[:,0].cpu().numpy()] = np.array(color_1)[:3]
86 | color_tgt[corr[:,1].cpu().numpy()] = np.array(color_2)[:3]
87 |
88 | vis_src = vis.pointcloud(pc1.cpu().numpy(), color=color_src, is_sphere=True)
89 | vis_tgt = vis.pointcloud(pc2.cpu().numpy(), color=color_tgt, is_sphere=True)
90 | vis.show_3d([vis_src, vis_tgt],[vis_src], [vis_tgt], use_new_api=True)
--------------------------------------------------------------------------------
/src/gsr/pcr.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 |
3 | def preprocess_point_cloud(pcd, voxel_size, camera_location):
4 | pcd_down = pcd.voxel_down_sample(voxel_size)
5 | pcd_down.estimate_normals(
6 | o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2.0,
7 | max_nn=30))
8 |
9 | pcd_down.orient_normals_towards_camera_location(
10 | camera_location=camera_location)
11 |
12 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
13 | pcd_down,
14 | o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 5.0,
15 | max_nn=100))
16 | return (pcd_down, pcd_fpfh)
17 |
18 |
19 | def execute_global_registration(source_down, target_down, source_fpfh,
20 | target_fpfh, voxel_size):
21 | distance_threshold = voxel_size * 1.5
22 | print(":: RANSAC registration on downsampled point clouds.")
23 | print(" Downsampling voxel size is %.3f," % voxel_size)
24 | print(" Using a liberal distance threshold %.3f." % distance_threshold)
25 | result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
26 | source_down, target_down, source_fpfh, target_fpfh, True,
27 | distance_threshold,
28 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
29 | 3, [
30 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(
31 | 0.9),
32 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(
33 | distance_threshold)
34 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(10000000, 0.99999))
35 | return result
36 |
37 | def refine_registration(source, target, source_fpfh, target_fpfh, voxel_size, init_trans):
38 | distance_threshold = voxel_size
39 | print(":: Point-to-plane ICP registration is applied on original point")
40 | print(" clouds to refine the alignment. This time we use a strict")
41 | print(" distance threshold %.3f." % distance_threshold)
42 | result = o3d.pipelines.registration.registration_icp(
43 | source, target, distance_threshold, init_trans,
44 | o3d.pipelines.registration.TransformationEstimationPointToPlane(),
45 | o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=500))
46 | return result
--------------------------------------------------------------------------------
/src/gsr/renderer.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import math
13 |
14 | import torch
15 | from diff_gaussian_rasterization import (
16 | GaussianRasterizationSettings,
17 | GaussianRasterizer,
18 | )
19 |
20 | from src.entities.gaussian_model import GaussianModel
21 | from src.utils.gaussian_model_utils import eval_sh
22 |
23 |
24 | def render(
25 | viewpoint_camera,
26 | pc: GaussianModel,
27 | pipe,
28 | bg_color: torch.Tensor,
29 | scaling_modifier=1.0,
30 | override_color=None,
31 | mask=None
32 | ):
33 | """
34 | Render the scene.
35 |
36 | Background tensor (bg_color) must be on GPU!
37 | """
38 |
39 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
40 | if pc.get_xyz().shape[0] == 0:
41 | return None
42 |
43 | screenspace_points = (
44 | torch.zeros_like(
45 | pc.get_xyz(), dtype=pc.get_xyz().dtype, requires_grad=True, device="cuda"
46 | )
47 | + 0
48 | )
49 | try:
50 | screenspace_points.retain_grad()
51 | except Exception:
52 | pass
53 |
54 | # Set up rasterization configuration
55 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
56 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
57 |
58 | raster_settings = GaussianRasterizationSettings(
59 | image_height=int(viewpoint_camera.image_height),
60 | image_width=int(viewpoint_camera.image_width),
61 | tanfovx=tanfovx,
62 | tanfovy=tanfovy,
63 | bg=bg_color,
64 | scale_modifier=scaling_modifier,
65 | viewmatrix=viewpoint_camera.world_view_transform,
66 | projmatrix=viewpoint_camera.full_proj_transform,
67 | projmatrix_raw=viewpoint_camera.projection_matrix,
68 | sh_degree=pc.active_sh_degree,
69 | campos=viewpoint_camera.camera_center,
70 | prefiltered=False,
71 | debug=False,
72 | )
73 |
74 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
75 |
76 | means3D = pc.get_xyz()
77 | means2D = screenspace_points
78 | opacity = pc.get_opacity()
79 |
80 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
81 | # scaling / rotation by the rasterizer.
82 | scales = None
83 | rotations = None
84 | cov3D_precomp = None
85 | if pipe.compute_cov3D_python:
86 | cov3D_precomp = pc.get_covariance(scaling_modifier)
87 | else:
88 | # check if the covariance is isotropic
89 | if pc.get_scaling().shape[-1] == 1:
90 | scales = pc.get_scaling().repeat(1, 3)
91 | else:
92 | scales = pc.get_scaling()
93 | rotations = pc.get_rotation()
94 |
95 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
96 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
97 | shs = None
98 | colors_precomp = None
99 | if colors_precomp is None:
100 | if pipe.convert_SHs_python:
101 | shs_view = pc.get_features().transpose(1, 2).view(
102 | -1, 3, (pc.max_sh_degree + 1) ** 2
103 | )
104 | dir_pp = pc.get_xyz() - viewpoint_camera.camera_center.repeat(
105 | pc.get_features().shape[0], 1
106 | )
107 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
108 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
109 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
110 | else:
111 | shs = pc.get_features()
112 | else:
113 | colors_precomp = override_color
114 |
115 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
116 | if mask is not None:
117 | rendered_image, radii, depth, opacity = rasterizer(
118 | means3D=means3D[mask],
119 | means2D=means2D[mask],
120 | shs=shs[mask],
121 | colors_precomp=colors_precomp[mask] if colors_precomp is not None else None,
122 | opacities=opacity[mask],
123 | scales=scales[mask],
124 | rotations=rotations[mask],
125 | cov3D_precomp=cov3D_precomp[mask] if cov3D_precomp is not None else None,
126 | theta=viewpoint_camera.cam_rot_delta,
127 | rho=viewpoint_camera.cam_trans_delta,
128 | )
129 | else:
130 | rendered_image, radii, depth, opacity, n_touched = rasterizer(
131 | means3D=means3D,
132 | means2D=means2D,
133 | shs=shs,
134 | colors_precomp=colors_precomp,
135 | opacities=opacity,
136 | scales=scales,
137 | rotations=rotations,
138 | cov3D_precomp=cov3D_precomp,
139 | theta=viewpoint_camera.cam_rot_delta,
140 | rho=viewpoint_camera.cam_trans_delta,
141 | )
142 |
143 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
144 | # They will be excluded from value updates used in the splitting criteria.
145 | return {
146 | "render": rendered_image,
147 | "viewspace_points": screenspace_points,
148 | "visibility_filter": radii > 0,
149 | "radii": radii,
150 | "depth": depth,
151 | "opacity": opacity,
152 | "n_touched": n_touched,
153 | }
154 |
--------------------------------------------------------------------------------
/src/gsr/se3/numpy_se3.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.transform import Rotation
3 |
4 |
5 | def identity():
6 | return np.eye(3, 4)
7 |
8 |
9 | def transform(g: np.ndarray, pts: np.ndarray):
10 | """ Applies the SE3 transform
11 |
12 | Args:
13 | g: SE3 transformation matrix of size ([B,] 3/4, 4)
14 | pts: Points to be transformed ([B,] N, 3)
15 |
16 | Returns:
17 | transformed points of size (N, 3)
18 | """
19 | rot = g[..., :3, :3] # (3, 3)
20 | trans = g[..., :3, 3] # (3)
21 |
22 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :]
23 | return transformed
24 |
25 |
26 | def inverse(g: np.ndarray):
27 | """Returns the inverse of the SE3 transform
28 |
29 | Args:
30 | g: ([B,] 3/4, 4) transform
31 |
32 | Returns:
33 | ([B,] 3/4, 4) matrix containing the inverse
34 |
35 | """
36 | rot = g[..., :3, :3] # (3, 3)
37 | trans = g[..., :3, 3] # (3)
38 |
39 | inv_rot = np.swapaxes(rot, -1, -2)
40 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1)
41 | if g.shape[-2] == 4:
42 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2)
43 |
44 | return inverse_transform
45 |
46 |
47 | def concatenate(a: np.ndarray, b: np.ndarray):
48 | """ Concatenate two SE3 transforms
49 |
50 | Args:
51 | a: First transform ([B,] 3/4, 4)
52 | b: Second transform ([B,] 3/4, 4)
53 |
54 | Returns:
55 | a*b ([B, ] 3/4, 4)
56 |
57 | """
58 |
59 | r_a, t_a = a[..., :3, :3], a[..., :3, 3]
60 | r_b, t_b = b[..., :3, :3], b[..., :3, 3]
61 |
62 | r_ab = r_a @ r_b
63 | t_ab = r_a @ t_b[..., None] + t_a[..., None]
64 |
65 | concatenated = np.concatenate([r_ab, t_ab], axis=-1)
66 |
67 | if a.shape[-2] == 4:
68 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2)
69 |
70 | return concatenated
71 |
72 |
73 | def from_xyzquat(xyzquat):
74 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw
75 |
76 | Args:
77 | xyzquat: np.array (7,) containing translation and quaterion
78 |
79 | Returns:
80 | SE3 matrix (4, 4)
81 | """
82 | rot = Rotation.from_quat(xyzquat[3:])
83 | trans = rot.apply(-xyzquat[:3])
84 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1)
85 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
86 |
87 | return transform
--------------------------------------------------------------------------------
/src/gsr/se3/torch_se3.py:
--------------------------------------------------------------------------------
1 | """ 3-d rigid body transformation group
2 | """
3 | import torch
4 |
5 |
6 | def identity(batch_size):
7 | return torch.eye(3, 4)[None, ...].repeat(batch_size, 1, 1)
8 |
9 |
10 | def inverse(g):
11 | """ Returns the inverse of the SE3 transform
12 |
13 | Args:
14 | g: (B, 3/4, 4) transform
15 |
16 | Returns:
17 | (B, 3, 4) matrix containing the inverse
18 |
19 | """
20 | # Compute inverse
21 | rot = g[..., 0:3, 0:3]
22 | trans = g[..., 0:3, 3]
23 | inverse_transform = torch.cat([rot.transpose(-1, -2), rot.transpose(-1, -2) @ -trans[..., None]], dim=-1)
24 |
25 | return inverse_transform
26 |
27 |
28 | def concatenate(a, b):
29 | """Concatenate two SE3 transforms,
30 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix)
31 |
32 | Args:
33 | a: (B, 3/4, 4)
34 | b: (B, 3/4, 4)
35 |
36 | Returns:
37 | (B, 3/4, 4)
38 | """
39 |
40 | rot1 = a[..., :3, :3]
41 | trans1 = a[..., :3, 3]
42 | rot2 = b[..., :3, :3]
43 | trans2 = b[..., :3, 3]
44 |
45 | rot_cat = rot1 @ rot2
46 | trans_cat = rot1 @ trans2[..., None] + trans1[..., None]
47 | concatenated = torch.cat([rot_cat, trans_cat], dim=-1)
48 |
49 | return concatenated
50 |
51 |
52 | def transform(g, a, normals=None):
53 | """ Applies the SE3 transform
54 |
55 | Args:
56 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
57 | a: Points to be transformed (N, 3) or (B, N, 3)
58 | normals: (Optional). If provided, normals will be transformed
59 |
60 | Returns:
61 | transformed points of size (N, 3) or (B, N, 3)
62 |
63 | """
64 | R = g[..., :3, :3] # (B, 3, 3)
65 | p = g[..., :3, 3] # (B, 3)
66 |
67 | if len(g.size()) == len(a.size()):
68 | b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
69 | else:
70 | raise NotImplementedError
71 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked
72 |
73 | if normals is not None:
74 | rotated_normals = normals @ R.transpose(-1, -2)
75 | return b, rotated_normals
76 |
77 | else:
78 | return b
79 |
80 |
81 | def Rt_to_SE3(R, t):
82 | '''
83 | Merge 3D rotation and translation into 4x4 SE transformation.
84 | Args:
85 | R: SO(3) rotation matrix (B, 3, 3)
86 | t: translation vector (B, 3, 1)
87 | '''
88 | B = R.shape[0]
89 | SE3 = torch.zeros((B, 4, 4)).to(R.device)
90 | SE3[:,3,3] = 1
91 | SE3[:,:3,:3] = R
92 | SE3[:,:3,[3]] = t
93 | return SE3
--------------------------------------------------------------------------------
/src/gsr/solver.py:
--------------------------------------------------------------------------------
1 | import torch, roma
2 | import numpy as np
3 | import copy
4 |
5 | from src.gsr.renderer import render
6 | from src.gsr.loss import get_loss_tracking
7 | from src.gsr.overlap import compute_overlap_gaussians
8 | from src.utils.pose_utils import update_pose
9 |
10 |
11 | class CustomPipeline:
12 | convert_SHs_python = False
13 | compute_cov3D_python = False
14 | debug = False
15 |
16 | def viewpoint_localizer(viewpoint, gaussians, base_lr: float=1e-3):
17 | """Localize a single viewpoint in a 3DGS
18 |
19 | Args:
20 | viewpoint (Camera): Camera instance
21 | gaussians (Gaussians): 3D Gaussians to locate the viewpoint
22 | base_lr (float, optional). Defaults to 1e-3.
23 |
24 | Returns:
25 | _type_: _description_
26 | """
27 | opt_params = []
28 | pipe = CustomPipeline()
29 | bg_color = torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda", requires_grad=False)
30 | config = {
31 | 'Training': {
32 | 'monocular': False,
33 | "rgb_boundary_threshold": 0.01,
34 | }
35 | }
36 |
37 | init_T = viewpoint.get_T.detach()
38 |
39 | opt_params.append(
40 | {
41 | "params": [viewpoint.cam_rot_delta],
42 | "lr": 3*base_lr,
43 | "name": "rot_{}".format(viewpoint.uid),
44 | }
45 | )
46 | opt_params.append(
47 | {
48 | "params": [viewpoint.cam_trans_delta],
49 | "lr": base_lr,
50 | "name": "trans_{}".format(viewpoint.uid),
51 | }
52 | )
53 | opt_params.append(
54 | {
55 | "params": [viewpoint.exposure_a],
56 | "lr": 0.01,
57 | "name": "exposure_a_{}".format(viewpoint.uid),
58 | }
59 | )
60 | opt_params.append(
61 | {
62 | "params": [viewpoint.exposure_b],
63 | "lr": 0.01,
64 | "name": "exposure_b_{}".format(viewpoint.uid),
65 | }
66 | )
67 | optimizer = torch.optim.Adam(opt_params)
68 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.98, patience=5, verbose=False)
69 |
70 | loss_log = []
71 | opt_iterations = 100
72 | for tracking_itr in range(opt_iterations):
73 | optimizer.zero_grad()
74 | render_pkg = render(
75 | viewpoint, gaussians, pipe, bg_color
76 | )
77 | image, depth, opacity = (
78 | render_pkg["render"],
79 | render_pkg["depth"],
80 | render_pkg["opacity"],
81 | )
82 |
83 | loss = get_loss_tracking(config, image, depth, opacity, viewpoint)
84 | loss.backward()
85 | loss_log.append(loss.item())
86 |
87 | with torch.no_grad():
88 | optimizer.step()
89 | scheduler.step(loss)
90 | converged = update_pose(viewpoint)
91 |
92 | if converged:
93 | break
94 |
95 | rel_tsfm = (init_T.inverse() @ viewpoint.get_T).inverse()
96 | loss_residual = loss.item()
97 |
98 | return converged, rel_tsfm, loss_residual, loss_log
99 |
100 | def gaussian_registration(src_dict, tgt_dict, config: dict, visualize=False):
101 | """_summary_
102 |
103 | Args:
104 | src_dict (dict): dictionary of source gaussians and its keyframes
105 | tgt_dict (dict): dictionary of target gaussians and its keyframes
106 | base_lr (float, optional): the base learning rate for optimization. Defaults to 5e-3.
107 |
108 | Returns:
109 | dict: dictionary of registration result
110 | """
111 |
112 | # print("Pairwise registration ...")
113 | init_overlap = compute_overlap_gaussians(src_dict['gaussians'], tgt_dict['gaussians'], 0.1)
114 | if init_overlap< 0.2:
115 | print("Initial overlap between two submaps are too small, skipping ...")
116 | return {
117 | 'successful': False,
118 | "pred_tsfm": torch.eye(4).cuda(),
119 | "gt_tsfm": torch.eye(4).cuda(),
120 | "overlap": init_overlap.item()
121 | }
122 |
123 | src_3dgs, src_view_list = copy.deepcopy(src_dict['gaussians']), copy.deepcopy(src_dict['cameras'])
124 | tgt_3dgs, tgt_view_list = copy.deepcopy(tgt_dict['gaussians']), copy.deepcopy(tgt_dict['cameras'])
125 |
126 | # compute gt tsfm
127 | src_keyframe= src_dict['cameras'][0].get_T.detach()
128 | src_gt= src_dict['cameras'][0].get_T_gt.detach()
129 | tgt_keyframe= tgt_dict['cameras'][0].get_T.detach()
130 | tgt_gt= tgt_dict['cameras'][0].get_T_gt.detach()
131 | delta_src = src_gt.inverse() @ src_keyframe
132 | delta_tgt = tgt_gt.inverse() @ tgt_keyframe
133 | gt_tsfm = delta_tgt.inverse() @ delta_src
134 |
135 | # similarity choosing
136 | device = "cuda" if torch.cuda.is_available() else "cpu"
137 | src_desc, tgt_desc = src_dict['kf_desc'], tgt_dict['kf_desc']
138 |
139 | score_cross = torch.einsum("id,jd->ij", src_desc.to(device), tgt_desc.to(device))
140 | score_best_src, _ = score_cross.topk(1)
141 | _, ii = score_best_src.view(-1).topk(2)
142 |
143 | score_best_tgt, _ = score_cross.T.topk(1)
144 | _, jj = score_best_tgt.view(-1).topk(2)
145 |
146 | src_view_list = [src_view_list[i.item()] for i in ii]
147 | tgt_view_list = [tgt_view_list[j.item()] for j in jj]
148 |
149 | pred_list, residual_list, converged_list, loss_log_list = [], [], [], []
150 |
151 | pipe = CustomPipeline()
152 | bg_color = torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda", requires_grad=False)
153 | # per-cam
154 | for viewpoint in src_view_list:
155 |
156 | # use rendered image as target not the raw observation
157 | if config["use_render"]:
158 | render_pkg = render(viewpoint, src_3dgs, pipe, bg_color)
159 | viewpoint.load_rgb(render_pkg['render'].detach())
160 | viewpoint.depth = render_pkg['depth'].squeeze().detach().cpu().numpy()
161 | else:
162 | viewpoint.load_rgb()
163 | converged, pred_tsfm, residual, loss_log = viewpoint_localizer(viewpoint, tgt_3dgs, config["base_lr"])
164 | pred_list.append(pred_tsfm)
165 | residual_list.append(residual)
166 | converged_list.append(converged)
167 | loss_log_list.append(loss_log)
168 |
169 | for viewpoint in tgt_view_list:
170 | if config["use_render"]:
171 | render_pkg = render(viewpoint, tgt_3dgs, pipe, bg_color)
172 | viewpoint.load_rgb(render_pkg['render'].detach())
173 | viewpoint.depth = render_pkg['depth'].squeeze().detach().cpu().numpy()
174 | else:
175 | viewpoint.load_rgb()
176 | converged, pred_tsfm, residual, loss_log = viewpoint_localizer(viewpoint, src_3dgs, config["base_lr"])
177 | pred_list.append(pred_tsfm.inverse())
178 | residual_list.append(residual)
179 | converged_list.append(converged)
180 | loss_log_list.append(loss_log)
181 |
182 |
183 | pred_tsfms = torch.stack(pred_list)
184 | residuals = torch.Tensor(residual_list).cuda().float()
185 | # probability based on residuals
186 | prob = 1/residuals / (1/residuals).sum()
187 |
188 | M = torch.sum(prob[:, None, None] * pred_tsfms[:,:3,:3], dim=0)
189 | try:
190 | R_w = roma.special_procrustes(M)
191 | except Exception as e:
192 | print(f"Error in roma.special_procrustes: {e}")
193 | return {
194 | 'successful': False,
195 | "pred_tsfm": torch.eye(4).cuda(),
196 | "gt_tsfm": torch.eye(4).cuda(),
197 | "overlap": init_overlap.item()
198 | }
199 | t_w = torch.sum(prob[:, None] * pred_tsfms[:,:3, 3], dim=0)
200 |
201 | best_tsfm = torch.eye(4).cuda().float()
202 | best_tsfm[:3,:3] = R_w
203 | best_tsfm[:3, 3] = t_w
204 |
205 | result_dict = {
206 | "gt_tsfm": gt_tsfm,
207 | "pred_tsfm": best_tsfm,
208 | "successful": True,
209 | "best_viewpoint": src_view_list[0].get_T
210 | }
211 |
212 | if visualize:
213 | import matplotlib
214 | import matplotlib.pyplot as plt
215 | from src.gsr.utils import visualize_registration
216 | matplotlib.use('TkAgg')
217 | plt.figure(figsize=(10, 6))
218 | for log in loss_log_list:
219 | plt.plot(log)
220 |
221 | plt.xlabel('Epoch')
222 | plt.ylabel('Loss')
223 | plt.title('Loss Curves of 10 Independent Optimizations')
224 | plt.legend([f'Run {i+1}' for i in range(len(loss_log_list))], loc='upper right')
225 | plt.grid(True)
226 | plt.show()
227 |
228 | visualize_registration(src_3dgs, tgt_3dgs, best_tsfm, gt_tsfm)
229 |
230 | del src_3dgs, src_view_list, tgt_3dgs, tgt_view_list
231 | return result_dict
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/src/gsr/utils.py:
--------------------------------------------------------------------------------
1 | import json, yaml
2 | import numpy as np
3 | import torch
4 | from pycg import vis
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | import open3d as o3d
8 |
9 |
10 | from src.utils.graphics_utils import getProjectionMatrix2, focal2fov
11 | from src.gsr.se3.numpy_se3 import transform
12 | from src.gsr.camera import Camera
13 |
14 | def read_json_data(file_path):
15 | try:
16 | with open(file_path, "r", encoding="utf-8") as f:
17 | data = json.load(f)
18 | return data
19 | except FileNotFoundError:
20 | print(f"File '{file_path}' not found.")
21 | return None
22 | except json.JSONDecodeError as e:
23 | print(f"Error decoding JSON: {e}")
24 | return None
25 |
26 | def read_trajectory(file_path, slam_config, scale = 1., device = "cuda:0"):
27 | """read trajectory data from plot output to list of Cameras [to use in 3DGS]
28 |
29 | Args:
30 | file_path (string): path
31 | """
32 | trj_data = dict()
33 | json_data = read_json_data(file_path)
34 | dataset = slam_config['Dataset']['type']
35 |
36 | if json_data:
37 | # Access the data as needed
38 | trj_data["trj_id"] = json_data["trj_id"]
39 | trj_data["trj_est"] = torch.Tensor(json_data["trj_est"]).to(device).float()
40 | trj_data["trj_gt"] = torch.Tensor(json_data["trj_gt"]).to(device).float()
41 | # print("Trajectory loaded successfully!")
42 |
43 | cam_dict = dict()
44 | calibration = slam_config["Dataset"]["Calibration"]
45 |
46 | proj_matrix = getProjectionMatrix2(
47 | znear=0.01,
48 | zfar=100.0,
49 | fx = calibration["fx"] / scale,
50 | fy = calibration["fy"] / scale,
51 | cx = calibration["cx"] / scale,
52 | cy = calibration["cy"] / scale,
53 | W = calibration["width"] / scale,
54 | H = calibration["height"] / scale,
55 | ).T
56 |
57 | fovx = focal2fov(calibration['fx'], calibration['width'])
58 | fovy = focal2fov(calibration['fy'], calibration['height'])
59 | for id, est_pose, gt_pose in zip(trj_data["trj_id"], trj_data["trj_est"], trj_data["trj_gt"]):
60 | T_gt = torch.linalg.inv(gt_pose)
61 | cam_i = Camera(id, None, None,
62 | T_gt,
63 | proj_matrix,
64 | calibration["fx"]/ scale,
65 | calibration["fy"]/ scale,
66 | calibration["cx"]/ scale,
67 | calibration["cy"]/ scale,
68 | fovx,
69 | fovy,
70 | calibration["height"]/ scale,
71 | calibration["width"]/ scale)
72 | est_T = torch.linalg.inv(est_pose)
73 | cam_i.R = est_T[:3, :3]
74 | cam_i.T = est_T[:3, 3]
75 |
76 | cam_dict[id] = cam_i
77 |
78 | return cam_dict
79 |
80 | def visualize_registration(src_3dgs, tgt_3dgs, pre_tsfm, gt_tsfm):
81 | """visualize registration in 3D Gaussians
82 |
83 | Args:
84 | gs3d_src (Gaussian Model): _description_
85 | gs3d_tgt (Gaussian Model): _description_
86 | tsfm (torch.Tensor): _description_
87 | """
88 | src_pc = src_3dgs.get_xyz().detach().cpu().numpy()
89 | tgt_pc = tgt_3dgs.get_xyz().detach().cpu().numpy()
90 | est_src_pc = transform(pre_tsfm.detach().cpu().numpy(), src_pc)
91 | gt_src_pc = transform(gt_tsfm.detach().cpu().numpy(), src_pc)
92 |
93 | src_vis = vis.pointcloud(src_pc[::2], ucid=0, cmap='tab10', is_sphere=True)
94 | tgt_vis = vis.pointcloud(tgt_pc[::2], ucid=1, cmap='tab10', is_sphere=True)
95 |
96 | est_vis = vis.pointcloud(est_src_pc[::2], ucid=0, cmap='tab10', is_sphere=True)
97 | gt_vis = vis.pointcloud(gt_src_pc[::2], ucid=0, cmap='tab10', is_sphere=True)
98 |
99 | try:
100 | vis.show_3d([src_vis, tgt_vis], [est_vis, tgt_vis], [gt_vis, tgt_vis], use_new_api=True, show=True)
101 | except:
102 | print("estimate is not a good transformation")
103 | vis.show_3d([src_vis, tgt_vis], [gt_vis, tgt_vis], use_new_api=True, show=True)
104 |
105 | def visualize_mv_registration(gaussian_list, pred_rel_tsfms, gt_rel_tsfms):
106 | vis_pred_list, vis_gt_list = [], []
107 | pred_abs_tsfm = torch.eye(4).cuda()
108 | gt_abs_tsfm = torch.eye(4).cuda()
109 | for i, gaussians in enumerate(gaussian_list[:5]):
110 | pc = gaussians.get_xyz().detach().cpu().numpy()
111 |
112 | pc_pred = transform(pred_abs_tsfm.detach().cpu().numpy(), pc)
113 | vis_pred_list.append(vis.pointcloud(pc_pred, ucid=i, cmap='tab10', is_sphere=True))
114 | if i>0: pred_abs_tsfm = pred_abs_tsfm @ pred_rel_tsfms[i-1]
115 |
116 | pc_gt = transform(gt_abs_tsfm.detach().cpu().numpy(), pc)
117 | vis_gt_list.append(vis.pointcloud(pc_gt, ucid=i, cmap='tab10', is_sphere=True))
118 | if i>0: gt_abs_tsfm = gt_abs_tsfm @ gt_rel_tsfms[i-1]
119 |
120 | vis.show_3d(vis_pred_list, vis_gt_list, use_new_api=True, show=True)
121 |
122 |
123 |
124 | def axis_angle_to_rot_mat(axes, thetas):
125 | """
126 | Computer a rotation matrix from the axis-angle representation using the Rodrigues formula.
127 | \mathbf{R} = \mathbf{I} + (sin(\theta)\mathbf{K} + (1 - cos(\theta)\mathbf{K}^2), where K = \mathbf{I} \cross \frac{\mathbf{K}}{||\mathbf{K}||}
128 |
129 | Args:
130 | axes (numpy array): array of axes used to compute the rotation matrices [b,3]
131 | thetas (numpy array): array of angles used to compute the rotation matrices [b,1]
132 |
133 | Returns:
134 | rot_matrices (numpy array): array of the rotation matrices computed from the angle, axis representation [b,3,3]
135 |
136 | borrowed from: https://github.com/zgojcic/3D_multiview_reg/blob/master/lib/utils.py
137 | """
138 |
139 | R = []
140 | for k in range(axes.shape[0]):
141 | K = np.cross(np.eye(3), axes[k,:]/np.linalg.norm(axes[k,:]))
142 | R.append( np.eye(3) + np.sin(thetas[k])*K + (1 - np.cos(thetas[k])) * np.matmul(K,K))
143 |
144 | rot_matrices = np.stack(R)
145 | return rot_matrices
146 |
147 | def sample_random_trans(pcd, randg=None, rotation_range=360):
148 | """
149 | Samples random transformation paramaters with the rotaitons limited to the rotation range
150 |
151 | Args:
152 | pcd (numpy array): numpy array of coordinates for which the transformation paramaters are sampled [n,3]
153 | randg (numpy random generator): numpy random generator
154 |
155 | Returns:
156 | T (numpy array): sampled transformation paramaters [4,4]
157 |
158 | borrowed from: https://github.com/zgojcic/3D_multiview_reg/blob/master/lib/utils.py
159 | """
160 | if randg == None:
161 | randg = np.random.default_rng(41)
162 |
163 | # Create 3D identity matrix
164 | T = np.zeros((4,4))
165 | idx = np.arange(4)
166 | T[idx,idx] = 1
167 |
168 | axes = np.random.rand(1,3) - 0.5
169 |
170 | angles = rotation_range * np.pi / 180.0 * (np.random.rand(1,1) - 0.5)
171 |
172 | R = axis_angle_to_rot_mat(axes, angles)
173 |
174 | T[:3, :3] = R
175 | # T[:3, 3] = np.random.rand(3)-0.5
176 | T[:3, 3] = np.matmul(R,-np.mean(pcd, axis=0))
177 |
178 | return T
179 |
180 | def visualize_mv_registration(data_list, pred_pose_list):
181 | n_view = len(pred_pose_list)
182 | vis_init_list, vis_pred_list, vis_gt_list = [], [], []
183 | for i in range(n_view):
184 | pc = data_list[i]['gaussians'].get_xyz().detach().cpu().numpy()
185 | pred_pc = transform(pred_pose_list[i].detach().cpu().numpy(), pc)
186 |
187 | gt_tsfm = data_list[0]['gt_tsfm'] @ data_list[i]['gt_tsfm'].inverse()
188 | gt_pc = transform(gt_tsfm.detach().cpu().numpy(), data_list[i]['gaussians'].get_xyz().detach().cpu().numpy())
189 |
190 | vis_pred_list.append(vis.pointcloud(pred_pc[::2], ucid=i, is_sphere=True))
191 | vis_gt_list.append(vis.pointcloud(gt_pc[::2], ucid=i, is_sphere=True))
192 | vis_init_list.append(vis.pointcloud(pc[::2], ucid=i, is_sphere=True))
193 |
194 | vis.show_3d(vis_init_list, vis_pred_list, vis_gt_list, use_new_api=True)
195 |
196 |
197 | @torch.no_grad()
198 | def plot_and_save(points, pngname, title='', axlim=None):
199 | points = points.detach().cpu().numpy()
200 | plt.figure(figsize=(7, 7))
201 | ax = plt.axes(projection='3d')
202 | ax.plot3D(points[:,0], points[:,1], points[:,2], 'b')
203 | plt.title(title)
204 | if axlim is not None:
205 | ax.set_xlim(axlim[0])
206 | ax.set_ylim(axlim[1])
207 | ax.set_zlim(axlim[2])
208 | plt.savefig(pngname)
209 | print('Saving to', pngname)
210 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
211 |
212 |
213 | def colorize_depth_maps(
214 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
215 | ):
216 | """
217 | Colorize depth maps.
218 | """
219 | assert len(depth_map.shape) >= 2, "Invalid dimension"
220 |
221 | if isinstance(depth_map, torch.Tensor):
222 | depth = depth_map.detach().squeeze().numpy()
223 | elif isinstance(depth_map, np.ndarray):
224 | depth = depth_map.copy().squeeze()
225 | # reshape to [ (B,) H, W ]
226 | if depth.ndim < 3:
227 | depth = depth[np.newaxis, :, :]
228 |
229 | # colorize
230 | cm = matplotlib.colormaps[cmap]
231 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
232 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
233 | img_colored_np = np.rollaxis(img_colored_np, 3, 1)
234 |
235 | if valid_mask is not None:
236 | if isinstance(depth_map, torch.Tensor):
237 | valid_mask = valid_mask.detach().numpy()
238 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
239 | if valid_mask.ndim < 3:
240 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
241 | else:
242 | valid_mask = valid_mask[:, np.newaxis, :, :]
243 | valid_mask = np.repeat(valid_mask, 3, axis=1)
244 | img_colored_np[~valid_mask] = 0
245 |
246 | if isinstance(depth_map, torch.Tensor):
247 | img_colored = torch.from_numpy(img_colored_np).float()
248 | elif isinstance(depth_map, np.ndarray):
249 | img_colored = img_colored_np
250 |
251 | return img_colored
252 |
253 |
254 | def chw2hwc(chw):
255 | assert 3 == len(chw.shape)
256 | if isinstance(chw, torch.Tensor):
257 | hwc = torch.permute(chw, (1, 2, 0))
258 | elif isinstance(chw, np.ndarray):
259 | hwc = np.moveaxis(chw, 0, -1)
260 | return hwc
261 |
262 | def visualize_camera_traj(cam_list):
263 | vis_cam_list = []
264 | for cam in cam_list:
265 | intrinsic = np.array([
266 | [cam.fx, 0.0, cam.cx],
267 | [0.0, cam.fy, cam.cy],
268 | [0.0, 0.0, 1.0]
269 | ])
270 | pred_extrinsic = cam.get_T.cpu().numpy()
271 | gt_extrinsic = cam.get_T_gt.cpu().numpy()
272 | vis_pred_cam = o3d.geometry.LineSet.create_camera_visualization(
273 | 640, 480, intrinsic, pred_extrinsic, scale=0.1)
274 | vis_gt_cam = o3d.geometry.LineSet.create_camera_visualization(
275 | 640, 480, intrinsic, gt_extrinsic, scale=0.1)
276 |
277 | # Set colors for predicted (blue) and ground truth (green) cameras
278 | vis_pred_cam.paint_uniform_color([1, 0, 0])
279 | vis_gt_cam.paint_uniform_color([0, 1, 0])
280 |
281 | vis_cam_list.append(vis_pred_cam)
282 | vis_cam_list.append(vis_gt_cam)
283 |
284 | return vis_cam_list
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import evo
4 | import numpy as np
5 |
6 | import torch
7 | from errno import EEXIST
8 | from os import makedirs, path
9 | from evo.core import metrics, trajectory
10 | from evo.core.metrics import PoseRelation, Unit
11 | from evo.core.trajectory import PosePath3D, PoseTrajectory3D
12 | from evo.tools import plot
13 | from evo.tools.plot import PlotMode
14 | from evo.tools.settings import SETTINGS
15 | from matplotlib import pyplot as plt
16 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
17 |
18 | from tqdm import tqdm
19 |
20 | import wandb
21 |
22 | import rich
23 |
24 | _log_styles = {
25 | "Eval": "bold red",
26 | }
27 |
28 | def get_style(tag):
29 | if tag in _log_styles.keys():
30 | return _log_styles[tag]
31 | return "bold blue"
32 |
33 |
34 | def Log(*args, tag="MonoGS"):
35 | style = get_style(tag)
36 | rich.print(f"[{style}]{tag}:[/{style}]", *args)
37 |
38 | def mkdir_p(folder_path):
39 | # Creates a directory. equivalent to using mkdir -p on the command line
40 | try:
41 | makedirs(folder_path)
42 | except OSError as exc: # Python >2.5
43 | if exc.errno == EEXIST and path.isdir(folder_path):
44 | pass
45 | else:
46 | raise
47 | def evaluate_evo(poses_gt, poses_est, plot_dir, label, monocular=False):
48 | ## Plot
49 | traj_ref = PosePath3D(poses_se3=poses_gt)
50 | traj_est = PosePath3D(poses_se3=poses_est)
51 | traj_est_aligned = trajectory.align_trajectory(
52 | traj_est, traj_ref, correct_scale=monocular
53 | )
54 |
55 | ## RMSE
56 | pose_relation = metrics.PoseRelation.translation_part
57 | data = (traj_ref, traj_est_aligned)
58 | ape_metric = metrics.APE(pose_relation)
59 | ape_metric.process_data(data)
60 | ape_stat = ape_metric.get_statistic(metrics.StatisticsType.rmse)
61 | ape_stats = ape_metric.get_all_statistics()
62 | Log("RMSE ATE \[m]", ape_stat, tag="Eval")
63 |
64 | with open(
65 | os.path.join(plot_dir, "stats_{}.json".format(str(label))),
66 | "w",
67 | encoding="utf-8",
68 | ) as f:
69 | json.dump(ape_stats, f, indent=4)
70 |
71 | plot_mode = evo.tools.plot.PlotMode.xy
72 | fig = plt.figure()
73 | ax = evo.tools.plot.prepare_axis(fig, plot_mode)
74 | ax.set_title(f"ATE RMSE: {ape_stat}")
75 | evo.tools.plot.traj(ax, plot_mode, traj_ref, "--", "gray", "gt")
76 | evo.tools.plot.traj_colormap(
77 | ax,
78 | traj_est_aligned,
79 | ape_metric.error,
80 | plot_mode,
81 | min_map=ape_stats["min"],
82 | max_map=ape_stats["max"],
83 | )
84 | ax.legend()
85 | plt.savefig(os.path.join(plot_dir, "evo_2dplot_{}.png".format(str(label))), dpi=90)
86 |
87 | return ape_stat
88 |
89 |
90 | def eval_ate(frames, kf_ids, save_dir, iterations, final=False, monocular=False):
91 | trj_data = dict()
92 | latest_frame_idx = kf_ids[-1] + 2 if final else kf_ids[-1] + 1
93 | trj_id, trj_est, trj_gt = [], [], []
94 | trj_est_np, trj_gt_np = [], []
95 |
96 | def gen_pose_matrix(R, T):
97 | pose = np.eye(4)
98 | pose[0:3, 0:3] = R.cpu().numpy()
99 | pose[0:3, 3] = T.cpu().numpy()
100 | return pose
101 |
102 | for kf_id in kf_ids:
103 | kf = frames[kf_id]
104 | pose_est = np.linalg.inv(gen_pose_matrix(kf.R, kf.T))
105 | pose_gt = np.linalg.inv(gen_pose_matrix(kf.R_gt, kf.T_gt))
106 |
107 | trj_id.append(frames[kf_id].uid)
108 | trj_est.append(pose_est.tolist())
109 | trj_gt.append(pose_gt.tolist())
110 |
111 | trj_est_np.append(pose_est)
112 | trj_gt_np.append(pose_gt)
113 |
114 | trj_data["trj_id"] = trj_id
115 | trj_data["trj_est"] = trj_est
116 | trj_data["trj_gt"] = trj_gt
117 |
118 | plot_dir = os.path.join(save_dir, "plot")
119 | mkdir_p(plot_dir)
120 |
121 | label_evo = "final" if final else "{:04}".format(iterations)
122 | with open(
123 | os.path.join(plot_dir, f"trj_{label_evo}.json"), "w", encoding="utf-8"
124 | ) as f:
125 | json.dump(trj_data, f, indent=4)
126 |
127 | ate = evaluate_evo(
128 | poses_gt=trj_gt_np,
129 | poses_est=trj_est_np,
130 | plot_dir=plot_dir,
131 | label=label_evo,
132 | monocular=monocular,
133 | )
134 | return ate
--------------------------------------------------------------------------------
/src/utils/gaussian_model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | from typing import NamedTuple
24 | import numpy as np
25 | import torch
26 | import torch.nn.functional as F
27 |
28 | C0 = 0.28209479177387814
29 | C1 = 0.4886025119029199
30 | C2 = [
31 | 1.0925484305920792,
32 | -1.0925484305920792,
33 | 0.31539156525252005,
34 | -1.0925484305920792,
35 | 0.5462742152960396
36 | ]
37 | C3 = [
38 | -0.5900435899266435,
39 | 2.890611442640554,
40 | -0.4570457994644658,
41 | 0.3731763325901154,
42 | -0.4570457994644658,
43 | 1.445305721320277,
44 | -0.5900435899266435
45 | ]
46 | C4 = [
47 | 2.5033429417967046,
48 | -1.7701307697799304,
49 | 0.9461746957575601,
50 | -0.6690465435572892,
51 | 0.10578554691520431,
52 | -0.6690465435572892,
53 | 0.47308734787878004,
54 | -1.7701307697799304,
55 | 0.6258357354491761,
56 | ]
57 |
58 |
59 | def eval_sh(deg, sh, dirs):
60 | """
61 | Evaluate spherical harmonics at unit directions
62 | using hardcoded SH polynomials.
63 | Works with torch/np/jnp.
64 | ... Can be 0 or more batch dimensions.
65 | Args:
66 | deg: int SH deg. Currently, 0-3 supported
67 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
68 | dirs: jnp.ndarray unit directions [..., 3]
69 | Returns:
70 | [..., C]
71 | """
72 | assert deg <= 4 and deg >= 0
73 | coeff = (deg + 1) ** 2
74 | assert sh.shape[-1] >= coeff
75 |
76 | result = C0 * sh[..., 0]
77 | if deg > 0:
78 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
79 | result = (result -
80 | C1 * y * sh[..., 1] +
81 | C1 * z * sh[..., 2] -
82 | C1 * x * sh[..., 3])
83 |
84 | if deg > 1:
85 | xx, yy, zz = x * x, y * y, z * z
86 | xy, yz, xz = x * y, y * z, x * z
87 | result = (result +
88 | C2[0] * xy * sh[..., 4] +
89 | C2[1] * yz * sh[..., 5] +
90 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
91 | C2[3] * xz * sh[..., 7] +
92 | C2[4] * (xx - yy) * sh[..., 8])
93 |
94 | if deg > 2:
95 | result = (result +
96 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
97 | C3[1] * xy * z * sh[..., 10] +
98 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] +
99 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
100 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
101 | C3[5] * z * (xx - yy) * sh[..., 14] +
102 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
103 |
104 | if deg > 3:
105 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
106 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
107 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
108 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
109 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
110 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
111 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
112 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
113 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
114 | return result
115 |
116 |
117 | def RGB2SH(rgb):
118 | return (rgb - 0.5) / C0
119 |
120 |
121 | def SH2RGB(sh):
122 | return sh * C0 + 0.5
123 |
124 |
125 | def inverse_sigmoid(x):
126 | return torch.log(x/(1-x))
127 |
128 |
129 | def get_expon_lr_func(
130 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
131 | ):
132 | """
133 | Copied from Plenoxels
134 |
135 | Continuous learning rate decay function. Adapted from JaxNeRF
136 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
137 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
138 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
139 | function of lr_delay_mult, such that the initial learning rate is
140 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
141 | to the normal learning rate when steps>lr_delay_steps.
142 | :param conf: config subtree 'lr' or similar
143 | :param max_steps: int, the number of steps during optimization.
144 | :return HoF which takes step as input
145 | """
146 |
147 | def helper(step):
148 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
149 | # Disable this parameter
150 | return 0.0
151 | if lr_delay_steps > 0:
152 | # A kind of reverse cosine decay.
153 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
154 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
155 | )
156 | else:
157 | delay_rate = 1.0
158 | t = np.clip(step / max_steps, 0, 1)
159 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
160 | return delay_rate * log_lerp
161 |
162 | return helper
163 |
164 |
165 | def strip_lowerdiag(L):
166 | uncertainty = torch.zeros(
167 | (L.shape[0], 6), dtype=torch.float, device="cuda")
168 |
169 | uncertainty[:, 0] = L[:, 0, 0]
170 | uncertainty[:, 1] = L[:, 0, 1]
171 | uncertainty[:, 2] = L[:, 0, 2]
172 | uncertainty[:, 3] = L[:, 1, 1]
173 | uncertainty[:, 4] = L[:, 1, 2]
174 | uncertainty[:, 5] = L[:, 2, 2]
175 | return uncertainty
176 |
177 |
178 | def strip_symmetric(sym):
179 | return strip_lowerdiag(sym)
180 |
181 |
182 | def build_rotation(r):
183 |
184 | q = F.normalize(r, p=2, dim=1)
185 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
186 |
187 | r = q[:, 0]
188 | x = q[:, 1]
189 | y = q[:, 2]
190 | z = q[:, 3]
191 |
192 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
193 | R[:, 0, 1] = 2 * (x*y - r*z)
194 | R[:, 0, 2] = 2 * (x*z + r*y)
195 | R[:, 1, 0] = 2 * (x*y + r*z)
196 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
197 | R[:, 1, 2] = 2 * (y*z - r*x)
198 | R[:, 2, 0] = 2 * (x*z - r*y)
199 | R[:, 2, 1] = 2 * (y*z + r*x)
200 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
201 | return R
202 |
203 |
204 | def build_scaling_rotation(s, r):
205 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
206 | R = build_rotation(r)
207 |
208 | L[:, 0, 0] = s[:, 0]
209 | L[:, 1, 1] = s[:, 1]
210 | L[:, 2, 2] = s[:, 2]
211 |
212 | L = R @ L
213 | return L
214 |
215 |
216 | class BasicPointCloud(NamedTuple):
217 | points: np.array
218 | colors: np.array
--------------------------------------------------------------------------------
/src/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import math
13 | from typing import NamedTuple
14 |
15 | import numpy as np
16 | import torch
17 |
18 |
19 | class BasicPointCloud(NamedTuple):
20 | points: np.array
21 | colors: np.array
22 | normals: np.array
23 |
24 |
25 | def getWorld2View(R, t):
26 | Rt = np.zeros((4, 4))
27 | Rt[:3, :3] = R.transpose()
28 | Rt[:3, 3] = t
29 | Rt[3, 3] = 1.0
30 | return np.float32(Rt)
31 |
32 |
33 | def getWorld2View2(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0):
34 | translate = translate.to(R.device)
35 | Rt = torch.zeros((4, 4), device=R.device)
36 | # Rt[:3, :3] = R.transpose()
37 | Rt[:3, :3] = R
38 | Rt[:3, 3] = t
39 | Rt[3, 3] = 1.0
40 |
41 | C2W = torch.linalg.inv(Rt)
42 | cam_center = C2W[:3, 3]
43 | cam_center = (cam_center + translate) * scale
44 | C2W[:3, 3] = cam_center
45 | Rt = torch.linalg.inv(C2W)
46 | return Rt
47 |
48 |
49 | def getProjectionMatrix(znear, zfar, fovX, fovY):
50 | tanHalfFovY = math.tan((fovY / 2))
51 | tanHalfFovX = math.tan((fovX / 2))
52 |
53 | top = tanHalfFovY * znear
54 | bottom = -top
55 | right = tanHalfFovX * znear
56 | left = -right
57 |
58 | P = torch.zeros(4, 4)
59 |
60 | z_sign = 1.0
61 |
62 | P[0, 0] = 2.0 * znear / (right - left)
63 | P[1, 1] = 2.0 * znear / (top - bottom)
64 | P[0, 2] = (right + left) / (right - left)
65 | P[1, 2] = (top + bottom) / (top - bottom)
66 | P[3, 2] = z_sign
67 | P[2, 2] = -(zfar + znear) / (zfar - znear)
68 | P[2, 3] = -2 * (zfar * znear) / (zfar - znear)
69 | return P
70 |
71 |
72 | def getProjectionMatrix2(znear, zfar, cx, cy, fx, fy, W, H):
73 | left = ((2 * cx - W) / W - 1.0) * W / 2.0
74 | right = ((2 * cx - W) / W + 1.0) * W / 2.0
75 | top = ((2 * cy - H) / H + 1.0) * H / 2.0
76 | bottom = ((2 * cy - H) / H - 1.0) * H / 2.0
77 | left = znear / fx * left
78 | right = znear / fx * right
79 | top = znear / fy * top
80 | bottom = znear / fy * bottom
81 | P = torch.zeros(4, 4)
82 |
83 | z_sign = 1.0
84 |
85 | P[0, 0] = 2.0 * znear / (right - left)
86 | P[1, 1] = 2.0 * znear / (top - bottom)
87 | P[0, 2] = (right + left) / (right - left)
88 | P[1, 2] = (top + bottom) / (top - bottom)
89 | P[3, 2] = z_sign
90 | P[2, 2] = z_sign * zfar / (zfar - znear)
91 | P[2, 3] = -(zfar * znear) / (zfar - znear)
92 |
93 | return P
94 |
95 |
96 | def fov2focal(fov, pixels):
97 | return pixels / (2 * math.tan(fov / 2))
98 |
99 |
100 | def focal2fov(focal, pixels):
101 | return 2 * math.atan(pixels / (2 * focal))
102 |
--------------------------------------------------------------------------------
/src/utils/io_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | import open3d as o3d
7 | import torch
8 | import wandb
9 | import yaml
10 |
11 |
12 | def mkdir_decorator(func):
13 | """A decorator that creates the directory specified in the function's 'directory' keyword
14 | argument before calling the function.
15 | Args:
16 | func: The function to be decorated.
17 | Returns:
18 | The wrapper function.
19 | """
20 | def wrapper(*args, **kwargs):
21 | output_path = Path(kwargs["directory"])
22 | output_path.mkdir(parents=True, exist_ok=True)
23 | return func(*args, **kwargs)
24 | return wrapper
25 |
26 |
27 | @mkdir_decorator
28 | def save_clouds(clouds: list, cloud_names: list, *, directory: Union[str, Path]) -> None:
29 | """ Saves a list of point clouds to the specified directory, creating the directory if it does not exist.
30 | Args:
31 | clouds: A list of point cloud objects to be saved.
32 | cloud_names: A list of filenames for the point clouds, corresponding by index to the clouds.
33 | directory: The directory where the point clouds will be saved.
34 | """
35 | for cld_name, cloud in zip(cloud_names, clouds):
36 | o3d.io.write_point_cloud(str(directory / cld_name), cloud)
37 |
38 |
39 | @mkdir_decorator
40 | def save_dict_to_ckpt(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
41 | """ Saves a dictionary to a checkpoint file in the specified directory, creating the directory if it does not exist.
42 | Args:
43 | dictionary: The dictionary to be saved.
44 | file_name: The name of the checkpoint file.
45 | directory: The directory where the checkpoint file will be saved.
46 | """
47 | torch.save(dictionary, directory / file_name,
48 | _use_new_zipfile_serialization=False)
49 |
50 |
51 | @mkdir_decorator
52 | def save_dict_to_yaml(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
53 | """ Saves a dictionary to a YAML file in the specified directory, creating the directory if it does not exist.
54 | Args:
55 | dictionary: The dictionary to be saved.
56 | file_name: The name of the YAML file.
57 | directory: The directory where the YAML file will be saved.
58 | """
59 | with open(directory / file_name, "w") as f:
60 | yaml.dump(dictionary, f)
61 |
62 |
63 | @mkdir_decorator
64 | def save_dict_to_json(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
65 | """ Saves a dictionary to a JSON file in the specified directory, creating the directory if it does not exist.
66 | Args:
67 | dictionary: The dictionary to be saved.
68 | file_name: The name of the JSON file.
69 | directory: The directory where the JSON file will be saved.
70 | """
71 | with open(directory / file_name, "w") as f:
72 | json.dump(dictionary, f)
73 |
74 |
75 | def load_config(path: str, default_path: str = None) -> dict:
76 | """
77 | Loads a configuration file and optionally merges it with a default configuration file.
78 |
79 | This function loads a configuration from the given path. If the configuration specifies an inheritance
80 | path (`inherit_from`), or if a `default_path` is provided, it loads the base configuration and updates it
81 | with the specific configuration.
82 |
83 | Args:
84 | path: The path to the specific configuration file.
85 | default_path: An optional path to a default configuration file that is loaded if the specific configuration
86 | does not specify an inheritance or as a base for the inheritance.
87 |
88 | Returns:
89 | A dictionary containing the merged configuration.
90 | """
91 | # load configuration from per scene/dataset cfg.
92 | with open(path, 'r') as f:
93 | cfg_special = yaml.full_load(f)
94 | inherit_from = cfg_special.get('inherit_from')
95 | cfg = dict()
96 | if inherit_from is not None:
97 | cfg = load_config(inherit_from, default_path)
98 | elif default_path is not None:
99 | with open(default_path, 'r') as f:
100 | cfg = yaml.full_load(f)
101 | update_recursive(cfg, cfg_special)
102 | return cfg
103 |
104 |
105 | def update_recursive(dict1: dict, dict2: dict) -> None:
106 | """ Recursively updates the first dictionary with the contents of the second dictionary.
107 |
108 | This function iterates through `dict2` and updates `dict1` with its contents. If a key from `dict2`
109 | exists in `dict1` and its value is also a dictionary, the function updates the value recursively.
110 | Otherwise, it overwrites the value in `dict1` with the value from `dict2`.
111 |
112 | Args:
113 | dict1: The dictionary to be updated.
114 | dict2: The dictionary whose entries are used to update `dict1`.
115 |
116 | Returns:
117 | None: The function modifies `dict1` in place.
118 | """
119 | for k, v in dict2.items():
120 | if k not in dict1:
121 | dict1[k] = dict()
122 | if isinstance(v, dict):
123 | update_recursive(dict1[k], v)
124 | else:
125 | dict1[k] = v
126 |
127 |
128 | def log_metrics_to_wandb(json_files: list, output_path: str, section: str = "Evaluation") -> None:
129 | """ Logs metrics from JSON files to Weights & Biases under a specified section.
130 |
131 | This function reads metrics from a list of JSON files and logs them to Weights & Biases (wandb).
132 | Each metric is prefixed with a specified section name for organized logging.
133 |
134 | Args:
135 | json_files: A list of filenames for JSON files containing metrics to be logged.
136 | output_path: The directory path where the JSON files are located.
137 | section: The section under which to log the metrics in wandb. Defaults to "Evaluation".
138 |
139 | Returns:
140 | None: Metrics are logged to wandb and the function does not return a value.
141 | """
142 | for json_file in json_files:
143 | file_path = os.path.join(output_path, json_file)
144 | if os.path.exists(file_path):
145 | with open(file_path, 'r') as file:
146 | metrics = json.load(file)
147 | prefixed_metrics = {
148 | f"{section}/{key}": value for key, value in metrics.items()}
149 | wandb.log(prefixed_metrics)
150 |
--------------------------------------------------------------------------------
/src/utils/pose_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def rt2mat(R, T):
6 | mat = np.eye(4)
7 | mat[0:3, 0:3] = R
8 | mat[0:3, 3] = T
9 | return mat
10 |
11 |
12 | def skew_sym_mat(x):
13 | device = x.device
14 | dtype = x.dtype
15 | ssm = torch.zeros(3, 3, device=device, dtype=dtype)
16 | ssm[0, 1] = -x[2]
17 | ssm[0, 2] = x[1]
18 | ssm[1, 0] = x[2]
19 | ssm[1, 2] = -x[0]
20 | ssm[2, 0] = -x[1]
21 | ssm[2, 1] = x[0]
22 | return ssm
23 |
24 |
25 | def SO3_exp(theta):
26 | device = theta.device
27 | dtype = theta.dtype
28 |
29 | W = skew_sym_mat(theta)
30 | W2 = W @ W
31 | angle = torch.norm(theta)
32 | I = torch.eye(3, device=device, dtype=dtype)
33 | if angle < 1e-5:
34 | return I + W + 0.5 * W2
35 | else:
36 | return (
37 | I
38 | + (torch.sin(angle) / angle) * W
39 | + ((1 - torch.cos(angle)) / (angle**2)) * W2
40 | )
41 |
42 |
43 | def V(theta):
44 | dtype = theta.dtype
45 | device = theta.device
46 | I = torch.eye(3, device=device, dtype=dtype)
47 | W = skew_sym_mat(theta)
48 | W2 = W @ W
49 | angle = torch.norm(theta)
50 | if angle < 1e-5:
51 | V = I + 0.5 * W + (1.0 / 6.0) * W2
52 | else:
53 | V = (
54 | I
55 | + W * ((1.0 - torch.cos(angle)) / (angle**2))
56 | + W2 * ((angle - torch.sin(angle)) / (angle**3))
57 | )
58 | return V
59 |
60 |
61 | def SE3_exp(tau):
62 | dtype = tau.dtype
63 | device = tau.device
64 |
65 | rho = tau[:3]
66 | theta = tau[3:]
67 | R = SO3_exp(theta)
68 | t = V(theta) @ rho
69 |
70 | T = torch.eye(4, device=device, dtype=dtype)
71 | T[:3, :3] = R
72 | T[:3, 3] = t
73 | return T
74 |
75 |
76 | def update_pose(camera, converged_threshold=1e-4):
77 | tau = torch.cat([camera.cam_trans_delta, camera.cam_rot_delta], axis=0)
78 |
79 | T_w2c = torch.eye(4, device=tau.device)
80 | T_w2c[0:3, 0:3] = camera.R
81 | T_w2c[0:3, 3] = camera.T
82 |
83 | new_w2c = SE3_exp(tau) @ T_w2c
84 |
85 | new_R = new_w2c[0:3, 0:3]
86 | new_T = new_w2c[0:3, 3]
87 |
88 | converged = tau.norm() < converged_threshold
89 | camera.update_RT(new_R, new_T)
90 |
91 | camera.cam_rot_delta.data.fill_(0)
92 | camera.cam_trans_delta.data.fill_(0)
93 | return converged
94 |
--------------------------------------------------------------------------------
/src/utils/tracker_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy.spatial.transform import Rotation
4 | from typing import Union
5 | from src.utils.utils import np2torch
6 |
7 |
8 | def multiply_quaternions(q: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
9 | """Performs batch-wise quaternion multiplication.
10 |
11 | Given two quaternions, this function computes their product. The operation is
12 | vectorized and can be performed on batches of quaternions.
13 |
14 | Args:
15 | q: A tensor representing the first quaternion or a batch of quaternions.
16 | Expected shape is (... , 4), where the last dimension contains quaternion components (w, x, y, z).
17 | r: A tensor representing the second quaternion or a batch of quaternions with the same shape as q.
18 | Returns:
19 | A tensor of the same shape as the input tensors, representing the product of the input quaternions.
20 | """
21 | w0, x0, y0, z0 = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
22 | w1, x1, y1, z1 = r[..., 0], r[..., 1], r[..., 2], r[..., 3]
23 |
24 | w = -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0
25 | x = x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0
26 | y = -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0
27 | z = x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0
28 | return torch.stack((w, x, y, z), dim=-1)
29 |
30 |
31 | def transformation_to_quaternion(RT: Union[torch.Tensor, np.ndarray]):
32 | """ Converts a rotation-translation matrix to a tensor representing quaternion and translation.
33 |
34 | This function takes a 3x4 transformation matrix (rotation and translation) and converts it
35 | into a tensor that combines the quaternion representation of the rotation and the translation vector.
36 |
37 | Args:
38 | RT: A 3x4 matrix representing the rotation and translation. This can be a NumPy array
39 | or a torch.Tensor. If it's a torch.Tensor and resides on a GPU, it will be moved to CPU.
40 |
41 | Returns:
42 | A tensor combining the quaternion (in w, x, y, z order) and translation vector. The tensor
43 | will be moved to the original device if the input was a GPU tensor.
44 | """
45 | gpu_id = -1
46 | if isinstance(RT, torch.Tensor):
47 | if RT.get_device() != -1:
48 | RT = RT.detach().cpu()
49 | gpu_id = RT.get_device()
50 | RT = RT.numpy()
51 | R, T = RT[:3, :3], RT[:3, 3]
52 |
53 | rot = Rotation.from_matrix(R)
54 | quad = rot.as_quat(canonical=True)
55 | quad = np.roll(quad, 1)
56 | tensor = np.concatenate([quad, T], 0)
57 | tensor = torch.from_numpy(tensor).float()
58 | if gpu_id != -1:
59 | tensor = tensor.to(gpu_id)
60 | return tensor
61 |
62 |
63 | def extrapolate_poses(poses: np.ndarray) -> np.ndarray:
64 | """ Generates an interpolated pose based on the first two poses in the given array.
65 | Args:
66 | poses: An array of poses, where each pose is represented by a 4x4 transformation matrix.
67 | Returns:
68 | A 4x4 numpy ndarray representing the interpolated transformation matrix.
69 | """
70 | return poses[1, :] @ np.linalg.inv(poses[0, :]) @ poses[1, :]
71 |
72 |
73 | def compute_camera_opt_params(estimate_rel_w2c: np.ndarray) -> tuple:
74 | """ Computes the camera's rotation and translation parameters from an world-to-camera transformation matrix.
75 | This function extracts the rotation component of the transformation matrix, converts it to a quaternion,
76 | and reorders it to match a specific convention. Both rotation and translation parameters are converted
77 | to torch Parameters and intended to be optimized in a PyTorch model.
78 | Args:
79 | estimate_rel_w2c: A 4x4 numpy ndarray representing the estimated world-to-camera transformation matrix.
80 | Returns:
81 | A tuple containing two torch.nn.Parameters: camera's rotation and camera's translation.
82 | """
83 | quaternion = Rotation.from_matrix(estimate_rel_w2c[:3, :3]).as_quat(canonical=True)
84 | quaternion = quaternion[[3, 0, 1, 2]]
85 | opt_cam_rot = torch.nn.Parameter(np2torch(quaternion, "cuda"))
86 | opt_cam_trans = torch.nn.Parameter(np2torch(estimate_rel_w2c[:3, 3], "cuda"))
87 | return opt_cam_rot, opt_cam_trans
88 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | from scipy.ndimage import median_filter
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | from gaussian_rasterizer import GaussianRasterizationSettings, GaussianRasterizer
9 |
10 |
11 | def setup_seed(seed: int) -> None:
12 | """ Sets the seed for generating random numbers to ensure reproducibility across multiple runs.
13 | Args:
14 | seed: The seed value to set for random number generators in torch, numpy, and random.
15 | """
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | os.environ["PYTHONHASHSEED"] = str(seed)
19 | np.random.seed(seed)
20 | random.seed(seed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 |
24 |
25 | def torch2np(tensor: torch.Tensor) -> np.ndarray:
26 | """ Converts a PyTorch tensor to a NumPy ndarray.
27 | Args:
28 | tensor: The PyTorch tensor to convert.
29 | Returns:
30 | A NumPy ndarray with the same data and dtype as the input tensor.
31 | """
32 | return tensor.detach().cpu().numpy()
33 |
34 |
35 | def np2torch(array: np.ndarray, device: str = "cpu") -> torch.Tensor:
36 | """Converts a NumPy ndarray to a PyTorch tensor.
37 | Args:
38 | array: The NumPy ndarray to convert.
39 | device: The device to which the tensor is sent. Defaults to 'cpu'.
40 |
41 | Returns:
42 | A PyTorch tensor with the same data as the input array.
43 | """
44 | return torch.from_numpy(array).float().to(device)
45 |
46 |
47 | def np2ptcloud(pts: np.ndarray, rgb=None) -> o3d.geometry.PointCloud:
48 | """converts numpy array to point cloud
49 | Args:
50 | pts (ndarray): point cloud
51 | Returns:
52 | (PointCloud): resulting point cloud
53 | """
54 | cloud = o3d.geometry.PointCloud()
55 | cloud.points = o3d.utility.Vector3dVector(pts)
56 | if rgb is not None:
57 | cloud.colors = o3d.utility.Vector3dVector(rgb)
58 | return cloud
59 |
60 |
61 | def dict2device(dict: dict, device: str = "cpu") -> dict:
62 | """Sends all tensors in a dictionary to a specified device.
63 | Args:
64 | dict: The dictionary containing tensors.
65 | device: The device to send the tensors to. Defaults to 'cpu'.
66 | Returns:
67 | The dictionary with all tensors sent to the specified device.
68 | """
69 | for k, v in dict.items():
70 | if isinstance(v, torch.Tensor):
71 | dict[k] = v.to(device)
72 | return dict
73 |
74 |
75 | def get_render_settings(w, h, intrinsics, w2c, near=0.01, far=100, sh_degree=0):
76 | """
77 | Constructs and returns a GaussianRasterizationSettings object for rendering,
78 | configured with given camera parameters.
79 |
80 | Args:
81 | width (int): The width of the image.
82 | height (int): The height of the image.
83 | intrinsic (array): 3*3, Intrinsic camera matrix.
84 | w2c (array): World to camera transformation matrix.
85 | near (float, optional): The near plane for the camera. Defaults to 0.01.
86 | far (float, optional): The far plane for the camera. Defaults to 100.
87 |
88 | Returns:
89 | GaussianRasterizationSettings: Configured settings for Gaussian rasterization.
90 | """
91 | fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1,
92 | 1], intrinsics[0, 2], intrinsics[1, 2]
93 | w2c = torch.tensor(w2c).cuda().float()
94 | cam_center = torch.inverse(w2c)[:3, 3]
95 | viewmatrix = w2c.transpose(0, 1)
96 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0],
97 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0],
98 | [0.0, 0.0, far /
99 | (far - near), -(far * near) / (far - near)],
100 | [0.0, 0.0, 1.0, 0.0]], device='cuda').float().transpose(0, 1)
101 | full_proj_matrix = viewmatrix.unsqueeze(
102 | 0).bmm(opengl_proj.unsqueeze(0)).squeeze(0)
103 | return GaussianRasterizationSettings(
104 | image_height=h,
105 | image_width=w,
106 | tanfovx=w / (2 * fx),
107 | tanfovy=h / (2 * fy),
108 | bg=torch.tensor([0, 0, 0], device='cuda').float(),
109 | scale_modifier=1.0,
110 | viewmatrix=viewmatrix,
111 | projmatrix=full_proj_matrix,
112 | sh_degree=sh_degree,
113 | campos=cam_center,
114 | prefiltered=False,
115 | debug=False)
116 |
117 |
118 | def render_gaussian_model(gaussian_model, render_settings,
119 | override_means_3d=None, override_means_2d=None,
120 | override_scales=None, override_rotations=None,
121 | override_opacities=None, override_colors=None):
122 | """
123 | Renders a Gaussian model with specified rendering settings, allowing for
124 | optional overrides of various model parameters.
125 |
126 | Args:
127 | gaussian_model: A Gaussian model object that provides methods to get
128 | various properties like xyz coordinates, opacity, features, etc.
129 | render_settings: Configuration settings for the GaussianRasterizer.
130 | override_means_3d (Optional): If provided, these values will override
131 | the 3D mean values from the Gaussian model.
132 | override_means_2d (Optional): If provided, these values will override
133 | the 2D mean values. Defaults to zeros if not provided.
134 | override_scales (Optional): If provided, these values will override the
135 | scale values from the Gaussian model.
136 | override_rotations (Optional): If provided, these values will override
137 | the rotation values from the Gaussian model.
138 | override_opacities (Optional): If provided, these values will override
139 | the opacity values from the Gaussian model.
140 | override_colors (Optional): If provided, these values will override the
141 | color values from the Gaussian model.
142 | Returns:
143 | A dictionary containing the rendered color, depth, radii, and 2D means
144 | of the Gaussian model. The keys of this dictionary are 'color', 'depth',
145 | 'radii', and 'means2D', each mapping to their respective rendered values.
146 | """
147 | renderer = GaussianRasterizer(raster_settings=render_settings)
148 |
149 | if override_means_3d is None:
150 | means3D = gaussian_model.get_xyz()
151 | else:
152 | means3D = override_means_3d
153 |
154 | if override_means_2d is None:
155 | means2D = torch.zeros_like(
156 | means3D, dtype=means3D.dtype, requires_grad=True, device="cuda")
157 | means2D.retain_grad()
158 | else:
159 | means2D = override_means_2d
160 |
161 | if override_opacities is None:
162 | opacities = gaussian_model.get_opacity()
163 | else:
164 | opacities = override_opacities
165 |
166 | shs, colors_precomp = None, None
167 | if override_colors is not None:
168 | colors_precomp = override_colors
169 | else:
170 | shs = gaussian_model.get_features()
171 |
172 | render_args = {
173 | "means3D": means3D,
174 | "means2D": means2D,
175 | "opacities": opacities,
176 | "colors_precomp": colors_precomp,
177 | "shs": shs,
178 | "scales": gaussian_model.get_scaling() if override_scales is None else override_scales,
179 | "rotations": gaussian_model.get_rotation() if override_rotations is None else override_rotations,
180 | "cov3D_precomp": None
181 | }
182 | color, depth, alpha, radii = renderer(**render_args)
183 |
184 | return {"color": color, "depth": depth, "radii": radii, "means2D": means2D, "alpha": alpha}
185 |
186 |
187 | def batch_search_faiss(indexer, query_points, k):
188 | """
189 | Perform a batch search on a IndexIVFFlat indexer to circumvent the search size limit of 65535.
190 |
191 | Args:
192 | indexer: The FAISS indexer object.
193 | query_points: A tensor of query points.
194 | k (int): The number of nearest neighbors to find.
195 |
196 | Returns:
197 | distances (torch.Tensor): The distances of the nearest neighbors.
198 | ids (torch.Tensor): The indices of the nearest neighbors.
199 | """
200 | split_pos = torch.split(query_points, 65535, dim=0)
201 | distances_list, ids_list = [], []
202 |
203 | for split_p in split_pos:
204 | distance, id = indexer.search(split_p.float(), k)
205 | distances_list.append(distance.clone())
206 | ids_list.append(id.clone())
207 | distances = torch.cat(distances_list, dim=0)
208 | ids = torch.cat(ids_list, dim=0)
209 |
210 | return distances, ids
211 |
212 |
213 | def filter_depth_outliers(depth_map, kernel_size=3, threshold=1.0):
214 | median_filtered = median_filter(depth_map, size=kernel_size)
215 | abs_diff = np.abs(depth_map - median_filtered)
216 | outlier_mask = abs_diff > threshold
217 | depth_map_filtered = np.where(outlier_mask, median_filtered, depth_map)
218 | return depth_map_filtered
219 |
--------------------------------------------------------------------------------
/src/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from copy import deepcopy
3 | from typing import List, Union
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | from matplotlib import colors
8 |
9 | COLORS_ANSI = OrderedDict({
10 | "blue": "\033[94m",
11 | "orange": "\033[93m",
12 | "green": "\033[92m",
13 | "red": "\033[91m",
14 | "purple": "\033[95m",
15 | "brown": "\033[93m", # No exact match, using yellow
16 | "pink": "\033[95m",
17 | "gray": "\033[90m",
18 | "olive": "\033[93m", # No exact match, using yellow
19 | "cyan": "\033[96m",
20 | "end": "\033[0m", # Reset color
21 | })
22 |
23 |
24 | COLORS_MATPLOTLIB = OrderedDict({
25 | 'blue': '#1f77b4',
26 | 'orange': '#ff7f0e',
27 | 'green': '#2ca02c',
28 | 'red': '#d62728',
29 | 'purple': '#9467bd',
30 | 'brown': '#8c564b',
31 | 'pink': '#e377c2',
32 | 'gray': '#7f7f7f',
33 | 'yellow-green': '#bcbd22',
34 | 'cyan': '#17becf'
35 | })
36 |
37 |
38 | COLORS_MATPLOTLIB_RGB = OrderedDict({
39 | 'blue': np.array([31, 119, 180]) / 255.0,
40 | 'orange': np.array([255, 127, 14]) / 255.0,
41 | 'green': np.array([44, 160, 44]) / 255.0,
42 | 'red': np.array([214, 39, 40]) / 255.0,
43 | 'purple': np.array([148, 103, 189]) / 255.0,
44 | 'brown': np.array([140, 86, 75]) / 255.0,
45 | 'pink': np.array([227, 119, 194]) / 255.0,
46 | 'gray': np.array([127, 127, 127]) / 255.0,
47 | 'yellow-green': np.array([188, 189, 34]) / 255.0,
48 | 'cyan': np.array([23, 190, 207]) / 255.0
49 | })
50 |
51 |
52 | def get_color(color_name: str):
53 | """ Returns the RGB values of a given color name as a normalized numpy array.
54 | Args:
55 | color_name: The name of the color. Can be any color name from CSS4_COLORS.
56 | Returns:
57 | A numpy array representing the RGB values of the specified color, normalized to the range [0, 1].
58 | """
59 | if color_name == "custom_yellow":
60 | return np.asarray([255.0, 204.0, 102.0]) / 255.0
61 | if color_name == "custom_blue":
62 | return np.asarray([102.0, 153.0, 255.0]) / 255.0
63 | assert color_name in colors.CSS4_COLORS
64 | return np.asarray(colors.to_rgb(colors.CSS4_COLORS[color_name]))
65 |
66 |
67 | def plot_ptcloud(point_clouds: Union[List, o3d.geometry.PointCloud], show_frame: bool = True):
68 | """ Visualizes one or more point clouds, optionally showing the coordinate frame.
69 | Args:
70 | point_clouds: A single point cloud or a list of point clouds to be visualized.
71 | show_frame: If True, displays the coordinate frame in the visualization. Defaults to True.
72 | """
73 | # rotate down up
74 | if not isinstance(point_clouds, list):
75 | point_clouds = [point_clouds]
76 | if show_frame:
77 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0])
78 | point_clouds = point_clouds + [mesh_frame]
79 | o3d.visualization.draw_geometries(point_clouds)
80 |
81 |
82 | def draw_registration_result_original_color(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud,
83 | transformation: np.ndarray):
84 | """ Visualizes the result of a point cloud registration, keeping the original color of the source point cloud.
85 | Args:
86 | source: The source point cloud.
87 | target: The target point cloud.
88 | transformation: The transformation matrix applied to the source point cloud.
89 | """
90 | source_temp = deepcopy(source)
91 | source_temp.transform(transformation)
92 | o3d.visualization.draw_geometries([source_temp, target])
93 |
94 |
95 | def draw_registration_result(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud,
96 | transformation: np.ndarray, source_color: str = "blue", target_color: str = "orange"):
97 | """ Visualizes the result of a point cloud registration, coloring the source and target point clouds.
98 | Args:
99 | source: The source point cloud.
100 | target: The target point cloud.
101 | transformation: The transformation matrix applied to the source point cloud.
102 | source_color: The color to apply to the source point cloud. Defaults to "blue".
103 | target_color: The color to apply to the target point cloud. Defaults to "orange".
104 | """
105 | source_temp = deepcopy(source)
106 | source_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[source_color])
107 |
108 | target_temp = deepcopy(target)
109 | target_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[target_color])
110 |
111 | source_temp.transform(transformation)
112 | o3d.visualization.draw_geometries([source_temp, target_temp])
113 |
--------------------------------------------------------------------------------