├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── architecture.png └── loopsplat.gif ├── configs ├── Replica │ ├── office0.yaml │ ├── office1.yaml │ ├── office2.yaml │ ├── office3.yaml │ ├── office4.yaml │ ├── replica.yaml │ ├── room0.yaml │ ├── room1.yaml │ └── room2.yaml ├── ScanNet │ ├── scannet.yaml │ ├── scene0000_00.yaml │ ├── scene0054_00.yaml │ ├── scene0059_00.yaml │ ├── scene0106_00.yaml │ ├── scene0169_00.yaml │ ├── scene0181_00.yaml │ ├── scene0207_00.yaml │ ├── scene0233_00.yaml │ └── scene0465_00.yaml ├── TUM_RGBD │ ├── rgbd_dataset_freiburg1_desk.yaml │ ├── rgbd_dataset_freiburg1_desk2.yaml │ ├── rgbd_dataset_freiburg1_room.yaml │ ├── rgbd_dataset_freiburg2_xyz.yaml │ ├── rgbd_dataset_freiburg3_long_office_household.yaml │ └── tum_rgbd.yaml └── scannetpp │ ├── 281bc17764.yaml │ ├── 2e74812d00.yaml │ ├── 8b5caf3398.yaml │ ├── b20a261fdf.yaml │ ├── fb05e13ad1.yaml │ └── scannetpp.yaml ├── environment.yml ├── requirements.txt ├── run_evaluation.py ├── run_slam.py ├── scripts ├── download_replica.sh ├── download_tum.sh ├── reproduce_sbatch.sh └── scannet_preprocess.ipynb └── src ├── entities ├── __init__.py ├── arguments.py ├── datasets.py ├── gaussian_model.py ├── gaussian_slam.py ├── lc.py ├── logger.py ├── losses.py ├── mapper.py ├── tracker.py └── visual_odometer.py ├── evaluation ├── __init__.py ├── evaluate_merged_map.py ├── evaluate_reconstruction.py ├── evaluate_trajectory.py └── evaluator.py ├── gsr ├── camera.py ├── descriptor.py ├── loss.py ├── overlap.py ├── pcr.py ├── renderer.py ├── se3 │ ├── numpy_se3.py │ └── torch_se3.py ├── solver.py └── utils.py └── utils ├── __init__.py ├── eval_utils.py ├── gaussian_model_utils.py ├── graphics_utils.py ├── io_utils.py ├── mapper_utils.py ├── pose_utils.py ├── tracker_utils.py ├── utils.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | debug 10 | wandb 11 | data 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/Hierarchical-Localization"] 2 | path = thirdparty/Hierarchical-Localization 3 | url = https://github.com/cvg/Hierarchical-Localization.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

LoopSplat: Loop Closure by Registering 3D Gaussian Splats

4 |

5 | Liyuan Zhu1 6 | · 7 | Yue Li2 8 | · 9 | Erik Sandström3 10 | · 11 | Shengyu Huang3 12 | · 13 | Konrad Schindler3 14 | · 15 | Iro Armeni1 16 |

17 | 18 |

International Conference on 3D Vision (3DV) 2025 19 |

20 | 1Stanford University · 2University of Amsterdam · 3ETH Zurich 21 |

22 |

23 | 24 | [![arXiv](https://img.shields.io/badge/arXiv-2408.10154-blue?logo=arxiv&color=%23B31B1B)](https://arxiv.org/abs/2408.10154) [![ProjectPage](https://img.shields.io/badge/Project_Page-LoopSplat-blue)](https://loopsplat.github.io/) [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 25 |
26 |

27 | 28 | 29 | 30 |

31 | 32 | 33 | 34 |

35 | 36 | ## 📃 Description 37 |

38 | 39 | 40 | 41 |

42 | 43 | **LoopSplat** is a coupled RGB-D SLAM system that uses Gaussian splats as a unified scene representation for tracking, mapping, and maintaining global consistency. In the front-end, it continuously estimates the camera position while constructing the scene using Gaussian splats submaps. When the camera traverses beyond a predefined threshold, the current submap is finalized, and a new one is initiated. Concurrently, the back-end loop closure module monitors for location revisits. Upon detecting a loop, the system generates a pose graph, incorporating loop edge constraints derived from our proposed 3DGS registration. Subsequently, pose graph optimization (PGO) is executed to refine both camera poses and submaps, ensuring overall spatial coherence. 44 | 45 | # 🛠️ Setup 46 | The code has been tested on: 47 | 48 | - Ubuntu 22.04 LTS, Python 3.10.14, CUDA 12.2, GeForce RTX 4090/RTX 3090 49 | - CentOS Linux 7, Python 3.12.1, CUDA 12.4, A100/A6000 50 | 51 | ## 📦 Repository 52 | 53 | Clone the repo with `--recursive` because we have submodules: 54 | 55 | ``` 56 | git clone --recursive git@github.com:GradientSpaces/LoopSplat.git 57 | cd LoopSplat 58 | ``` 59 | 60 | ## 💻 Installation 61 | Make sure that gcc and g++ paths on your system are exported: 62 | 63 | ``` 64 | export CC= 65 | export CXX= 66 | ``` 67 | 68 | To find the gcc path and g++ path on your machine you can use which gcc. 69 | 70 | 71 | Then setup environment from the provided conda environment file, 72 | 73 | ``` 74 | conda create -n loop_splat -c nvidia/label/cuda-12.1.0 cuda=12.1 cuda-toolkit=12.1 cuda-nvcc=12.1 75 | conda env update --file environment.yml --prune 76 | conda activate loop_splat 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | You will also need to install hloc for loop detection and 3DGS registration. 81 | ``` 82 | cd thirdparty/Hierarchical-Localization 83 | python -m pip install -e . 84 | cd ../.. 85 | ``` 86 | 87 | We tested our code on RTX4090 and RTX A6000 GPUs respectively and Ubuntu22 and CentOS7.5. 88 | 89 | ## 🚀 Usage 90 | 91 | Here we elaborate on how to load the necessary data, configure Gaussian-SLAM for your use-case, debug it, and how to reproduce the results mentioned in the paper. 92 | 93 | 95 | ### Downloading the Datasets 96 | We tested our code on Replica, TUM_RGBD, ScanNet, and ScanNet++ datasets. We also provide scripts for downloading Replica and TUM_RGBD in `scripts` folder. Install git lfs before using the scripts by running ```git lfs install```. 97 | 98 | For reconstruction evaluation on Replica, we follow [Co-SLAM](https://github.com/JingwenWang95/neural_slam_eval?tab=readme-ov-file#datasets) mesh culling protocal, please use their code to process the mesh first. 99 | 100 | For downloading ScanNet, follow the procedure described on here.
101 | Pay attention! There are some frames in ScanNet with `inf` poses, we filter them out using the jupyter notebook `scripts/scannet_preprocess.ipynb`. Please change the path to your ScanNet data and run the cells. 102 | 103 | For downloading ScanNet++, follow the procedure described on here.
104 | 105 | The config files are named after the sequences that we used for our method. 106 | 107 | 108 | ### Running the code 109 | Start the system with the command: 110 | 111 | ``` 112 | python run_slam.py configs// --input_path --output_path 113 | ``` 114 | 115 | You can also configure input and output paths in the config yaml file. 116 | 117 | 118 | ### Reproducing Results 119 | 120 | You can reproduce the results for a single scene by running: 121 | 122 | ``` 123 | python run_slam.py configs// --input_path --output_path 124 | ``` 125 | 126 | If you are running on a SLURM cluster, you can reproduce the results for all scenes in a dataset by running the script: 127 | ``` 128 | ./scripts/reproduce_sbatch.sh 129 | ``` 130 | Please note the evaluation of ```depth_L1``` metric requires reconstruction of the mesh, which in turns requires headless installation of open3d if you are running on a cluster. 131 | 132 | ## 📧 Contact 133 | If you have any questions regarding this project, please contact Liyuan Zhu (liyzhu@stanford.edu). If you want to use our intermediate results for qualitative comparisons, please reach out to the same email. 134 | 135 | # ✏️ Acknowledgement 136 | Our implementation is heavily based on Gaussian-SLAM and MonoGS. We thank the authors for their open-source contributions. If you use the code that is based on their contribution, please cite them as well. We thank [Jianhao Zheng](https://jianhao-zheng.github.io/) for the help with datasets and [Yue Pan](https://github.com/YuePanEdward) for the fruitful discussion.
137 | 138 | # 🎓 Citation 139 | 140 | If you find our paper and code useful, please cite us: 141 | 142 | ```bib 143 | @inproceedings{zhu2025_loopsplat, 144 | title={LoopSplat: Loop Closure by Registering 3D Gaussian Splats}, 145 | author={Liyuan Zhu and Yue Li and Erik Sandström and Shengyu Huang and Konrad Schindler and Iro Armeni}, 146 | year={2025}, 147 | booktitle = {International Conference on 3D Vision (3DV)}, 148 | } 149 | -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/assets/architecture.png -------------------------------------------------------------------------------- /assets/loopsplat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/assets/loopsplat.gif -------------------------------------------------------------------------------- /configs/Replica/office0.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office0 4 | input_path: data/Replica-SLAM/Replica/office0/ 5 | output_path: output/Replica/office0/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office1 4 | input_path: data/Replica-SLAM/Replica/office1/ 5 | output_path: output/Replica/office1/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office2 4 | input_path: data/Replica-SLAM/Replica/office2/ 5 | output_path: output/Replica/office2/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office3.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office3 4 | input_path: data/Replica-SLAM/Replica/office3/ 5 | output_path: output/Replica/office3/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office4.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office4 4 | input_path: data/Replica-SLAM/Replica/office4/ 5 | output_path: output/Replica/office4/ 6 | -------------------------------------------------------------------------------- /configs/Replica/replica.yaml: -------------------------------------------------------------------------------- 1 | project_name: "LoopSplat_SLAM_replica" 2 | dataset_name: "replica" 3 | checkpoint_path: null 4 | use_wandb: True 5 | frame_limit: -1 # for debugging, set to -1 to disable 6 | seed: 1 7 | mapping: 8 | new_submap_every: 50 9 | map_every: 5 10 | iterations: 100 11 | new_submap_iterations: 1000 12 | new_submap_points_num: 600000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: -1 15 | new_points_radius: 0.0000001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.1 19 | submap_using_motion_heuristic: True 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.95 23 | iterations: 60 24 | cam_rot_lr: 0.0002 25 | cam_trans_lr: 0.002 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: False # temp option to help const_init 28 | init_err_ratio: 5 29 | odometer_method: "point_to_plane" # hybrid or point_to_plane 30 | filter_alpha: False 31 | filter_outlier_depth: True 32 | alpha_thre: 0.98 33 | soft_alpha: True 34 | mask_invalid_depth: False 35 | enable_exposure: False 36 | cam: 37 | H: 680 38 | W: 1200 39 | fx: 600.0 40 | fy: 600.0 41 | cx: 599.5 42 | cy: 339.5 43 | depth_scale: 6553.5 44 | lc: 45 | min_similarity: 0.5 46 | pgo_edge_prune_thres: 0.25 47 | voxel_size: 0.02 48 | pgo_max_iterations: 500 49 | registration: 50 | method: "gs_reg" 51 | base_lr: 1.e-3 52 | min_overlap_ratio: 0.1 53 | use_render: False 54 | min_interval: 2 55 | final: False -------------------------------------------------------------------------------- /configs/Replica/room0.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: room0 4 | input_path: data/Replica-SLAM/room0/ 5 | output_path: output/Replica/room0/ 6 | -------------------------------------------------------------------------------- /configs/Replica/room1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: room1 4 | input_path: data/Replica-SLAM/Replica/room1/ 5 | output_path: output/Replica/room1/ 6 | -------------------------------------------------------------------------------- /configs/Replica/room2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: room2 4 | input_path: data/Replica-SLAM/Replica/room2/ 5 | output_path: output/Replica/room2/ 6 | -------------------------------------------------------------------------------- /configs/ScanNet/scannet.yaml: -------------------------------------------------------------------------------- 1 | project_name: "LoopSplat_SLAM_scannet" 2 | dataset_name: "scan_net" 3 | checkpoint_path: null 4 | use_wandb: True 5 | frame_limit: -1 # for debugging, set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 50 9 | map_every: 1 10 | iterations: 100 11 | new_submap_iterations: 100 12 | new_submap_points_num: 100000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: 30000 15 | new_points_radius: 0.0001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.5 19 | submap_using_motion_heuristic: False 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.6 23 | iterations: 200 24 | cam_rot_lr: 0.002 25 | cam_trans_lr: 0.01 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: False # temp option to help const_init 28 | init_err_ratio: 5 29 | odometer_method: "hybrid" # hybrid or point_to_plane 30 | filter_alpha: True 31 | filter_outlier_depth: True 32 | alpha_thre: 0.98 33 | soft_alpha: True 34 | mask_invalid_depth: True 35 | enable_exposure: True 36 | cam: 37 | H: 480 38 | W: 640 39 | fx: 577.590698 40 | fy: 578.729797 41 | cx: 318.905426 42 | cy: 242.683609 43 | depth_scale: 1. 44 | crop_edge: 12 45 | lc: 46 | min_similarity: 0.5 47 | pgo_edge_prune_thres: 0.25 48 | voxel_size: 0.02 49 | pgo_max_iterations: 500 50 | registration: 51 | method: "gs_reg" 52 | base_lr: 5.e-3 53 | min_overlap_ratio: 0.2 54 | use_render: True # use rendered image as target for registration 55 | min_interval: 4 56 | final: False 57 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0000_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0000_00 4 | output_path: output/ScanNet/scene0000 5 | scene_name: scene0000_00 6 | cam: 7 | H: 480 8 | W: 640 9 | fx: 577.6 10 | fy: 578.7 11 | cx: 318.9 12 | cy: 242.7 13 | depth_scale: 1. 14 | crop_edge: 12 -------------------------------------------------------------------------------- /configs/ScanNet/scene0054_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0054_00 4 | output_path: output/ScanNet/scene0054 5 | scene_name: scene0054_00 6 | cam: 7 | H: 480 8 | W: 640 9 | fx: 578.0 10 | fy: 578.0 11 | cx: 319.5 12 | cy: 239.5 13 | depth_scale: 1. 14 | crop_edge: 12 -------------------------------------------------------------------------------- /configs/ScanNet/scene0059_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0059_00 4 | output_path: output/ScanNet/scene0059 5 | scene_name: scene0059_00 6 | cam: 7 | H: 480 8 | W: 640 9 | fx: 577.6 10 | fy: 578.7 11 | cx: 318.9 12 | cy: 242.7 13 | depth_scale: 1. 14 | crop_edge: 12 -------------------------------------------------------------------------------- /configs/ScanNet/scene0106_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0106_00 4 | output_path: output/ScanNet/scene0106 5 | scene_name: scene0106_00 6 | cam: 7 | H: 480 8 | W: 640 9 | fx: 577.6 10 | fy: 578.7 11 | cx: 318.9 12 | cy: 242.7 13 | depth_scale: 1. 14 | crop_edge: 12 -------------------------------------------------------------------------------- /configs/ScanNet/scene0169_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0169_00 4 | output_path: output/ScanNet/scene0169 5 | scene_name: scene0169_00 6 | cam: 7 | H: 480 8 | W: 640 9 | fx: 574.5 10 | fy: 577.6 11 | cx: 322.5 12 | cy: 238.6 13 | depth_scale: 1. 14 | crop_edge: 12 -------------------------------------------------------------------------------- /configs/ScanNet/scene0181_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0181_00 4 | output_path: output/ScanNet/scene0181 5 | scene_name: scene0181_00 6 | cam: 7 | fx: 575.547668 8 | fy: 577.459778 9 | cx: 323.171967 10 | cy: 236.417465 11 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0207_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0207_00 4 | output_path: output/ScanNet/scene0207 5 | scene_name: scene0207_00 6 | cam: 7 | fx: 577.6 8 | fy: 578.7 9 | cx: 318.9 10 | cy: 242.7 -------------------------------------------------------------------------------- /configs/ScanNet/scene0233_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0233_00 4 | output_path: output/ScanNet/scene0233 5 | scene_name: scene0233_00 6 | cam: 7 | fx: 577.9 8 | fy: 577.9 9 | cx: 319.5 10 | cy: 239.5 -------------------------------------------------------------------------------- /configs/ScanNet/scene0465_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: /home/liyuanzhu/projects/GSR/MonoGS/datasets/scannet/scene0465_00 4 | output_path: output/ScanNet/scene0465 5 | scene_name: scene0465_00 6 | cam: 7 | fx: 577.9 8 | fy: 577.9 9 | cx: 319.5 10 | cy: 239.5 -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk/ 5 | scene_name: rgbd_dataset_freiburg1_desk 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 517.3 10 | fy: 516.5 11 | cx: 318.6 12 | cy: 255.3 13 | crop_edge: 50 14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg1_desk2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk2 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk2/ 5 | scene_name: rgbd_dataset_freiburg1_desk2 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 517.3 10 | fy: 516.5 11 | cx: 318.6 12 | cy: 255.3 13 | crop_edge: 50 14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg1_room.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_room 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_room/ 5 | scene_name: rgbd_dataset_freiburg1_room 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 517.3 10 | fy: 516.5 11 | cx: 318.6 12 | cy: 255.3 13 | crop_edge: 50 14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg2_xyz 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg2_xyz/ 5 | scene_name: rgbd_dataset_freiburg2_xyz 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 520.9 10 | fy: 521.0 11 | cx: 325.1 12 | cy: 249.7 13 | crop_edge: 8 14 | distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172] -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg3_long_office_household/ 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household/ 5 | scene_name: rgbd_dataset_freiburg3_long_office_household 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 517.3 10 | fy: 516.5 11 | cx: 318.6 12 | cy: 255.3 13 | crop_edge: 50 14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/TUM_RGBD/tum_rgbd.yaml: -------------------------------------------------------------------------------- 1 | project_name: "LoopSplat_SLAM_tumrgbd" 2 | dataset_name: "tum_rgbd" 3 | checkpoint_path: null 4 | use_wandb: True 5 | frame_limit: -1 # for debugging, set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 50 9 | map_every: 1 10 | iterations: 100 11 | new_submap_iterations: 100 12 | new_submap_points_num: 100000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: 30000 15 | new_points_radius: 0.0001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.5 19 | submap_using_motion_heuristic: True 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.6 23 | iterations: 200 24 | cam_rot_lr: 0.002 25 | cam_trans_lr: 0.01 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: False # temp option to help const_init 28 | init_err_ratio: 5 29 | odometer_method: "hybrid" # hybrid or point_to_plane 30 | filter_alpha: False 31 | filter_outlier_depth: False 32 | alpha_thre: 0.98 33 | soft_alpha: True 34 | mask_invalid_depth: True 35 | enable_exposure: False 36 | cam: 37 | crop_edge: 16 38 | depth_scale: 5000.0 39 | lc: 40 | min_similarity: 0.5 41 | pgo_edge_prune_thres: 0.25 42 | voxel_size: 0.02 43 | pgo_max_iterations: 500 44 | registration: 45 | method: "gs_reg" 46 | base_lr: 5.e-3 47 | min_overlap_ratio: 0.2 48 | use_render: False 49 | min_interval: 3 50 | final: False -------------------------------------------------------------------------------- /configs/scannetpp/281bc17764.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/281bc17764 4 | output_path: output/ScanNetPP/281bc17764 5 | scene_name: "281bc17764" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 312.79197434640764 12 | fy: 313.48022477591036 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/2e74812d00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/2e74812d00 4 | output_path: output/ScanNetPP/2e74812d00 5 | scene_name: "2e74812d00" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 312.0984049779051 12 | fy: 312.4823067146056 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/8b5caf3398.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/8b5caf3398 4 | output_path: output/ScanNetPP/8b5caf3398 5 | scene_name: "8b5caf3398" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 316.3837659917395 12 | fy: 319.18649362678593 13 | cx: 438.0 14 | cy: 292.0 15 | -------------------------------------------------------------------------------- /configs/scannetpp/b20a261fdf.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/b20a261fdf 4 | output_path: output/ScanNetPP/b20a261fdf 5 | scene_name: "b20a261fdf" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 312.7099188244687 12 | fy: 313.5121746848229 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/fb05e13ad1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/fb05e13ad1 4 | output_path: output/ScanNetPP/fb05e13ad1 5 | scene_name: "fb05e13ad1" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 231.8197441948914 12 | fy: 231.9980523882361 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/scannetpp.yaml: -------------------------------------------------------------------------------- 1 | project_name: "LoopSplat_SLAM_scannetpp" 2 | dataset_name: "scannetpp" 3 | checkpoint_path: null 4 | use_wandb: True 5 | frame_limit: -1 # set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 100 9 | map_every: 2 10 | iterations: 500 11 | new_submap_iterations: 500 12 | new_submap_points_num: 400000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: 100000 15 | new_points_radius: 0.00000001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.5 19 | submap_using_motion_heuristic: False 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.5 23 | iterations: 300 24 | cam_rot_lr: 0.002 25 | cam_trans_lr: 0.01 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: True 28 | init_err_ratio: 50 29 | odometer_method: "point_to_plane" # hybrid or point_to_plane 30 | filter_alpha: True 31 | filter_outlier_depth: True 32 | alpha_thre: 0.98 33 | soft_alpha: False 34 | mask_invalid_depth: True 35 | enable_exposure: False 36 | cam: 37 | crop_edge: 0 38 | depth_scale: 1000.0 39 | lc: 40 | min_similarity: 0.34 41 | pgo_edge_prune_thres: 0.25 42 | voxel_size: 0.02 43 | pgo_max_iterations: 500 44 | registration: 45 | method: "gs_reg" 46 | base_lr: 5.e-3 47 | min_overlap_ratio: 0.2 48 | min_interval: 0 49 | final: False -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: loop_splat 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - nvidia/label/cuda-12.1.0 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - python=3.10 10 | - faiss-gpu=1.8.0 11 | - cuda-toolkit=12.1 12 | - pytorch=2.1.2 13 | - pytorch-cuda=12.1 14 | - torchvision=0.16.2 15 | - pip 16 | - pip: 17 | - open3d==0.18.0 18 | - wandb 19 | - trimesh==4.0.10 20 | - pytorch_msssim 21 | - torchmetrics 22 | - tqdm 23 | - imageio 24 | - opencv-python 25 | - plyfile 26 | - roma 27 | - einops==0.8.0 28 | - numpy==1.26.4 29 | - PyQt5==5.15.11 30 | - matplotlib==3.5.1 31 | - evo==1.11.0 32 | - python-pycg 33 | - einops 34 | - git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7 35 | - git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d 36 | - git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129 37 | - git+https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose.git@43e21bff91cd24986ee3dd52fe0bb06952e50ec7 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open3d==0.18.0 2 | wandb 3 | trimesh==4.0.10 4 | pytorch_msssim 5 | torchmetrics 6 | tqdm 7 | imageio 8 | opencv-python 9 | plyfile 10 | roma 11 | einops==0.8.0 12 | numpy==1.26.4 13 | PyQt5==5.15.11 14 | matplotlib==3.5.1 15 | evo==1.11.0 16 | python-pycg 17 | einops 18 | git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d 19 | git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7 20 | git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129 21 | git+https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose.git@43e21bff91cd24986ee3dd52fe0bb06952e50ec7 -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from src.evaluation.evaluator import Evaluator 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(description='Arguments to compute the mesh') 8 | parser.add_argument('--checkpoint_path', type=str, help='SLAM checkpoint path', default="output/slam/full_experiment/") 9 | parser.add_argument('--config_path', type=str, help='Config path', default="") 10 | return parser.parse_args() 11 | 12 | 13 | if __name__ == "__main__": 14 | args = get_args() 15 | if args.config_path == "": 16 | args.config_path = Path(args.checkpoint_path) / "config.yaml" 17 | 18 | evaluator = Evaluator(Path(args.checkpoint_path), Path(args.config_path)) 19 | evaluator.run() 20 | -------------------------------------------------------------------------------- /run_slam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import uuid 5 | 6 | import wandb 7 | 8 | from src.entities.gaussian_slam import GaussianSLAM 9 | from src.evaluation.evaluator import Evaluator 10 | from src.utils.io_utils import load_config 11 | from src.utils.utils import setup_seed 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Arguments to compute the mesh') 17 | parser.add_argument('config_path', type=str, 18 | help='Path to the configuration yaml file') 19 | parser.add_argument('--input_path', default="") 20 | parser.add_argument('--output_path', default="") 21 | parser.add_argument('--track_w_color_loss', type=float) 22 | parser.add_argument('--track_alpha_thre', type=float) 23 | parser.add_argument('--track_iters', type=int) 24 | parser.add_argument('--track_filter_alpha', action='store_true') 25 | parser.add_argument('--track_filter_outlier', action='store_true') 26 | parser.add_argument('--track_wo_filter_alpha', action='store_true') 27 | parser.add_argument("--track_wo_filter_outlier", action="store_true") 28 | parser.add_argument("--track_cam_trans_lr", type=float) 29 | parser.add_argument('--alpha_seeding_thre', type=float) 30 | parser.add_argument('--map_every', type=int) 31 | parser.add_argument("--map_iters", type=int) 32 | parser.add_argument('--new_submap_every', type=int) 33 | parser.add_argument('--project_name', type=str) 34 | parser.add_argument('--group_name', type=str) 35 | parser.add_argument('--gt_camera', action='store_true') 36 | parser.add_argument('--help_camera_initialization', action='store_true') 37 | parser.add_argument('--soft_alpha', action='store_true') 38 | parser.add_argument('--seed', type=int) 39 | parser.add_argument('--submap_using_motion_heuristic', action='store_true') 40 | parser.add_argument('--new_submap_points_num', type=int) 41 | return parser.parse_args() 42 | 43 | 44 | def update_config_with_args(config, args): 45 | if args.input_path: 46 | config["data"]["input_path"] = args.input_path 47 | if args.output_path: 48 | config["data"]["output_path"] = args.output_path 49 | if args.track_w_color_loss is not None: 50 | config["tracking"]["w_color_loss"] = args.track_w_color_loss 51 | if args.track_iters is not None: 52 | config["tracking"]["iterations"] = args.track_iters 53 | if args.track_filter_alpha: 54 | config["tracking"]["filter_alpha"] = True 55 | if args.track_wo_filter_alpha: 56 | config["tracking"]["filter_alpha"] = False 57 | if args.track_filter_outlier: 58 | config["tracking"]["filter_outlier_depth"] = True 59 | if args.track_wo_filter_outlier: 60 | config["tracking"]["filter_outlier_depth"] = False 61 | if args.track_alpha_thre is not None: 62 | config["tracking"]["alpha_thre"] = args.track_alpha_thre 63 | if args.map_every: 64 | config["mapping"]["map_every"] = args.map_every 65 | if args.map_iters: 66 | config["mapping"]["iterations"] = args.map_iters 67 | if args.new_submap_every: 68 | config["mapping"]["new_submap_every"] = args.new_submap_every 69 | if args.project_name: 70 | config["project_name"] = args.project_name 71 | if args.alpha_seeding_thre is not None: 72 | config["mapping"]["alpha_thre"] = args.alpha_seeding_thre 73 | if args.seed: 74 | config["seed"] = args.seed 75 | if args.help_camera_initialization: 76 | config["tracking"]["help_camera_initialization"] = True 77 | if args.soft_alpha: 78 | config["tracking"]["soft_alpha"] = True 79 | if args.submap_using_motion_heuristic: 80 | config["mapping"]["submap_using_motion_heuristic"] = True 81 | if args.new_submap_points_num: 82 | config["mapping"]["new_submap_points_num"] = args.new_submap_points_num 83 | if args.track_cam_trans_lr: 84 | config["tracking"]["cam_trans_lr"] = args.track_cam_trans_lr 85 | return config 86 | 87 | 88 | if __name__ == "__main__": 89 | args = get_args() 90 | config = load_config(args.config_path) 91 | config = update_config_with_args(config, args) 92 | 93 | if os.getenv('DISABLE_WANDB') == 'true': 94 | config["use_wandb"] = False 95 | if config["use_wandb"]: 96 | wandb.init( 97 | project=config["project_name"], 98 | config=config, 99 | dir="/home/yli3/scratch/outputs/slam/wandb", 100 | group=config["data"]["scene_name"] 101 | if not args.group_name 102 | else args.group_name, 103 | name=f'{config["data"]["scene_name"]}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}_{str(uuid.uuid4())[:5]}', 104 | ) 105 | wandb.run.log_code(".", include_fn=lambda path: path.endswith(".py")) 106 | 107 | setup_seed(config["seed"]) 108 | gslam = GaussianSLAM(config) 109 | gslam.run() 110 | 111 | evaluator = Evaluator(gslam.output_path, gslam.output_path / "config.yaml") 112 | evaluator.run() 113 | if config["use_wandb"]: 114 | wandb.finish() 115 | print("All done.✨") 116 | -------------------------------------------------------------------------------- /scripts/download_replica.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | git clone https://huggingface.co/datasets/voviktyl/Replica-SLAM -------------------------------------------------------------------------------- /scripts/download_tum.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | git clone https://huggingface.co/datasets/voviktyl/TUM_RGBD-SLAM -------------------------------------------------------------------------------- /scripts/reproduce_sbatch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=output/logs/%A_%a.log # please change accordingly 3 | #SBATCH --error=output/logs/%A_%a.log # please change accordingly 4 | #SBATCH -N 1 5 | #SBATCH -n 1 6 | #SBATCH --gpus-per-node=1 7 | #SBATCH --partition=gpu 8 | #SBATCH --cpus-per-task=12 9 | #SBATCH --time=24:00:00 10 | #SBATCH --array=0-4 # number of scenes, 0-7 for Replica, 0-2 for TUM_RGBD, 0-5 for ScanNet, 0-4 for ScanNet++ 11 | 12 | dataset="Replica" # set dataset 13 | if [ "$dataset" == "Replica" ]; then 14 | scenes=("room0" "room1" "room2" "office0" "office1" "office2" "office3" "office4") 15 | INPUT_PATH="data/Replica-SLAM" 16 | elif [ "$dataset" == "TUM_RGBD" ]; then 17 | scenes=("rgbd_dataset_freiburg1_desk" "rgbd_dataset_freiburg2_xyz" "rgbd_dataset_freiburg3_long_office_household") 18 | INPUT_PATH="data/TUM_RGBD-SLAM" 19 | elif [ "$dataset" == "ScanNet" ]; then 20 | scenes=("scene0000_00" "scene0059_00" "scene0106_00" "scene0169_00" "scene0181_00" "scene0207_00") 21 | INPUT_PATH="data/scannet/scans" 22 | elif [ "$dataset" == "ScanNetPP" ]; then 23 | scenes=("b20a261fdf" "8b5caf3398" "fb05e13ad1" "2e74812d00" "281bc17764") 24 | INPUT_PATH="data/scannetpp/data" 25 | else 26 | echo "Dataset not recognized!" 27 | exit 1 28 | fi 29 | 30 | OUTPUT_PATH="output" 31 | CONFIG_PATH="configs/${dataset}" 32 | EXPERIMENT_NAME="reproduce" 33 | SCENE_NAME=${scenes[$SLURM_ARRAY_TASK_ID]} 34 | 35 | source # please change accordingly 36 | conda activate gslam 37 | 38 | echo "Job for dataset: $dataset, scene: $SCENE_NAME" 39 | echo "Starting on: $(date)" 40 | echo "Running on node: $(hostname)" 41 | 42 | # Your command to run the experiment 43 | python run_slam.py "${CONFIG_PATH}/${SCENE_NAME}.yaml" \ 44 | --input_path "${INPUT_PATH}/${SCENE_NAME}" \ 45 | --output_path "${OUTPUT_PATH}/${dataset}/${EXPERIMENT_NAME}/${SCENE_NAME}" \ 46 | --group_name "${EXPERIMENT_NAME}" \ 47 | 48 | echo "Job for scene $SCENE_NAME completed." 49 | echo "Started at: $START_TIME" 50 | echo "Finished at: $(date)" -------------------------------------------------------------------------------- /scripts/scannet_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "8e3f10a3-fca1-4e2a-b39e-7cd8b44c724a", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", 14 | "[Open3D INFO] WebRTC GUI backend enabled.\n", 15 | "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import open3d as o3d\n", 22 | "import os\n", 23 | "import time\n", 24 | "import json\n", 25 | "import cv2\n", 26 | "from tqdm import tqdm\n", 27 | "import math\n", 28 | "from scipy.spatial.transform import Rotation as R\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import scipy\n", 31 | "import shutil\n", 32 | "\n", 33 | "from plyfile import PlyData, PlyElement\n", 34 | "import pandas as pd" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "id": "3db0619f-a9bb-4aea-890d-d8c8cb3f7e18", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def read_intrinsic(data_folder):\n", 45 | " with open(os.path.join(data_folder,'data/intrinsic/intrinsic_depth.txt'), 'r') as f:\n", 46 | " lines = f.readlines()\n", 47 | " intrinsic = np.zeros((4,4))\n", 48 | " for i, line in enumerate(lines):\n", 49 | " for j, content in enumerate(line.split(' ')):\n", 50 | " intrinsic[i][j] = float(content)\n", 51 | " return intrinsic" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "id": "9574fb19-fe0c-45ae-9e59-3125a17e3566", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stderr", 62 | "output_type": "stream", 63 | "text": [ 64 | "100%|███████████████████████████████████████| 5578/5578 [01:33<00:00, 59.58it/s]\n", 65 | "100%|███████████████████████████████████████| 1807/1807 [00:29<00:00, 61.95it/s]\n", 66 | "100%|███████████████████████████████████████| 2324/2324 [00:35<00:00, 65.18it/s]\n", 67 | "100%|███████████████████████████████████████| 2034/2034 [00:31<00:00, 63.85it/s]\n", 68 | "100%|███████████████████████████████████████| 2349/2349 [00:35<00:00, 65.97it/s]\n", 69 | "100%|███████████████████████████████████████| 1988/1988 [00:31<00:00, 63.08it/s]\n", 70 | "100%|███████████████████████████████████████| 7643/7643 [01:52<00:00, 68.11it/s]\n", 71 | "100%|███████████████████████████████████████| 6306/6306 [01:31<00:00, 69.10it/s]\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "fps_fake = 20 # fake camera frequency for offline ORB-SLAM\n", 77 | "raw_folder = \"./scans\"\n", 78 | "processed_folder = \"./processed\"\n", 79 | "\n", 80 | "scenes = ['scene0000_00', 'scene0054_00', 'scene0059_00', 'scene0106_00', 'scene0169_00', 'scene0181_00', 'scene0207_00', 'scene0233_00']\n", 81 | "\n", 82 | "for scene_idx, scene in enumerate(scenes):\n", 83 | " save_folder = os.path.join(processed_folder,scene)\n", 84 | " data_folder = os.path.join(raw_folder,scene)\n", 85 | " \n", 86 | " os.makedirs(save_folder)\n", 87 | " os.makedirs(os.path.join(save_folder,\"rgb\"))\n", 88 | " os.makedirs(os.path.join(save_folder,\"depth\"))\n", 89 | " \n", 90 | " shutil.copy(os.path.join(data_folder, 'data/intrinsic/intrinsic_depth.txt'), \n", 91 | " os.path.join(save_folder, 'intrinsic.txt'))\n", 92 | " \n", 93 | " with open(os.path.join(save_folder,'gt_pose.txt'), 'w') as f:\n", 94 | " f.write('# timestamp tx ty tz qx qy qz qw\\n')\n", 95 | " \n", 96 | " initial_time_stamp = time.time() \n", 97 | " \n", 98 | " color_folder = os.path.join(data_folder,\"data/color\")\n", 99 | " depth_folder = os.path.join(data_folder,\"data/depth\")\n", 100 | " pose_folder = os.path.join(data_folder,\"data/pose\")\n", 101 | " \n", 102 | " num_frames = len(os.listdir(color_folder))\n", 103 | " \n", 104 | " frame_idx = 0\n", 105 | " for raw_idx in tqdm(range(num_frames)):\n", 106 | " with open(os.path.join(pose_folder,\"{}.txt\".format(raw_idx)), \"r\") as f:\n", 107 | " lines = f.readlines()\n", 108 | " M_w_c = np.zeros((4,4))\n", 109 | " for i in range(4):\n", 110 | " content = lines[i].split(\" \")\n", 111 | " for j in range(4):\n", 112 | " M_w_c[i,j] = float(content[j])\n", 113 | " \n", 114 | " if \"inf\" in lines[0]:\n", 115 | " # invalid gt poses, skip this frame\n", 116 | " continue\n", 117 | "\n", 118 | " ######## convert depth to [m] and float type #########\n", 119 | " depth = cv2.imread(os.path.join(depth_folder,\"{}.png\".format(raw_idx)),cv2.IMREAD_UNCHANGED)\n", 120 | " depth = depth.astype(\"float32\")/1000.0\n", 121 | "\n", 122 | " ######## resize rgb to the same size of depth #########\n", 123 | " rgb = cv2.imread(os.path.join(color_folder,\"{}.jpg\".format(raw_idx)))\n", 124 | " rgb = cv2.resize(rgb,(depth.shape[1],depth.shape[0]),interpolation=cv2.INTER_CUBIC)\n", 125 | "\n", 126 | " cv2.imwrite(os.path.join(save_folder,\"rgb/frame_{}.png\".format(str(frame_idx).zfill(5))),rgb)\n", 127 | " cv2.imwrite(os.path.join(save_folder,\"depth/frame_{}.TIFF\".format(str(frame_idx).zfill(5))),depth)\n", 128 | "\n", 129 | " content = \"{:.4f}\".format(initial_time_stamp + frame_idx*1.0/fps_fake)\n", 130 | " for t in M_w_c[:3,3]:\n", 131 | " content += \" {:.9f}\".format(t)\n", 132 | " for q in R.from_matrix(M_w_c[:3,:3]).as_quat():\n", 133 | " content += \" {:.9f}\".format(q)\n", 134 | " \n", 135 | " with open(os.path.join(save_folder,'gt_pose.txt'), 'a') as f:\n", 136 | " f.write(content + '\\n')\n", 137 | " \n", 138 | " frame_idx += 1" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "1d1cc4cd-9868-4830-a849-d6e5702c5f20", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": "Python 3 (ipykernel)", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.11.4" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 5 171 | } 172 | -------------------------------------------------------------------------------- /src/entities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/src/entities/__init__.py -------------------------------------------------------------------------------- /src/entities/arguments.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from argparse import ArgumentParser, Namespace 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument( 34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true") 35 | else: 36 | group.add_argument( 37 | "--" + key, ("-" + key[0:1]), default=value, type=t) 38 | else: 39 | if t == bool: 40 | group.add_argument( 41 | "--" + key, default=value, action="store_true") 42 | else: 43 | group.add_argument("--" + key, default=value, type=t) 44 | 45 | def extract(self, args): 46 | group = GroupParams() 47 | for arg in vars(args).items(): 48 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 49 | setattr(group, arg[0], arg[1]) 50 | return group 51 | 52 | 53 | class OptimizationParams(ParamGroup): 54 | def __init__(self, parser): 55 | self.iterations = 30_000 56 | self.position_lr_init = 0.0001 57 | self.position_lr_final = 0.0000016 58 | self.position_lr_delay_mult = 0.01 59 | self.position_lr_max_steps = 30_000 60 | self.feature_lr = 0.0025 61 | self.opacity_lr = 0.05 62 | self.scaling_lr = 0.005 # before 0.005 63 | self.rotation_lr = 0.001 64 | self.percent_dense = 0.01 65 | self.lambda_dssim = 0.2 66 | self.densification_interval = 100 67 | self.opacity_reset_interval = 3000 68 | self.densify_from_iter = 500 69 | self.densify_until_iter = 15_000 70 | self.densify_grad_threshold = 0.0002 71 | super().__init__(parser, "Optimization Parameters") 72 | 73 | 74 | def get_combined_args(parser: ArgumentParser): 75 | cmdlne_string = sys.argv[1:] 76 | cfgfile_string = "Namespace()" 77 | args_cmdline = parser.parse_args(cmdlne_string) 78 | 79 | try: 80 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 81 | print("Looking for config file in", cfgfilepath) 82 | with open(cfgfilepath) as cfg_file: 83 | print("Config file found: {}".format(cfgfilepath)) 84 | cfgfile_string = cfg_file.read() 85 | except TypeError: 86 | print("Config file not found at") 87 | pass 88 | args_cfgfile = eval(cfgfile_string) 89 | 90 | merged_dict = vars(args_cfgfile).copy() 91 | for k, v in vars(args_cmdline).items(): 92 | if v is not None: 93 | merged_dict[k] = v 94 | return Namespace(**merged_dict) 95 | -------------------------------------------------------------------------------- /src/entities/datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import json 9 | import imageio 10 | import trimesh 11 | 12 | 13 | class BaseDataset(torch.utils.data.Dataset): 14 | 15 | def __init__(self, dataset_config: dict): 16 | self.dataset_path = Path(dataset_config["input_path"]) 17 | self.frame_limit = dataset_config.get("frame_limit", -1) 18 | self.dataset_config = dataset_config 19 | self.height = dataset_config["H"] 20 | self.width = dataset_config["W"] 21 | self.fx = dataset_config["fx"] 22 | self.fy = dataset_config["fy"] 23 | self.cx = dataset_config["cx"] 24 | self.cy = dataset_config["cy"] 25 | 26 | self.depth_scale = dataset_config["depth_scale"] 27 | self.distortion = np.array( 28 | dataset_config['distortion']) if 'distortion' in dataset_config else None 29 | self.crop_edge = dataset_config['crop_edge'] if 'crop_edge' in dataset_config else 0 30 | if self.crop_edge: 31 | self.height -= 2 * self.crop_edge 32 | self.width -= 2 * self.crop_edge 33 | self.cx -= self.crop_edge 34 | self.cy -= self.crop_edge 35 | 36 | self.fovx = 2 * math.atan(self.width / (2 * self.fx)) 37 | self.fovy = 2 * math.atan(self.height / (2 * self.fy)) 38 | self.intrinsics = np.array( 39 | [[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1]]) 40 | 41 | self.color_paths = [] 42 | self.depth_paths = [] 43 | 44 | def __len__(self): 45 | return len(self.color_paths) if self.frame_limit < 0 else int(self.frame_limit) 46 | 47 | 48 | class Replica(BaseDataset): 49 | 50 | def __init__(self, dataset_config: dict): 51 | super().__init__(dataset_config) 52 | self.color_paths = sorted( 53 | list((self.dataset_path / "results").glob("frame*.jpg"))) 54 | self.depth_paths = sorted( 55 | list((self.dataset_path / "results").glob("depth*.png"))) 56 | self.load_poses(self.dataset_path / "traj.txt") 57 | print(f"Loaded {len(self.color_paths)} frames") 58 | 59 | def load_poses(self, path): 60 | self.poses = [] 61 | with open(path, "r") as f: 62 | lines = f.readlines() 63 | for line in lines: 64 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 65 | self.poses.append(c2w.astype(np.float32)) 66 | 67 | def __getitem__(self, index): 68 | color_data = cv2.imread(str(self.color_paths[index])) 69 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 70 | depth_data = cv2.imread( 71 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED) 72 | depth_data = depth_data.astype(np.float32) / self.depth_scale 73 | return index, color_data, depth_data, self.poses[index] 74 | 75 | 76 | class TUM_RGBD(BaseDataset): 77 | def __init__(self, dataset_config: dict): 78 | super().__init__(dataset_config) 79 | self.color_paths, self.depth_paths, self.poses = self.loadtum( 80 | self.dataset_path, frame_rate=32) 81 | 82 | def parse_list(self, filepath, skiprows=0): 83 | """ read list data """ 84 | return np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows) 85 | 86 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08): 87 | """ pair images, depths, and poses """ 88 | associations = [] 89 | for i, t in enumerate(tstamp_image): 90 | if tstamp_pose is None: 91 | j = np.argmin(np.abs(tstamp_depth - t)) 92 | if (np.abs(tstamp_depth[j] - t) < max_dt): 93 | associations.append((i, j)) 94 | else: 95 | j = np.argmin(np.abs(tstamp_depth - t)) 96 | k = np.argmin(np.abs(tstamp_pose - t)) 97 | if (np.abs(tstamp_depth[j] - t) < max_dt) and (np.abs(tstamp_pose[k] - t) < max_dt): 98 | associations.append((i, j, k)) 99 | return associations 100 | 101 | def loadtum(self, datapath, frame_rate=-1): 102 | """ read video data in tum-rgbd format """ 103 | if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')): 104 | pose_list = os.path.join(datapath, 'groundtruth.txt') 105 | elif os.path.isfile(os.path.join(datapath, 'pose.txt')): 106 | pose_list = os.path.join(datapath, 'pose.txt') 107 | 108 | image_list = os.path.join(datapath, 'rgb.txt') 109 | depth_list = os.path.join(datapath, 'depth.txt') 110 | 111 | image_data = self.parse_list(image_list) 112 | depth_data = self.parse_list(depth_list) 113 | pose_data = self.parse_list(pose_list, skiprows=1) 114 | pose_vecs = pose_data[:, 1:].astype(np.float64) 115 | 116 | tstamp_image = image_data[:, 0].astype(np.float64) 117 | tstamp_depth = depth_data[:, 0].astype(np.float64) 118 | tstamp_pose = pose_data[:, 0].astype(np.float64) 119 | associations = self.associate_frames( 120 | tstamp_image, tstamp_depth, tstamp_pose) 121 | 122 | indicies = [0] 123 | for i in range(1, len(associations)): 124 | t0 = tstamp_image[associations[indicies[-1]][0]] 125 | t1 = tstamp_image[associations[i][0]] 126 | if t1 - t0 > 1.0 / frame_rate: 127 | indicies += [i] 128 | 129 | images, poses, depths = [], [], [] 130 | inv_pose = None 131 | for ix in indicies: 132 | (i, j, k) = associations[ix] 133 | images += [os.path.join(datapath, image_data[i, 1])] 134 | depths += [os.path.join(datapath, depth_data[j, 1])] 135 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k]) 136 | if inv_pose is None: 137 | inv_pose = np.linalg.inv(c2w) 138 | c2w = np.eye(4) 139 | else: 140 | c2w = inv_pose@c2w 141 | poses += [c2w.astype(np.float32)] 142 | 143 | return images, depths, poses 144 | 145 | def pose_matrix_from_quaternion(self, pvec): 146 | """ convert 4x4 pose matrix to (t, q) """ 147 | from scipy.spatial.transform import Rotation 148 | 149 | pose = np.eye(4) 150 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix() 151 | pose[:3, 3] = pvec[:3] 152 | return pose 153 | 154 | def __getitem__(self, index): 155 | color_data = cv2.imread(str(self.color_paths[index])) 156 | if self.distortion is not None: 157 | color_data = cv2.undistort( 158 | color_data, self.intrinsics, self.distortion) 159 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 160 | 161 | depth_data = cv2.imread( 162 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED) 163 | depth_data = depth_data.astype(np.float32) / self.depth_scale 164 | edge = self.crop_edge 165 | if edge > 0: 166 | color_data = color_data[edge:-edge, edge:-edge] 167 | depth_data = depth_data[edge:-edge, edge:-edge] 168 | # Interpolate depth values for splatting 169 | return index, color_data, depth_data, self.poses[index] 170 | 171 | 172 | class ScanNet(BaseDataset): 173 | def __init__(self, dataset_config: dict): 174 | super().__init__(dataset_config) 175 | self.color_paths = sorted(list( 176 | (self.dataset_path / "rgb").glob("*.png")), key=lambda x: int(os.path.basename(x)[-9:-4])) 177 | self.depth_paths = sorted(list( 178 | (self.dataset_path / "depth").glob("*.TIFF")), key=lambda x: int(os.path.basename(x)[-10:-5])) 179 | self.n_img = len(self.color_paths) 180 | self.load_poses(self.dataset_path / "gt_pose.txt") 181 | 182 | def load_poses(self, path): 183 | self.poses = [] 184 | pose_data = np.loadtxt(path, delimiter=" ", dtype=np.unicode_, skiprows=1) 185 | pose_vecs = pose_data[:, 0:].astype(np.float64) 186 | for i in range(self.n_img): 187 | quat = pose_vecs[i][4:] 188 | trans = pose_vecs[i][1:4] 189 | T = trimesh.transformations.quaternion_matrix(np.roll(quat, 1)) 190 | T[:3, 3] = trans 191 | pose = T 192 | self.poses.append(pose) 193 | 194 | def __getitem__(self, index): 195 | color_data = cv2.imread(str(self.color_paths[index])) 196 | if self.distortion is not None: 197 | color_data = cv2.undistort( 198 | color_data, self.intrinsics, self.distortion) 199 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 200 | color_data = cv2.resize(color_data, (self.dataset_config["W"], self.dataset_config["H"])) 201 | 202 | depth_data = cv2.imread( 203 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED) 204 | depth_data = depth_data.astype(np.float32) / self.depth_scale 205 | edge = self.crop_edge 206 | if edge > 0: 207 | color_data = color_data[edge:-edge, edge:-edge] 208 | depth_data = depth_data[edge:-edge, edge:-edge] 209 | # Interpolate depth values for splatting 210 | return index, color_data, depth_data, self.poses[index] 211 | 212 | 213 | class ScanNetPP(BaseDataset): 214 | def __init__(self, dataset_config: dict): 215 | super().__init__(dataset_config) 216 | self.use_train_split = dataset_config["use_train_split"] 217 | self.train_test_split = json.load(open(f"{self.dataset_path}/dslr/train_test_lists.json", "r")) 218 | if self.use_train_split: 219 | self.image_names = self.train_test_split["train"] 220 | else: 221 | self.image_names = self.train_test_split["test"] 222 | self.load_data() 223 | 224 | def load_data(self): 225 | self.poses = [] 226 | cams_path = self.dataset_path / "dslr" / "nerfstudio" / "transforms_undistorted.json" 227 | cams_metadata = json.load(open(str(cams_path), "r")) 228 | frames_key = "frames" if self.use_train_split else "test_frames" 229 | frames_metadata = cams_metadata[frames_key] 230 | frame2idx = {frame["file_path"]: index for index, frame in enumerate(frames_metadata)} 231 | P = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).astype(np.float32) 232 | for image_name in self.image_names: 233 | frame_metadata = frames_metadata[frame2idx[image_name]] 234 | # if self.ignore_bad and frame_metadata['is_bad']: 235 | # continue 236 | color_path = str(self.dataset_path / "dslr" / "undistorted_images" / image_name) 237 | depth_path = str(self.dataset_path / "dslr" / "undistorted_depths" / image_name.replace('.JPG', '.png')) 238 | self.color_paths.append(color_path) 239 | self.depth_paths.append(depth_path) 240 | c2w = np.array(frame_metadata["transform_matrix"]).astype(np.float32) 241 | c2w = P @ c2w @ P.T 242 | self.poses.append(c2w) 243 | 244 | def __len__(self): 245 | if self.use_train_split: 246 | return len(self.image_names) if self.frame_limit < 0 else int(self.frame_limit) 247 | else: 248 | return len(self.image_names) 249 | 250 | def __getitem__(self, index): 251 | 252 | color_data = np.asarray(imageio.imread(self.color_paths[index]), dtype=float) 253 | color_data = cv2.resize(color_data, (self.width, self.height), interpolation=cv2.INTER_LINEAR) 254 | color_data = color_data.astype(np.uint8) 255 | 256 | depth_data = np.asarray(imageio.imread(self.depth_paths[index]), dtype=np.int64) 257 | depth_data = cv2.resize(depth_data.astype(float), (self.width, self.height), interpolation=cv2.INTER_NEAREST) 258 | depth_data = depth_data.astype(np.float32) / self.depth_scale 259 | return index, color_data, depth_data, self.poses[index] 260 | 261 | 262 | def get_dataset(dataset_name: str): 263 | if dataset_name == "replica": 264 | return Replica 265 | elif dataset_name == "tum_rgbd": 266 | return TUM_RGBD 267 | elif dataset_name == "scan_net": 268 | return ScanNet 269 | elif dataset_name == "scannetpp": 270 | return ScanNetPP 271 | raise NotImplementedError(f"Dataset {dataset_name} not implemented") 272 | -------------------------------------------------------------------------------- /src/entities/gaussian_slam.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Gaussian-SLAM class, which is responsible for controlling Mapper and Tracker 2 | It also decides when to start a new submap and when to update the estimated camera poses. 3 | """ 4 | import os 5 | import pprint 6 | from argparse import ArgumentParser 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | import roma 13 | 14 | from src.entities.arguments import OptimizationParams 15 | from src.entities.datasets import get_dataset 16 | from src.entities.gaussian_model import GaussianModel 17 | from src.entities.mapper import Mapper 18 | from src.entities.tracker import Tracker 19 | from src.entities.lc import Loop_closure 20 | from src.entities.logger import Logger 21 | from src.utils.io_utils import save_dict_to_ckpt, save_dict_to_yaml 22 | from src.utils.mapper_utils import exceeds_motion_thresholds 23 | from src.utils.utils import np2torch, setup_seed, torch2np 24 | from src.utils.vis_utils import * # noqa - needed for debugging 25 | 26 | 27 | class GaussianSLAM(object): 28 | 29 | def __init__(self, config: dict) -> None: 30 | 31 | self._setup_output_path(config) 32 | self.device = "cuda" 33 | self.config = config 34 | 35 | self.scene_name = config["data"]["scene_name"] 36 | self.dataset_name = config["dataset_name"] 37 | self.dataset = get_dataset(config["dataset_name"])({**config["data"], **config["cam"]}) 38 | 39 | n_frames = len(self.dataset) 40 | frame_ids = list(range(n_frames)) 41 | self.mapping_frame_ids = frame_ids[::config["mapping"]["map_every"]] + [n_frames - 1] 42 | 43 | self.estimated_c2ws = torch.empty(len(self.dataset), 4, 4) 44 | self.estimated_c2ws[0] = torch.from_numpy(self.dataset[0][3]) 45 | self.exposures_ab = torch.zeros(len(self.dataset), 2) 46 | 47 | save_dict_to_yaml(config, "config.yaml", directory=self.output_path) 48 | 49 | self.submap_using_motion_heuristic = config["mapping"]["submap_using_motion_heuristic"] 50 | 51 | self.keyframes_info = {} 52 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters")) 53 | 54 | if self.submap_using_motion_heuristic: 55 | self.new_submap_frame_ids = [0] 56 | else: 57 | self.new_submap_frame_ids = frame_ids[::config["mapping"]["new_submap_every"]] + [n_frames - 1] 58 | self.new_submap_frame_ids.pop(0) 59 | 60 | self.logger = Logger(self.output_path, config["use_wandb"]) 61 | self.mapper = Mapper(config["mapping"], self.dataset, self.logger) 62 | self.tracker = Tracker(config["tracking"], self.dataset, self.logger) 63 | self.enable_exposure = self.tracker.enable_exposure 64 | self.loop_closer = Loop_closure(config, self.dataset, self.logger) 65 | self.loop_closer.submap_path = self.output_path / "submaps" 66 | 67 | print('Tracking config') 68 | pprint.PrettyPrinter().pprint(config["tracking"]) 69 | print('Mapping config') 70 | pprint.PrettyPrinter().pprint(config["mapping"]) 71 | print('Loop closure config') 72 | pprint.PrettyPrinter().pprint(config["lc"]) 73 | 74 | 75 | def _setup_output_path(self, config: dict) -> None: 76 | """ Sets up the output path for saving results based on the provided configuration. If the output path is not 77 | specified in the configuration, it creates a new directory with a timestamp. 78 | Args: 79 | config: A dictionary containing the experiment configuration including data and output path information. 80 | """ 81 | if "output_path" not in config["data"]: 82 | output_path = Path(config["data"]["output_path"]) 83 | self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 84 | self.output_path = output_path / self.timestamp 85 | else: 86 | self.output_path = Path(config["data"]["output_path"]) 87 | self.output_path.mkdir(exist_ok=True, parents=True) 88 | 89 | os.makedirs(self.output_path / "mapping_vis", exist_ok=True) 90 | os.makedirs(self.output_path / "tracking_vis", exist_ok=True) 91 | 92 | def should_start_new_submap(self, frame_id: int) -> bool: 93 | """ Determines whether a new submap should be started based on the motion heuristic or specific frame IDs. 94 | Args: 95 | frame_id: The ID of the current frame being processed. 96 | Returns: 97 | A boolean indicating whether to start a new submap. 98 | """ 99 | if self.submap_using_motion_heuristic: 100 | if exceeds_motion_thresholds( 101 | self.estimated_c2ws[frame_id], self.estimated_c2ws[self.new_submap_frame_ids[-1]], 102 | rot_thre=50, trans_thre=0.5): 103 | print(f"\nNew submap at {frame_id}") 104 | return True 105 | elif frame_id in self.new_submap_frame_ids: 106 | return True 107 | return False 108 | 109 | def save_current_submap(self, gaussian_model: GaussianModel): 110 | """Saving the current submap's checkpoint and resetting the Gaussian model 111 | 112 | Args: 113 | gaussian_model (GaussianModel): The current GaussianModel instance to capture and reset for the new submap. 114 | """ 115 | 116 | gaussian_params = gaussian_model.capture_dict() 117 | submap_ckpt_name = str(self.submap_id).zfill(6) 118 | submap_ckpt = { 119 | "gaussian_params": gaussian_params, 120 | "submap_keyframes": sorted(list(self.keyframes_info.keys())) 121 | } 122 | save_dict_to_ckpt( 123 | submap_ckpt, f"{submap_ckpt_name}.ckpt", directory=self.output_path / "submaps") 124 | 125 | def start_new_submap(self, frame_id: int, gaussian_model: GaussianModel) -> None: 126 | """ Initializes a new submap. 127 | This function updates the submap count and optionally marks the current frame ID for new submap initiation. 128 | Args: 129 | frame_id: The ID of the current frame at which the new submap is started. 130 | gaussian_model: The current GaussianModel instance to capture and reset for the new submap. 131 | Returns: 132 | A new, reset GaussianModel instance for the new submap. 133 | """ 134 | 135 | gaussian_model = GaussianModel(0) 136 | gaussian_model.training_setup(self.opt) 137 | self.mapper.keyframes = [] 138 | self.keyframes_info = {} 139 | if self.submap_using_motion_heuristic: 140 | self.new_submap_frame_ids.append(frame_id) 141 | self.mapping_frame_ids.append(frame_id) if frame_id not in self.mapping_frame_ids else self.mapping_frame_ids 142 | self.submap_id += 1 143 | self.loop_closer.submap_id += 1 144 | return gaussian_model 145 | 146 | def rigid_transform_gaussians(self, gaussian_params, tsfm_matrix): 147 | ''' 148 | Apply a rigid transformation to the Gaussian parameters. 149 | 150 | Args: 151 | gaussian_params (dict): Dictionary containing Gaussian parameters. 152 | tsfm_matrix (torch.Tensor): 4x4 rigid transformation matrix. 153 | 154 | Returns: 155 | dict: Updated Gaussian parameters after applying the transformation. 156 | ''' 157 | # Transform Gaussian centers (xyz) 158 | tsfm_matrix = torch.from_numpy(tsfm_matrix).float() 159 | xyz = gaussian_params['xyz'] 160 | pts_ones = torch.ones((xyz.shape[0], 1)) 161 | pts_homo = torch.cat([xyz, pts_ones], dim=1) 162 | transformed_xyz = (tsfm_matrix @ pts_homo.T).T[:, :3] 163 | gaussian_params['xyz'] = transformed_xyz 164 | 165 | # Rotate covariance matrix (rotation) 166 | rotation = gaussian_params['rotation'] 167 | cur_rot = roma.unitquat_to_rotmat(rotation) 168 | rot_mat = tsfm_matrix[:3, :3].unsqueeze(0) # Adding batch dimension 169 | new_rot = rot_mat @ cur_rot 170 | new_quat = roma.rotmat_to_unitquat(new_rot) 171 | gaussian_params['rotation'] = new_quat.squeeze() 172 | 173 | return gaussian_params 174 | 175 | def update_keyframe_poses(self, lc_output, submaps_kf_ids, cur_frame_id): 176 | ''' 177 | Update the keyframe poses using the correction from pgo, currently update the frame range that covered by the keyframes. 178 | 179 | ''' 180 | for correction in lc_output: 181 | submap_id = correction['submap_id'] 182 | correct_tsfm = correction['correct_tsfm'] 183 | submap_kf_ids = submaps_kf_ids[submap_id] 184 | min_id, max_id = min(submap_kf_ids), max(submap_kf_ids) 185 | self.estimated_c2ws[min_id:max_id + 1] = torch.from_numpy(correct_tsfm).float() @ self.estimated_c2ws[min_id:max_id + 1] 186 | 187 | # last tracked frame is based on last submap, update it as well 188 | self.estimated_c2ws[cur_frame_id] = torch.from_numpy(lc_output[-1]['correct_tsfm']).float() @ self.estimated_c2ws[cur_frame_id] 189 | 190 | 191 | def apply_correction_to_submaps(self, correction_list): 192 | submaps_kf_ids= {} 193 | for correction in correction_list: 194 | submap_id = correction['submap_id'] 195 | correct_tsfm = correction['correct_tsfm'] 196 | 197 | submap_ckpt_name = str(submap_id).zfill(6) + ".ckpt" 198 | submap_ckpt = torch.load(self.output_path / "submaps" / submap_ckpt_name) 199 | submaps_kf_ids[submap_id] = submap_ckpt["submap_keyframes"] 200 | 201 | gaussian_params = submap_ckpt["gaussian_params"] 202 | updated_gaussian_params = self.rigid_transform_gaussians( 203 | gaussian_params, correct_tsfm) 204 | 205 | submap_ckpt["gaussian_params"] = updated_gaussian_params 206 | torch.save(submap_ckpt, self.output_path / "submaps" / submap_ckpt_name) 207 | return submaps_kf_ids 208 | 209 | def run(self) -> None: 210 | """ Starts the main program flow for Gaussian-SLAM, including tracking and mapping. """ 211 | setup_seed(self.config["seed"]) 212 | gaussian_model = GaussianModel(0) 213 | gaussian_model.training_setup(self.opt) 214 | self.submap_id = 0 215 | 216 | for frame_id in range(len(self.dataset)): 217 | 218 | if frame_id in [0, 1]: 219 | estimated_c2w = self.dataset[frame_id][-1] 220 | exposure_ab = torch.nn.Parameter(torch.tensor( 221 | 0.0, device="cuda")), torch.nn.Parameter(torch.tensor(0.0, device="cuda")) 222 | else: 223 | estimated_c2w, exposure_ab = self.tracker.track( 224 | frame_id, gaussian_model, 225 | torch2np(self.estimated_c2ws[torch.tensor([0, frame_id - 2, frame_id - 1])])) 226 | exposure_ab = exposure_ab if self.enable_exposure else None 227 | self.estimated_c2ws[frame_id] = np2torch(estimated_c2w) 228 | 229 | # Reinitialize gaussian model for new segment 230 | if self.should_start_new_submap(frame_id): 231 | # first save current submap and its keyframe info 232 | self.save_current_submap(gaussian_model) 233 | 234 | # update submap infomation for loop closer 235 | self.loop_closer.update_submaps_info(self.keyframes_info) 236 | 237 | # apply loop closure 238 | lc_output = self.loop_closer.loop_closure(self.estimated_c2ws) 239 | 240 | if len(lc_output) > 0: 241 | submaps_kf_ids = self.apply_correction_to_submaps(lc_output) 242 | self.update_keyframe_poses(lc_output, submaps_kf_ids, frame_id) 243 | 244 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path) 245 | 246 | gaussian_model = self.start_new_submap(frame_id, gaussian_model) 247 | 248 | if frame_id in self.mapping_frame_ids: 249 | print("\nMapping frame", frame_id) 250 | gaussian_model.training_setup(self.opt, exposure_ab) 251 | estimate_c2w = torch2np(self.estimated_c2ws[frame_id]) 252 | new_submap = not bool(self.keyframes_info) 253 | opt_dict = self.mapper.map( 254 | frame_id, estimate_c2w, gaussian_model, new_submap, exposure_ab) 255 | 256 | # Keyframes info update 257 | self.keyframes_info[frame_id] = { 258 | "keyframe_id": frame_id, 259 | "opt_dict": opt_dict, 260 | } 261 | if self.enable_exposure: 262 | self.keyframes_info[frame_id]["exposure_a"] = exposure_ab[0].item() 263 | self.keyframes_info[frame_id]["exposure_b"] = exposure_ab[1].item() 264 | 265 | if frame_id == len(self.dataset) - 1 and self.config['lc']['final']: 266 | print("\n Final loop closure ...") 267 | self.loop_closer.update_submaps_info(self.keyframes_info) 268 | lc_output = self.loop_closer.loop_closure(self.estimated_c2ws, final=True) 269 | if len(lc_output) > 0: 270 | submaps_kf_ids = self.apply_correction_to_submaps(lc_output) 271 | self.update_keyframe_poses(lc_output, submaps_kf_ids, frame_id) 272 | if self.enable_exposure: 273 | self.exposures_ab[frame_id] = torch.tensor([exposure_ab[0].item(), exposure_ab[1].item()]) 274 | 275 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path) 276 | if self.enable_exposure: 277 | save_dict_to_ckpt(self.exposures_ab, "exposures_ab.ckpt", directory=self.output_path) 278 | -------------------------------------------------------------------------------- /src/entities/logger.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Logger class, which is responsible for logging for both Mapper and the Tracker """ 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import wandb 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, output_path: Union[Path, str], use_wandb=False) -> None: 14 | self.output_path = Path(output_path) 15 | (self.output_path / "mapping_vis").mkdir(exist_ok=True, parents=True) 16 | self.use_wandb = use_wandb 17 | 18 | def log_tracking_iteration(self, frame_id, cur_pose, gt_quat, gt_trans, total_loss, 19 | color_loss, depth_loss, iter, num_iters, 20 | wandb_output=False, print_output=False) -> None: 21 | """ Logs tracking iteration metrics including pose error, losses, and optionally reports to Weights & Biases. 22 | Logs the error between the current pose estimate and ground truth quaternion and translation, 23 | as well as various loss metrics. Can output to wandb if enabled and specified, and print to console. 24 | Args: 25 | frame_id: Identifier for the current frame. 26 | cur_pose: The current estimated pose as a tensor (quaternion + translation). 27 | gt_quat: Ground truth quaternion. 28 | gt_trans: Ground truth translation. 29 | total_loss: Total computed loss for the current iteration. 30 | color_loss: Computed color loss for the current iteration. 31 | depth_loss: Computed depth loss for the current iteration. 32 | iter: The current iteration number. 33 | num_iters: The total number of iterations planned. 34 | wandb_output: Whether to output the log to wandb. 35 | print_output: Whether to print the log output. 36 | """ 37 | 38 | quad_err = torch.abs(cur_pose[:4] - gt_quat).mean().item() 39 | trans_err = torch.abs(cur_pose[4:] - gt_trans).mean().item() 40 | if self.use_wandb and wandb_output: 41 | wandb.log( 42 | { 43 | "Tracking/idx": frame_id, 44 | "Tracking/cam_quad_err": quad_err, 45 | "Tracking/cam_position_err": trans_err, 46 | "Tracking/total_loss": total_loss.item(), 47 | "Tracking/color_loss": color_loss.item(), 48 | "Tracking/depth_loss": depth_loss.item(), 49 | "Tracking/num_iters": num_iters, 50 | }) 51 | if iter == num_iters - 1: 52 | msg = f"frame_id: {frame_id}, cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f} " 53 | else: 54 | msg = f"iter: {iter}, color_loss: {color_loss.item():.5f}, depth_loss: {depth_loss.item():.5f} " 55 | msg = msg + f", cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f}" 56 | if print_output: 57 | print(msg, flush=True) 58 | 59 | def log_mapping_iteration(self, frame_id, new_pts_num, model_size, iter_opt_time, opt_dict: dict) -> None: 60 | """ Logs mapping iteration metrics including the number of new points, model size, and optimization times, 61 | and optionally reports to Weights & Biases (wandb). 62 | Args: 63 | frame_id: Identifier for the current frame. 64 | new_pts_num: The number of new points added in the current mapping iteration. 65 | model_size: The total size of the model after the current mapping iteration. 66 | iter_opt_time: Time taken per optimization iteration. 67 | opt_dict: A dictionary containing optimization metrics such as PSNR, color loss, and depth loss. 68 | """ 69 | if self.use_wandb: 70 | wandb.log({"Mapping/idx": frame_id, 71 | "Mapping/num_total_gs": model_size, 72 | "Mapping/num_new_gs": new_pts_num, 73 | "Mapping/per_iteration_time": iter_opt_time, 74 | "Mapping/psnr_render": opt_dict["psnr_render"], 75 | "Mapping/color_loss": opt_dict[frame_id]["color_loss"], 76 | "Mapping/depth_loss": opt_dict[frame_id]["depth_loss"]}) 77 | 78 | def vis_mapping_iteration(self, frame_id, iter, color, depth, gt_color, gt_depth, seeding_mask=None, interval=10) -> None: 79 | """ 80 | Visualization of depth, color images and save to file. 81 | 82 | Args: 83 | frame_id (int): current frame index. 84 | iter (int): the iteration number. 85 | save_rendered_image (bool): whether to save the rgb image in separate folder 86 | img_dir (str): the directory to save the visualization. 87 | seeding_mask: used in mapper when adding gaussians, if not none. 88 | """ 89 | if frame_id % interval != 0: 90 | return 91 | gt_depth_np = gt_depth.cpu().numpy() 92 | gt_color_np = gt_color.cpu().numpy() 93 | 94 | depth_np = depth.detach().cpu().numpy() 95 | color = torch.round(color * 255.0) / 255.0 96 | color_np = color.detach().cpu().numpy() 97 | depth_residual = np.abs(gt_depth_np - depth_np) 98 | depth_residual[gt_depth_np == 0.0] = 0.0 99 | # make errors >=5cm noticeable 100 | depth_residual = np.clip(depth_residual, 0.0, 0.05) 101 | 102 | color_residual = np.abs(gt_color_np - color_np) 103 | color_residual[np.squeeze(gt_depth_np == 0.0)] = 0.0 104 | 105 | # Determine Aspect Ratio and Figure Size 106 | aspect_ratio = color.shape[1] / color.shape[0] 107 | fig_height = 8 108 | # Adjust the multiplier as needed for better spacing 109 | fig_width = fig_height * aspect_ratio * 1.2 110 | 111 | fig, axs = plt.subplots(2, 3, figsize=(fig_width, fig_height)) 112 | axs[0, 0].imshow(gt_depth_np, cmap="jet", vmin=0, vmax=6) 113 | axs[0, 0].set_title('Input Depth', fontsize=16) 114 | axs[0, 0].set_xticks([]) 115 | axs[0, 0].set_yticks([]) 116 | axs[0, 1].imshow(depth_np, cmap="jet", vmin=0, vmax=6) 117 | axs[0, 1].set_title('Rendered Depth', fontsize=16) 118 | axs[0, 1].set_xticks([]) 119 | axs[0, 1].set_yticks([]) 120 | axs[0, 2].imshow(depth_residual, cmap="plasma") 121 | axs[0, 2].set_title('Depth Residual', fontsize=16) 122 | axs[0, 2].set_xticks([]) 123 | axs[0, 2].set_yticks([]) 124 | gt_color_np = np.clip(gt_color_np, 0, 1) 125 | color_np = np.clip(color_np, 0, 1) 126 | color_residual = np.clip(color_residual, 0, 1) 127 | axs[1, 0].imshow(gt_color_np, cmap="plasma") 128 | axs[1, 0].set_title('Input RGB', fontsize=16) 129 | axs[1, 0].set_xticks([]) 130 | axs[1, 0].set_yticks([]) 131 | axs[1, 1].imshow(color_np, cmap="plasma") 132 | axs[1, 1].set_title('Rendered RGB', fontsize=16) 133 | axs[1, 1].set_xticks([]) 134 | axs[1, 1].set_yticks([]) 135 | if seeding_mask is not None: 136 | axs[1, 2].imshow(seeding_mask, cmap="gray") 137 | axs[1, 2].set_title('Densification Mask', fontsize=16) 138 | axs[1, 2].set_xticks([]) 139 | axs[1, 2].set_yticks([]) 140 | else: 141 | axs[1, 2].imshow(color_residual, cmap="plasma") 142 | axs[1, 2].set_title('RGB Residual', fontsize=16) 143 | axs[1, 2].set_xticks([]) 144 | axs[1, 2].set_yticks([]) 145 | 146 | for ax in axs.flatten(): 147 | ax.axis('off') 148 | fig.tight_layout() 149 | plt.subplots_adjust(top=0.90) # Adjust top margin 150 | fig_name = str(self.output_path / "mapping_vis" / f'{frame_id:04d}_{iter:04d}.jpg') 151 | fig_title = f"Mapper Color/Depth at frame {frame_id:04d} iters {iter:04d}" 152 | plt.suptitle(fig_title, y=0.98, fontsize=20) 153 | plt.savefig(fig_name, dpi=250, bbox_inches='tight') 154 | plt.clf() 155 | plt.close() 156 | if self.use_wandb: 157 | log_title = "Mapping_vis/" + f'{frame_id:04d}_{iter:04d}' 158 | wandb.log({log_title: [wandb.Image(fig_name)]}) 159 | print(f"Saved rendering vis of color/depth at {frame_id:04d}_{iter:04d}.jpg") 160 | -------------------------------------------------------------------------------- /src/entities/losses.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | def l1_loss(network_output: torch.Tensor, gt: torch.Tensor, agg="mean") -> torch.Tensor: 9 | """ 10 | Computes the L1 loss, which is the mean absolute error between the network output and the ground truth. 11 | 12 | Args: 13 | network_output: The output from the network. 14 | gt: The ground truth tensor. 15 | agg: The aggregation method to be used. Defaults to "mean". 16 | Returns: 17 | The computed L1 loss. 18 | """ 19 | l1_loss = torch.abs(network_output - gt) 20 | if agg == "mean": 21 | return l1_loss.mean() 22 | elif agg == "sum": 23 | return l1_loss.sum() 24 | elif agg == "none": 25 | return l1_loss 26 | else: 27 | raise ValueError("Invalid aggregation method.") 28 | 29 | 30 | def gaussian(window_size: int, sigma: float) -> torch.Tensor: 31 | """ 32 | Creates a 1D Gaussian kernel. 33 | 34 | Args: 35 | window_size: The size of the window for the Gaussian kernel. 36 | sigma: The standard deviation of the Gaussian kernel. 37 | 38 | Returns: 39 | The 1D Gaussian kernel. 40 | """ 41 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / 42 | float(2 * sigma ** 2)) for x in range(window_size)]) 43 | return gauss / gauss.sum() 44 | 45 | 46 | def create_window(window_size: int, channel: int) -> Variable: 47 | """ 48 | Creates a 2D Gaussian window/kernel for SSIM computation. 49 | 50 | Args: 51 | window_size: The size of the window to be created. 52 | channel: The number of channels in the image. 53 | 54 | Returns: 55 | A 2D Gaussian window expanded to match the number of channels. 56 | """ 57 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 58 | _2D_window = _1D_window.mm( 59 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 60 | window = Variable(_2D_window.expand( 61 | channel, 1, window_size, window_size).contiguous()) 62 | return window 63 | 64 | 65 | def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int = 11, size_average: bool = True) -> torch.Tensor: 66 | """ 67 | Computes the Structural Similarity Index (SSIM) between two images. 68 | 69 | Args: 70 | img1: The first image. 71 | img2: The second image. 72 | window_size: The size of the window to be used in SSIM computation. Defaults to 11. 73 | size_average: If True, averages the SSIM over all pixels. Defaults to True. 74 | 75 | Returns: 76 | The computed SSIM value. 77 | """ 78 | channel = img1.size(-3) 79 | window = create_window(window_size, channel) 80 | 81 | if img1.is_cuda: 82 | window = window.cuda(img1.get_device()) 83 | window = window.type_as(img1) 84 | 85 | return _ssim(img1, img2, window, window_size, channel, size_average) 86 | 87 | 88 | def _ssim(img1: torch.Tensor, img2: torch.Tensor, window: Variable, window_size: int, 89 | channel: int, size_average: bool = True) -> torch.Tensor: 90 | """ 91 | Internal function to compute the Structural Similarity Index (SSIM) between two images. 92 | 93 | Args: 94 | img1: The first image. 95 | img2: The second image. 96 | window: The Gaussian window/kernel for SSIM computation. 97 | window_size: The size of the window to be used in SSIM computation. 98 | channel: The number of channels in the image. 99 | size_average: If True, averages the SSIM over all pixels. 100 | 101 | Returns: 102 | The computed SSIM value. 103 | """ 104 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 105 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 106 | 107 | mu1_sq = mu1.pow(2) 108 | mu2_sq = mu2.pow(2) 109 | mu1_mu2 = mu1 * mu2 110 | 111 | sigma1_sq = F.conv2d(img1 * img1, window, 112 | padding=window_size // 2, groups=channel) - mu1_sq 113 | sigma2_sq = F.conv2d(img2 * img2, window, 114 | padding=window_size // 2, groups=channel) - mu2_sq 115 | sigma12 = F.conv2d(img1 * img2, window, 116 | padding=window_size // 2, groups=channel) - mu1_mu2 117 | 118 | C1 = 0.01 ** 2 119 | C2 = 0.03 ** 2 120 | 121 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 122 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 123 | 124 | if size_average: 125 | return ssim_map.mean() 126 | else: 127 | return ssim_map.mean(1).mean(1).mean(1) 128 | 129 | 130 | def isotropic_loss(scaling: torch.Tensor) -> torch.Tensor: 131 | """ 132 | Computes loss enforcing isotropic scaling for the 3D Gaussians 133 | Args: 134 | scaling: scaling tensor of 3D Gaussians of shape (n, 3) 135 | Returns: 136 | The computed isotropic loss 137 | """ 138 | mean_scaling = scaling.mean(dim=1, keepdim=True) 139 | isotropic_diff = torch.abs(scaling - mean_scaling * torch.ones_like(scaling)) 140 | return isotropic_diff.mean() 141 | -------------------------------------------------------------------------------- /src/entities/tracker.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Mapper class, which is responsible scene mapping: Paper Section 3.4 """ 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | from src.entities.arguments import OptimizationParams 11 | from src.entities.losses import l1_loss 12 | from src.entities.gaussian_model import GaussianModel 13 | from src.entities.logger import Logger 14 | from src.entities.datasets import BaseDataset 15 | from src.entities.visual_odometer import VisualOdometer 16 | from src.utils.gaussian_model_utils import build_rotation 17 | from src.utils.tracker_utils import (compute_camera_opt_params, 18 | extrapolate_poses, multiply_quaternions, 19 | transformation_to_quaternion) 20 | from src.utils.utils import (get_render_settings, np2torch, 21 | render_gaussian_model, torch2np) 22 | 23 | 24 | class Tracker(object): 25 | def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None: 26 | """ Initializes the Tracker with a given configuration, dataset, and logger. 27 | Args: 28 | config: Configuration dictionary specifying hyperparameters and operational settings. 29 | dataset: The dataset object providing access to the sequence of frames. 30 | logger: Logger object for logging the tracking process. 31 | """ 32 | self.dataset = dataset 33 | self.logger = logger 34 | self.config = config 35 | self.filter_alpha = self.config["filter_alpha"] 36 | self.filter_outlier_depth = self.config["filter_outlier_depth"] 37 | self.alpha_thre = self.config["alpha_thre"] 38 | self.soft_alpha = self.config["soft_alpha"] 39 | self.mask_invalid_depth_in_color_loss = self.config["mask_invalid_depth"] 40 | self.w_color_loss = self.config["w_color_loss"] 41 | self.transform = torchvision.transforms.ToTensor() 42 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters")) 43 | self.frame_depth_loss = [] 44 | self.frame_color_loss = [] 45 | self.odometry_type = self.config["odometry_type"] 46 | self.help_camera_initialization = self.config["help_camera_initialization"] 47 | self.init_err_ratio = self.config["init_err_ratio"] 48 | self.enable_exposure = self.config["enable_exposure"] 49 | self.odometer = VisualOdometer(self.dataset.intrinsics, self.config["odometer_method"]) 50 | 51 | def compute_losses(self, gaussian_model: GaussianModel, render_settings: dict, 52 | opt_cam_rot: torch.Tensor, opt_cam_trans: torch.Tensor, 53 | gt_color: torch.Tensor, gt_depth: torch.Tensor, depth_mask: torch.Tensor, 54 | exposure_ab=None) -> tuple: 55 | """ Computes the tracking losses with respect to ground truth color and depth. 56 | Args: 57 | gaussian_model: The current state of the Gaussian model of the scene. 58 | render_settings: Dictionary containing rendering settings such as image dimensions and camera intrinsics. 59 | opt_cam_rot: Optimizable tensor representing the camera's rotation. 60 | opt_cam_trans: Optimizable tensor representing the camera's translation. 61 | gt_color: Ground truth color image tensor. 62 | gt_depth: Ground truth depth image tensor. 63 | depth_mask: Binary mask indicating valid depth values in the ground truth depth image. 64 | Returns: 65 | A tuple containing losses and renders 66 | """ 67 | rel_transform = torch.eye(4).cuda().float() 68 | rel_transform[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None]))[0] 69 | rel_transform[:3, 3] = opt_cam_trans 70 | 71 | pts = gaussian_model.get_xyz() 72 | pts_ones = torch.ones(pts.shape[0], 1).cuda().float() 73 | pts4 = torch.cat((pts, pts_ones), dim=1) 74 | transformed_pts = (rel_transform @ pts4.T).T[:, :3] 75 | 76 | quat = F.normalize(opt_cam_rot[None]) 77 | _rotations = multiply_quaternions(gaussian_model.get_rotation(), quat.unsqueeze(0)).squeeze(0) 78 | 79 | render_dict = render_gaussian_model(gaussian_model, render_settings, 80 | override_means_3d=transformed_pts, override_rotations=_rotations) 81 | rendered_color, rendered_depth = render_dict["color"], render_dict["depth"] 82 | if self.enable_exposure: 83 | rendered_color = torch.clamp(torch.exp(exposure_ab[0]) * rendered_color + exposure_ab[1], 0, 1.) 84 | alpha_mask = render_dict["alpha"] > self.alpha_thre 85 | 86 | tracking_mask = torch.ones_like(alpha_mask).bool() 87 | tracking_mask &= depth_mask 88 | depth_err = torch.abs(rendered_depth - gt_depth) * depth_mask 89 | 90 | if self.filter_alpha: 91 | tracking_mask &= alpha_mask 92 | if self.filter_outlier_depth and torch.median(depth_err) > 0: 93 | tracking_mask &= depth_err < 50 * torch.median(depth_err) 94 | 95 | color_loss = l1_loss(rendered_color, gt_color, agg="none") 96 | depth_loss = l1_loss(rendered_depth, gt_depth, agg="none") * tracking_mask 97 | 98 | if self.soft_alpha: 99 | alpha = render_dict["alpha"] ** 3 100 | color_loss *= alpha 101 | depth_loss *= alpha 102 | if self.mask_invalid_depth_in_color_loss: 103 | color_loss *= tracking_mask 104 | else: 105 | color_loss *= tracking_mask 106 | 107 | color_loss = color_loss.sum() 108 | depth_loss = depth_loss.sum() 109 | 110 | return color_loss, depth_loss, rendered_color, rendered_depth, alpha_mask 111 | 112 | def track(self, frame_id: int, gaussian_model: GaussianModel, prev_c2ws: np.ndarray) -> np.ndarray: 113 | """ 114 | Updates the camera pose estimation for the current frame based on the provided image and depth, using either ground truth poses, 115 | constant speed assumption, or visual odometry. 116 | Args: 117 | frame_id: Index of the current frame being processed. 118 | gaussian_model: The current Gaussian model of the scene. 119 | prev_c2ws: Array containing the camera-to-world transformation matrices for the frames (0, i - 2, i - 1) 120 | Returns: 121 | The updated camera-to-world transformation matrix for the current frame. 122 | """ 123 | _, image, depth, gt_c2w = self.dataset[frame_id] 124 | 125 | if (self.help_camera_initialization or self.odometry_type == "odometer") and self.odometer.last_rgbd is None: 126 | _, last_image, last_depth, _ = self.dataset[frame_id - 1] 127 | self.odometer.update_last_rgbd(last_image, last_depth) 128 | 129 | if self.odometry_type == "gt": 130 | return gt_c2w 131 | elif self.odometry_type == "const_speed": 132 | init_c2w = extrapolate_poses(prev_c2ws[1:]) 133 | elif self.odometry_type == "odometer": 134 | odometer_rel = self.odometer.estimate_rel_pose(image, depth) 135 | init_c2w = prev_c2ws[-1] @ odometer_rel 136 | elif self.odometry_type == "previous": 137 | init_c2w = prev_c2ws[-1] 138 | 139 | last_c2w = prev_c2ws[-1] 140 | last_w2c = np.linalg.inv(last_c2w) 141 | init_rel = init_c2w @ np.linalg.inv(last_c2w) 142 | init_rel_w2c = np.linalg.inv(init_rel) 143 | reference_w2c = last_w2c 144 | render_settings = get_render_settings( 145 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, reference_w2c) 146 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c) 147 | if self.enable_exposure: 148 | exposure_ab = torch.nn.Parameter(torch.tensor( 149 | 0.0, device="cuda")), torch.nn.Parameter(torch.tensor(0.0, device="cuda")) 150 | else: 151 | exposure_ab = None 152 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config, exposure_ab) 153 | 154 | gt_color = self.transform(image).cuda() 155 | gt_depth = np2torch(depth, "cuda") 156 | depth_mask = gt_depth > 0.0 157 | gt_trans = np2torch(gt_c2w[:3, 3]) 158 | gt_quat = np2torch(R.from_matrix(gt_c2w[:3, :3]).as_quat(canonical=True)[[3, 0, 1, 2]]) 159 | num_iters = self.config["iterations"] 160 | current_min_loss = float("inf") 161 | 162 | print(f"\nTracking frame {frame_id}") 163 | # Initial loss check 164 | color_loss, depth_loss, _, _, _ = self.compute_losses(gaussian_model, render_settings, opt_cam_rot, 165 | opt_cam_trans, gt_color, gt_depth, depth_mask, 166 | exposure_ab) 167 | if len(self.frame_color_loss) > 0 and ( 168 | color_loss.item() > self.init_err_ratio * np.median(self.frame_color_loss) 169 | or depth_loss.item() > self.init_err_ratio * np.median(self.frame_depth_loss) 170 | ): 171 | num_iters *= 2 172 | print(f"Higher initial loss, increasing num_iters to {num_iters}") 173 | if self.help_camera_initialization and self.odometry_type != "odometer": 174 | _, last_image, last_depth, _ = self.dataset[frame_id - 1] 175 | self.odometer.update_last_rgbd(last_image, last_depth) 176 | odometer_rel = self.odometer.estimate_rel_pose(image, depth) 177 | init_c2w = last_c2w @ odometer_rel 178 | init_rel = init_c2w @ np.linalg.inv(last_c2w) 179 | init_rel_w2c = np.linalg.inv(init_rel) 180 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c) 181 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config, exposure_ab) 182 | render_settings = get_render_settings( 183 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, last_w2c) 184 | print(f"re-init with odometer for frame {frame_id}") 185 | 186 | for iter in range(num_iters): 187 | color_loss, depth_loss, _, _, _, = self.compute_losses( 188 | gaussian_model, render_settings, opt_cam_rot, opt_cam_trans, gt_color, gt_depth, depth_mask, exposure_ab) 189 | 190 | total_loss = (self.w_color_loss * color_loss + (1 - self.w_color_loss) * depth_loss) 191 | total_loss.backward() 192 | gaussian_model.optimizer.step() 193 | # gaussian_model.scheduler.step(total_loss, epoch=iter) 194 | gaussian_model.optimizer.zero_grad(set_to_none=True) 195 | 196 | with torch.no_grad(): 197 | if total_loss.item() < current_min_loss: 198 | current_min_loss = total_loss.item() 199 | best_w2c = torch.eye(4) 200 | best_w2c[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None].clone().detach().cpu()))[0] 201 | best_w2c[:3, 3] = opt_cam_trans.clone().detach().cpu() 202 | 203 | cur_quat, cur_trans = F.normalize(opt_cam_rot[None].clone().detach()), opt_cam_trans.clone().detach() 204 | cur_rel_w2c = torch.eye(4) 205 | cur_rel_w2c[:3, :3] = build_rotation(cur_quat)[0] 206 | cur_rel_w2c[:3, 3] = cur_trans 207 | if iter == num_iters - 1: 208 | cur_w2c = torch.from_numpy(reference_w2c) @ best_w2c 209 | else: 210 | cur_w2c = torch.from_numpy(reference_w2c) @ cur_rel_w2c 211 | cur_c2w = torch.inverse(cur_w2c) 212 | cur_cam = transformation_to_quaternion(cur_c2w) 213 | if (gt_quat * cur_cam[:4]).sum() < 0: # for logging purpose 214 | gt_quat *= -1 215 | if iter == num_iters - 1: 216 | self.frame_color_loss.append(color_loss.item()) 217 | self.frame_depth_loss.append(depth_loss.item()) 218 | self.logger.log_tracking_iteration( 219 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters, 220 | wandb_output=True, print_output=True) 221 | elif iter % 20 == 0: 222 | self.logger.log_tracking_iteration( 223 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters, 224 | wandb_output=False, print_output=True) 225 | 226 | final_c2w = torch.inverse(torch.from_numpy(reference_w2c) @ best_w2c) 227 | final_c2w[-1, :] = torch.tensor([0., 0., 0., 1.], dtype=final_c2w.dtype, device=final_c2w.device) 228 | return torch2np(final_c2w), exposure_ab 229 | -------------------------------------------------------------------------------- /src/entities/visual_odometer.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Odometer class, which is allows for fast pose estimation from RGBD neighbor frames """ 2 | import numpy as np 3 | import open3d as o3d 4 | import open3d.core as o3c 5 | 6 | 7 | class VisualOdometer(object): 8 | 9 | def __init__(self, intrinsics: np.ndarray, method_name="hybrid", device="cuda"): 10 | """ Initializes the visual odometry system with specified intrinsics, method, and device. 11 | Args: 12 | intrinsics: Camera intrinsic parameters. 13 | method_name: The name of the odometry computation method to use ('hybrid' or 'point_to_plane'). 14 | device: The computation device ('cuda' or 'cpu'). 15 | """ 16 | device = "CUDA:0" if device == "cuda" else "CPU:0" 17 | self.device = o3c.Device(device) 18 | self.intrinsics = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64) 19 | self.last_abs_pose = None 20 | self.last_frame = None 21 | self.criteria_list = [ 22 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500), 23 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500), 24 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500)] 25 | self.setup_method(method_name) 26 | self.max_depth = 10.0 27 | self.depth_scale = 1.0 28 | self.last_rgbd = None 29 | 30 | def setup_method(self, method_name: str) -> None: 31 | """ Sets up the odometry computation method based on the provided method name. 32 | Args: 33 | method_name: The name of the odometry method to use ('hybrid' or 'point_to_plane'). 34 | """ 35 | if method_name == "hybrid": 36 | self.method = o3d.t.pipelines.odometry.Method.Hybrid 37 | elif method_name == "point_to_plane": 38 | self.method = o3d.t.pipelines.odometry.Method.PointToPlane 39 | else: 40 | raise ValueError("Odometry method does not exist!") 41 | 42 | def update_last_rgbd(self, image: np.ndarray, depth: np.ndarray) -> None: 43 | """ Updates the last RGB-D frame stored in the system with a new RGB-D frame constructed from provided image and depth. 44 | Args: 45 | image: The new RGB image as a numpy ndarray. 46 | depth: The new depth image as a numpy ndarray. 47 | """ 48 | self.last_rgbd = o3d.t.geometry.RGBDImage( 49 | o3d.t.geometry.Image(np.ascontiguousarray( 50 | image).astype(np.float32)).to(self.device), 51 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device)) 52 | 53 | def estimate_rel_pose(self, image: np.ndarray, depth: np.ndarray, init_transform=np.eye(4)): 54 | """ Estimates the relative pose of the current frame with respect to the last frame using RGB-D odometry. 55 | Args: 56 | image: The current RGB image as a numpy ndarray. 57 | depth: The current depth image as a numpy ndarray. 58 | init_transform: An initial transformation guess as a numpy ndarray. Defaults to the identity matrix. 59 | Returns: 60 | The relative transformation matrix as a numpy ndarray. 61 | """ 62 | rgbd = o3d.t.geometry.RGBDImage( 63 | o3d.t.geometry.Image(np.ascontiguousarray(image).astype(np.float32)).to(self.device), 64 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device)) 65 | rel_transform = o3d.t.pipelines.odometry.rgbd_odometry_multi_scale( 66 | self.last_rgbd, rgbd, self.intrinsics, o3c.Tensor(init_transform), 67 | self.depth_scale, self.max_depth, self.criteria_list, self.method) 68 | self.last_rgbd = rgbd.clone() 69 | 70 | # Adjust for the coordinate system difference 71 | rel_transform = rel_transform.transformation.cpu().numpy() 72 | rel_transform[0, [1, 2, 3]] *= -1 73 | rel_transform[1, [0, 2, 3]] *= -1 74 | rel_transform[2, [0, 1, 3]] *= -1 75 | 76 | return rel_transform 77 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/evaluate_merged_map.py: -------------------------------------------------------------------------------- 1 | """ This module is responsible for merging submaps. """ 2 | from argparse import ArgumentParser 3 | 4 | import faiss 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | from torch.utils.data import Dataset 9 | from tqdm import tqdm 10 | 11 | from src.entities.arguments import OptimizationParams 12 | from src.entities.gaussian_model import GaussianModel 13 | from src.entities.losses import isotropic_loss, l1_loss, ssim 14 | from src.utils.utils import (batch_search_faiss, get_render_settings, 15 | np2ptcloud, render_gaussian_model, torch2np) 16 | from src.utils.gaussian_model_utils import BasicPointCloud 17 | 18 | 19 | class RenderFrames(Dataset): 20 | """A dataset class for loading keyframes along with their estimated camera poses and render settings.""" 21 | def __init__(self, dataset, render_poses: np.ndarray, height: int, width: int, fx: float, fy: float, exposures_ab=None): 22 | self.dataset = dataset 23 | self.render_poses = render_poses 24 | self.height = height 25 | self.width = width 26 | self.fx = fx 27 | self.fy = fy 28 | self.device = "cuda" 29 | self.stride = 1 30 | self.exposures_ab = exposures_ab 31 | if len(dataset) > 1000: 32 | self.stride = len(dataset) // 1000 33 | 34 | def __len__(self) -> int: 35 | return len(self.dataset) // self.stride 36 | 37 | def __getitem__(self, idx): 38 | idx = idx * self.stride 39 | color = (torch.from_numpy( 40 | self.dataset[idx][1]) / 255.0).float().to(self.device) 41 | depth = torch.from_numpy(self.dataset[idx][2]).float().to(self.device) 42 | estimate_c2w = self.render_poses[idx] 43 | estimate_w2c = np.linalg.inv(estimate_c2w) 44 | frame = { 45 | "frame_id": idx, 46 | "color": color, 47 | "depth": depth, 48 | "render_settings": get_render_settings( 49 | self.width, self.height, self.dataset.intrinsics, estimate_w2c) 50 | } 51 | if self.exposures_ab is not None: 52 | frame["exposure_ab"] = self.exposures_ab[idx] 53 | return frame 54 | 55 | 56 | def merge_submaps(submaps_paths: list, radius: float = 0.0001, device: str = "cuda") -> o3d.geometry.PointCloud: 57 | """ Merge submaps into a single point cloud, which is then used for global map refinement. 58 | Args: 59 | segments_paths (list): Folder path of the submaps. 60 | radius (float, optional): Nearest neighbor distance threshold for adding a point. Defaults to 0.0001. 61 | device (str, optional): Defaults to "cuda". 62 | 63 | Returns: 64 | o3d.geometry.PointCloud: merged point cloud 65 | """ 66 | pts_index = faiss.IndexFlatL2(3) 67 | if device == "cuda": 68 | pts_index = faiss.index_cpu_to_gpu( 69 | faiss.StandardGpuResources(), 70 | 0, 71 | faiss.IndexIVFFlat(faiss.IndexFlatL2(3), 3, 500, faiss.METRIC_L2)) 72 | pts_index.nprobe = 5 73 | merged_pts = [] 74 | for submap_path in tqdm(submaps_paths, desc="Merging submaps"): 75 | gaussian_params = torch.load(submap_path)["gaussian_params"] 76 | current_pts = gaussian_params["xyz"].to(device).float().contiguous() 77 | pts_index.train(current_pts) 78 | distances, _ = batch_search_faiss(pts_index, current_pts, 8) 79 | neighbor_num = (distances < radius).sum(axis=1).int() 80 | ids_to_include = torch.where(neighbor_num == 0)[0] 81 | pts_index.add(current_pts[ids_to_include]) 82 | merged_pts.append(current_pts[ids_to_include]) 83 | pts = torch2np(torch.vstack(merged_pts)) 84 | pt_cloud = np2ptcloud(pts, np.zeros_like(pts)) 85 | 86 | # Downsampling if the total number of points is too large 87 | if len(pt_cloud.points) > 1_000_000: 88 | voxel_size = 0.02 89 | pt_cloud = pt_cloud.voxel_down_sample(voxel_size) 90 | print(f"Downsampled point cloud to {len(pt_cloud.points)} points") 91 | filtered_pt_cloud, _ = pt_cloud.remove_statistical_outlier(nb_neighbors=40, std_ratio=3.0) 92 | del pts_index 93 | return filtered_pt_cloud 94 | 95 | 96 | def refine_global_map(pt_cloud: o3d.geometry.PointCloud, training_frames: list, max_iterations: int, 97 | export_refine_mesh=False, output_dir=".", 98 | len_frames=None, o3d_intrinsic=None, enable_sh=True, enable_exposure=False) -> GaussianModel: 99 | """Refines a global map based on the merged point cloud and training keyframes frames. 100 | Args: 101 | pt_cloud (o3d.geometry.PointCloud): The merged point cloud used for refinement. 102 | training_frames (list): A list of training frames for map refinement. 103 | max_iterations (int): The maximum number of iterations to perform for refinement. 104 | Returns: 105 | GaussianModel: The refined global map as a Gaussian model. 106 | """ 107 | opt_params = OptimizationParams(ArgumentParser(description="Training script parameters")) 108 | 109 | gaussian_model = GaussianModel(3) 110 | gaussian_model.active_sh_degree = 0 111 | if pt_cloud is None: 112 | output_mesh = output_dir / "mesh" / "cleaned_mesh.ply" 113 | output_mesh = o3d.io.read_triangle_mesh(str(output_mesh)) 114 | pcd = o3d.geometry.PointCloud() 115 | pcd.points = output_mesh.vertices 116 | pcd.colors = output_mesh.vertex_colors 117 | pcd = pcd.voxel_down_sample(voxel_size=0.02) 118 | pcd = BasicPointCloud(points=np.asarray(pcd.points), 119 | colors=np.asarray(pcd.colors)) 120 | gaussian_model.create_from_pcd(pcd, 1.0) 121 | gaussian_model.training_setup(opt_params) 122 | else: 123 | gaussian_model.training_setup(opt_params) 124 | gaussian_model.add_points(pt_cloud) 125 | 126 | iteration = 0 127 | for iteration in tqdm(range(max_iterations), desc="Refinement"): 128 | training_frame = next(training_frames) 129 | gaussian_model.update_learning_rate(iteration) 130 | if enable_sh and iteration > 0 and iteration % 1000 == 0: 131 | gaussian_model.oneupSHdegree() 132 | gt_color, gt_depth, render_settings = ( 133 | training_frame["color"].squeeze(0), 134 | training_frame["depth"].squeeze(0), 135 | training_frame["render_settings"]) 136 | 137 | render_dict = render_gaussian_model(gaussian_model, render_settings) 138 | rendered_color, rendered_depth = (render_dict["color"].permute(1, 2, 0), render_dict["depth"]) 139 | if enable_exposure and training_frame.get("exposure_ab") is not None: 140 | rendered_color = torch.clamp( 141 | rendered_color * torch.exp(training_frame["exposure_ab"][0,0]) + training_frame["exposure_ab"][0,1], 0, 1.) 142 | 143 | reg_loss = isotropic_loss(gaussian_model.get_scaling()) 144 | depth_mask = (gt_depth > 0) 145 | color_loss = (1.0 - opt_params.lambda_dssim) * l1_loss( 146 | rendered_color[depth_mask, :], gt_color[depth_mask, :] 147 | ) + opt_params.lambda_dssim * (1.0 - ssim(rendered_color, gt_color)) 148 | depth_loss = l1_loss( 149 | rendered_depth[:, depth_mask], gt_depth[depth_mask]) 150 | 151 | total_loss = color_loss + depth_loss + reg_loss 152 | total_loss.backward() 153 | 154 | with torch.no_grad(): 155 | if iteration % 500 == 0: 156 | prune_mask = (gaussian_model.get_opacity() < 0.005).squeeze() 157 | gaussian_model.prune_points(prune_mask) 158 | 159 | # Optimizer step 160 | gaussian_model.optimizer.step() 161 | gaussian_model.optimizer.zero_grad(set_to_none=True) 162 | iteration += 1 163 | 164 | try: 165 | if export_refine_mesh: 166 | output_dir = output_dir / "mesh" / "refined_mesh.ply" 167 | scale = 1.0 168 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 169 | voxel_length=5.0 * scale / 512.0, 170 | sdf_trunc=0.04 * scale, 171 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8) 172 | for i in tqdm(range(len_frames), desc="Integrating mesh"): # one cycle 173 | training_frame = next(training_frames) 174 | gt_color, gt_depth, render_settings, estimate_w2c = ( 175 | training_frame["color"].squeeze(0), 176 | training_frame["depth"].squeeze(0), 177 | training_frame["render_settings"], 178 | training_frame["estimate_w2c"]) 179 | 180 | render_dict = render_gaussian_model(gaussian_model, render_settings) 181 | rendered_color, rendered_depth = ( 182 | render_dict["color"].permute(1, 2, 0), render_dict["depth"]) 183 | rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0) 184 | 185 | rendered_color = ( 186 | torch2np(rendered_color) * 255).astype(np.uint8) 187 | rendered_depth = torch2np(rendered_depth.squeeze()) 188 | # rendered_depth = filter_depth_outliers( 189 | # rendered_depth, kernel_size=20, threshold=0.1) 190 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 191 | o3d.geometry.Image(np.ascontiguousarray(rendered_color)), 192 | o3d.geometry.Image(rendered_depth), 193 | depth_scale=scale, 194 | depth_trunc=30, 195 | convert_rgb_to_intensity=False) 196 | volume.integrate( 197 | rgbd, o3d_intrinsic, estimate_w2c.squeeze().cpu().numpy().astype(np.float64)) 198 | 199 | o3d_mesh = volume.extract_triangle_mesh() 200 | o3d.io.write_triangle_mesh(str(output_dir), o3d_mesh) 201 | print(f"Refined mesh saved to {output_dir}") 202 | 203 | except Exception as e: 204 | print(f"Error export_refine_mesh in refine_global_map:\n {e}") 205 | 206 | return gaussian_model 207 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_reconstruction.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | import trimesh 9 | from evaluate_3d_reconstruction import run_evaluation 10 | from tqdm import tqdm 11 | 12 | 13 | def normalize(x): 14 | return x / np.linalg.norm(x) 15 | 16 | 17 | def get_align_transformation(rec_meshfile, gt_meshfile): 18 | """ 19 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh. 20 | """ 21 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 22 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 23 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices) 24 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices) 25 | trans_init = np.eye(4) 26 | threshold = 0.1 27 | reg_p2p = o3d.pipelines.registration.registration_icp( 28 | o3d_rec_pc, 29 | o3d_gt_pc, 30 | threshold, 31 | trans_init, 32 | o3d.pipelines.registration.TransformationEstimationPointToPoint(), 33 | ) 34 | transformation = reg_p2p.transformation 35 | return transformation 36 | 37 | 38 | def check_proj(points, W, H, fx, fy, cx, cy, c2w): 39 | """ 40 | Check if points can be projected into the camera view. 41 | 42 | Returns: 43 | bool: True if there are points can be projected 44 | 45 | """ 46 | c2w = c2w.copy() 47 | c2w[:3, 1] *= -1.0 48 | c2w[:3, 2] *= -1.0 49 | points = torch.from_numpy(points).cuda().clone() 50 | w2c = np.linalg.inv(c2w) 51 | w2c = torch.from_numpy(w2c).cuda().float() 52 | K = torch.from_numpy( 53 | np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]).reshape(3, 3) 54 | ).cuda() 55 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda() 56 | homo_points = ( 57 | torch.cat([points, ones], dim=1).reshape(-1, 4, 1).cuda().float() 58 | ) # (N, 4) 59 | cam_cord_homo = w2c @ homo_points # (N, 4, 1)=(4,4)*(N, 4, 1) 60 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1) 61 | cam_cord[:, 0] *= -1 62 | uv = K.float() @ cam_cord.float() 63 | z = uv[:, -1:] + 1e-5 64 | uv = uv[:, :2] / z 65 | uv = uv.float().squeeze(-1).cpu().numpy() 66 | edge = 0 67 | mask = ( 68 | (0 <= -z[:, 0, 0].cpu().numpy()) 69 | & (uv[:, 0] < W - edge) 70 | & (uv[:, 0] > edge) 71 | & (uv[:, 1] < H - edge) 72 | & (uv[:, 1] > edge) 73 | ) 74 | return mask.sum() > 0 75 | 76 | 77 | def get_cam_position(gt_meshfile): 78 | mesh_gt = trimesh.load(gt_meshfile) 79 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt) 80 | extents[2] *= 0.7 81 | extents[1] *= 0.7 82 | extents[0] *= 0.3 83 | transform = np.linalg.inv(to_origin) 84 | transform[2, 3] += 0.4 85 | return extents, transform 86 | 87 | 88 | def viewmatrix(z, up, pos): 89 | vec2 = normalize(z) 90 | vec1_avg = up 91 | vec0 = normalize(np.cross(vec1_avg, vec2)) 92 | vec1 = normalize(np.cross(vec2, vec0)) 93 | m = np.stack([vec0, vec1, vec2, pos], 1) 94 | return m 95 | 96 | 97 | def calc_2d_metric( 98 | rec_meshfile, gt_meshfile, unseen_gt_pointcloud_file, align=True, n_imgs=1000 99 | ): 100 | """ 101 | 2D reconstruction metric, depth L1 loss. 102 | 103 | """ 104 | H = 500 105 | W = 500 106 | focal = 300 107 | fx = focal 108 | fy = focal 109 | cx = H / 2.0 - 0.5 110 | cy = W / 2.0 - 0.5 111 | 112 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 113 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 114 | pc_unseen = np.load(unseen_gt_pointcloud_file) 115 | if align: 116 | transformation = get_align_transformation(rec_meshfile, gt_meshfile) 117 | rec_mesh = rec_mesh.transform(transformation) 118 | 119 | # get vacant area inside the room 120 | extents, transform = get_cam_position(gt_meshfile) 121 | 122 | vis = o3d.visualization.Visualizer() 123 | vis.create_window(width=W, height=H, visible=False) 124 | vis.get_render_option().mesh_show_back_face = True 125 | errors = [] 126 | for i in tqdm(range(n_imgs)): 127 | while True: 128 | # sample view, and check if unseen region is not inside the camera view 129 | # if inside, then needs to resample 130 | up = [0, 0, -1] 131 | origin = trimesh.sample.volume_rectangular( 132 | extents, 1, transform=transform) 133 | origin = origin.reshape(-1) 134 | tx = round(random.uniform(-10000, +10000), 2) 135 | ty = round(random.uniform(-10000, +10000), 2) 136 | tz = round(random.uniform(-10000, +10000), 2) 137 | # will be normalized, so sample from range [0.0,1.0] 138 | target = [tx, ty, tz] 139 | target = np.array(target) - np.array(origin) 140 | c2w = viewmatrix(target, up, origin) 141 | tmp = np.eye(4) 142 | tmp[:3, :] = c2w # sample translations 143 | c2w = tmp 144 | # if unseen points are projected into current view (c2w) 145 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w) 146 | if ~seen: 147 | break 148 | 149 | param = o3d.camera.PinholeCameraParameters() 150 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array 151 | 152 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic( 153 | W, H, fx, fy, cx, cy) 154 | 155 | ctr = vis.get_view_control() 156 | ctr.set_constant_z_far(20) 157 | ctr.convert_from_pinhole_camera_parameters(param, True) 158 | 159 | vis.add_geometry( 160 | gt_mesh, 161 | reset_bounding_box=True, 162 | ) 163 | ctr.convert_from_pinhole_camera_parameters(param, True) 164 | vis.poll_events() 165 | vis.update_renderer() 166 | gt_depth = vis.capture_depth_float_buffer(True) 167 | gt_depth = np.asarray(gt_depth) 168 | vis.remove_geometry( 169 | gt_mesh, 170 | reset_bounding_box=True, 171 | ) 172 | 173 | vis.add_geometry( 174 | rec_mesh, 175 | reset_bounding_box=True, 176 | ) 177 | ctr.convert_from_pinhole_camera_parameters(param, True) 178 | vis.poll_events() 179 | vis.update_renderer() 180 | ours_depth = vis.capture_depth_float_buffer(True) 181 | ours_depth = np.asarray(ours_depth) 182 | vis.remove_geometry( 183 | rec_mesh, 184 | reset_bounding_box=True, 185 | ) 186 | 187 | # filter missing surfaces where depth is 0 188 | if (ours_depth > 0).sum() > 0: 189 | errors += [ 190 | np.abs(gt_depth[ours_depth > 0] - 191 | ours_depth[ours_depth > 0]).mean() 192 | ] 193 | else: 194 | continue 195 | 196 | errors = np.array(errors) 197 | return {"depth_l1_sample_view": errors.mean() * 100} 198 | 199 | 200 | def clean_mesh(mesh): 201 | mesh_tri = trimesh.Trimesh( 202 | vertices=np.asarray(mesh.vertices), 203 | faces=np.asarray(mesh.triangles), 204 | vertex_colors=np.asarray(mesh.vertex_colors), 205 | ) 206 | components = trimesh.graph.connected_components( 207 | edges=mesh_tri.edges_sorted) 208 | 209 | min_len = 200 210 | components_to_keep = [c for c in components if len(c) >= min_len] 211 | 212 | new_vertices = [] 213 | new_faces = [] 214 | new_colors = [] 215 | vertex_count = 0 216 | for component in components_to_keep: 217 | vertices = mesh_tri.vertices[component] 218 | colors = mesh_tri.visual.vertex_colors[component] 219 | 220 | # Create a mapping from old vertex indices to new vertex indices 221 | index_mapping = { 222 | old_idx: vertex_count + new_idx for new_idx, old_idx in enumerate(component) 223 | } 224 | vertex_count += len(vertices) 225 | 226 | # Select faces that are part of the current connected component and update vertex indices 227 | faces_in_component = mesh_tri.faces[ 228 | np.any(np.isin(mesh_tri.faces, component), axis=1) 229 | ] 230 | reindexed_faces = np.vectorize(index_mapping.get)(faces_in_component) 231 | 232 | new_vertices.extend(vertices) 233 | new_faces.extend(reindexed_faces) 234 | new_colors.extend(colors) 235 | 236 | cleaned_mesh_tri = trimesh.Trimesh(vertices=new_vertices, faces=new_faces) 237 | cleaned_mesh_tri.visual.vertex_colors = np.array(new_colors) 238 | 239 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.nondegenerate_faces()) 240 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.unique_faces()) 241 | print( 242 | f"Mesh cleaning (before/after), vertices: {len(mesh_tri.vertices)}/{len(cleaned_mesh_tri.vertices)}, faces: {len(mesh_tri.faces)}/{len(cleaned_mesh_tri.faces)}") 243 | 244 | cleaned_mesh = o3d.geometry.TriangleMesh( 245 | o3d.utility.Vector3dVector(cleaned_mesh_tri.vertices), 246 | o3d.utility.Vector3iVector(cleaned_mesh_tri.faces), 247 | ) 248 | vertex_colors = np.asarray(cleaned_mesh_tri.visual.vertex_colors)[ 249 | :, :3] / 255.0 250 | cleaned_mesh.vertex_colors = o3d.utility.Vector3dVector( 251 | vertex_colors.astype(np.float64) 252 | ) 253 | 254 | return cleaned_mesh 255 | 256 | 257 | def evaluate_reconstruction( 258 | mesh_path: Path, 259 | gt_mesh_path: Path, 260 | unseen_pc_path: Path, 261 | output_path: Path, 262 | to_clean=True, 263 | ): 264 | if to_clean: 265 | mesh = o3d.io.read_triangle_mesh(str(mesh_path)) 266 | print(mesh) 267 | cleaned_mesh = clean_mesh(mesh) 268 | cleaned_mesh_path = output_path / "mesh" / "cleaned_mesh.ply" 269 | o3d.io.write_triangle_mesh(str(cleaned_mesh_path), cleaned_mesh) 270 | mesh_path = cleaned_mesh_path 271 | 272 | result_3d = run_evaluation( 273 | str(mesh_path.parts[-1]), 274 | str(mesh_path.parent), 275 | str(gt_mesh_path).split("/")[-1].split(".")[0], 276 | distance_thresh=0.01, 277 | full_path_to_gt_ply=gt_mesh_path, 278 | icp_align=True, 279 | ) 280 | 281 | try: 282 | result_2d = calc_2d_metric(str(mesh_path), str(gt_mesh_path), str(unseen_pc_path), align=True, n_imgs=1000) 283 | except Exception as e: 284 | print(e) 285 | result_2d = {"depth_l1_sample_view": None} 286 | 287 | result = {**result_3d, **result_2d} 288 | with open(str(output_path / "reconstruction_metrics.json"), "w") as f: 289 | json.dump(result, f) 290 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_trajectory.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | class NumpyFloatValuesEncoder(json.JSONEncoder): 9 | def default(self, obj): 10 | if isinstance(obj, np.float32): 11 | return float(obj) 12 | return JSONEncoder.default(self, obj) 13 | 14 | 15 | def align(model, data): 16 | """Align two trajectories using the method of Horn (closed-form). 17 | 18 | Input: 19 | model -- first trajectory (3xn) 20 | data -- second trajectory (3xn) 21 | 22 | Output: 23 | rot -- rotation matrix (3x3) 24 | trans -- translation vector (3x1) 25 | trans_error -- translational error per point (1xn) 26 | 27 | """ 28 | np.set_printoptions(precision=3, suppress=True) 29 | model_zerocentered = model - model.mean(1) 30 | data_zerocentered = data - data.mean(1) 31 | 32 | W = np.zeros((3, 3)) 33 | for column in range(model.shape[1]): 34 | W += np.outer(model_zerocentered[:, 35 | column], data_zerocentered[:, column]) 36 | U, d, Vh = np.linalg.linalg.svd(W.transpose()) 37 | S = np.matrix(np.identity(3)) 38 | if (np.linalg.det(U) * np.linalg.det(Vh) < 0): 39 | S[2, 2] = -1 40 | rot = U * S * Vh 41 | trans = data.mean(1) - rot * model.mean(1) 42 | 43 | model_aligned = rot * model + trans 44 | alignment_error = model_aligned - data 45 | 46 | trans_error = np.sqrt( 47 | np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0] 48 | 49 | return rot, trans, trans_error 50 | 51 | 52 | def align_trajectories(t_pred: np.ndarray, t_gt: np.ndarray): 53 | """ 54 | Args: 55 | t_pred: (n, 3) translations 56 | t_gt: (n, 3) translations 57 | Returns: 58 | t_align: (n, 3) aligned translations 59 | """ 60 | t_align = np.matrix(t_pred).transpose() 61 | R, t, _ = align(t_align, np.matrix(t_gt).transpose()) 62 | t_align = R * t_align + t 63 | t_align = np.asarray(t_align).T 64 | return t_align 65 | 66 | 67 | def pose_error(t_pred: np.ndarray, t_gt: np.ndarray, align=False): 68 | """ 69 | Args: 70 | t_pred: (n, 3) translations 71 | t_gt: (n, 3) translations 72 | Returns: 73 | dict: error dict 74 | """ 75 | n = t_pred.shape[0] 76 | trans_error = np.linalg.norm(t_pred - t_gt, axis=1) 77 | return { 78 | "compared_pose_pairs": n, 79 | "rmse": np.sqrt(np.dot(trans_error, trans_error) / n), 80 | "mean": np.mean(trans_error), 81 | "median": np.median(trans_error), 82 | "std": np.std(trans_error), 83 | "min": np.min(trans_error), 84 | "max": np.max(trans_error) 85 | } 86 | 87 | 88 | def plot_2d(pts, ax=None, color="green", label="None", title="3D Trajectory in 2D"): 89 | if ax is None: 90 | _, ax = plt.subplots() 91 | ax.scatter(pts[:, 0], pts[:, 1], color=color, label=label, s=0.7) 92 | ax.set_xlabel('X') 93 | ax.set_ylabel('Y') 94 | ax.set_title(title) 95 | return ax 96 | 97 | 98 | def evaluate_trajectory(estimated_poses: np.ndarray, gt_poses: np.ndarray, output_path: Path): 99 | output_path.mkdir(exist_ok=True, parents=True) 100 | # Truncate the ground truth trajectory if needed 101 | if gt_poses.shape[0] > estimated_poses.shape[0]: 102 | gt_poses = gt_poses[:estimated_poses.shape[0]] 103 | valid = ~np.any(np.isnan(gt_poses) | 104 | np.isinf(gt_poses), axis=(1, 2)) 105 | gt_poses = gt_poses[valid] 106 | estimated_poses = estimated_poses[valid] 107 | 108 | gt_t = gt_poses[:, :3, 3] 109 | estimated_t = estimated_poses[:, :3, 3] 110 | estimated_t_aligned = align_trajectories(estimated_t, gt_t) 111 | ate = pose_error(estimated_t, gt_t) 112 | ate_aligned = pose_error(estimated_t_aligned, gt_t) 113 | 114 | with open(str(output_path / "ate.json"), "w") as f: 115 | f.write(json.dumps(ate, cls=NumpyFloatValuesEncoder)) 116 | 117 | with open(str(output_path / "ate_aligned.json"), "w") as f: 118 | f.write(json.dumps(ate_aligned, cls=NumpyFloatValuesEncoder)) 119 | 120 | ate_rmse, ate_rmse_aligned = ate["rmse"], ate_aligned["rmse"] 121 | ax = plot_2d( 122 | estimated_t, label=f"ate-rmse: {round(ate_rmse * 100, 2)} cm", color="orange") 123 | ax = plot_2d(estimated_t_aligned, ax, 124 | label=f"ate-rsme (aligned): {round(ate_rmse_aligned * 100, 2)} cm", color="lightskyblue") 125 | ax = plot_2d(gt_t, ax, label="GT", color="green") 126 | ax.legend() 127 | plt.savefig(str(output_path / "eval_trajectory.png"), dpi=300) 128 | plt.close() 129 | print( 130 | f"ATE-RMSE: {ate_rmse * 100:.2f} cm, ATE-RMSE (aligned): {ate_rmse_aligned * 100:.2f} cm") 131 | -------------------------------------------------------------------------------- /src/gsr/camera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from src.gsr.loss import image_gradient, image_gradient_mask 7 | from src.utils.graphics_utils import getProjectionMatrix2, getWorld2View2 8 | 9 | 10 | 11 | class Camera(nn.Module): 12 | def __init__( 13 | self, 14 | uid, 15 | color, 16 | depth, 17 | gt_T, 18 | projection_matrix, 19 | fx, 20 | fy, 21 | cx, 22 | cy, 23 | fovx, 24 | fovy, 25 | image_height, 26 | image_width, 27 | device="cuda:0", 28 | ): 29 | super(Camera, self).__init__() 30 | self.uid = uid 31 | self.device = device 32 | 33 | T = torch.eye(4, device=device) 34 | self.R = T[:3, :3] 35 | self.T = T[:3, 3] 36 | self.R_gt = gt_T[:3, :3] 37 | self.T_gt = gt_T[:3, 3] 38 | 39 | self.original_image = color 40 | self.depth = depth 41 | self.grad_mask = None 42 | 43 | self.fx = fx 44 | self.fy = fy 45 | self.cx = cx 46 | self.cy = cy 47 | self.FoVx = fovx 48 | self.FoVy = fovy 49 | self.image_height = image_height 50 | self.image_width = image_width 51 | 52 | self.cam_rot_delta = nn.Parameter( 53 | torch.zeros(3, requires_grad=True, device=device) 54 | ) 55 | self.cam_trans_delta = nn.Parameter( 56 | torch.zeros(3, requires_grad=True, device=device) 57 | ) 58 | 59 | self.exposure_a = nn.Parameter( 60 | torch.tensor([0.0], requires_grad=True, device=device) 61 | ) 62 | self.exposure_b = nn.Parameter( 63 | torch.tensor([0.0], requires_grad=True, device=device) 64 | ) 65 | 66 | self.projection_matrix = projection_matrix.to(device=device) 67 | 68 | @staticmethod 69 | def init_from_dataset(dataset, idx, projection_matrix): 70 | gt_color, gt_depth, gt_pose = dataset[idx] 71 | return Camera( 72 | idx, 73 | gt_color, 74 | gt_depth, 75 | gt_pose, 76 | projection_matrix, 77 | dataset.fx, 78 | dataset.fy, 79 | dataset.cx, 80 | dataset.cy, 81 | dataset.fovx, 82 | dataset.fovy, 83 | dataset.height, 84 | dataset.width, 85 | device=dataset.device, 86 | ) 87 | 88 | @staticmethod 89 | def init_from_gui(uid, T, FoVx, FoVy, fx, fy, cx, cy, H, W): 90 | projection_matrix = getProjectionMatrix2( 91 | znear=0.01, zfar=100.0, fx=fx, fy=fy, cx=cx, cy=cy, W=W, H=H 92 | ).transpose(0, 1) 93 | return Camera( 94 | uid, None, None, T, projection_matrix, fx, fy, cx, cy, FoVx, FoVy, H, W 95 | ) 96 | 97 | @property 98 | def world_view_transform(self): 99 | return getWorld2View2(self.R, self.T).transpose(0, 1) 100 | 101 | @property 102 | def full_proj_transform(self): 103 | return ( 104 | self.world_view_transform.unsqueeze(0).bmm( 105 | self.projection_matrix.unsqueeze(0) 106 | ) 107 | ).squeeze(0) 108 | 109 | @property 110 | def camera_center(self): 111 | return self.world_view_transform.inverse()[3, :3] 112 | 113 | def update_RT(self, R, t): 114 | self.R = R.to(device=self.device) 115 | self.T = t.to(device=self.device) 116 | 117 | def compute_grad_mask(self, config): 118 | edge_threshold = config["Training"]["edge_threshold"] 119 | 120 | gray_img = self.original_image.mean(dim=0, keepdim=True) 121 | gray_grad_v, gray_grad_h = image_gradient(gray_img) 122 | mask_v, mask_h = image_gradient_mask(gray_img) 123 | gray_grad_v = gray_grad_v * mask_v 124 | gray_grad_h = gray_grad_h * mask_h 125 | img_grad_intensity = torch.sqrt(gray_grad_v**2 + gray_grad_h**2) 126 | 127 | if config["Dataset"]["type"] == "replica": 128 | row, col = 32, 32 129 | multiplier = edge_threshold 130 | _, h, w = self.original_image.shape 131 | for r in range(row): 132 | for c in range(col): 133 | block = img_grad_intensity[ 134 | :, 135 | r * int(h / row) : (r + 1) * int(h / row), 136 | c * int(w / col) : (c + 1) * int(w / col), 137 | ] 138 | th_median = block.median() 139 | block[block > (th_median * multiplier)] = 1 140 | block[block <= (th_median * multiplier)] = 0 141 | self.grad_mask = img_grad_intensity 142 | else: 143 | median_img_grad_intensity = img_grad_intensity.median() 144 | self.grad_mask = ( 145 | img_grad_intensity > median_img_grad_intensity * edge_threshold 146 | ) 147 | 148 | def clean(self): 149 | self.original_image = None 150 | self.depth = None 151 | self.grad_mask = None 152 | 153 | self.cam_rot_delta = None 154 | self.cam_trans_delta = None 155 | 156 | self.exposure_a = None 157 | self.exposure_b = None 158 | 159 | @property 160 | def get_T(self): 161 | T = torch.eye(4, device=self.device).float() 162 | T[:3, :3] = self.R 163 | T[:3, 3] = self.T 164 | return T 165 | 166 | @property 167 | def get_T_gt(self): 168 | T = torch.eye(4, device=self.device).float() 169 | T[:3, :3] = self.R_gt 170 | T[:3, 3] = self.T_gt 171 | return T 172 | 173 | def load_rgb(self, image=None): 174 | 175 | if image==None and hasattr(self, "rgb_path"): 176 | self.original_image = torch.from_numpy(np.array(Image.open(self.rgb_path))).permute(2, 0, 1).cuda().float() / 255.0 177 | self.compute_grad_mask(self.config) 178 | 179 | if image is not None: 180 | self.original_image = image 181 | self.compute_grad_mask(self.config) -------------------------------------------------------------------------------- /src/gsr/descriptor.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import sys 4 | 5 | import torch 6 | from einops import * 7 | 8 | sys.path.append("thirdparty/Hierarchical-Localization") 9 | 10 | with contextlib.redirect_stderr(io.StringIO()): 11 | from hloc.extractors.netvlad import NetVLAD 12 | 13 | 14 | class GlobalDesc: 15 | 16 | def __init__(self): 17 | conf = { 18 | 'output': 'global-feats-netvlad', 19 | 'model': {'name': 'netvlad'}, 20 | 'preprocessing': {'resize_max': 1024}, 21 | } 22 | self.netvlad = NetVLAD(conf).to('cuda').eval() 23 | 24 | @torch.no_grad() 25 | def __call__(self, images): 26 | assert parse_shape(images, '_ rgb _ _') == dict(rgb=3) 27 | assert (images.dtype == torch.float) and (images.max() <= 1.0001), images.max() 28 | return self.netvlad({'image': images})['global_descriptor'] # B 4096 -------------------------------------------------------------------------------- /src/gsr/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | """ 3 | https://github.com/muskie82/MonoGS/blob/main/utils/slam_utils.py 4 | """ 5 | 6 | def image_gradient(image): 7 | # Compute image gradient using Scharr Filter 8 | c = image.shape[0] 9 | conv_y = torch.tensor( 10 | [[3, 0, -3], [10, 0, -10], [3, 0, -3]], dtype=torch.float32, device="cuda" 11 | ) 12 | conv_x = torch.tensor( 13 | [[3, 10, 3], [0, 0, 0], [-3, -10, -3]], dtype=torch.float32, device="cuda" 14 | ) 15 | normalizer = 1.0 / torch.abs(conv_y).sum() 16 | p_img = torch.nn.functional.pad(image, (1, 1, 1, 1), mode="reflect")[None] 17 | img_grad_v = normalizer * torch.nn.functional.conv2d( 18 | p_img, conv_x.view(1, 1, 3, 3).repeat(c, 1, 1, 1), groups=c 19 | ) 20 | img_grad_h = normalizer * torch.nn.functional.conv2d( 21 | p_img, conv_y.view(1, 1, 3, 3).repeat(c, 1, 1, 1), groups=c 22 | ) 23 | return img_grad_v[0], img_grad_h[0] 24 | 25 | 26 | def image_gradient_mask(image, eps=0.01): 27 | # Compute image gradient mask 28 | c = image.shape[0] 29 | conv_y = torch.ones((1, 1, 3, 3), dtype=torch.float32, device="cuda") 30 | conv_x = torch.ones((1, 1, 3, 3), dtype=torch.float32, device="cuda") 31 | p_img = torch.nn.functional.pad(image, (1, 1, 1, 1), mode="reflect")[None] 32 | p_img = torch.abs(p_img) > eps 33 | img_grad_v = torch.nn.functional.conv2d( 34 | p_img.float(), conv_x.repeat(c, 1, 1, 1), groups=c 35 | ) 36 | img_grad_h = torch.nn.functional.conv2d( 37 | p_img.float(), conv_y.repeat(c, 1, 1, 1), groups=c 38 | ) 39 | 40 | return img_grad_v[0] == torch.sum(conv_x), img_grad_h[0] == torch.sum(conv_y) 41 | 42 | 43 | def depth_reg(depth, gt_image, huber_eps=0.1, mask=None): 44 | mask_v, mask_h = image_gradient_mask(depth) 45 | gray_grad_v, gray_grad_h = image_gradient(gt_image.mean(dim=0, keepdim=True)) 46 | depth_grad_v, depth_grad_h = image_gradient(depth) 47 | gray_grad_v, gray_grad_h = gray_grad_v[mask_v], gray_grad_h[mask_h] 48 | depth_grad_v, depth_grad_h = depth_grad_v[mask_v], depth_grad_h[mask_h] 49 | 50 | w_h = torch.exp(-10 * gray_grad_h**2) 51 | w_v = torch.exp(-10 * gray_grad_v**2) 52 | err = (w_h * torch.abs(depth_grad_h)).mean() + ( 53 | w_v * torch.abs(depth_grad_v) 54 | ).mean() 55 | return err 56 | 57 | 58 | def get_loss_tracking(config, image, depth, opacity, viewpoint, initialization=False): 59 | image_ab = (torch.exp(viewpoint.exposure_a)) * image + viewpoint.exposure_b 60 | if config["Training"]["monocular"]: 61 | return get_loss_tracking_rgb(config, image_ab, depth, opacity, viewpoint) 62 | return get_loss_tracking_rgbd(config, image_ab, depth, opacity, viewpoint) 63 | 64 | 65 | def get_loss_tracking_rgb(config, image, depth, opacity, viewpoint): 66 | gt_image = viewpoint.original_image.cuda() 67 | _, h, w = gt_image.shape 68 | mask_shape = (1, h, w) 69 | rgb_boundary_threshold = config["Training"]["rgb_boundary_threshold"] 70 | rgb_pixel_mask = (gt_image.sum(dim=0) > rgb_boundary_threshold).view(*mask_shape) 71 | rgb_pixel_mask = rgb_pixel_mask * viewpoint.grad_mask 72 | l1 = opacity * torch.abs(image * rgb_pixel_mask - gt_image * rgb_pixel_mask) 73 | return l1.mean() 74 | 75 | 76 | def get_loss_tracking_rgbd( 77 | config, image, depth, opacity, viewpoint, initialization=False 78 | ): 79 | alpha = config["Training"]["alpha"] if "alpha" in config["Training"] else 0.95 80 | 81 | gt_depth = torch.from_numpy(viewpoint.depth).to( 82 | dtype=torch.float32, device=image.device 83 | )[None] 84 | depth_pixel_mask = (gt_depth > 0.01).view(*depth.shape) 85 | opacity_mask = (opacity > 0.95).view(*depth.shape) 86 | 87 | l1_rgb = get_loss_tracking_rgb(config, image, depth, opacity, viewpoint) 88 | depth_mask = depth_pixel_mask * opacity_mask 89 | l1_depth = torch.abs(depth * depth_mask - gt_depth * depth_mask) 90 | return alpha * l1_rgb + (1 - alpha) * l1_depth.mean() -------------------------------------------------------------------------------- /src/gsr/overlap.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | import torch 4 | import faiss 5 | import faiss.contrib.torch_utils 6 | from pycg import vis 7 | 8 | from src.utils.utils import batch_search_faiss 9 | 10 | def get_correspondences(src_pcd, tgt_pcd, trans, search_voxel_size, K=None): 11 | src_pcd.transform(trans) 12 | pcd_tree = o3d.geometry.KDTreeFlann(tgt_pcd) 13 | 14 | correspondences = [] 15 | for i, point in enumerate(src_pcd.points): 16 | [count, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) 17 | if K is not None: 18 | idx = idx[:K] 19 | for j in idx: 20 | correspondences.append([i, j]) 21 | 22 | correspondences = np.array(correspondences) 23 | correspondences = torch.from_numpy(correspondences) 24 | return correspondences 25 | 26 | def get_overlap_ratio(source,target,threshold=0.03): 27 | """ 28 | We compute overlap ratio from source point cloud to target point cloud 29 | """ 30 | pcd_tree = o3d.geometry.KDTreeFlann(target) 31 | 32 | match_count=0 33 | for i, point in enumerate(source.points): 34 | [count, _, _] = pcd_tree.search_radius_vector_3d(point, threshold) 35 | if(count!=0): 36 | match_count+=1 37 | 38 | overlap_ratio = match_count / min(len(source.points), len(target.points)) 39 | return overlap_ratio 40 | 41 | def compute_overlap_gaussians(src_gs, tgt_gs, threshold=0.03): 42 | """compute the overlap ratio and correspondences between two gaussians 43 | 44 | Args: 45 | src_gs: _description_ 46 | tgt_ts: _description_ 47 | threshold (float, optional): _description_. Defaults to 0.03. 48 | """ 49 | src_tensor = src_gs.get_xyz().detach() 50 | tgt_tensor = tgt_gs.get_xyz().detach() 51 | cpu_index = faiss.IndexFlatL2(3) 52 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 53 | gpu_index.add(tgt_tensor) 54 | 55 | distances, _ = batch_search_faiss(gpu_index, src_tensor, 1) 56 | mask_src = distances < threshold 57 | 58 | cpu_index = faiss.IndexFlatL2(3) 59 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 60 | gpu_index.add(src_tensor) 61 | 62 | distances, _ = batch_search_faiss(gpu_index, tgt_tensor, 1) 63 | mask_tgt = distances < threshold 64 | 65 | faiss_overlap_ratio = min(mask_src.sum()/len(mask_src), mask_tgt.sum()/len(mask_tgt)) 66 | 67 | return faiss_overlap_ratio 68 | 69 | def visualize_overlap(pc1, pc2, corr): 70 | import matplotlib.cm as cm 71 | src_pcd = o3d.geometry.PointCloud() 72 | src_pcd.points = o3d.utility.Vector3dVector(pc1.cpu().numpy()) 73 | tgt_pcd = o3d.geometry.PointCloud() 74 | tgt_pcd.points = o3d.utility.Vector3dVector(pc2.cpu().numpy()) 75 | # corr = get_correspondences(src_pcd, tgt_pcd, np.eye(4), 0.05) 76 | 77 | color_1 = cm.tab10(0) 78 | color_2 = cm.tab10(1) 79 | overlap_color = cm.tab10(2) 80 | 81 | color_src = np.ones_like(pc1.cpu().numpy()) 82 | color_tgt = np.ones_like(pc2.cpu().numpy()) 83 | 84 | if len(corr)>0: 85 | color_src[corr[:,0].cpu().numpy()] = np.array(color_1)[:3] 86 | color_tgt[corr[:,1].cpu().numpy()] = np.array(color_2)[:3] 87 | 88 | vis_src = vis.pointcloud(pc1.cpu().numpy(), color=color_src, is_sphere=True) 89 | vis_tgt = vis.pointcloud(pc2.cpu().numpy(), color=color_tgt, is_sphere=True) 90 | vis.show_3d([vis_src, vis_tgt],[vis_src], [vis_tgt], use_new_api=True) -------------------------------------------------------------------------------- /src/gsr/pcr.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | 3 | def preprocess_point_cloud(pcd, voxel_size, camera_location): 4 | pcd_down = pcd.voxel_down_sample(voxel_size) 5 | pcd_down.estimate_normals( 6 | o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2.0, 7 | max_nn=30)) 8 | 9 | pcd_down.orient_normals_towards_camera_location( 10 | camera_location=camera_location) 11 | 12 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature( 13 | pcd_down, 14 | o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 5.0, 15 | max_nn=100)) 16 | return (pcd_down, pcd_fpfh) 17 | 18 | 19 | def execute_global_registration(source_down, target_down, source_fpfh, 20 | target_fpfh, voxel_size): 21 | distance_threshold = voxel_size * 1.5 22 | print(":: RANSAC registration on downsampled point clouds.") 23 | print(" Downsampling voxel size is %.3f," % voxel_size) 24 | print(" Using a liberal distance threshold %.3f." % distance_threshold) 25 | result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( 26 | source_down, target_down, source_fpfh, target_fpfh, True, 27 | distance_threshold, 28 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 29 | 3, [ 30 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength( 31 | 0.9), 32 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance( 33 | distance_threshold) 34 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(10000000, 0.99999)) 35 | return result 36 | 37 | def refine_registration(source, target, source_fpfh, target_fpfh, voxel_size, init_trans): 38 | distance_threshold = voxel_size 39 | print(":: Point-to-plane ICP registration is applied on original point") 40 | print(" clouds to refine the alignment. This time we use a strict") 41 | print(" distance threshold %.3f." % distance_threshold) 42 | result = o3d.pipelines.registration.registration_icp( 43 | source, target, distance_threshold, init_trans, 44 | o3d.pipelines.registration.TransformationEstimationPointToPlane(), 45 | o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=500)) 46 | return result -------------------------------------------------------------------------------- /src/gsr/renderer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | 14 | import torch 15 | from diff_gaussian_rasterization import ( 16 | GaussianRasterizationSettings, 17 | GaussianRasterizer, 18 | ) 19 | 20 | from src.entities.gaussian_model import GaussianModel 21 | from src.utils.gaussian_model_utils import eval_sh 22 | 23 | 24 | def render( 25 | viewpoint_camera, 26 | pc: GaussianModel, 27 | pipe, 28 | bg_color: torch.Tensor, 29 | scaling_modifier=1.0, 30 | override_color=None, 31 | mask=None 32 | ): 33 | """ 34 | Render the scene. 35 | 36 | Background tensor (bg_color) must be on GPU! 37 | """ 38 | 39 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 40 | if pc.get_xyz().shape[0] == 0: 41 | return None 42 | 43 | screenspace_points = ( 44 | torch.zeros_like( 45 | pc.get_xyz(), dtype=pc.get_xyz().dtype, requires_grad=True, device="cuda" 46 | ) 47 | + 0 48 | ) 49 | try: 50 | screenspace_points.retain_grad() 51 | except Exception: 52 | pass 53 | 54 | # Set up rasterization configuration 55 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 56 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 57 | 58 | raster_settings = GaussianRasterizationSettings( 59 | image_height=int(viewpoint_camera.image_height), 60 | image_width=int(viewpoint_camera.image_width), 61 | tanfovx=tanfovx, 62 | tanfovy=tanfovy, 63 | bg=bg_color, 64 | scale_modifier=scaling_modifier, 65 | viewmatrix=viewpoint_camera.world_view_transform, 66 | projmatrix=viewpoint_camera.full_proj_transform, 67 | projmatrix_raw=viewpoint_camera.projection_matrix, 68 | sh_degree=pc.active_sh_degree, 69 | campos=viewpoint_camera.camera_center, 70 | prefiltered=False, 71 | debug=False, 72 | ) 73 | 74 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 75 | 76 | means3D = pc.get_xyz() 77 | means2D = screenspace_points 78 | opacity = pc.get_opacity() 79 | 80 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 81 | # scaling / rotation by the rasterizer. 82 | scales = None 83 | rotations = None 84 | cov3D_precomp = None 85 | if pipe.compute_cov3D_python: 86 | cov3D_precomp = pc.get_covariance(scaling_modifier) 87 | else: 88 | # check if the covariance is isotropic 89 | if pc.get_scaling().shape[-1] == 1: 90 | scales = pc.get_scaling().repeat(1, 3) 91 | else: 92 | scales = pc.get_scaling() 93 | rotations = pc.get_rotation() 94 | 95 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 96 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 97 | shs = None 98 | colors_precomp = None 99 | if colors_precomp is None: 100 | if pipe.convert_SHs_python: 101 | shs_view = pc.get_features().transpose(1, 2).view( 102 | -1, 3, (pc.max_sh_degree + 1) ** 2 103 | ) 104 | dir_pp = pc.get_xyz() - viewpoint_camera.camera_center.repeat( 105 | pc.get_features().shape[0], 1 106 | ) 107 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 108 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 109 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 110 | else: 111 | shs = pc.get_features() 112 | else: 113 | colors_precomp = override_color 114 | 115 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 116 | if mask is not None: 117 | rendered_image, radii, depth, opacity = rasterizer( 118 | means3D=means3D[mask], 119 | means2D=means2D[mask], 120 | shs=shs[mask], 121 | colors_precomp=colors_precomp[mask] if colors_precomp is not None else None, 122 | opacities=opacity[mask], 123 | scales=scales[mask], 124 | rotations=rotations[mask], 125 | cov3D_precomp=cov3D_precomp[mask] if cov3D_precomp is not None else None, 126 | theta=viewpoint_camera.cam_rot_delta, 127 | rho=viewpoint_camera.cam_trans_delta, 128 | ) 129 | else: 130 | rendered_image, radii, depth, opacity, n_touched = rasterizer( 131 | means3D=means3D, 132 | means2D=means2D, 133 | shs=shs, 134 | colors_precomp=colors_precomp, 135 | opacities=opacity, 136 | scales=scales, 137 | rotations=rotations, 138 | cov3D_precomp=cov3D_precomp, 139 | theta=viewpoint_camera.cam_rot_delta, 140 | rho=viewpoint_camera.cam_trans_delta, 141 | ) 142 | 143 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 144 | # They will be excluded from value updates used in the splitting criteria. 145 | return { 146 | "render": rendered_image, 147 | "viewspace_points": screenspace_points, 148 | "visibility_filter": radii > 0, 149 | "radii": radii, 150 | "depth": depth, 151 | "opacity": opacity, 152 | "n_touched": n_touched, 153 | } 154 | -------------------------------------------------------------------------------- /src/gsr/se3/numpy_se3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | 5 | def identity(): 6 | return np.eye(3, 4) 7 | 8 | 9 | def transform(g: np.ndarray, pts: np.ndarray): 10 | """ Applies the SE3 transform 11 | 12 | Args: 13 | g: SE3 transformation matrix of size ([B,] 3/4, 4) 14 | pts: Points to be transformed ([B,] N, 3) 15 | 16 | Returns: 17 | transformed points of size (N, 3) 18 | """ 19 | rot = g[..., :3, :3] # (3, 3) 20 | trans = g[..., :3, 3] # (3) 21 | 22 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :] 23 | return transformed 24 | 25 | 26 | def inverse(g: np.ndarray): 27 | """Returns the inverse of the SE3 transform 28 | 29 | Args: 30 | g: ([B,] 3/4, 4) transform 31 | 32 | Returns: 33 | ([B,] 3/4, 4) matrix containing the inverse 34 | 35 | """ 36 | rot = g[..., :3, :3] # (3, 3) 37 | trans = g[..., :3, 3] # (3) 38 | 39 | inv_rot = np.swapaxes(rot, -1, -2) 40 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1) 41 | if g.shape[-2] == 4: 42 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 43 | 44 | return inverse_transform 45 | 46 | 47 | def concatenate(a: np.ndarray, b: np.ndarray): 48 | """ Concatenate two SE3 transforms 49 | 50 | Args: 51 | a: First transform ([B,] 3/4, 4) 52 | b: Second transform ([B,] 3/4, 4) 53 | 54 | Returns: 55 | a*b ([B, ] 3/4, 4) 56 | 57 | """ 58 | 59 | r_a, t_a = a[..., :3, :3], a[..., :3, 3] 60 | r_b, t_b = b[..., :3, :3], b[..., :3, 3] 61 | 62 | r_ab = r_a @ r_b 63 | t_ab = r_a @ t_b[..., None] + t_a[..., None] 64 | 65 | concatenated = np.concatenate([r_ab, t_ab], axis=-1) 66 | 67 | if a.shape[-2] == 4: 68 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 69 | 70 | return concatenated 71 | 72 | 73 | def from_xyzquat(xyzquat): 74 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw 75 | 76 | Args: 77 | xyzquat: np.array (7,) containing translation and quaterion 78 | 79 | Returns: 80 | SE3 matrix (4, 4) 81 | """ 82 | rot = Rotation.from_quat(xyzquat[3:]) 83 | trans = rot.apply(-xyzquat[:3]) 84 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1) 85 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 86 | 87 | return transform -------------------------------------------------------------------------------- /src/gsr/se3/torch_se3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rigid body transformation group 2 | """ 3 | import torch 4 | 5 | 6 | def identity(batch_size): 7 | return torch.eye(3, 4)[None, ...].repeat(batch_size, 1, 1) 8 | 9 | 10 | def inverse(g): 11 | """ Returns the inverse of the SE3 transform 12 | 13 | Args: 14 | g: (B, 3/4, 4) transform 15 | 16 | Returns: 17 | (B, 3, 4) matrix containing the inverse 18 | 19 | """ 20 | # Compute inverse 21 | rot = g[..., 0:3, 0:3] 22 | trans = g[..., 0:3, 3] 23 | inverse_transform = torch.cat([rot.transpose(-1, -2), rot.transpose(-1, -2) @ -trans[..., None]], dim=-1) 24 | 25 | return inverse_transform 26 | 27 | 28 | def concatenate(a, b): 29 | """Concatenate two SE3 transforms, 30 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix) 31 | 32 | Args: 33 | a: (B, 3/4, 4) 34 | b: (B, 3/4, 4) 35 | 36 | Returns: 37 | (B, 3/4, 4) 38 | """ 39 | 40 | rot1 = a[..., :3, :3] 41 | trans1 = a[..., :3, 3] 42 | rot2 = b[..., :3, :3] 43 | trans2 = b[..., :3, 3] 44 | 45 | rot_cat = rot1 @ rot2 46 | trans_cat = rot1 @ trans2[..., None] + trans1[..., None] 47 | concatenated = torch.cat([rot_cat, trans_cat], dim=-1) 48 | 49 | return concatenated 50 | 51 | 52 | def transform(g, a, normals=None): 53 | """ Applies the SE3 transform 54 | 55 | Args: 56 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4) 57 | a: Points to be transformed (N, 3) or (B, N, 3) 58 | normals: (Optional). If provided, normals will be transformed 59 | 60 | Returns: 61 | transformed points of size (N, 3) or (B, N, 3) 62 | 63 | """ 64 | R = g[..., :3, :3] # (B, 3, 3) 65 | p = g[..., :3, 3] # (B, 3) 66 | 67 | if len(g.size()) == len(a.size()): 68 | b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :] 69 | else: 70 | raise NotImplementedError 71 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked 72 | 73 | if normals is not None: 74 | rotated_normals = normals @ R.transpose(-1, -2) 75 | return b, rotated_normals 76 | 77 | else: 78 | return b 79 | 80 | 81 | def Rt_to_SE3(R, t): 82 | ''' 83 | Merge 3D rotation and translation into 4x4 SE transformation. 84 | Args: 85 | R: SO(3) rotation matrix (B, 3, 3) 86 | t: translation vector (B, 3, 1) 87 | ''' 88 | B = R.shape[0] 89 | SE3 = torch.zeros((B, 4, 4)).to(R.device) 90 | SE3[:,3,3] = 1 91 | SE3[:,:3,:3] = R 92 | SE3[:,:3,[3]] = t 93 | return SE3 -------------------------------------------------------------------------------- /src/gsr/solver.py: -------------------------------------------------------------------------------- 1 | import torch, roma 2 | import numpy as np 3 | import copy 4 | 5 | from src.gsr.renderer import render 6 | from src.gsr.loss import get_loss_tracking 7 | from src.gsr.overlap import compute_overlap_gaussians 8 | from src.utils.pose_utils import update_pose 9 | 10 | 11 | class CustomPipeline: 12 | convert_SHs_python = False 13 | compute_cov3D_python = False 14 | debug = False 15 | 16 | def viewpoint_localizer(viewpoint, gaussians, base_lr: float=1e-3): 17 | """Localize a single viewpoint in a 3DGS 18 | 19 | Args: 20 | viewpoint (Camera): Camera instance 21 | gaussians (Gaussians): 3D Gaussians to locate the viewpoint 22 | base_lr (float, optional). Defaults to 1e-3. 23 | 24 | Returns: 25 | _type_: _description_ 26 | """ 27 | opt_params = [] 28 | pipe = CustomPipeline() 29 | bg_color = torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda", requires_grad=False) 30 | config = { 31 | 'Training': { 32 | 'monocular': False, 33 | "rgb_boundary_threshold": 0.01, 34 | } 35 | } 36 | 37 | init_T = viewpoint.get_T.detach() 38 | 39 | opt_params.append( 40 | { 41 | "params": [viewpoint.cam_rot_delta], 42 | "lr": 3*base_lr, 43 | "name": "rot_{}".format(viewpoint.uid), 44 | } 45 | ) 46 | opt_params.append( 47 | { 48 | "params": [viewpoint.cam_trans_delta], 49 | "lr": base_lr, 50 | "name": "trans_{}".format(viewpoint.uid), 51 | } 52 | ) 53 | opt_params.append( 54 | { 55 | "params": [viewpoint.exposure_a], 56 | "lr": 0.01, 57 | "name": "exposure_a_{}".format(viewpoint.uid), 58 | } 59 | ) 60 | opt_params.append( 61 | { 62 | "params": [viewpoint.exposure_b], 63 | "lr": 0.01, 64 | "name": "exposure_b_{}".format(viewpoint.uid), 65 | } 66 | ) 67 | optimizer = torch.optim.Adam(opt_params) 68 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.98, patience=5, verbose=False) 69 | 70 | loss_log = [] 71 | opt_iterations = 100 72 | for tracking_itr in range(opt_iterations): 73 | optimizer.zero_grad() 74 | render_pkg = render( 75 | viewpoint, gaussians, pipe, bg_color 76 | ) 77 | image, depth, opacity = ( 78 | render_pkg["render"], 79 | render_pkg["depth"], 80 | render_pkg["opacity"], 81 | ) 82 | 83 | loss = get_loss_tracking(config, image, depth, opacity, viewpoint) 84 | loss.backward() 85 | loss_log.append(loss.item()) 86 | 87 | with torch.no_grad(): 88 | optimizer.step() 89 | scheduler.step(loss) 90 | converged = update_pose(viewpoint) 91 | 92 | if converged: 93 | break 94 | 95 | rel_tsfm = (init_T.inverse() @ viewpoint.get_T).inverse() 96 | loss_residual = loss.item() 97 | 98 | return converged, rel_tsfm, loss_residual, loss_log 99 | 100 | def gaussian_registration(src_dict, tgt_dict, config: dict, visualize=False): 101 | """_summary_ 102 | 103 | Args: 104 | src_dict (dict): dictionary of source gaussians and its keyframes 105 | tgt_dict (dict): dictionary of target gaussians and its keyframes 106 | base_lr (float, optional): the base learning rate for optimization. Defaults to 5e-3. 107 | 108 | Returns: 109 | dict: dictionary of registration result 110 | """ 111 | 112 | # print("Pairwise registration ...") 113 | init_overlap = compute_overlap_gaussians(src_dict['gaussians'], tgt_dict['gaussians'], 0.1) 114 | if init_overlap< 0.2: 115 | print("Initial overlap between two submaps are too small, skipping ...") 116 | return { 117 | 'successful': False, 118 | "pred_tsfm": torch.eye(4).cuda(), 119 | "gt_tsfm": torch.eye(4).cuda(), 120 | "overlap": init_overlap.item() 121 | } 122 | 123 | src_3dgs, src_view_list = copy.deepcopy(src_dict['gaussians']), copy.deepcopy(src_dict['cameras']) 124 | tgt_3dgs, tgt_view_list = copy.deepcopy(tgt_dict['gaussians']), copy.deepcopy(tgt_dict['cameras']) 125 | 126 | # compute gt tsfm 127 | src_keyframe= src_dict['cameras'][0].get_T.detach() 128 | src_gt= src_dict['cameras'][0].get_T_gt.detach() 129 | tgt_keyframe= tgt_dict['cameras'][0].get_T.detach() 130 | tgt_gt= tgt_dict['cameras'][0].get_T_gt.detach() 131 | delta_src = src_gt.inverse() @ src_keyframe 132 | delta_tgt = tgt_gt.inverse() @ tgt_keyframe 133 | gt_tsfm = delta_tgt.inverse() @ delta_src 134 | 135 | # similarity choosing 136 | device = "cuda" if torch.cuda.is_available() else "cpu" 137 | src_desc, tgt_desc = src_dict['kf_desc'], tgt_dict['kf_desc'] 138 | 139 | score_cross = torch.einsum("id,jd->ij", src_desc.to(device), tgt_desc.to(device)) 140 | score_best_src, _ = score_cross.topk(1) 141 | _, ii = score_best_src.view(-1).topk(2) 142 | 143 | score_best_tgt, _ = score_cross.T.topk(1) 144 | _, jj = score_best_tgt.view(-1).topk(2) 145 | 146 | src_view_list = [src_view_list[i.item()] for i in ii] 147 | tgt_view_list = [tgt_view_list[j.item()] for j in jj] 148 | 149 | pred_list, residual_list, converged_list, loss_log_list = [], [], [], [] 150 | 151 | pipe = CustomPipeline() 152 | bg_color = torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda", requires_grad=False) 153 | # per-cam 154 | for viewpoint in src_view_list: 155 | 156 | # use rendered image as target not the raw observation 157 | if config["use_render"]: 158 | render_pkg = render(viewpoint, src_3dgs, pipe, bg_color) 159 | viewpoint.load_rgb(render_pkg['render'].detach()) 160 | viewpoint.depth = render_pkg['depth'].squeeze().detach().cpu().numpy() 161 | else: 162 | viewpoint.load_rgb() 163 | converged, pred_tsfm, residual, loss_log = viewpoint_localizer(viewpoint, tgt_3dgs, config["base_lr"]) 164 | pred_list.append(pred_tsfm) 165 | residual_list.append(residual) 166 | converged_list.append(converged) 167 | loss_log_list.append(loss_log) 168 | 169 | for viewpoint in tgt_view_list: 170 | if config["use_render"]: 171 | render_pkg = render(viewpoint, tgt_3dgs, pipe, bg_color) 172 | viewpoint.load_rgb(render_pkg['render'].detach()) 173 | viewpoint.depth = render_pkg['depth'].squeeze().detach().cpu().numpy() 174 | else: 175 | viewpoint.load_rgb() 176 | converged, pred_tsfm, residual, loss_log = viewpoint_localizer(viewpoint, src_3dgs, config["base_lr"]) 177 | pred_list.append(pred_tsfm.inverse()) 178 | residual_list.append(residual) 179 | converged_list.append(converged) 180 | loss_log_list.append(loss_log) 181 | 182 | 183 | pred_tsfms = torch.stack(pred_list) 184 | residuals = torch.Tensor(residual_list).cuda().float() 185 | # probability based on residuals 186 | prob = 1/residuals / (1/residuals).sum() 187 | 188 | M = torch.sum(prob[:, None, None] * pred_tsfms[:,:3,:3], dim=0) 189 | try: 190 | R_w = roma.special_procrustes(M) 191 | except Exception as e: 192 | print(f"Error in roma.special_procrustes: {e}") 193 | return { 194 | 'successful': False, 195 | "pred_tsfm": torch.eye(4).cuda(), 196 | "gt_tsfm": torch.eye(4).cuda(), 197 | "overlap": init_overlap.item() 198 | } 199 | t_w = torch.sum(prob[:, None] * pred_tsfms[:,:3, 3], dim=0) 200 | 201 | best_tsfm = torch.eye(4).cuda().float() 202 | best_tsfm[:3,:3] = R_w 203 | best_tsfm[:3, 3] = t_w 204 | 205 | result_dict = { 206 | "gt_tsfm": gt_tsfm, 207 | "pred_tsfm": best_tsfm, 208 | "successful": True, 209 | "best_viewpoint": src_view_list[0].get_T 210 | } 211 | 212 | if visualize: 213 | import matplotlib 214 | import matplotlib.pyplot as plt 215 | from src.gsr.utils import visualize_registration 216 | matplotlib.use('TkAgg') 217 | plt.figure(figsize=(10, 6)) 218 | for log in loss_log_list: 219 | plt.plot(log) 220 | 221 | plt.xlabel('Epoch') 222 | plt.ylabel('Loss') 223 | plt.title('Loss Curves of 10 Independent Optimizations') 224 | plt.legend([f'Run {i+1}' for i in range(len(loss_log_list))], loc='upper right') 225 | plt.grid(True) 226 | plt.show() 227 | 228 | visualize_registration(src_3dgs, tgt_3dgs, best_tsfm, gt_tsfm) 229 | 230 | del src_3dgs, src_view_list, tgt_3dgs, tgt_view_list 231 | return result_dict 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /src/gsr/utils.py: -------------------------------------------------------------------------------- 1 | import json, yaml 2 | import numpy as np 3 | import torch 4 | from pycg import vis 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import open3d as o3d 8 | 9 | 10 | from src.utils.graphics_utils import getProjectionMatrix2, focal2fov 11 | from src.gsr.se3.numpy_se3 import transform 12 | from src.gsr.camera import Camera 13 | 14 | def read_json_data(file_path): 15 | try: 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | data = json.load(f) 18 | return data 19 | except FileNotFoundError: 20 | print(f"File '{file_path}' not found.") 21 | return None 22 | except json.JSONDecodeError as e: 23 | print(f"Error decoding JSON: {e}") 24 | return None 25 | 26 | def read_trajectory(file_path, slam_config, scale = 1., device = "cuda:0"): 27 | """read trajectory data from plot output to list of Cameras [to use in 3DGS] 28 | 29 | Args: 30 | file_path (string): path 31 | """ 32 | trj_data = dict() 33 | json_data = read_json_data(file_path) 34 | dataset = slam_config['Dataset']['type'] 35 | 36 | if json_data: 37 | # Access the data as needed 38 | trj_data["trj_id"] = json_data["trj_id"] 39 | trj_data["trj_est"] = torch.Tensor(json_data["trj_est"]).to(device).float() 40 | trj_data["trj_gt"] = torch.Tensor(json_data["trj_gt"]).to(device).float() 41 | # print("Trajectory loaded successfully!") 42 | 43 | cam_dict = dict() 44 | calibration = slam_config["Dataset"]["Calibration"] 45 | 46 | proj_matrix = getProjectionMatrix2( 47 | znear=0.01, 48 | zfar=100.0, 49 | fx = calibration["fx"] / scale, 50 | fy = calibration["fy"] / scale, 51 | cx = calibration["cx"] / scale, 52 | cy = calibration["cy"] / scale, 53 | W = calibration["width"] / scale, 54 | H = calibration["height"] / scale, 55 | ).T 56 | 57 | fovx = focal2fov(calibration['fx'], calibration['width']) 58 | fovy = focal2fov(calibration['fy'], calibration['height']) 59 | for id, est_pose, gt_pose in zip(trj_data["trj_id"], trj_data["trj_est"], trj_data["trj_gt"]): 60 | T_gt = torch.linalg.inv(gt_pose) 61 | cam_i = Camera(id, None, None, 62 | T_gt, 63 | proj_matrix, 64 | calibration["fx"]/ scale, 65 | calibration["fy"]/ scale, 66 | calibration["cx"]/ scale, 67 | calibration["cy"]/ scale, 68 | fovx, 69 | fovy, 70 | calibration["height"]/ scale, 71 | calibration["width"]/ scale) 72 | est_T = torch.linalg.inv(est_pose) 73 | cam_i.R = est_T[:3, :3] 74 | cam_i.T = est_T[:3, 3] 75 | 76 | cam_dict[id] = cam_i 77 | 78 | return cam_dict 79 | 80 | def visualize_registration(src_3dgs, tgt_3dgs, pre_tsfm, gt_tsfm): 81 | """visualize registration in 3D Gaussians 82 | 83 | Args: 84 | gs3d_src (Gaussian Model): _description_ 85 | gs3d_tgt (Gaussian Model): _description_ 86 | tsfm (torch.Tensor): _description_ 87 | """ 88 | src_pc = src_3dgs.get_xyz().detach().cpu().numpy() 89 | tgt_pc = tgt_3dgs.get_xyz().detach().cpu().numpy() 90 | est_src_pc = transform(pre_tsfm.detach().cpu().numpy(), src_pc) 91 | gt_src_pc = transform(gt_tsfm.detach().cpu().numpy(), src_pc) 92 | 93 | src_vis = vis.pointcloud(src_pc[::2], ucid=0, cmap='tab10', is_sphere=True) 94 | tgt_vis = vis.pointcloud(tgt_pc[::2], ucid=1, cmap='tab10', is_sphere=True) 95 | 96 | est_vis = vis.pointcloud(est_src_pc[::2], ucid=0, cmap='tab10', is_sphere=True) 97 | gt_vis = vis.pointcloud(gt_src_pc[::2], ucid=0, cmap='tab10', is_sphere=True) 98 | 99 | try: 100 | vis.show_3d([src_vis, tgt_vis], [est_vis, tgt_vis], [gt_vis, tgt_vis], use_new_api=True, show=True) 101 | except: 102 | print("estimate is not a good transformation") 103 | vis.show_3d([src_vis, tgt_vis], [gt_vis, tgt_vis], use_new_api=True, show=True) 104 | 105 | def visualize_mv_registration(gaussian_list, pred_rel_tsfms, gt_rel_tsfms): 106 | vis_pred_list, vis_gt_list = [], [] 107 | pred_abs_tsfm = torch.eye(4).cuda() 108 | gt_abs_tsfm = torch.eye(4).cuda() 109 | for i, gaussians in enumerate(gaussian_list[:5]): 110 | pc = gaussians.get_xyz().detach().cpu().numpy() 111 | 112 | pc_pred = transform(pred_abs_tsfm.detach().cpu().numpy(), pc) 113 | vis_pred_list.append(vis.pointcloud(pc_pred, ucid=i, cmap='tab10', is_sphere=True)) 114 | if i>0: pred_abs_tsfm = pred_abs_tsfm @ pred_rel_tsfms[i-1] 115 | 116 | pc_gt = transform(gt_abs_tsfm.detach().cpu().numpy(), pc) 117 | vis_gt_list.append(vis.pointcloud(pc_gt, ucid=i, cmap='tab10', is_sphere=True)) 118 | if i>0: gt_abs_tsfm = gt_abs_tsfm @ gt_rel_tsfms[i-1] 119 | 120 | vis.show_3d(vis_pred_list, vis_gt_list, use_new_api=True, show=True) 121 | 122 | 123 | 124 | def axis_angle_to_rot_mat(axes, thetas): 125 | """ 126 | Computer a rotation matrix from the axis-angle representation using the Rodrigues formula. 127 | \mathbf{R} = \mathbf{I} + (sin(\theta)\mathbf{K} + (1 - cos(\theta)\mathbf{K}^2), where K = \mathbf{I} \cross \frac{\mathbf{K}}{||\mathbf{K}||} 128 | 129 | Args: 130 | axes (numpy array): array of axes used to compute the rotation matrices [b,3] 131 | thetas (numpy array): array of angles used to compute the rotation matrices [b,1] 132 | 133 | Returns: 134 | rot_matrices (numpy array): array of the rotation matrices computed from the angle, axis representation [b,3,3] 135 | 136 | borrowed from: https://github.com/zgojcic/3D_multiview_reg/blob/master/lib/utils.py 137 | """ 138 | 139 | R = [] 140 | for k in range(axes.shape[0]): 141 | K = np.cross(np.eye(3), axes[k,:]/np.linalg.norm(axes[k,:])) 142 | R.append( np.eye(3) + np.sin(thetas[k])*K + (1 - np.cos(thetas[k])) * np.matmul(K,K)) 143 | 144 | rot_matrices = np.stack(R) 145 | return rot_matrices 146 | 147 | def sample_random_trans(pcd, randg=None, rotation_range=360): 148 | """ 149 | Samples random transformation paramaters with the rotaitons limited to the rotation range 150 | 151 | Args: 152 | pcd (numpy array): numpy array of coordinates for which the transformation paramaters are sampled [n,3] 153 | randg (numpy random generator): numpy random generator 154 | 155 | Returns: 156 | T (numpy array): sampled transformation paramaters [4,4] 157 | 158 | borrowed from: https://github.com/zgojcic/3D_multiview_reg/blob/master/lib/utils.py 159 | """ 160 | if randg == None: 161 | randg = np.random.default_rng(41) 162 | 163 | # Create 3D identity matrix 164 | T = np.zeros((4,4)) 165 | idx = np.arange(4) 166 | T[idx,idx] = 1 167 | 168 | axes = np.random.rand(1,3) - 0.5 169 | 170 | angles = rotation_range * np.pi / 180.0 * (np.random.rand(1,1) - 0.5) 171 | 172 | R = axis_angle_to_rot_mat(axes, angles) 173 | 174 | T[:3, :3] = R 175 | # T[:3, 3] = np.random.rand(3)-0.5 176 | T[:3, 3] = np.matmul(R,-np.mean(pcd, axis=0)) 177 | 178 | return T 179 | 180 | def visualize_mv_registration(data_list, pred_pose_list): 181 | n_view = len(pred_pose_list) 182 | vis_init_list, vis_pred_list, vis_gt_list = [], [], [] 183 | for i in range(n_view): 184 | pc = data_list[i]['gaussians'].get_xyz().detach().cpu().numpy() 185 | pred_pc = transform(pred_pose_list[i].detach().cpu().numpy(), pc) 186 | 187 | gt_tsfm = data_list[0]['gt_tsfm'] @ data_list[i]['gt_tsfm'].inverse() 188 | gt_pc = transform(gt_tsfm.detach().cpu().numpy(), data_list[i]['gaussians'].get_xyz().detach().cpu().numpy()) 189 | 190 | vis_pred_list.append(vis.pointcloud(pred_pc[::2], ucid=i, is_sphere=True)) 191 | vis_gt_list.append(vis.pointcloud(gt_pc[::2], ucid=i, is_sphere=True)) 192 | vis_init_list.append(vis.pointcloud(pc[::2], ucid=i, is_sphere=True)) 193 | 194 | vis.show_3d(vis_init_list, vis_pred_list, vis_gt_list, use_new_api=True) 195 | 196 | 197 | @torch.no_grad() 198 | def plot_and_save(points, pngname, title='', axlim=None): 199 | points = points.detach().cpu().numpy() 200 | plt.figure(figsize=(7, 7)) 201 | ax = plt.axes(projection='3d') 202 | ax.plot3D(points[:,0], points[:,1], points[:,2], 'b') 203 | plt.title(title) 204 | if axlim is not None: 205 | ax.set_xlim(axlim[0]) 206 | ax.set_ylim(axlim[1]) 207 | ax.set_zlim(axlim[2]) 208 | plt.savefig(pngname) 209 | print('Saving to', pngname) 210 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 211 | 212 | 213 | def colorize_depth_maps( 214 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None 215 | ): 216 | """ 217 | Colorize depth maps. 218 | """ 219 | assert len(depth_map.shape) >= 2, "Invalid dimension" 220 | 221 | if isinstance(depth_map, torch.Tensor): 222 | depth = depth_map.detach().squeeze().numpy() 223 | elif isinstance(depth_map, np.ndarray): 224 | depth = depth_map.copy().squeeze() 225 | # reshape to [ (B,) H, W ] 226 | if depth.ndim < 3: 227 | depth = depth[np.newaxis, :, :] 228 | 229 | # colorize 230 | cm = matplotlib.colormaps[cmap] 231 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) 232 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 233 | img_colored_np = np.rollaxis(img_colored_np, 3, 1) 234 | 235 | if valid_mask is not None: 236 | if isinstance(depth_map, torch.Tensor): 237 | valid_mask = valid_mask.detach().numpy() 238 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] 239 | if valid_mask.ndim < 3: 240 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] 241 | else: 242 | valid_mask = valid_mask[:, np.newaxis, :, :] 243 | valid_mask = np.repeat(valid_mask, 3, axis=1) 244 | img_colored_np[~valid_mask] = 0 245 | 246 | if isinstance(depth_map, torch.Tensor): 247 | img_colored = torch.from_numpy(img_colored_np).float() 248 | elif isinstance(depth_map, np.ndarray): 249 | img_colored = img_colored_np 250 | 251 | return img_colored 252 | 253 | 254 | def chw2hwc(chw): 255 | assert 3 == len(chw.shape) 256 | if isinstance(chw, torch.Tensor): 257 | hwc = torch.permute(chw, (1, 2, 0)) 258 | elif isinstance(chw, np.ndarray): 259 | hwc = np.moveaxis(chw, 0, -1) 260 | return hwc 261 | 262 | def visualize_camera_traj(cam_list): 263 | vis_cam_list = [] 264 | for cam in cam_list: 265 | intrinsic = np.array([ 266 | [cam.fx, 0.0, cam.cx], 267 | [0.0, cam.fy, cam.cy], 268 | [0.0, 0.0, 1.0] 269 | ]) 270 | pred_extrinsic = cam.get_T.cpu().numpy() 271 | gt_extrinsic = cam.get_T_gt.cpu().numpy() 272 | vis_pred_cam = o3d.geometry.LineSet.create_camera_visualization( 273 | 640, 480, intrinsic, pred_extrinsic, scale=0.1) 274 | vis_gt_cam = o3d.geometry.LineSet.create_camera_visualization( 275 | 640, 480, intrinsic, gt_extrinsic, scale=0.1) 276 | 277 | # Set colors for predicted (blue) and ground truth (green) cameras 278 | vis_pred_cam.paint_uniform_color([1, 0, 0]) 279 | vis_gt_cam.paint_uniform_color([0, 1, 0]) 280 | 281 | vis_cam_list.append(vis_pred_cam) 282 | vis_cam_list.append(vis_gt_cam) 283 | 284 | return vis_cam_list -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/LoopSplat/676e18f950b2de6be39525a613120712b67768bb/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import evo 4 | import numpy as np 5 | 6 | import torch 7 | from errno import EEXIST 8 | from os import makedirs, path 9 | from evo.core import metrics, trajectory 10 | from evo.core.metrics import PoseRelation, Unit 11 | from evo.core.trajectory import PosePath3D, PoseTrajectory3D 12 | from evo.tools import plot 13 | from evo.tools.plot import PlotMode 14 | from evo.tools.settings import SETTINGS 15 | from matplotlib import pyplot as plt 16 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 17 | 18 | from tqdm import tqdm 19 | 20 | import wandb 21 | 22 | import rich 23 | 24 | _log_styles = { 25 | "Eval": "bold red", 26 | } 27 | 28 | def get_style(tag): 29 | if tag in _log_styles.keys(): 30 | return _log_styles[tag] 31 | return "bold blue" 32 | 33 | 34 | def Log(*args, tag="MonoGS"): 35 | style = get_style(tag) 36 | rich.print(f"[{style}]{tag}:[/{style}]", *args) 37 | 38 | def mkdir_p(folder_path): 39 | # Creates a directory. equivalent to using mkdir -p on the command line 40 | try: 41 | makedirs(folder_path) 42 | except OSError as exc: # Python >2.5 43 | if exc.errno == EEXIST and path.isdir(folder_path): 44 | pass 45 | else: 46 | raise 47 | def evaluate_evo(poses_gt, poses_est, plot_dir, label, monocular=False): 48 | ## Plot 49 | traj_ref = PosePath3D(poses_se3=poses_gt) 50 | traj_est = PosePath3D(poses_se3=poses_est) 51 | traj_est_aligned = trajectory.align_trajectory( 52 | traj_est, traj_ref, correct_scale=monocular 53 | ) 54 | 55 | ## RMSE 56 | pose_relation = metrics.PoseRelation.translation_part 57 | data = (traj_ref, traj_est_aligned) 58 | ape_metric = metrics.APE(pose_relation) 59 | ape_metric.process_data(data) 60 | ape_stat = ape_metric.get_statistic(metrics.StatisticsType.rmse) 61 | ape_stats = ape_metric.get_all_statistics() 62 | Log("RMSE ATE \[m]", ape_stat, tag="Eval") 63 | 64 | with open( 65 | os.path.join(plot_dir, "stats_{}.json".format(str(label))), 66 | "w", 67 | encoding="utf-8", 68 | ) as f: 69 | json.dump(ape_stats, f, indent=4) 70 | 71 | plot_mode = evo.tools.plot.PlotMode.xy 72 | fig = plt.figure() 73 | ax = evo.tools.plot.prepare_axis(fig, plot_mode) 74 | ax.set_title(f"ATE RMSE: {ape_stat}") 75 | evo.tools.plot.traj(ax, plot_mode, traj_ref, "--", "gray", "gt") 76 | evo.tools.plot.traj_colormap( 77 | ax, 78 | traj_est_aligned, 79 | ape_metric.error, 80 | plot_mode, 81 | min_map=ape_stats["min"], 82 | max_map=ape_stats["max"], 83 | ) 84 | ax.legend() 85 | plt.savefig(os.path.join(plot_dir, "evo_2dplot_{}.png".format(str(label))), dpi=90) 86 | 87 | return ape_stat 88 | 89 | 90 | def eval_ate(frames, kf_ids, save_dir, iterations, final=False, monocular=False): 91 | trj_data = dict() 92 | latest_frame_idx = kf_ids[-1] + 2 if final else kf_ids[-1] + 1 93 | trj_id, trj_est, trj_gt = [], [], [] 94 | trj_est_np, trj_gt_np = [], [] 95 | 96 | def gen_pose_matrix(R, T): 97 | pose = np.eye(4) 98 | pose[0:3, 0:3] = R.cpu().numpy() 99 | pose[0:3, 3] = T.cpu().numpy() 100 | return pose 101 | 102 | for kf_id in kf_ids: 103 | kf = frames[kf_id] 104 | pose_est = np.linalg.inv(gen_pose_matrix(kf.R, kf.T)) 105 | pose_gt = np.linalg.inv(gen_pose_matrix(kf.R_gt, kf.T_gt)) 106 | 107 | trj_id.append(frames[kf_id].uid) 108 | trj_est.append(pose_est.tolist()) 109 | trj_gt.append(pose_gt.tolist()) 110 | 111 | trj_est_np.append(pose_est) 112 | trj_gt_np.append(pose_gt) 113 | 114 | trj_data["trj_id"] = trj_id 115 | trj_data["trj_est"] = trj_est 116 | trj_data["trj_gt"] = trj_gt 117 | 118 | plot_dir = os.path.join(save_dir, "plot") 119 | mkdir_p(plot_dir) 120 | 121 | label_evo = "final" if final else "{:04}".format(iterations) 122 | with open( 123 | os.path.join(plot_dir, f"trj_{label_evo}.json"), "w", encoding="utf-8" 124 | ) as f: 125 | json.dump(trj_data, f, indent=4) 126 | 127 | ate = evaluate_evo( 128 | poses_gt=trj_gt_np, 129 | poses_est=trj_est_np, 130 | plot_dir=plot_dir, 131 | label=label_evo, 132 | monocular=monocular, 133 | ) 134 | return ate -------------------------------------------------------------------------------- /src/utils/gaussian_model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | from typing import NamedTuple 24 | import numpy as np 25 | import torch 26 | import torch.nn.functional as F 27 | 28 | C0 = 0.28209479177387814 29 | C1 = 0.4886025119029199 30 | C2 = [ 31 | 1.0925484305920792, 32 | -1.0925484305920792, 33 | 0.31539156525252005, 34 | -1.0925484305920792, 35 | 0.5462742152960396 36 | ] 37 | C3 = [ 38 | -0.5900435899266435, 39 | 2.890611442640554, 40 | -0.4570457994644658, 41 | 0.3731763325901154, 42 | -0.4570457994644658, 43 | 1.445305721320277, 44 | -0.5900435899266435 45 | ] 46 | C4 = [ 47 | 2.5033429417967046, 48 | -1.7701307697799304, 49 | 0.9461746957575601, 50 | -0.6690465435572892, 51 | 0.10578554691520431, 52 | -0.6690465435572892, 53 | 0.47308734787878004, 54 | -1.7701307697799304, 55 | 0.6258357354491761, 56 | ] 57 | 58 | 59 | def eval_sh(deg, sh, dirs): 60 | """ 61 | Evaluate spherical harmonics at unit directions 62 | using hardcoded SH polynomials. 63 | Works with torch/np/jnp. 64 | ... Can be 0 or more batch dimensions. 65 | Args: 66 | deg: int SH deg. Currently, 0-3 supported 67 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 68 | dirs: jnp.ndarray unit directions [..., 3] 69 | Returns: 70 | [..., C] 71 | """ 72 | assert deg <= 4 and deg >= 0 73 | coeff = (deg + 1) ** 2 74 | assert sh.shape[-1] >= coeff 75 | 76 | result = C0 * sh[..., 0] 77 | if deg > 0: 78 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 79 | result = (result - 80 | C1 * y * sh[..., 1] + 81 | C1 * z * sh[..., 2] - 82 | C1 * x * sh[..., 3]) 83 | 84 | if deg > 1: 85 | xx, yy, zz = x * x, y * y, z * z 86 | xy, yz, xz = x * y, y * z, x * z 87 | result = (result + 88 | C2[0] * xy * sh[..., 4] + 89 | C2[1] * yz * sh[..., 5] + 90 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 91 | C2[3] * xz * sh[..., 7] + 92 | C2[4] * (xx - yy) * sh[..., 8]) 93 | 94 | if deg > 2: 95 | result = (result + 96 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 97 | C3[1] * xy * z * sh[..., 10] + 98 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + 99 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 100 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 101 | C3[5] * z * (xx - yy) * sh[..., 14] + 102 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 103 | 104 | if deg > 3: 105 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 106 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 107 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 108 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 109 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 110 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 111 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 112 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 113 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 114 | return result 115 | 116 | 117 | def RGB2SH(rgb): 118 | return (rgb - 0.5) / C0 119 | 120 | 121 | def SH2RGB(sh): 122 | return sh * C0 + 0.5 123 | 124 | 125 | def inverse_sigmoid(x): 126 | return torch.log(x/(1-x)) 127 | 128 | 129 | def get_expon_lr_func( 130 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 131 | ): 132 | """ 133 | Copied from Plenoxels 134 | 135 | Continuous learning rate decay function. Adapted from JaxNeRF 136 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 137 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 138 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 139 | function of lr_delay_mult, such that the initial learning rate is 140 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 141 | to the normal learning rate when steps>lr_delay_steps. 142 | :param conf: config subtree 'lr' or similar 143 | :param max_steps: int, the number of steps during optimization. 144 | :return HoF which takes step as input 145 | """ 146 | 147 | def helper(step): 148 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 149 | # Disable this parameter 150 | return 0.0 151 | if lr_delay_steps > 0: 152 | # A kind of reverse cosine decay. 153 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 154 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 155 | ) 156 | else: 157 | delay_rate = 1.0 158 | t = np.clip(step / max_steps, 0, 1) 159 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 160 | return delay_rate * log_lerp 161 | 162 | return helper 163 | 164 | 165 | def strip_lowerdiag(L): 166 | uncertainty = torch.zeros( 167 | (L.shape[0], 6), dtype=torch.float, device="cuda") 168 | 169 | uncertainty[:, 0] = L[:, 0, 0] 170 | uncertainty[:, 1] = L[:, 0, 1] 171 | uncertainty[:, 2] = L[:, 0, 2] 172 | uncertainty[:, 3] = L[:, 1, 1] 173 | uncertainty[:, 4] = L[:, 1, 2] 174 | uncertainty[:, 5] = L[:, 2, 2] 175 | return uncertainty 176 | 177 | 178 | def strip_symmetric(sym): 179 | return strip_lowerdiag(sym) 180 | 181 | 182 | def build_rotation(r): 183 | 184 | q = F.normalize(r, p=2, dim=1) 185 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 186 | 187 | r = q[:, 0] 188 | x = q[:, 1] 189 | y = q[:, 2] 190 | z = q[:, 3] 191 | 192 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 193 | R[:, 0, 1] = 2 * (x*y - r*z) 194 | R[:, 0, 2] = 2 * (x*z + r*y) 195 | R[:, 1, 0] = 2 * (x*y + r*z) 196 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 197 | R[:, 1, 2] = 2 * (y*z - r*x) 198 | R[:, 2, 0] = 2 * (x*z - r*y) 199 | R[:, 2, 1] = 2 * (y*z + r*x) 200 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 201 | return R 202 | 203 | 204 | def build_scaling_rotation(s, r): 205 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 206 | R = build_rotation(r) 207 | 208 | L[:, 0, 0] = s[:, 0] 209 | L[:, 1, 1] = s[:, 1] 210 | L[:, 2, 2] = s[:, 2] 211 | 212 | L = R @ L 213 | return L 214 | 215 | 216 | class BasicPointCloud(NamedTuple): 217 | points: np.array 218 | colors: np.array -------------------------------------------------------------------------------- /src/utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | from typing import NamedTuple 14 | 15 | import numpy as np 16 | import torch 17 | 18 | 19 | class BasicPointCloud(NamedTuple): 20 | points: np.array 21 | colors: np.array 22 | normals: np.array 23 | 24 | 25 | def getWorld2View(R, t): 26 | Rt = np.zeros((4, 4)) 27 | Rt[:3, :3] = R.transpose() 28 | Rt[:3, 3] = t 29 | Rt[3, 3] = 1.0 30 | return np.float32(Rt) 31 | 32 | 33 | def getWorld2View2(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0): 34 | translate = translate.to(R.device) 35 | Rt = torch.zeros((4, 4), device=R.device) 36 | # Rt[:3, :3] = R.transpose() 37 | Rt[:3, :3] = R 38 | Rt[:3, 3] = t 39 | Rt[3, 3] = 1.0 40 | 41 | C2W = torch.linalg.inv(Rt) 42 | cam_center = C2W[:3, 3] 43 | cam_center = (cam_center + translate) * scale 44 | C2W[:3, 3] = cam_center 45 | Rt = torch.linalg.inv(C2W) 46 | return Rt 47 | 48 | 49 | def getProjectionMatrix(znear, zfar, fovX, fovY): 50 | tanHalfFovY = math.tan((fovY / 2)) 51 | tanHalfFovX = math.tan((fovX / 2)) 52 | 53 | top = tanHalfFovY * znear 54 | bottom = -top 55 | right = tanHalfFovX * znear 56 | left = -right 57 | 58 | P = torch.zeros(4, 4) 59 | 60 | z_sign = 1.0 61 | 62 | P[0, 0] = 2.0 * znear / (right - left) 63 | P[1, 1] = 2.0 * znear / (top - bottom) 64 | P[0, 2] = (right + left) / (right - left) 65 | P[1, 2] = (top + bottom) / (top - bottom) 66 | P[3, 2] = z_sign 67 | P[2, 2] = -(zfar + znear) / (zfar - znear) 68 | P[2, 3] = -2 * (zfar * znear) / (zfar - znear) 69 | return P 70 | 71 | 72 | def getProjectionMatrix2(znear, zfar, cx, cy, fx, fy, W, H): 73 | left = ((2 * cx - W) / W - 1.0) * W / 2.0 74 | right = ((2 * cx - W) / W + 1.0) * W / 2.0 75 | top = ((2 * cy - H) / H + 1.0) * H / 2.0 76 | bottom = ((2 * cy - H) / H - 1.0) * H / 2.0 77 | left = znear / fx * left 78 | right = znear / fx * right 79 | top = znear / fy * top 80 | bottom = znear / fy * bottom 81 | P = torch.zeros(4, 4) 82 | 83 | z_sign = 1.0 84 | 85 | P[0, 0] = 2.0 * znear / (right - left) 86 | P[1, 1] = 2.0 * znear / (top - bottom) 87 | P[0, 2] = (right + left) / (right - left) 88 | P[1, 2] = (top + bottom) / (top - bottom) 89 | P[3, 2] = z_sign 90 | P[2, 2] = z_sign * zfar / (zfar - znear) 91 | P[2, 3] = -(zfar * znear) / (zfar - znear) 92 | 93 | return P 94 | 95 | 96 | def fov2focal(fov, pixels): 97 | return pixels / (2 * math.tan(fov / 2)) 98 | 99 | 100 | def focal2fov(focal, pixels): 101 | return 2 * math.atan(pixels / (2 * focal)) 102 | -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import open3d as o3d 7 | import torch 8 | import wandb 9 | import yaml 10 | 11 | 12 | def mkdir_decorator(func): 13 | """A decorator that creates the directory specified in the function's 'directory' keyword 14 | argument before calling the function. 15 | Args: 16 | func: The function to be decorated. 17 | Returns: 18 | The wrapper function. 19 | """ 20 | def wrapper(*args, **kwargs): 21 | output_path = Path(kwargs["directory"]) 22 | output_path.mkdir(parents=True, exist_ok=True) 23 | return func(*args, **kwargs) 24 | return wrapper 25 | 26 | 27 | @mkdir_decorator 28 | def save_clouds(clouds: list, cloud_names: list, *, directory: Union[str, Path]) -> None: 29 | """ Saves a list of point clouds to the specified directory, creating the directory if it does not exist. 30 | Args: 31 | clouds: A list of point cloud objects to be saved. 32 | cloud_names: A list of filenames for the point clouds, corresponding by index to the clouds. 33 | directory: The directory where the point clouds will be saved. 34 | """ 35 | for cld_name, cloud in zip(cloud_names, clouds): 36 | o3d.io.write_point_cloud(str(directory / cld_name), cloud) 37 | 38 | 39 | @mkdir_decorator 40 | def save_dict_to_ckpt(dictionary, file_name: str, *, directory: Union[str, Path]) -> None: 41 | """ Saves a dictionary to a checkpoint file in the specified directory, creating the directory if it does not exist. 42 | Args: 43 | dictionary: The dictionary to be saved. 44 | file_name: The name of the checkpoint file. 45 | directory: The directory where the checkpoint file will be saved. 46 | """ 47 | torch.save(dictionary, directory / file_name, 48 | _use_new_zipfile_serialization=False) 49 | 50 | 51 | @mkdir_decorator 52 | def save_dict_to_yaml(dictionary, file_name: str, *, directory: Union[str, Path]) -> None: 53 | """ Saves a dictionary to a YAML file in the specified directory, creating the directory if it does not exist. 54 | Args: 55 | dictionary: The dictionary to be saved. 56 | file_name: The name of the YAML file. 57 | directory: The directory where the YAML file will be saved. 58 | """ 59 | with open(directory / file_name, "w") as f: 60 | yaml.dump(dictionary, f) 61 | 62 | 63 | @mkdir_decorator 64 | def save_dict_to_json(dictionary, file_name: str, *, directory: Union[str, Path]) -> None: 65 | """ Saves a dictionary to a JSON file in the specified directory, creating the directory if it does not exist. 66 | Args: 67 | dictionary: The dictionary to be saved. 68 | file_name: The name of the JSON file. 69 | directory: The directory where the JSON file will be saved. 70 | """ 71 | with open(directory / file_name, "w") as f: 72 | json.dump(dictionary, f) 73 | 74 | 75 | def load_config(path: str, default_path: str = None) -> dict: 76 | """ 77 | Loads a configuration file and optionally merges it with a default configuration file. 78 | 79 | This function loads a configuration from the given path. If the configuration specifies an inheritance 80 | path (`inherit_from`), or if a `default_path` is provided, it loads the base configuration and updates it 81 | with the specific configuration. 82 | 83 | Args: 84 | path: The path to the specific configuration file. 85 | default_path: An optional path to a default configuration file that is loaded if the specific configuration 86 | does not specify an inheritance or as a base for the inheritance. 87 | 88 | Returns: 89 | A dictionary containing the merged configuration. 90 | """ 91 | # load configuration from per scene/dataset cfg. 92 | with open(path, 'r') as f: 93 | cfg_special = yaml.full_load(f) 94 | inherit_from = cfg_special.get('inherit_from') 95 | cfg = dict() 96 | if inherit_from is not None: 97 | cfg = load_config(inherit_from, default_path) 98 | elif default_path is not None: 99 | with open(default_path, 'r') as f: 100 | cfg = yaml.full_load(f) 101 | update_recursive(cfg, cfg_special) 102 | return cfg 103 | 104 | 105 | def update_recursive(dict1: dict, dict2: dict) -> None: 106 | """ Recursively updates the first dictionary with the contents of the second dictionary. 107 | 108 | This function iterates through `dict2` and updates `dict1` with its contents. If a key from `dict2` 109 | exists in `dict1` and its value is also a dictionary, the function updates the value recursively. 110 | Otherwise, it overwrites the value in `dict1` with the value from `dict2`. 111 | 112 | Args: 113 | dict1: The dictionary to be updated. 114 | dict2: The dictionary whose entries are used to update `dict1`. 115 | 116 | Returns: 117 | None: The function modifies `dict1` in place. 118 | """ 119 | for k, v in dict2.items(): 120 | if k not in dict1: 121 | dict1[k] = dict() 122 | if isinstance(v, dict): 123 | update_recursive(dict1[k], v) 124 | else: 125 | dict1[k] = v 126 | 127 | 128 | def log_metrics_to_wandb(json_files: list, output_path: str, section: str = "Evaluation") -> None: 129 | """ Logs metrics from JSON files to Weights & Biases under a specified section. 130 | 131 | This function reads metrics from a list of JSON files and logs them to Weights & Biases (wandb). 132 | Each metric is prefixed with a specified section name for organized logging. 133 | 134 | Args: 135 | json_files: A list of filenames for JSON files containing metrics to be logged. 136 | output_path: The directory path where the JSON files are located. 137 | section: The section under which to log the metrics in wandb. Defaults to "Evaluation". 138 | 139 | Returns: 140 | None: Metrics are logged to wandb and the function does not return a value. 141 | """ 142 | for json_file in json_files: 143 | file_path = os.path.join(output_path, json_file) 144 | if os.path.exists(file_path): 145 | with open(file_path, 'r') as file: 146 | metrics = json.load(file) 147 | prefixed_metrics = { 148 | f"{section}/{key}": value for key, value in metrics.items()} 149 | wandb.log(prefixed_metrics) 150 | -------------------------------------------------------------------------------- /src/utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def rt2mat(R, T): 6 | mat = np.eye(4) 7 | mat[0:3, 0:3] = R 8 | mat[0:3, 3] = T 9 | return mat 10 | 11 | 12 | def skew_sym_mat(x): 13 | device = x.device 14 | dtype = x.dtype 15 | ssm = torch.zeros(3, 3, device=device, dtype=dtype) 16 | ssm[0, 1] = -x[2] 17 | ssm[0, 2] = x[1] 18 | ssm[1, 0] = x[2] 19 | ssm[1, 2] = -x[0] 20 | ssm[2, 0] = -x[1] 21 | ssm[2, 1] = x[0] 22 | return ssm 23 | 24 | 25 | def SO3_exp(theta): 26 | device = theta.device 27 | dtype = theta.dtype 28 | 29 | W = skew_sym_mat(theta) 30 | W2 = W @ W 31 | angle = torch.norm(theta) 32 | I = torch.eye(3, device=device, dtype=dtype) 33 | if angle < 1e-5: 34 | return I + W + 0.5 * W2 35 | else: 36 | return ( 37 | I 38 | + (torch.sin(angle) / angle) * W 39 | + ((1 - torch.cos(angle)) / (angle**2)) * W2 40 | ) 41 | 42 | 43 | def V(theta): 44 | dtype = theta.dtype 45 | device = theta.device 46 | I = torch.eye(3, device=device, dtype=dtype) 47 | W = skew_sym_mat(theta) 48 | W2 = W @ W 49 | angle = torch.norm(theta) 50 | if angle < 1e-5: 51 | V = I + 0.5 * W + (1.0 / 6.0) * W2 52 | else: 53 | V = ( 54 | I 55 | + W * ((1.0 - torch.cos(angle)) / (angle**2)) 56 | + W2 * ((angle - torch.sin(angle)) / (angle**3)) 57 | ) 58 | return V 59 | 60 | 61 | def SE3_exp(tau): 62 | dtype = tau.dtype 63 | device = tau.device 64 | 65 | rho = tau[:3] 66 | theta = tau[3:] 67 | R = SO3_exp(theta) 68 | t = V(theta) @ rho 69 | 70 | T = torch.eye(4, device=device, dtype=dtype) 71 | T[:3, :3] = R 72 | T[:3, 3] = t 73 | return T 74 | 75 | 76 | def update_pose(camera, converged_threshold=1e-4): 77 | tau = torch.cat([camera.cam_trans_delta, camera.cam_rot_delta], axis=0) 78 | 79 | T_w2c = torch.eye(4, device=tau.device) 80 | T_w2c[0:3, 0:3] = camera.R 81 | T_w2c[0:3, 3] = camera.T 82 | 83 | new_w2c = SE3_exp(tau) @ T_w2c 84 | 85 | new_R = new_w2c[0:3, 0:3] 86 | new_T = new_w2c[0:3, 3] 87 | 88 | converged = tau.norm() < converged_threshold 89 | camera.update_RT(new_R, new_T) 90 | 91 | camera.cam_rot_delta.data.fill_(0) 92 | camera.cam_trans_delta.data.fill_(0) 93 | return converged 94 | -------------------------------------------------------------------------------- /src/utils/tracker_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.spatial.transform import Rotation 4 | from typing import Union 5 | from src.utils.utils import np2torch 6 | 7 | 8 | def multiply_quaternions(q: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 9 | """Performs batch-wise quaternion multiplication. 10 | 11 | Given two quaternions, this function computes their product. The operation is 12 | vectorized and can be performed on batches of quaternions. 13 | 14 | Args: 15 | q: A tensor representing the first quaternion or a batch of quaternions. 16 | Expected shape is (... , 4), where the last dimension contains quaternion components (w, x, y, z). 17 | r: A tensor representing the second quaternion or a batch of quaternions with the same shape as q. 18 | Returns: 19 | A tensor of the same shape as the input tensors, representing the product of the input quaternions. 20 | """ 21 | w0, x0, y0, z0 = q[..., 0], q[..., 1], q[..., 2], q[..., 3] 22 | w1, x1, y1, z1 = r[..., 0], r[..., 1], r[..., 2], r[..., 3] 23 | 24 | w = -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0 25 | x = x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0 26 | y = -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0 27 | z = x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0 28 | return torch.stack((w, x, y, z), dim=-1) 29 | 30 | 31 | def transformation_to_quaternion(RT: Union[torch.Tensor, np.ndarray]): 32 | """ Converts a rotation-translation matrix to a tensor representing quaternion and translation. 33 | 34 | This function takes a 3x4 transformation matrix (rotation and translation) and converts it 35 | into a tensor that combines the quaternion representation of the rotation and the translation vector. 36 | 37 | Args: 38 | RT: A 3x4 matrix representing the rotation and translation. This can be a NumPy array 39 | or a torch.Tensor. If it's a torch.Tensor and resides on a GPU, it will be moved to CPU. 40 | 41 | Returns: 42 | A tensor combining the quaternion (in w, x, y, z order) and translation vector. The tensor 43 | will be moved to the original device if the input was a GPU tensor. 44 | """ 45 | gpu_id = -1 46 | if isinstance(RT, torch.Tensor): 47 | if RT.get_device() != -1: 48 | RT = RT.detach().cpu() 49 | gpu_id = RT.get_device() 50 | RT = RT.numpy() 51 | R, T = RT[:3, :3], RT[:3, 3] 52 | 53 | rot = Rotation.from_matrix(R) 54 | quad = rot.as_quat(canonical=True) 55 | quad = np.roll(quad, 1) 56 | tensor = np.concatenate([quad, T], 0) 57 | tensor = torch.from_numpy(tensor).float() 58 | if gpu_id != -1: 59 | tensor = tensor.to(gpu_id) 60 | return tensor 61 | 62 | 63 | def extrapolate_poses(poses: np.ndarray) -> np.ndarray: 64 | """ Generates an interpolated pose based on the first two poses in the given array. 65 | Args: 66 | poses: An array of poses, where each pose is represented by a 4x4 transformation matrix. 67 | Returns: 68 | A 4x4 numpy ndarray representing the interpolated transformation matrix. 69 | """ 70 | return poses[1, :] @ np.linalg.inv(poses[0, :]) @ poses[1, :] 71 | 72 | 73 | def compute_camera_opt_params(estimate_rel_w2c: np.ndarray) -> tuple: 74 | """ Computes the camera's rotation and translation parameters from an world-to-camera transformation matrix. 75 | This function extracts the rotation component of the transformation matrix, converts it to a quaternion, 76 | and reorders it to match a specific convention. Both rotation and translation parameters are converted 77 | to torch Parameters and intended to be optimized in a PyTorch model. 78 | Args: 79 | estimate_rel_w2c: A 4x4 numpy ndarray representing the estimated world-to-camera transformation matrix. 80 | Returns: 81 | A tuple containing two torch.nn.Parameters: camera's rotation and camera's translation. 82 | """ 83 | quaternion = Rotation.from_matrix(estimate_rel_w2c[:3, :3]).as_quat(canonical=True) 84 | quaternion = quaternion[[3, 0, 1, 2]] 85 | opt_cam_rot = torch.nn.Parameter(np2torch(quaternion, "cuda")) 86 | opt_cam_trans = torch.nn.Parameter(np2torch(estimate_rel_w2c[:3, 3], "cuda")) 87 | return opt_cam_rot, opt_cam_trans 88 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage import median_filter 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | from gaussian_rasterizer import GaussianRasterizationSettings, GaussianRasterizer 9 | 10 | 11 | def setup_seed(seed: int) -> None: 12 | """ Sets the seed for generating random numbers to ensure reproducibility across multiple runs. 13 | Args: 14 | seed: The seed value to set for random number generators in torch, numpy, and random. 15 | """ 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | os.environ["PYTHONHASHSEED"] = str(seed) 19 | np.random.seed(seed) 20 | random.seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | 25 | def torch2np(tensor: torch.Tensor) -> np.ndarray: 26 | """ Converts a PyTorch tensor to a NumPy ndarray. 27 | Args: 28 | tensor: The PyTorch tensor to convert. 29 | Returns: 30 | A NumPy ndarray with the same data and dtype as the input tensor. 31 | """ 32 | return tensor.detach().cpu().numpy() 33 | 34 | 35 | def np2torch(array: np.ndarray, device: str = "cpu") -> torch.Tensor: 36 | """Converts a NumPy ndarray to a PyTorch tensor. 37 | Args: 38 | array: The NumPy ndarray to convert. 39 | device: The device to which the tensor is sent. Defaults to 'cpu'. 40 | 41 | Returns: 42 | A PyTorch tensor with the same data as the input array. 43 | """ 44 | return torch.from_numpy(array).float().to(device) 45 | 46 | 47 | def np2ptcloud(pts: np.ndarray, rgb=None) -> o3d.geometry.PointCloud: 48 | """converts numpy array to point cloud 49 | Args: 50 | pts (ndarray): point cloud 51 | Returns: 52 | (PointCloud): resulting point cloud 53 | """ 54 | cloud = o3d.geometry.PointCloud() 55 | cloud.points = o3d.utility.Vector3dVector(pts) 56 | if rgb is not None: 57 | cloud.colors = o3d.utility.Vector3dVector(rgb) 58 | return cloud 59 | 60 | 61 | def dict2device(dict: dict, device: str = "cpu") -> dict: 62 | """Sends all tensors in a dictionary to a specified device. 63 | Args: 64 | dict: The dictionary containing tensors. 65 | device: The device to send the tensors to. Defaults to 'cpu'. 66 | Returns: 67 | The dictionary with all tensors sent to the specified device. 68 | """ 69 | for k, v in dict.items(): 70 | if isinstance(v, torch.Tensor): 71 | dict[k] = v.to(device) 72 | return dict 73 | 74 | 75 | def get_render_settings(w, h, intrinsics, w2c, near=0.01, far=100, sh_degree=0): 76 | """ 77 | Constructs and returns a GaussianRasterizationSettings object for rendering, 78 | configured with given camera parameters. 79 | 80 | Args: 81 | width (int): The width of the image. 82 | height (int): The height of the image. 83 | intrinsic (array): 3*3, Intrinsic camera matrix. 84 | w2c (array): World to camera transformation matrix. 85 | near (float, optional): The near plane for the camera. Defaults to 0.01. 86 | far (float, optional): The far plane for the camera. Defaults to 100. 87 | 88 | Returns: 89 | GaussianRasterizationSettings: Configured settings for Gaussian rasterization. 90 | """ 91 | fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 92 | 1], intrinsics[0, 2], intrinsics[1, 2] 93 | w2c = torch.tensor(w2c).cuda().float() 94 | cam_center = torch.inverse(w2c)[:3, 3] 95 | viewmatrix = w2c.transpose(0, 1) 96 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0], 97 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0], 98 | [0.0, 0.0, far / 99 | (far - near), -(far * near) / (far - near)], 100 | [0.0, 0.0, 1.0, 0.0]], device='cuda').float().transpose(0, 1) 101 | full_proj_matrix = viewmatrix.unsqueeze( 102 | 0).bmm(opengl_proj.unsqueeze(0)).squeeze(0) 103 | return GaussianRasterizationSettings( 104 | image_height=h, 105 | image_width=w, 106 | tanfovx=w / (2 * fx), 107 | tanfovy=h / (2 * fy), 108 | bg=torch.tensor([0, 0, 0], device='cuda').float(), 109 | scale_modifier=1.0, 110 | viewmatrix=viewmatrix, 111 | projmatrix=full_proj_matrix, 112 | sh_degree=sh_degree, 113 | campos=cam_center, 114 | prefiltered=False, 115 | debug=False) 116 | 117 | 118 | def render_gaussian_model(gaussian_model, render_settings, 119 | override_means_3d=None, override_means_2d=None, 120 | override_scales=None, override_rotations=None, 121 | override_opacities=None, override_colors=None): 122 | """ 123 | Renders a Gaussian model with specified rendering settings, allowing for 124 | optional overrides of various model parameters. 125 | 126 | Args: 127 | gaussian_model: A Gaussian model object that provides methods to get 128 | various properties like xyz coordinates, opacity, features, etc. 129 | render_settings: Configuration settings for the GaussianRasterizer. 130 | override_means_3d (Optional): If provided, these values will override 131 | the 3D mean values from the Gaussian model. 132 | override_means_2d (Optional): If provided, these values will override 133 | the 2D mean values. Defaults to zeros if not provided. 134 | override_scales (Optional): If provided, these values will override the 135 | scale values from the Gaussian model. 136 | override_rotations (Optional): If provided, these values will override 137 | the rotation values from the Gaussian model. 138 | override_opacities (Optional): If provided, these values will override 139 | the opacity values from the Gaussian model. 140 | override_colors (Optional): If provided, these values will override the 141 | color values from the Gaussian model. 142 | Returns: 143 | A dictionary containing the rendered color, depth, radii, and 2D means 144 | of the Gaussian model. The keys of this dictionary are 'color', 'depth', 145 | 'radii', and 'means2D', each mapping to their respective rendered values. 146 | """ 147 | renderer = GaussianRasterizer(raster_settings=render_settings) 148 | 149 | if override_means_3d is None: 150 | means3D = gaussian_model.get_xyz() 151 | else: 152 | means3D = override_means_3d 153 | 154 | if override_means_2d is None: 155 | means2D = torch.zeros_like( 156 | means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") 157 | means2D.retain_grad() 158 | else: 159 | means2D = override_means_2d 160 | 161 | if override_opacities is None: 162 | opacities = gaussian_model.get_opacity() 163 | else: 164 | opacities = override_opacities 165 | 166 | shs, colors_precomp = None, None 167 | if override_colors is not None: 168 | colors_precomp = override_colors 169 | else: 170 | shs = gaussian_model.get_features() 171 | 172 | render_args = { 173 | "means3D": means3D, 174 | "means2D": means2D, 175 | "opacities": opacities, 176 | "colors_precomp": colors_precomp, 177 | "shs": shs, 178 | "scales": gaussian_model.get_scaling() if override_scales is None else override_scales, 179 | "rotations": gaussian_model.get_rotation() if override_rotations is None else override_rotations, 180 | "cov3D_precomp": None 181 | } 182 | color, depth, alpha, radii = renderer(**render_args) 183 | 184 | return {"color": color, "depth": depth, "radii": radii, "means2D": means2D, "alpha": alpha} 185 | 186 | 187 | def batch_search_faiss(indexer, query_points, k): 188 | """ 189 | Perform a batch search on a IndexIVFFlat indexer to circumvent the search size limit of 65535. 190 | 191 | Args: 192 | indexer: The FAISS indexer object. 193 | query_points: A tensor of query points. 194 | k (int): The number of nearest neighbors to find. 195 | 196 | Returns: 197 | distances (torch.Tensor): The distances of the nearest neighbors. 198 | ids (torch.Tensor): The indices of the nearest neighbors. 199 | """ 200 | split_pos = torch.split(query_points, 65535, dim=0) 201 | distances_list, ids_list = [], [] 202 | 203 | for split_p in split_pos: 204 | distance, id = indexer.search(split_p.float(), k) 205 | distances_list.append(distance.clone()) 206 | ids_list.append(id.clone()) 207 | distances = torch.cat(distances_list, dim=0) 208 | ids = torch.cat(ids_list, dim=0) 209 | 210 | return distances, ids 211 | 212 | 213 | def filter_depth_outliers(depth_map, kernel_size=3, threshold=1.0): 214 | median_filtered = median_filter(depth_map, size=kernel_size) 215 | abs_diff = np.abs(depth_map - median_filtered) 216 | outlier_mask = abs_diff > threshold 217 | depth_map_filtered = np.where(outlier_mask, median_filtered, depth_map) 218 | return depth_map_filtered 219 | -------------------------------------------------------------------------------- /src/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | from matplotlib import colors 8 | 9 | COLORS_ANSI = OrderedDict({ 10 | "blue": "\033[94m", 11 | "orange": "\033[93m", 12 | "green": "\033[92m", 13 | "red": "\033[91m", 14 | "purple": "\033[95m", 15 | "brown": "\033[93m", # No exact match, using yellow 16 | "pink": "\033[95m", 17 | "gray": "\033[90m", 18 | "olive": "\033[93m", # No exact match, using yellow 19 | "cyan": "\033[96m", 20 | "end": "\033[0m", # Reset color 21 | }) 22 | 23 | 24 | COLORS_MATPLOTLIB = OrderedDict({ 25 | 'blue': '#1f77b4', 26 | 'orange': '#ff7f0e', 27 | 'green': '#2ca02c', 28 | 'red': '#d62728', 29 | 'purple': '#9467bd', 30 | 'brown': '#8c564b', 31 | 'pink': '#e377c2', 32 | 'gray': '#7f7f7f', 33 | 'yellow-green': '#bcbd22', 34 | 'cyan': '#17becf' 35 | }) 36 | 37 | 38 | COLORS_MATPLOTLIB_RGB = OrderedDict({ 39 | 'blue': np.array([31, 119, 180]) / 255.0, 40 | 'orange': np.array([255, 127, 14]) / 255.0, 41 | 'green': np.array([44, 160, 44]) / 255.0, 42 | 'red': np.array([214, 39, 40]) / 255.0, 43 | 'purple': np.array([148, 103, 189]) / 255.0, 44 | 'brown': np.array([140, 86, 75]) / 255.0, 45 | 'pink': np.array([227, 119, 194]) / 255.0, 46 | 'gray': np.array([127, 127, 127]) / 255.0, 47 | 'yellow-green': np.array([188, 189, 34]) / 255.0, 48 | 'cyan': np.array([23, 190, 207]) / 255.0 49 | }) 50 | 51 | 52 | def get_color(color_name: str): 53 | """ Returns the RGB values of a given color name as a normalized numpy array. 54 | Args: 55 | color_name: The name of the color. Can be any color name from CSS4_COLORS. 56 | Returns: 57 | A numpy array representing the RGB values of the specified color, normalized to the range [0, 1]. 58 | """ 59 | if color_name == "custom_yellow": 60 | return np.asarray([255.0, 204.0, 102.0]) / 255.0 61 | if color_name == "custom_blue": 62 | return np.asarray([102.0, 153.0, 255.0]) / 255.0 63 | assert color_name in colors.CSS4_COLORS 64 | return np.asarray(colors.to_rgb(colors.CSS4_COLORS[color_name])) 65 | 66 | 67 | def plot_ptcloud(point_clouds: Union[List, o3d.geometry.PointCloud], show_frame: bool = True): 68 | """ Visualizes one or more point clouds, optionally showing the coordinate frame. 69 | Args: 70 | point_clouds: A single point cloud or a list of point clouds to be visualized. 71 | show_frame: If True, displays the coordinate frame in the visualization. Defaults to True. 72 | """ 73 | # rotate down up 74 | if not isinstance(point_clouds, list): 75 | point_clouds = [point_clouds] 76 | if show_frame: 77 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0]) 78 | point_clouds = point_clouds + [mesh_frame] 79 | o3d.visualization.draw_geometries(point_clouds) 80 | 81 | 82 | def draw_registration_result_original_color(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud, 83 | transformation: np.ndarray): 84 | """ Visualizes the result of a point cloud registration, keeping the original color of the source point cloud. 85 | Args: 86 | source: The source point cloud. 87 | target: The target point cloud. 88 | transformation: The transformation matrix applied to the source point cloud. 89 | """ 90 | source_temp = deepcopy(source) 91 | source_temp.transform(transformation) 92 | o3d.visualization.draw_geometries([source_temp, target]) 93 | 94 | 95 | def draw_registration_result(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud, 96 | transformation: np.ndarray, source_color: str = "blue", target_color: str = "orange"): 97 | """ Visualizes the result of a point cloud registration, coloring the source and target point clouds. 98 | Args: 99 | source: The source point cloud. 100 | target: The target point cloud. 101 | transformation: The transformation matrix applied to the source point cloud. 102 | source_color: The color to apply to the source point cloud. Defaults to "blue". 103 | target_color: The color to apply to the target point cloud. Defaults to "orange". 104 | """ 105 | source_temp = deepcopy(source) 106 | source_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[source_color]) 107 | 108 | target_temp = deepcopy(target) 109 | target_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[target_color]) 110 | 111 | source_temp.transform(transformation) 112 | o3d.visualization.draw_geometries([source_temp, target_temp]) 113 | --------------------------------------------------------------------------------