├── .gitignore ├── LICENSE ├── README.md ├── assets ├── gaussian_slam.gif └── gaussian_slam.mp4 ├── configs ├── Replica │ ├── office0.yaml │ ├── office1.yaml │ ├── office2.yaml │ ├── office3.yaml │ ├── office4.yaml │ ├── replica.yaml │ ├── room0.yaml │ ├── room1.yaml │ └── room2.yaml ├── ScanNet │ ├── scannet.yaml │ ├── scene0000_00.yaml │ ├── scene0059_00.yaml │ ├── scene0106_00.yaml │ ├── scene0169_00.yaml │ ├── scene0181_00.yaml │ └── scene0207_00.yaml ├── TUM_RGBD │ ├── rgbd_dataset_freiburg1_desk.yaml │ ├── rgbd_dataset_freiburg2_xyz.yaml │ ├── rgbd_dataset_freiburg3_long_office_household.yaml │ └── tum_rgbd.yaml └── scannetpp │ ├── 281bc17764.yaml │ ├── 2e74812d00.yaml │ ├── 8b5caf3398.yaml │ ├── b20a261fdf.yaml │ ├── fb05e13ad1.yaml │ └── scannetpp.yaml ├── environment.yml ├── run_evaluation.py ├── run_slam.py ├── scripts ├── download_replica.sh ├── download_tum.sh └── reproduce_sbatch.sh └── src ├── entities ├── __init__.py ├── arguments.py ├── datasets.py ├── gaussian_model.py ├── gaussian_slam.py ├── logger.py ├── losses.py ├── mapper.py ├── tracker.py └── visual_odometer.py ├── evaluation ├── __init__.py ├── evaluate_merged_map.py ├── evaluate_reconstruction.py ├── evaluate_trajectory.py └── evaluator.py └── utils ├── __init__.py ├── gaussian_model_utils.py ├── io_utils.py ├── mapper_utils.py ├── tracker_utils.py ├── utils.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | debug 10 | wandb 11 | data 12 | *.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

Gaussian-SLAM: Photo-realistic Dense SLAM with Gaussian Splatting

4 |

5 | Vladimir Yugay 6 | · 7 | Yue Li* 8 | · 9 | Theo Gevers 10 | · 11 | Martin Oswald 12 |

13 |
*Significant contribution
14 |

Project Page

15 |
16 |

17 | 18 |

19 | 20 | 21 | 22 |

23 | 24 | ## ⚙️ Setting Things Up 25 | 26 | Clone the repo: 27 | 28 | ``` 29 | git clone https://github.com/VladimirYugay/Gaussian-SLAM 30 | ``` 31 | 32 | Make sure that gcc and g++ paths on your system are exported: 33 | 34 | ``` 35 | export CC= 36 | export CXX= 37 | ``` 38 | 39 | To find the gcc path and g++ path on your machine you can use which gcc. 40 | 41 | 42 | Then setup environment from the provided conda environment file, 43 | 44 | ``` 45 | conda env create -f environment.yml 46 | conda activate gslam 47 | ``` 48 | We tested our code on RTX3090 and RTX A6000 GPUs respectively and Ubuntu22 and CentOS7.5. 49 | 50 | ## 🔨 Running Gaussian-SLAM 51 | 52 | Here we elaborate on how to load the necessary data, configure Gaussian-SLAM for your use-case, debug it, and how to reproduce the results mentioned in the paper. 53 | 54 |
55 | Downloading the Data 56 | We tested our code on Replica, TUM_RGBD, ScanNet, and ScanNet++ datasets. We also provide scripts for downloading Replica nad TUM_RGBD. Install git lfs before using the scripts by running "git lfs install".
57 | For downloading ScanNet, follow the procedure described on here.
58 | For downloading ScanNet++, follow the procedure described on here.
59 | The config files are named after the sequences that we used for our method. 60 |
61 | 62 |
63 | Running the code 64 | Start the system with the command: 65 | 66 | ``` 67 | python run_slam.py configs// --input_path --output_path 68 | ``` 69 | For example: 70 | ``` 71 | python run_slam.py configs/Replica/room0.yaml --input_path /home/datasets/Replica/room0 --output_path output/Replica/room0 72 | ``` 73 | You can also configure input and output paths in the config yaml file. 74 |
75 | 76 |
77 | Reproducing Results 78 | While we made all parts of our code deterministic, differential rasterizer of Gaussian Splatting is not. The metrics can be slightly different from run to run. In the paper we report average metrics that were computed over three seeds: 0, 1, and 2. 79 | 80 | You can reproduce the results for a single scene by running: 81 | 82 | ``` 83 | python run_slam.py configs// --input_path --output_path 84 | ``` 85 | 86 | If you are running on a SLURM cluster, you can reproduce the results for all scenes in a dataset by running the script: 87 | ``` 88 | ./scripts/reproduce_sbatch.sh 89 | ``` 90 | Please note the evaluation of ```depth_L1``` metric requires reconstruction of the mesh, which in turns requires headless installation of open3d if you are running on a cluster. 91 |
92 | 93 |
94 | Demo 95 | We used the camera path tool in gaussian-splatting-lightning repo to help make the fly-through video based on the reconstructed scenes. We thank its author for the great work. 96 |
97 | 98 | ## 📌 Citation 99 | 100 | If you find our paper and code useful, please cite us: 101 | 102 | ```bib 103 | @misc{yugay2023gaussianslam, 104 | title={Gaussian-SLAM: Photo-realistic Dense SLAM with Gaussian Splatting}, 105 | author={Vladimir Yugay and Yue Li and Theo Gevers and Martin R. Oswald}, 106 | year={2023}, 107 | eprint={2312.10070}, 108 | archivePrefix={arXiv}, 109 | primaryClass={cs.CV} 110 | } 111 | -------------------------------------------------------------------------------- /assets/gaussian_slam.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/assets/gaussian_slam.gif -------------------------------------------------------------------------------- /assets/gaussian_slam.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/assets/gaussian_slam.mp4 -------------------------------------------------------------------------------- /configs/Replica/office0.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office0 4 | input_path: data/Replica-SLAM/Replica/office0/ 5 | output_path: output/Replica/office0/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office1 4 | input_path: data/Replica-SLAM/Replica/office1/ 5 | output_path: output/Replica/office1/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office2 4 | input_path: data/Replica-SLAM/Replica/office2/ 5 | output_path: output/Replica/office2/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office3.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office3 4 | input_path: data/Replica-SLAM/Replica/office3/ 5 | output_path: output/Replica/office3/ 6 | -------------------------------------------------------------------------------- /configs/Replica/office4.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: office4 4 | input_path: data/Replica-SLAM/Replica/office4/ 5 | output_path: output/Replica/office4/ 6 | -------------------------------------------------------------------------------- /configs/Replica/replica.yaml: -------------------------------------------------------------------------------- 1 | project_name: "Gaussian_SLAM_replica" 2 | dataset_name: "replica" 3 | checkpoint_path: null 4 | use_wandb: False 5 | frame_limit: -1 # for debugging, set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 50 9 | map_every: 5 10 | iterations: 100 11 | new_submap_iterations: 1000 12 | new_submap_points_num: 600000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: -1 15 | new_points_radius: 0.0000001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.1 19 | submap_using_motion_heuristic: True 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.95 23 | iterations: 60 24 | cam_rot_lr: 0.0002 25 | cam_trans_lr: 0.002 26 | odometry_type: "odometer" # gt, const_speed, odometer 27 | help_camera_initialization: False # temp option to help const_init 28 | init_err_ratio: 5 29 | odometer_method: "point_to_plane" # hybrid or point_to_plane 30 | filter_alpha: False 31 | filter_outlier_depth: True 32 | alpha_thre: 0.98 33 | soft_alpha: True 34 | mask_invalid_depth: False 35 | cam: 36 | H: 680 37 | W: 1200 38 | fx: 600.0 39 | fy: 600.0 40 | cx: 599.5 41 | cy: 339.5 42 | depth_scale: 6553.5 43 | -------------------------------------------------------------------------------- /configs/Replica/room0.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: room0 4 | input_path: data/Replica-SLAM/room0/ 5 | output_path: output/Replica/room0/ 6 | -------------------------------------------------------------------------------- /configs/Replica/room1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: room1 4 | input_path: data/Replica-SLAM/Replica/room1/ 5 | output_path: output/Replica/room1/ 6 | -------------------------------------------------------------------------------- /configs/Replica/room2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | data: 3 | scene_name: room2 4 | input_path: data/Replica-SLAM/Replica/room2/ 5 | output_path: output/Replica/room2/ 6 | -------------------------------------------------------------------------------- /configs/ScanNet/scannet.yaml: -------------------------------------------------------------------------------- 1 | project_name: "Gaussian_SLAM_scannet" 2 | dataset_name: "scan_net" 3 | checkpoint_path: null 4 | use_wandb: False 5 | frame_limit: -1 # for debugging, set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 50 9 | map_every: 1 10 | iterations: 100 11 | new_submap_iterations: 100 12 | new_submap_points_num: 100000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: 30000 15 | new_points_radius: 0.0001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.5 19 | submap_using_motion_heuristic: False 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.6 23 | iterations: 200 24 | cam_rot_lr: 0.002 25 | cam_trans_lr: 0.01 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: False # temp option to help const_init 28 | init_err_ratio: 5 29 | odometer_method: "hybrid" # hybrid or point_to_plane 30 | filter_alpha: True 31 | filter_outlier_depth: True 32 | alpha_thre: 0.98 33 | soft_alpha: True 34 | mask_invalid_depth: True 35 | cam: 36 | H: 480 37 | W: 640 38 | fx: 577.590698 39 | fy: 578.729797 40 | cx: 318.905426 41 | cy: 242.683609 42 | depth_scale: 1000. 43 | crop_edge: 12 -------------------------------------------------------------------------------- /configs/ScanNet/scene0000_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: data/scannet/scans/scene0000_00 4 | output_path: output/ScanNet/scene0000 5 | scene_name: scene0000_00 -------------------------------------------------------------------------------- /configs/ScanNet/scene0059_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: data/scannet/scans/scene0059_00 4 | output_path: output/ScanNet/scene0059 5 | scene_name: scene0059_00 -------------------------------------------------------------------------------- /configs/ScanNet/scene0106_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: data/scannet/scans/scene0106_00 4 | output_path: output/ScanNet/scene0106 5 | scene_name: scene0106_00 -------------------------------------------------------------------------------- /configs/ScanNet/scene0169_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: data/scannet/scans/scene0169_00 4 | output_path: output/ScanNet/scene0169 5 | scene_name: scene0169_00 6 | cam: 7 | fx: 574.540771 8 | fy: 577.583740 9 | cx: 322.522827 10 | cy: 238.558853 -------------------------------------------------------------------------------- /configs/ScanNet/scene0181_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: data/scannet/scans/scene0181_00 4 | output_path: output/ScanNet/scene0181 5 | scene_name: scene0181_00 6 | cam: 7 | fx: 575.547668 8 | fy: 577.459778 9 | cx: 323.171967 10 | cy: 236.417465 11 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0207_00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | data: 3 | input_path: data/scannet/scans/scene0207_00 4 | output_path: output/ScanNet/scene0207 5 | scene_name: scene0207_00 -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk/ 5 | scene_name: rgbd_dataset_freiburg1_desk 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 517.3 10 | fy: 516.5 11 | cx: 318.6 12 | cy: 255.3 13 | crop_edge: 50 14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg2_xyz 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg2_xyz/ 5 | scene_name: rgbd_dataset_freiburg2_xyz 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 520.9 10 | fy: 521.0 11 | cx: 325.1 12 | cy: 249.7 13 | crop_edge: 8 14 | distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172] -------------------------------------------------------------------------------- /configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml 2 | data: 3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg3_long_office_household/ 4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household/ 5 | scene_name: rgbd_dataset_freiburg3_long_office_household 6 | cam: #intrinsic is different per scene in TUM 7 | H: 480 8 | W: 640 9 | fx: 517.3 10 | fy: 516.5 11 | cx: 318.6 12 | cy: 255.3 13 | crop_edge: 50 14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/TUM_RGBD/tum_rgbd.yaml: -------------------------------------------------------------------------------- 1 | project_name: "Gaussian_SLAM_tumrgbd" 2 | dataset_name: "tum_rgbd" 3 | checkpoint_path: null 4 | use_wandb: False 5 | frame_limit: -1 # for debugging, set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 50 9 | map_every: 1 10 | iterations: 100 11 | new_submap_iterations: 100 12 | new_submap_points_num: 100000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: 30000 15 | new_points_radius: 0.0001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.5 19 | submap_using_motion_heuristic: True 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.6 23 | iterations: 200 24 | cam_rot_lr: 0.002 25 | cam_trans_lr: 0.01 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: False # temp option to help const_init 28 | init_err_ratio: 5 29 | odometer_method: "hybrid" # hybrid or point_to_plane 30 | filter_alpha: False 31 | filter_outlier_depth: False 32 | alpha_thre: 0.98 33 | soft_alpha: True 34 | mask_invalid_depth: True 35 | cam: 36 | crop_edge: 16 37 | depth_scale: 5000.0 -------------------------------------------------------------------------------- /configs/scannetpp/281bc17764.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/281bc17764 4 | output_path: output/ScanNetPP/281bc17764 5 | scene_name: "281bc17764" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 312.79197434640764 12 | fy: 313.48022477591036 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/2e74812d00.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/2e74812d00 4 | output_path: output/ScanNetPP/2e74812d00 5 | scene_name: "2e74812d00" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 312.0984049779051 12 | fy: 312.4823067146056 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/8b5caf3398.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/8b5caf3398 4 | output_path: output/ScanNetPP/8b5caf3398 5 | scene_name: "8b5caf3398" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 316.3837659917395 12 | fy: 319.18649362678593 13 | cx: 438.0 14 | cy: 292.0 15 | -------------------------------------------------------------------------------- /configs/scannetpp/b20a261fdf.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/b20a261fdf 4 | output_path: output/ScanNetPP/b20a261fdf 5 | scene_name: "b20a261fdf" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 312.7099188244687 12 | fy: 313.5121746848229 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/fb05e13ad1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/scannetpp/scannetpp.yaml 2 | data: 3 | input_path: data/scannetpp/data/fb05e13ad1 4 | output_path: output/ScanNetPP/fb05e13ad1 5 | scene_name: "fb05e13ad1" 6 | use_train_split: True 7 | frame_limit: 250 8 | cam: 9 | H: 584 10 | W: 876 11 | fx: 231.8197441948914 12 | fy: 231.9980523882361 13 | cx: 438.0 14 | cy: 292.0 -------------------------------------------------------------------------------- /configs/scannetpp/scannetpp.yaml: -------------------------------------------------------------------------------- 1 | project_name: "Gaussian_SLAM_scannetpp" 2 | dataset_name: "scannetpp" 3 | checkpoint_path: null 4 | use_wandb: False 5 | frame_limit: -1 # set to -1 to disable 6 | seed: 0 7 | mapping: 8 | new_submap_every: 100 9 | map_every: 2 10 | iterations: 500 11 | new_submap_iterations: 500 12 | new_submap_points_num: 400000 13 | new_submap_gradient_points_num: 50000 14 | new_frame_sample_size: 100000 15 | new_points_radius: 0.00000001 16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view 17 | alpha_thre: 0.6 18 | pruning_thre: 0.5 19 | submap_using_motion_heuristic: False 20 | tracking: 21 | gt_camera: False 22 | w_color_loss: 0.5 23 | iterations: 300 24 | cam_rot_lr: 0.002 25 | cam_trans_lr: 0.01 26 | odometry_type: "const_speed" # gt, const_speed, odometer 27 | help_camera_initialization: True 28 | init_err_ratio: 50 29 | odometer_method: "point_to_plane" # hybrid or point_to_plane 30 | filter_alpha: True 31 | filter_outlier_depth: True 32 | alpha_thre: 0.98 33 | soft_alpha: False 34 | mask_invalid_depth: True 35 | cam: 36 | crop_edge: 0 37 | depth_scale: 1000.0 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gslam 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - nvidia/label/cuda-12.1.0 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - python=3.10 10 | - faiss-gpu=1.8.0 11 | - cuda-toolkit=12.1 12 | - pytorch=2.1.2 13 | - pytorch-cuda=12.1 14 | - torchvision=0.16.2 15 | - pip 16 | - pip: 17 | - open3d==0.18.0 18 | - wandb 19 | - trimesh 20 | - pytorch_msssim 21 | - torchmetrics 22 | - tqdm 23 | - imageio 24 | - opencv-python 25 | - plyfile 26 | - git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7 27 | - git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d 28 | - git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129 -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from src.evaluation.evaluator import Evaluator 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(description='Arguments to compute the mesh') 8 | parser.add_argument('--checkpoint_path', type=str, help='SLAM checkpoint path', default="output/slam/full_experiment/") 9 | parser.add_argument('--config_path', type=str, help='Config path', default="") 10 | return parser.parse_args() 11 | 12 | 13 | if __name__ == "__main__": 14 | args = get_args() 15 | if args.config_path == "": 16 | args.config_path = Path(args.checkpoint_path) / "config.yaml" 17 | 18 | evaluator = Evaluator(Path(args.checkpoint_path), Path(args.config_path)) 19 | evaluator.run() 20 | -------------------------------------------------------------------------------- /run_slam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import uuid 5 | 6 | import wandb 7 | 8 | from src.entities.gaussian_slam import GaussianSLAM 9 | from src.evaluation.evaluator import Evaluator 10 | from src.utils.io_utils import load_config, log_metrics_to_wandb 11 | from src.utils.utils import setup_seed 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Arguments to compute the mesh') 17 | parser.add_argument('config_path', type=str, 18 | help='Path to the configuration yaml file') 19 | parser.add_argument('--input_path', default="") 20 | parser.add_argument('--output_path', default="") 21 | parser.add_argument('--track_w_color_loss', type=float) 22 | parser.add_argument('--track_alpha_thre', type=float) 23 | parser.add_argument('--track_iters', type=int) 24 | parser.add_argument('--track_filter_alpha', action='store_true') 25 | parser.add_argument('--track_filter_outlier', action='store_true') 26 | parser.add_argument('--track_wo_filter_alpha', action='store_true') 27 | parser.add_argument("--track_wo_filter_outlier", action="store_true") 28 | parser.add_argument("--track_cam_trans_lr", type=float) 29 | parser.add_argument('--alpha_seeding_thre', type=float) 30 | parser.add_argument('--map_every', type=int) 31 | parser.add_argument("--map_iters", type=int) 32 | parser.add_argument('--new_submap_every', type=int) 33 | parser.add_argument('--project_name', type=str) 34 | parser.add_argument('--group_name', type=str) 35 | parser.add_argument('--gt_camera', action='store_true') 36 | parser.add_argument('--help_camera_initialization', action='store_true') 37 | parser.add_argument('--soft_alpha', action='store_true') 38 | parser.add_argument('--seed', type=int) 39 | parser.add_argument('--submap_using_motion_heuristic', action='store_true') 40 | parser.add_argument('--new_submap_points_num', type=int) 41 | return parser.parse_args() 42 | 43 | 44 | def update_config_with_args(config, args): 45 | if args.input_path: 46 | config["data"]["input_path"] = args.input_path 47 | if args.output_path: 48 | config["data"]["output_path"] = args.output_path 49 | if args.track_w_color_loss is not None: 50 | config["tracking"]["w_color_loss"] = args.track_w_color_loss 51 | if args.track_iters is not None: 52 | config["tracking"]["iterations"] = args.track_iters 53 | if args.track_filter_alpha: 54 | config["tracking"]["filter_alpha"] = True 55 | if args.track_wo_filter_alpha: 56 | config["tracking"]["filter_alpha"] = False 57 | if args.track_filter_outlier: 58 | config["tracking"]["filter_outlier_depth"] = True 59 | if args.track_wo_filter_outlier: 60 | config["tracking"]["filter_outlier_depth"] = False 61 | if args.track_alpha_thre is not None: 62 | config["tracking"]["alpha_thre"] = args.track_alpha_thre 63 | if args.map_every: 64 | config["mapping"]["map_every"] = args.map_every 65 | if args.map_iters: 66 | config["mapping"]["iterations"] = args.map_iters 67 | if args.new_submap_every: 68 | config["mapping"]["new_submap_every"] = args.new_submap_every 69 | if args.project_name: 70 | config["project_name"] = args.project_name 71 | if args.alpha_seeding_thre is not None: 72 | config["mapping"]["alpha_thre"] = args.alpha_seeding_thre 73 | if args.seed: 74 | config["seed"] = args.seed 75 | if args.help_camera_initialization: 76 | config["tracking"]["help_camera_initialization"] = True 77 | if args.soft_alpha: 78 | config["tracking"]["soft_alpha"] = True 79 | if args.submap_using_motion_heuristic: 80 | config["mapping"]["submap_using_motion_heuristic"] = True 81 | if args.new_submap_points_num: 82 | config["mapping"]["new_submap_points_num"] = args.new_submap_points_num 83 | if args.track_cam_trans_lr: 84 | config["tracking"]["cam_trans_lr"] = args.track_cam_trans_lr 85 | return config 86 | 87 | 88 | if __name__ == "__main__": 89 | args = get_args() 90 | config = load_config(args.config_path) 91 | config = update_config_with_args(config, args) 92 | 93 | if os.getenv('DISABLE_WANDB') == 'true': 94 | config["use_wandb"] = False 95 | if config["use_wandb"]: 96 | wandb.init( 97 | project=config["project_name"], 98 | config=config, 99 | dir="/home/yli3/scratch/outputs/slam/wandb", 100 | group=config["data"]["scene_name"] 101 | if not args.group_name 102 | else args.group_name, 103 | name=f'{config["data"]["scene_name"]}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}_{str(uuid.uuid4())[:5]}', 104 | ) 105 | wandb.run.log_code(".", include_fn=lambda path: path.endswith(".py")) 106 | 107 | setup_seed(config["seed"]) 108 | gslam = GaussianSLAM(config) 109 | gslam.run() 110 | 111 | evaluator = Evaluator(gslam.output_path, gslam.output_path / "config.yaml") 112 | evaluator.run() 113 | if config["use_wandb"]: 114 | evals = ["rendering_metrics.json", 115 | "reconstruction_metrics.json", "ate_aligned.json"] 116 | log_metrics_to_wandb(evals, gslam.output_path, "Evaluation") 117 | wandb.finish() 118 | print("All done.✨") 119 | -------------------------------------------------------------------------------- /scripts/download_replica.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | git clone https://huggingface.co/datasets/voviktyl/Replica-SLAM -------------------------------------------------------------------------------- /scripts/download_tum.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | git clone https://huggingface.co/datasets/voviktyl/TUM_RGBD-SLAM -------------------------------------------------------------------------------- /scripts/reproduce_sbatch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=output/logs/%A_%a.log # please change accordingly 3 | #SBATCH --error=output/logs/%A_%a.log # please change accordingly 4 | #SBATCH -N 1 5 | #SBATCH -n 1 6 | #SBATCH --gpus-per-node=1 7 | #SBATCH --partition=gpu 8 | #SBATCH --cpus-per-task=12 9 | #SBATCH --time=24:00:00 10 | #SBATCH --array=0-4 # number of scenes, 0-7 for Replica, 0-2 for TUM_RGBD, 0-5 for ScanNet, 0-4 for ScanNet++ 11 | 12 | dataset="Replica" # set dataset 13 | if [ "$dataset" == "Replica" ]; then 14 | scenes=("room0" "room1" "room2" "office0" "office1" "office2" "office3" "office4") 15 | INPUT_PATH="data/Replica-SLAM" 16 | elif [ "$dataset" == "TUM_RGBD" ]; then 17 | scenes=("rgbd_dataset_freiburg1_desk" "rgbd_dataset_freiburg2_xyz" "rgbd_dataset_freiburg3_long_office_household") 18 | INPUT_PATH="data/TUM_RGBD-SLAM" 19 | elif [ "$dataset" == "ScanNet" ]; then 20 | scenes=("scene0000_00" "scene0059_00" "scene0106_00" "scene0169_00" "scene0181_00" "scene0207_00") 21 | INPUT_PATH="data/scannet/scans" 22 | elif [ "$dataset" == "ScanNetPP" ]; then 23 | scenes=("b20a261fdf" "8b5caf3398" "fb05e13ad1" "2e74812d00" "281bc17764") 24 | INPUT_PATH="data/scannetpp/data" 25 | else 26 | echo "Dataset not recognized!" 27 | exit 1 28 | fi 29 | 30 | OUTPUT_PATH="output" 31 | CONFIG_PATH="configs/${dataset}" 32 | EXPERIMENT_NAME="reproduce" 33 | SCENE_NAME=${scenes[$SLURM_ARRAY_TASK_ID]} 34 | 35 | source # please change accordingly 36 | conda activate gslam 37 | 38 | echo "Job for dataset: $dataset, scene: $SCENE_NAME" 39 | echo "Starting on: $(date)" 40 | echo "Running on node: $(hostname)" 41 | 42 | # Your command to run the experiment 43 | python run_slam.py "${CONFIG_PATH}/${SCENE_NAME}.yaml" \ 44 | --input_path "${INPUT_PATH}/${SCENE_NAME}" \ 45 | --output_path "${OUTPUT_PATH}/${dataset}/${EXPERIMENT_NAME}/${SCENE_NAME}" \ 46 | --group_name "${EXPERIMENT_NAME}" \ 47 | 48 | echo "Job for scene $SCENE_NAME completed." 49 | echo "Started at: $START_TIME" 50 | echo "Finished at: $(date)" -------------------------------------------------------------------------------- /src/entities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/src/entities/__init__.py -------------------------------------------------------------------------------- /src/entities/arguments.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from argparse import ArgumentParser, Namespace 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument( 34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true") 35 | else: 36 | group.add_argument( 37 | "--" + key, ("-" + key[0:1]), default=value, type=t) 38 | else: 39 | if t == bool: 40 | group.add_argument( 41 | "--" + key, default=value, action="store_true") 42 | else: 43 | group.add_argument("--" + key, default=value, type=t) 44 | 45 | def extract(self, args): 46 | group = GroupParams() 47 | for arg in vars(args).items(): 48 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 49 | setattr(group, arg[0], arg[1]) 50 | return group 51 | 52 | 53 | class OptimizationParams(ParamGroup): 54 | def __init__(self, parser): 55 | self.iterations = 30_000 56 | self.position_lr_init = 0.0001 57 | self.position_lr_final = 0.0000016 58 | self.position_lr_delay_mult = 0.01 59 | self.position_lr_max_steps = 30_000 60 | self.feature_lr = 0.0025 61 | self.opacity_lr = 0.05 62 | self.scaling_lr = 0.005 # before 0.005 63 | self.rotation_lr = 0.001 64 | self.percent_dense = 0.01 65 | self.lambda_dssim = 0.2 66 | self.densification_interval = 100 67 | self.opacity_reset_interval = 3000 68 | self.densify_from_iter = 500 69 | self.densify_until_iter = 15_000 70 | self.densify_grad_threshold = 0.0002 71 | super().__init__(parser, "Optimization Parameters") 72 | 73 | 74 | def get_combined_args(parser: ArgumentParser): 75 | cmdlne_string = sys.argv[1:] 76 | cfgfile_string = "Namespace()" 77 | args_cmdline = parser.parse_args(cmdlne_string) 78 | 79 | try: 80 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 81 | print("Looking for config file in", cfgfilepath) 82 | with open(cfgfilepath) as cfg_file: 83 | print("Config file found: {}".format(cfgfilepath)) 84 | cfgfile_string = cfg_file.read() 85 | except TypeError: 86 | print("Config file not found at") 87 | pass 88 | args_cfgfile = eval(cfgfile_string) 89 | 90 | merged_dict = vars(args_cfgfile).copy() 91 | for k, v in vars(args_cmdline).items(): 92 | if v is not None: 93 | merged_dict[k] = v 94 | return Namespace(**merged_dict) 95 | -------------------------------------------------------------------------------- /src/entities/datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import json 9 | import imageio 10 | 11 | 12 | class BaseDataset(torch.utils.data.Dataset): 13 | 14 | def __init__(self, dataset_config: dict): 15 | self.dataset_path = Path(dataset_config["input_path"]) 16 | self.frame_limit = dataset_config.get("frame_limit", -1) 17 | self.dataset_config = dataset_config 18 | self.height = dataset_config["H"] 19 | self.width = dataset_config["W"] 20 | self.fx = dataset_config["fx"] 21 | self.fy = dataset_config["fy"] 22 | self.cx = dataset_config["cx"] 23 | self.cy = dataset_config["cy"] 24 | 25 | self.depth_scale = dataset_config["depth_scale"] 26 | self.distortion = np.array( 27 | dataset_config['distortion']) if 'distortion' in dataset_config else None 28 | self.crop_edge = dataset_config['crop_edge'] if 'crop_edge' in dataset_config else 0 29 | if self.crop_edge: 30 | self.height -= 2 * self.crop_edge 31 | self.width -= 2 * self.crop_edge 32 | self.cx -= self.crop_edge 33 | self.cy -= self.crop_edge 34 | 35 | self.fovx = 2 * math.atan(self.width / (2 * self.fx)) 36 | self.fovy = 2 * math.atan(self.height / (2 * self.fy)) 37 | self.intrinsics = np.array( 38 | [[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1]]) 39 | 40 | self.color_paths = [] 41 | self.depth_paths = [] 42 | 43 | def __len__(self): 44 | return len(self.color_paths) if self.frame_limit < 0 else int(self.frame_limit) 45 | 46 | 47 | class Replica(BaseDataset): 48 | 49 | def __init__(self, dataset_config: dict): 50 | super().__init__(dataset_config) 51 | self.color_paths = sorted( 52 | list((self.dataset_path / "results").glob("frame*.jpg"))) 53 | self.depth_paths = sorted( 54 | list((self.dataset_path / "results").glob("depth*.png"))) 55 | self.load_poses(self.dataset_path / "traj.txt") 56 | print(f"Loaded {len(self.color_paths)} frames") 57 | 58 | def load_poses(self, path): 59 | self.poses = [] 60 | with open(path, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 64 | self.poses.append(c2w.astype(np.float32)) 65 | 66 | def __getitem__(self, index): 67 | color_data = cv2.imread(str(self.color_paths[index])) 68 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 69 | depth_data = cv2.imread( 70 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED) 71 | depth_data = depth_data.astype(np.float32) / self.depth_scale 72 | return index, color_data, depth_data, self.poses[index] 73 | 74 | 75 | class TUM_RGBD(BaseDataset): 76 | def __init__(self, dataset_config: dict): 77 | super().__init__(dataset_config) 78 | self.color_paths, self.depth_paths, self.poses = self.loadtum( 79 | self.dataset_path, frame_rate=32) 80 | 81 | def parse_list(self, filepath, skiprows=0): 82 | """ read list data """ 83 | return np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows) 84 | 85 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08): 86 | """ pair images, depths, and poses """ 87 | associations = [] 88 | for i, t in enumerate(tstamp_image): 89 | if tstamp_pose is None: 90 | j = np.argmin(np.abs(tstamp_depth - t)) 91 | if (np.abs(tstamp_depth[j] - t) < max_dt): 92 | associations.append((i, j)) 93 | else: 94 | j = np.argmin(np.abs(tstamp_depth - t)) 95 | k = np.argmin(np.abs(tstamp_pose - t)) 96 | if (np.abs(tstamp_depth[j] - t) < max_dt) and (np.abs(tstamp_pose[k] - t) < max_dt): 97 | associations.append((i, j, k)) 98 | return associations 99 | 100 | def loadtum(self, datapath, frame_rate=-1): 101 | """ read video data in tum-rgbd format """ 102 | if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')): 103 | pose_list = os.path.join(datapath, 'groundtruth.txt') 104 | elif os.path.isfile(os.path.join(datapath, 'pose.txt')): 105 | pose_list = os.path.join(datapath, 'pose.txt') 106 | 107 | image_list = os.path.join(datapath, 'rgb.txt') 108 | depth_list = os.path.join(datapath, 'depth.txt') 109 | 110 | image_data = self.parse_list(image_list) 111 | depth_data = self.parse_list(depth_list) 112 | pose_data = self.parse_list(pose_list, skiprows=1) 113 | pose_vecs = pose_data[:, 1:].astype(np.float64) 114 | 115 | tstamp_image = image_data[:, 0].astype(np.float64) 116 | tstamp_depth = depth_data[:, 0].astype(np.float64) 117 | tstamp_pose = pose_data[:, 0].astype(np.float64) 118 | associations = self.associate_frames( 119 | tstamp_image, tstamp_depth, tstamp_pose) 120 | 121 | indicies = [0] 122 | for i in range(1, len(associations)): 123 | t0 = tstamp_image[associations[indicies[-1]][0]] 124 | t1 = tstamp_image[associations[i][0]] 125 | if t1 - t0 > 1.0 / frame_rate: 126 | indicies += [i] 127 | 128 | images, poses, depths = [], [], [] 129 | inv_pose = None 130 | for ix in indicies: 131 | (i, j, k) = associations[ix] 132 | images += [os.path.join(datapath, image_data[i, 1])] 133 | depths += [os.path.join(datapath, depth_data[j, 1])] 134 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k]) 135 | if inv_pose is None: 136 | inv_pose = np.linalg.inv(c2w) 137 | c2w = np.eye(4) 138 | else: 139 | c2w = inv_pose@c2w 140 | poses += [c2w.astype(np.float32)] 141 | 142 | return images, depths, poses 143 | 144 | def pose_matrix_from_quaternion(self, pvec): 145 | """ convert 4x4 pose matrix to (t, q) """ 146 | from scipy.spatial.transform import Rotation 147 | 148 | pose = np.eye(4) 149 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix() 150 | pose[:3, 3] = pvec[:3] 151 | return pose 152 | 153 | def __getitem__(self, index): 154 | color_data = cv2.imread(str(self.color_paths[index])) 155 | if self.distortion is not None: 156 | color_data = cv2.undistort( 157 | color_data, self.intrinsics, self.distortion) 158 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 159 | 160 | depth_data = cv2.imread( 161 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED) 162 | depth_data = depth_data.astype(np.float32) / self.depth_scale 163 | edge = self.crop_edge 164 | if edge > 0: 165 | color_data = color_data[edge:-edge, edge:-edge] 166 | depth_data = depth_data[edge:-edge, edge:-edge] 167 | # Interpolate depth values for splatting 168 | return index, color_data, depth_data, self.poses[index] 169 | 170 | 171 | class ScanNet(BaseDataset): 172 | def __init__(self, dataset_config: dict): 173 | super().__init__(dataset_config) 174 | self.color_paths = sorted(list( 175 | (self.dataset_path / "color").glob("*.jpg")), key=lambda x: int(os.path.basename(x)[:-4])) 176 | self.depth_paths = sorted(list( 177 | (self.dataset_path / "depth").glob("*.png")), key=lambda x: int(os.path.basename(x)[:-4])) 178 | self.load_poses(self.dataset_path / "pose") 179 | 180 | def load_poses(self, path): 181 | self.poses = [] 182 | pose_paths = sorted(path.glob('*.txt'), 183 | key=lambda x: int(os.path.basename(x)[:-4])) 184 | for pose_path in pose_paths: 185 | with open(pose_path, "r") as f: 186 | lines = f.readlines() 187 | ls = [] 188 | for line in lines: 189 | ls.append(list(map(float, line.split(' ')))) 190 | c2w = np.array(ls).reshape(4, 4).astype(np.float32) 191 | self.poses.append(c2w) 192 | 193 | def __getitem__(self, index): 194 | color_data = cv2.imread(str(self.color_paths[index])) 195 | if self.distortion is not None: 196 | color_data = cv2.undistort( 197 | color_data, self.intrinsics, self.distortion) 198 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 199 | color_data = cv2.resize(color_data, (self.dataset_config["W"], self.dataset_config["H"])) 200 | 201 | depth_data = cv2.imread( 202 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED) 203 | depth_data = depth_data.astype(np.float32) / self.depth_scale 204 | edge = self.crop_edge 205 | if edge > 0: 206 | color_data = color_data[edge:-edge, edge:-edge] 207 | depth_data = depth_data[edge:-edge, edge:-edge] 208 | # Interpolate depth values for splatting 209 | return index, color_data, depth_data, self.poses[index] 210 | 211 | 212 | class ScanNetPP(BaseDataset): 213 | def __init__(self, dataset_config: dict): 214 | super().__init__(dataset_config) 215 | self.use_train_split = dataset_config["use_train_split"] 216 | self.train_test_split = json.load(open(f"{self.dataset_path}/dslr/train_test_lists.json", "r")) 217 | if self.use_train_split: 218 | self.image_names = self.train_test_split["train"] 219 | else: 220 | self.image_names = self.train_test_split["test"] 221 | self.load_data() 222 | 223 | def load_data(self): 224 | self.poses = [] 225 | cams_path = self.dataset_path / "dslr" / "nerfstudio" / "transforms_undistorted.json" 226 | cams_metadata = json.load(open(str(cams_path), "r")) 227 | frames_key = "frames" if self.use_train_split else "test_frames" 228 | frames_metadata = cams_metadata[frames_key] 229 | frame2idx = {frame["file_path"]: index for index, frame in enumerate(frames_metadata)} 230 | P = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).astype(np.float32) 231 | for image_name in self.image_names: 232 | frame_metadata = frames_metadata[frame2idx[image_name]] 233 | # if self.ignore_bad and frame_metadata['is_bad']: 234 | # continue 235 | color_path = str(self.dataset_path / "dslr" / "undistorted_images" / image_name) 236 | depth_path = str(self.dataset_path / "dslr" / "undistorted_depths" / image_name.replace('.JPG', '.png')) 237 | self.color_paths.append(color_path) 238 | self.depth_paths.append(depth_path) 239 | c2w = np.array(frame_metadata["transform_matrix"]).astype(np.float32) 240 | c2w = P @ c2w @ P.T 241 | self.poses.append(c2w) 242 | 243 | def __len__(self): 244 | if self.use_train_split: 245 | return len(self.image_names) if self.frame_limit < 0 else int(self.frame_limit) 246 | else: 247 | return len(self.image_names) 248 | 249 | def __getitem__(self, index): 250 | 251 | color_data = np.asarray(imageio.imread(self.color_paths[index]), dtype=float) 252 | color_data = cv2.resize(color_data, (self.width, self.height), interpolation=cv2.INTER_LINEAR) 253 | color_data = color_data.astype(np.uint8) 254 | 255 | depth_data = np.asarray(imageio.imread(self.depth_paths[index]), dtype=np.int64) 256 | depth_data = cv2.resize(depth_data.astype(float), (self.width, self.height), interpolation=cv2.INTER_NEAREST) 257 | depth_data = depth_data.astype(np.float32) / self.depth_scale 258 | return index, color_data, depth_data, self.poses[index] 259 | 260 | 261 | def get_dataset(dataset_name: str): 262 | if dataset_name == "replica": 263 | return Replica 264 | elif dataset_name == "tum_rgbd": 265 | return TUM_RGBD 266 | elif dataset_name == "scan_net": 267 | return ScanNet 268 | elif dataset_name == "scannetpp": 269 | return ScanNetPP 270 | raise NotImplementedError(f"Dataset {dataset_name} not implemented") 271 | -------------------------------------------------------------------------------- /src/entities/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import open3d as o3d 15 | import torch 16 | from plyfile import PlyData, PlyElement 17 | from simple_knn._C import distCUDA2 18 | from torch import nn 19 | 20 | from src.utils.gaussian_model_utils import (RGB2SH, build_scaling_rotation, 21 | get_expon_lr_func, inverse_sigmoid, 22 | strip_symmetric) 23 | 24 | 25 | class GaussianModel: 26 | def __init__(self, sh_degree: int = 3, isotropic=False): 27 | self.gaussian_param_names = [ 28 | "active_sh_degree", 29 | "xyz", 30 | "features_dc", 31 | "features_rest", 32 | "scaling", 33 | "rotation", 34 | "opacity", 35 | "max_radii2D", 36 | "xyz_gradient_accum", 37 | "denom", 38 | "spatial_lr_scale", 39 | "optimizer", 40 | ] 41 | self.max_sh_degree = sh_degree 42 | self.active_sh_degree = sh_degree # temp 43 | self._xyz = torch.empty(0).cuda() 44 | self._features_dc = torch.empty(0).cuda() 45 | self._features_rest = torch.empty(0).cuda() 46 | self._scaling = torch.empty(0).cuda() 47 | self._rotation = torch.empty(0, 4).cuda() 48 | self._opacity = torch.empty(0).cuda() 49 | self.max_radii2D = torch.empty(0) 50 | self.xyz_gradient_accum = torch.empty(0) 51 | self.denom = torch.empty(0) 52 | self.optimizer = None 53 | self.percent_dense = 0 54 | self.spatial_lr_scale = 1 55 | self.setup_functions() 56 | self.isotropic = isotropic 57 | 58 | def restore_from_params(self, params_dict, training_args): 59 | self.training_setup(training_args) 60 | self.densification_postfix( 61 | params_dict["xyz"], 62 | params_dict["features_dc"], 63 | params_dict["features_rest"], 64 | params_dict["opacity"], 65 | params_dict["scaling"], 66 | params_dict["rotation"]) 67 | 68 | def build_covariance_from_scaling_rotation(self, scaling, scaling_modifier, rotation): 69 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 70 | actual_covariance = L @ L.transpose(1, 2) 71 | symm = strip_symmetric(actual_covariance) 72 | return symm 73 | 74 | def setup_functions(self): 75 | self.scaling_activation = torch.exp 76 | self.scaling_inverse_activation = torch.log 77 | self.opacity_activation = torch.sigmoid 78 | self.inverse_opacity_activation = inverse_sigmoid 79 | self.rotation_activation = torch.nn.functional.normalize 80 | 81 | def capture_dict(self): 82 | return { 83 | "active_sh_degree": self.active_sh_degree, 84 | "xyz": self._xyz.clone().detach().cpu(), 85 | "features_dc": self._features_dc.clone().detach().cpu(), 86 | "features_rest": self._features_rest.clone().detach().cpu(), 87 | "scaling": self._scaling.clone().detach().cpu(), 88 | "rotation": self._rotation.clone().detach().cpu(), 89 | "opacity": self._opacity.clone().detach().cpu(), 90 | "max_radii2D": self.max_radii2D.clone().detach().cpu(), 91 | "xyz_gradient_accum": self.xyz_gradient_accum.clone().detach().cpu(), 92 | "denom": self.denom.clone().detach().cpu(), 93 | "spatial_lr_scale": self.spatial_lr_scale, 94 | "optimizer": self.optimizer.state_dict(), 95 | } 96 | 97 | def get_size(self): 98 | return self._xyz.shape[0] 99 | 100 | def get_scaling(self): 101 | if self.isotropic: 102 | scale = self.scaling_activation(self._scaling)[:, 0:1] # Extract the first column 103 | scales = scale.repeat(1, 3) # Replicate this column three times 104 | return scales 105 | return self.scaling_activation(self._scaling) 106 | 107 | def get_rotation(self): 108 | return self.rotation_activation(self._rotation) 109 | 110 | def get_xyz(self): 111 | return self._xyz 112 | 113 | def get_features(self): 114 | features_dc = self._features_dc 115 | features_rest = self._features_rest 116 | return torch.cat((features_dc, features_rest), dim=1) 117 | 118 | def get_opacity(self): 119 | return self.opacity_activation(self._opacity) 120 | 121 | def get_active_sh_degree(self): 122 | return self.active_sh_degree 123 | 124 | def get_covariance(self, scaling_modifier=1): 125 | return self.build_covariance_from_scaling_rotation(self.get_scaling(), scaling_modifier, self._rotation) 126 | 127 | def add_points(self, pcd: o3d.geometry.PointCloud, global_scale_init=True): 128 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 129 | fused_color = RGB2SH(torch.tensor( 130 | np.asarray(pcd.colors)).float().cuda()) 131 | features = (torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()) 132 | features[:, :3, 0] = fused_color 133 | features[:, 3:, 1:] = 0.0 134 | print("Number of added points: ", fused_point_cloud.shape[0]) 135 | 136 | if global_scale_init: 137 | global_points = torch.cat((self.get_xyz(),torch.from_numpy(np.asarray(pcd.points)).float().cuda())) 138 | dist2 = torch.clamp_min(distCUDA2(global_points), 0.0000001) 139 | dist2 = dist2[self.get_size():] 140 | else: 141 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 142 | scales = torch.log(1.0 * torch.sqrt(dist2))[..., None].repeat(1, 3) 143 | # scales = torch.log(0.001 * torch.ones_like(dist2))[..., None].repeat(1, 3) 144 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 145 | rots[:, 0] = 1 146 | opacities = inverse_sigmoid(0.5 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 147 | new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 148 | new_features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) 149 | new_features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) 150 | new_scaling = nn.Parameter(scales.requires_grad_(True)) 151 | new_rotation = nn.Parameter(rots.requires_grad_(True)) 152 | new_opacities = nn.Parameter(opacities.requires_grad_(True)) 153 | self.densification_postfix( 154 | new_xyz, 155 | new_features_dc, 156 | new_features_rest, 157 | new_opacities, 158 | new_scaling, 159 | new_rotation, 160 | ) 161 | 162 | def training_setup(self, training_args): 163 | self.percent_dense = training_args.percent_dense 164 | self.xyz_gradient_accum = torch.zeros( 165 | (self.get_xyz().shape[0], 1), device="cuda" 166 | ) 167 | self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda") 168 | 169 | params = [ 170 | {"params": [self._xyz], "lr": training_args.position_lr_init, "name": "xyz"}, 171 | {"params": [self._features_dc], "lr": training_args.feature_lr, "name": "f_dc"}, 172 | {"params": [self._features_rest], "lr": training_args.feature_lr / 20.0, "name": "f_rest"}, 173 | {"params": [self._opacity], "lr": training_args.opacity_lr, "name": "opacity"}, 174 | {"params": [self._scaling], "lr": training_args.scaling_lr, "name": "scaling"}, 175 | {"params": [self._rotation], "lr": training_args.rotation_lr, "name": "rotation"}, 176 | ] 177 | 178 | self.optimizer = torch.optim.Adam(params, lr=0.0, eps=1e-15) 179 | self.xyz_scheduler_args = get_expon_lr_func( 180 | lr_init=training_args.position_lr_init * self.spatial_lr_scale, 181 | lr_final=training_args.position_lr_final * self.spatial_lr_scale, 182 | lr_delay_mult=training_args.position_lr_delay_mult, 183 | max_steps=training_args.position_lr_max_steps, 184 | ) 185 | 186 | def training_setup_camera(self, cam_rot, cam_trans, cfg): 187 | self.xyz_gradient_accum = torch.zeros( 188 | (self.get_xyz().shape[0], 1), device="cuda" 189 | ) 190 | self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda") 191 | params = [ 192 | {"params": [self._xyz], "lr": 0.0, "name": "xyz"}, 193 | {"params": [self._features_dc], "lr": 0.0, "name": "f_dc"}, 194 | {"params": [self._features_rest], "lr": 0.0, "name": "f_rest"}, 195 | {"params": [self._opacity], "lr": 0.0, "name": "opacity"}, 196 | {"params": [self._scaling], "lr": 0.0, "name": "scaling"}, 197 | {"params": [self._rotation], "lr": 0.0, "name": "rotation"}, 198 | {"params": [cam_rot], "lr": cfg["cam_rot_lr"], 199 | "name": "cam_unnorm_rot"}, 200 | {"params": [cam_trans], "lr": cfg["cam_trans_lr"], 201 | "name": "cam_trans"}, 202 | ] 203 | self.optimizer = torch.optim.Adam(params, amsgrad=True) 204 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 205 | self.optimizer, "min", factor=0.98, patience=10, verbose=False) 206 | 207 | def construct_list_of_attributes(self): 208 | l = ["x", "y", "z", "nx", "ny", "nz"] 209 | # All channels except the 3 DC 210 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): 211 | l.append("f_dc_{}".format(i)) 212 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): 213 | l.append("f_rest_{}".format(i)) 214 | l.append("opacity") 215 | for i in range(self._scaling.shape[1]): 216 | l.append("scale_{}".format(i)) 217 | for i in range(self._rotation.shape[1]): 218 | l.append("rot_{}".format(i)) 219 | return l 220 | 221 | def save_ply(self, path): 222 | Path(path).parent.mkdir(parents=True, exist_ok=True) 223 | 224 | xyz = self._xyz.detach().cpu().numpy() 225 | normals = np.zeros_like(xyz) 226 | f_dc = ( 227 | self._features_dc.detach() 228 | .transpose(1, 2) 229 | .flatten(start_dim=1) 230 | .contiguous() 231 | .cpu() 232 | .numpy()) 233 | f_rest = ( 234 | self._features_rest.detach() 235 | .transpose(1, 2) 236 | .flatten(start_dim=1) 237 | .contiguous() 238 | .cpu() 239 | .numpy()) 240 | opacities = self._opacity.detach().cpu().numpy() 241 | if self.isotropic: 242 | # tile into shape (P, 3) 243 | scale = np.tile(self._scaling.detach().cpu().numpy()[:, 0].reshape(-1, 1), (1, 3)) 244 | else: 245 | scale = self._scaling.detach().cpu().numpy() 246 | rotation = self._rotation.detach().cpu().numpy() 247 | 248 | dtype_full = [(attribute, "f4") for attribute in self.construct_list_of_attributes()] 249 | 250 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 251 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 252 | elements[:] = list(map(tuple, attributes)) 253 | el = PlyElement.describe(elements, "vertex") 254 | PlyData([el]).write(path) 255 | 256 | def load_ply(self, path): 257 | plydata = PlyData.read(path) 258 | 259 | xyz = np.stack(( 260 | np.asarray(plydata.elements[0]["x"]), 261 | np.asarray(plydata.elements[0]["y"]), 262 | np.asarray(plydata.elements[0]["z"])), 263 | axis=1) 264 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 265 | 266 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 267 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 268 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 269 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 270 | 271 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 272 | extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) 273 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 274 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 275 | for idx, attr_name in enumerate(extra_f_names): 276 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 277 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 278 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 279 | 280 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 281 | scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) 282 | scales = np.zeros((xyz.shape[0], len(scale_names))) 283 | for idx, attr_name in enumerate(scale_names): 284 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 285 | 286 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 287 | rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) 288 | rots = np.zeros((xyz.shape[0], len(rot_names))) 289 | for idx, attr_name in enumerate(rot_names): 290 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 291 | 292 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 293 | self._features_dc = nn.Parameter( 294 | torch.tensor(features_dc, dtype=torch.float, device="cuda") 295 | .transpose(1, 2).contiguous().requires_grad_(True)) 296 | self._features_rest = nn.Parameter( 297 | torch.tensor(features_extra, dtype=torch.float, device="cuda") 298 | .transpose(1, 2) 299 | .contiguous() 300 | .requires_grad_(True) 301 | ) 302 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 303 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 304 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 305 | 306 | self.active_sh_degree = self.max_sh_degree 307 | 308 | def replace_tensor_to_optimizer(self, tensor, name): 309 | optimizable_tensors = {} 310 | for group in self.optimizer.param_groups: 311 | if group["name"] == name: 312 | stored_state = self.optimizer.state.get(group["params"][0], None) 313 | stored_state["exp_avg"] = torch.zeros_like(tensor) 314 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 315 | 316 | del self.optimizer.state[group["params"][0]] 317 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 318 | self.optimizer.state[group["params"][0]] = stored_state 319 | 320 | optimizable_tensors[group["name"]] = group["params"][0] 321 | return optimizable_tensors 322 | 323 | def _prune_optimizer(self, mask): 324 | optimizable_tensors = {} 325 | for group in self.optimizer.param_groups: 326 | stored_state = self.optimizer.state.get(group["params"][0], None) 327 | if stored_state is not None: 328 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 329 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 330 | 331 | del self.optimizer.state[group["params"][0]] 332 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 333 | self.optimizer.state[group["params"][0]] = stored_state 334 | optimizable_tensors[group["name"]] = group["params"][0] 335 | else: 336 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 337 | optimizable_tensors[group["name"]] = group["params"][0] 338 | return optimizable_tensors 339 | 340 | def prune_points(self, mask): 341 | valid_points_mask = ~mask 342 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 343 | 344 | self._xyz = optimizable_tensors["xyz"] 345 | self._features_dc = optimizable_tensors["f_dc"] 346 | self._features_rest = optimizable_tensors["f_rest"] 347 | self._opacity = optimizable_tensors["opacity"] 348 | self._scaling = optimizable_tensors["scaling"] 349 | self._rotation = optimizable_tensors["rotation"] 350 | 351 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 352 | 353 | self.denom = self.denom[valid_points_mask] 354 | self.max_radii2D = self.max_radii2D[valid_points_mask] 355 | 356 | def cat_tensors_to_optimizer(self, tensors_dict): 357 | optimizable_tensors = {} 358 | for group in self.optimizer.param_groups: 359 | assert len(group["params"]) == 1 360 | extension_tensor = tensors_dict[group["name"]] 361 | stored_state = self.optimizer.state.get(group["params"][0], None) 362 | if stored_state is not None: 363 | stored_state["exp_avg"] = torch.cat( 364 | (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 365 | stored_state["exp_avg_sq"] = torch.cat( 366 | (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 367 | 368 | del self.optimizer.state[group["params"][0]] 369 | group["params"][0] = nn.Parameter( 370 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 371 | self.optimizer.state[group["params"][0]] = stored_state 372 | 373 | optimizable_tensors[group["name"]] = group["params"][0] 374 | else: 375 | group["params"][0] = nn.Parameter( 376 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 377 | optimizable_tensors[group["name"]] = group["params"][0] 378 | 379 | return optimizable_tensors 380 | 381 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, 382 | new_opacities, new_scaling, new_rotation): 383 | d = { 384 | "xyz": new_xyz, 385 | "f_dc": new_features_dc, 386 | "f_rest": new_features_rest, 387 | "opacity": new_opacities, 388 | "scaling": new_scaling, 389 | "rotation": new_rotation, 390 | } 391 | 392 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 393 | self._xyz = optimizable_tensors["xyz"] 394 | self._features_dc = optimizable_tensors["f_dc"] 395 | self._features_rest = optimizable_tensors["f_rest"] 396 | self._opacity = optimizable_tensors["opacity"] 397 | self._scaling = optimizable_tensors["scaling"] 398 | self._rotation = optimizable_tensors["rotation"] 399 | 400 | self.xyz_gradient_accum = torch.zeros((self.get_xyz().shape[0], 1), device="cuda") 401 | self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda") 402 | self.max_radii2D = torch.zeros( 403 | (self.get_xyz().shape[0]), device="cuda") 404 | 405 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 406 | self.xyz_gradient_accum[update_filter] += torch.norm( 407 | viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True) 408 | self.denom[update_filter] += 1 409 | -------------------------------------------------------------------------------- /src/entities/gaussian_slam.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Gaussian-SLAM class, which is responsible for controlling Mapper and Tracker 2 | It also decides when to start a new submap and when to update the estimated camera poses. 3 | """ 4 | import os 5 | import pprint 6 | from argparse import ArgumentParser 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from src.entities.arguments import OptimizationParams 14 | from src.entities.datasets import get_dataset 15 | from src.entities.gaussian_model import GaussianModel 16 | from src.entities.mapper import Mapper 17 | from src.entities.tracker import Tracker 18 | from src.entities.logger import Logger 19 | from src.utils.io_utils import save_dict_to_ckpt, save_dict_to_yaml 20 | from src.utils.mapper_utils import exceeds_motion_thresholds 21 | from src.utils.utils import np2torch, setup_seed, torch2np 22 | from src.utils.vis_utils import * # noqa - needed for debugging 23 | 24 | 25 | class GaussianSLAM(object): 26 | 27 | def __init__(self, config: dict) -> None: 28 | 29 | self._setup_output_path(config) 30 | self.device = "cuda" 31 | self.config = config 32 | 33 | self.scene_name = config["data"]["scene_name"] 34 | self.dataset_name = config["dataset_name"] 35 | self.dataset = get_dataset(config["dataset_name"])({**config["data"], **config["cam"]}) 36 | 37 | n_frames = len(self.dataset) 38 | frame_ids = list(range(n_frames)) 39 | self.mapping_frame_ids = frame_ids[::config["mapping"]["map_every"]] + [n_frames - 1] 40 | 41 | self.estimated_c2ws = torch.empty(len(self.dataset), 4, 4) 42 | self.estimated_c2ws[0] = torch.from_numpy(self.dataset[0][3]) 43 | 44 | save_dict_to_yaml(config, "config.yaml", directory=self.output_path) 45 | 46 | self.submap_using_motion_heuristic = config["mapping"]["submap_using_motion_heuristic"] 47 | 48 | self.keyframes_info = {} 49 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters")) 50 | 51 | if self.submap_using_motion_heuristic: 52 | self.new_submap_frame_ids = [0] 53 | else: 54 | self.new_submap_frame_ids = frame_ids[::config["mapping"]["new_submap_every"]] + [n_frames - 1] 55 | self.new_submap_frame_ids.pop(0) 56 | 57 | self.logger = Logger(self.output_path, config["use_wandb"]) 58 | self.mapper = Mapper(config["mapping"], self.dataset, self.logger) 59 | self.tracker = Tracker(config["tracking"], self.dataset, self.logger) 60 | 61 | print('Tracking config') 62 | pprint.PrettyPrinter().pprint(config["tracking"]) 63 | print('Mapping config') 64 | pprint.PrettyPrinter().pprint(config["mapping"]) 65 | 66 | def _setup_output_path(self, config: dict) -> None: 67 | """ Sets up the output path for saving results based on the provided configuration. If the output path is not 68 | specified in the configuration, it creates a new directory with a timestamp. 69 | Args: 70 | config: A dictionary containing the experiment configuration including data and output path information. 71 | """ 72 | if "output_path" not in config["data"]: 73 | output_path = Path(config["data"]["output_path"]) 74 | self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 75 | self.output_path = output_path / self.timestamp 76 | else: 77 | self.output_path = Path(config["data"]["output_path"]) 78 | self.output_path.mkdir(exist_ok=True, parents=True) 79 | os.makedirs(self.output_path / "mapping_vis", exist_ok=True) 80 | os.makedirs(self.output_path / "tracking_vis", exist_ok=True) 81 | 82 | def should_start_new_submap(self, frame_id: int) -> bool: 83 | """ Determines whether a new submap should be started based on the motion heuristic or specific frame IDs. 84 | Args: 85 | frame_id: The ID of the current frame being processed. 86 | Returns: 87 | A boolean indicating whether to start a new submap. 88 | """ 89 | if self.submap_using_motion_heuristic: 90 | if exceeds_motion_thresholds( 91 | self.estimated_c2ws[frame_id], self.estimated_c2ws[self.new_submap_frame_ids[-1]], 92 | rot_thre=50, trans_thre=0.5): 93 | return True 94 | elif frame_id in self.new_submap_frame_ids: 95 | return True 96 | return False 97 | 98 | def start_new_submap(self, frame_id: int, gaussian_model: GaussianModel) -> None: 99 | """ Initializes a new submap, saving the current submap's checkpoint and resetting the Gaussian model. 100 | This function updates the submap count and optionally marks the current frame ID for new submap initiation. 101 | Args: 102 | frame_id: The ID of the current frame at which the new submap is started. 103 | gaussian_model: The current GaussianModel instance to capture and reset for the new submap. 104 | Returns: 105 | A new, reset GaussianModel instance for the new submap. 106 | """ 107 | gaussian_params = gaussian_model.capture_dict() 108 | submap_ckpt_name = str(self.submap_id).zfill(6) 109 | submap_ckpt = { 110 | "gaussian_params": gaussian_params, 111 | "submap_keyframes": sorted(list(self.keyframes_info.keys())) 112 | } 113 | save_dict_to_ckpt( 114 | submap_ckpt, f"{submap_ckpt_name}.ckpt", directory=self.output_path / "submaps") 115 | gaussian_model = GaussianModel(0) 116 | gaussian_model.training_setup(self.opt) 117 | self.mapper.keyframes = [] 118 | self.keyframes_info = {} 119 | if self.submap_using_motion_heuristic: 120 | self.new_submap_frame_ids.append(frame_id) 121 | self.mapping_frame_ids.append(frame_id) 122 | self.submap_id += 1 123 | return gaussian_model 124 | 125 | def run(self) -> None: 126 | """ Starts the main program flow for Gaussian-SLAM, including tracking and mapping. """ 127 | setup_seed(self.config["seed"]) 128 | gaussian_model = GaussianModel(0) 129 | gaussian_model.training_setup(self.opt) 130 | self.submap_id = 0 131 | 132 | for frame_id in range(len(self.dataset)): 133 | 134 | if frame_id in [0, 1]: 135 | estimated_c2w = self.dataset[frame_id][-1] 136 | else: 137 | estimated_c2w = self.tracker.track( 138 | frame_id, gaussian_model, 139 | torch2np(self.estimated_c2ws[torch.tensor([0, frame_id - 2, frame_id - 1])])) 140 | self.estimated_c2ws[frame_id] = np2torch(estimated_c2w) 141 | 142 | # Reinitialize gaussian model for new segment 143 | if self.should_start_new_submap(frame_id): 144 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path) 145 | gaussian_model = self.start_new_submap(frame_id, gaussian_model) 146 | 147 | if frame_id in self.mapping_frame_ids: 148 | print("\nMapping frame", frame_id) 149 | gaussian_model.training_setup(self.opt) 150 | estimate_c2w = torch2np(self.estimated_c2ws[frame_id]) 151 | new_submap = not bool(self.keyframes_info) 152 | opt_dict = self.mapper.map(frame_id, estimate_c2w, gaussian_model, new_submap) 153 | 154 | # Keyframes info update 155 | self.keyframes_info[frame_id] = { 156 | "keyframe_id": len(self.keyframes_info.keys()), 157 | "opt_dict": opt_dict 158 | } 159 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path) 160 | -------------------------------------------------------------------------------- /src/entities/logger.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Logger class, which is responsible for logging for both Mapper and the Tracker """ 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import wandb 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, output_path: Union[Path, str], use_wandb=False) -> None: 14 | self.output_path = Path(output_path) 15 | (self.output_path / "mapping_vis").mkdir(exist_ok=True, parents=True) 16 | self.use_wandb = use_wandb 17 | 18 | def log_tracking_iteration(self, frame_id, cur_pose, gt_quat, gt_trans, total_loss, 19 | color_loss, depth_loss, iter, num_iters, 20 | wandb_output=False, print_output=False) -> None: 21 | """ Logs tracking iteration metrics including pose error, losses, and optionally reports to Weights & Biases. 22 | Logs the error between the current pose estimate and ground truth quaternion and translation, 23 | as well as various loss metrics. Can output to wandb if enabled and specified, and print to console. 24 | Args: 25 | frame_id: Identifier for the current frame. 26 | cur_pose: The current estimated pose as a tensor (quaternion + translation). 27 | gt_quat: Ground truth quaternion. 28 | gt_trans: Ground truth translation. 29 | total_loss: Total computed loss for the current iteration. 30 | color_loss: Computed color loss for the current iteration. 31 | depth_loss: Computed depth loss for the current iteration. 32 | iter: The current iteration number. 33 | num_iters: The total number of iterations planned. 34 | wandb_output: Whether to output the log to wandb. 35 | print_output: Whether to print the log output. 36 | """ 37 | 38 | quad_err = torch.abs(cur_pose[:4] - gt_quat).mean().item() 39 | trans_err = torch.abs(cur_pose[4:] - gt_trans).mean().item() 40 | if self.use_wandb and wandb_output: 41 | wandb.log( 42 | { 43 | "Tracking/idx": frame_id, 44 | "Tracking/cam_quad_err": quad_err, 45 | "Tracking/cam_position_err": trans_err, 46 | "Tracking/total_loss": total_loss.item(), 47 | "Tracking/color_loss": color_loss.item(), 48 | "Tracking/depth_loss": depth_loss.item(), 49 | "Tracking/num_iters": num_iters, 50 | }) 51 | if iter == num_iters - 1: 52 | msg = f"frame_id: {frame_id}, cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f} " 53 | else: 54 | msg = f"iter: {iter}, color_loss: {color_loss.item():.5f}, depth_loss: {depth_loss.item():.5f} " 55 | msg = msg + f", cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f}" 56 | if print_output: 57 | print(msg, flush=True) 58 | 59 | def log_mapping_iteration(self, frame_id, new_pts_num, model_size, iter_opt_time, opt_dict: dict) -> None: 60 | """ Logs mapping iteration metrics including the number of new points, model size, and optimization times, 61 | and optionally reports to Weights & Biases (wandb). 62 | Args: 63 | frame_id: Identifier for the current frame. 64 | new_pts_num: The number of new points added in the current mapping iteration. 65 | model_size: The total size of the model after the current mapping iteration. 66 | iter_opt_time: Time taken per optimization iteration. 67 | opt_dict: A dictionary containing optimization metrics such as PSNR, color loss, and depth loss. 68 | """ 69 | if self.use_wandb: 70 | wandb.log({"Mapping/idx": frame_id, 71 | "Mapping/num_total_gs": model_size, 72 | "Mapping/num_new_gs": new_pts_num, 73 | "Mapping/per_iteration_time": iter_opt_time, 74 | "Mapping/psnr_render": opt_dict["psnr_render"], 75 | "Mapping/color_loss": opt_dict[frame_id]["color_loss"], 76 | "Mapping/depth_loss": opt_dict[frame_id]["depth_loss"]}) 77 | 78 | def vis_mapping_iteration(self, frame_id, iter, color, depth, gt_color, gt_depth, seeding_mask=None) -> None: 79 | """ 80 | Visualization of depth, color images and save to file. 81 | 82 | Args: 83 | frame_id (int): current frame index. 84 | iter (int): the iteration number. 85 | save_rendered_image (bool): whether to save the rgb image in separate folder 86 | img_dir (str): the directory to save the visualization. 87 | seeding_mask: used in mapper when adding gaussians, if not none. 88 | """ 89 | gt_depth_np = gt_depth.cpu().numpy() 90 | gt_color_np = gt_color.cpu().numpy() 91 | 92 | depth_np = depth.detach().cpu().numpy() 93 | color = torch.round(color * 255.0) / 255.0 94 | color_np = color.detach().cpu().numpy() 95 | depth_residual = np.abs(gt_depth_np - depth_np) 96 | depth_residual[gt_depth_np == 0.0] = 0.0 97 | # make errors >=5cm noticeable 98 | depth_residual = np.clip(depth_residual, 0.0, 0.05) 99 | 100 | color_residual = np.abs(gt_color_np - color_np) 101 | color_residual[np.squeeze(gt_depth_np == 0.0)] = 0.0 102 | 103 | # Determine Aspect Ratio and Figure Size 104 | aspect_ratio = color.shape[1] / color.shape[0] 105 | fig_height = 8 106 | # Adjust the multiplier as needed for better spacing 107 | fig_width = fig_height * aspect_ratio * 1.2 108 | 109 | fig, axs = plt.subplots(2, 3, figsize=(fig_width, fig_height)) 110 | axs[0, 0].imshow(gt_depth_np, cmap="jet", vmin=0, vmax=6) 111 | axs[0, 0].set_title('Input Depth', fontsize=16) 112 | axs[0, 0].set_xticks([]) 113 | axs[0, 0].set_yticks([]) 114 | axs[0, 1].imshow(depth_np, cmap="jet", vmin=0, vmax=6) 115 | axs[0, 1].set_title('Rendered Depth', fontsize=16) 116 | axs[0, 1].set_xticks([]) 117 | axs[0, 1].set_yticks([]) 118 | axs[0, 2].imshow(depth_residual, cmap="plasma") 119 | axs[0, 2].set_title('Depth Residual', fontsize=16) 120 | axs[0, 2].set_xticks([]) 121 | axs[0, 2].set_yticks([]) 122 | gt_color_np = np.clip(gt_color_np, 0, 1) 123 | color_np = np.clip(color_np, 0, 1) 124 | color_residual = np.clip(color_residual, 0, 1) 125 | axs[1, 0].imshow(gt_color_np, cmap="plasma") 126 | axs[1, 0].set_title('Input RGB', fontsize=16) 127 | axs[1, 0].set_xticks([]) 128 | axs[1, 0].set_yticks([]) 129 | axs[1, 1].imshow(color_np, cmap="plasma") 130 | axs[1, 1].set_title('Rendered RGB', fontsize=16) 131 | axs[1, 1].set_xticks([]) 132 | axs[1, 1].set_yticks([]) 133 | if seeding_mask is not None: 134 | axs[1, 2].imshow(seeding_mask, cmap="gray") 135 | axs[1, 2].set_title('Densification Mask', fontsize=16) 136 | axs[1, 2].set_xticks([]) 137 | axs[1, 2].set_yticks([]) 138 | else: 139 | axs[1, 2].imshow(color_residual, cmap="plasma") 140 | axs[1, 2].set_title('RGB Residual', fontsize=16) 141 | axs[1, 2].set_xticks([]) 142 | axs[1, 2].set_yticks([]) 143 | 144 | for ax in axs.flatten(): 145 | ax.axis('off') 146 | fig.tight_layout() 147 | plt.subplots_adjust(top=0.90) # Adjust top margin 148 | fig_name = str(self.output_path / "mapping_vis" / f'{frame_id:04d}_{iter:04d}.jpg') 149 | fig_title = f"Mapper Color/Depth at frame {frame_id:04d} iters {iter:04d}" 150 | plt.suptitle(fig_title, y=0.98, fontsize=20) 151 | plt.savefig(fig_name, dpi=250, bbox_inches='tight') 152 | plt.clf() 153 | plt.close() 154 | if self.use_wandb: 155 | log_title = "Mapping_vis/" + f'{frame_id:04d}_{iter:04d}' 156 | wandb.log({log_title: [wandb.Image(fig_name)]}) 157 | print(f"Saved rendering vis of color/depth at {frame_id:04d}_{iter:04d}.jpg") 158 | -------------------------------------------------------------------------------- /src/entities/losses.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | def l1_loss(network_output: torch.Tensor, gt: torch.Tensor, agg="mean") -> torch.Tensor: 9 | """ 10 | Computes the L1 loss, which is the mean absolute error between the network output and the ground truth. 11 | 12 | Args: 13 | network_output: The output from the network. 14 | gt: The ground truth tensor. 15 | agg: The aggregation method to be used. Defaults to "mean". 16 | Returns: 17 | The computed L1 loss. 18 | """ 19 | l1_loss = torch.abs(network_output - gt) 20 | if agg == "mean": 21 | return l1_loss.mean() 22 | elif agg == "sum": 23 | return l1_loss.sum() 24 | elif agg == "none": 25 | return l1_loss 26 | else: 27 | raise ValueError("Invalid aggregation method.") 28 | 29 | 30 | def gaussian(window_size: int, sigma: float) -> torch.Tensor: 31 | """ 32 | Creates a 1D Gaussian kernel. 33 | 34 | Args: 35 | window_size: The size of the window for the Gaussian kernel. 36 | sigma: The standard deviation of the Gaussian kernel. 37 | 38 | Returns: 39 | The 1D Gaussian kernel. 40 | """ 41 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / 42 | float(2 * sigma ** 2)) for x in range(window_size)]) 43 | return gauss / gauss.sum() 44 | 45 | 46 | def create_window(window_size: int, channel: int) -> Variable: 47 | """ 48 | Creates a 2D Gaussian window/kernel for SSIM computation. 49 | 50 | Args: 51 | window_size: The size of the window to be created. 52 | channel: The number of channels in the image. 53 | 54 | Returns: 55 | A 2D Gaussian window expanded to match the number of channels. 56 | """ 57 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 58 | _2D_window = _1D_window.mm( 59 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 60 | window = Variable(_2D_window.expand( 61 | channel, 1, window_size, window_size).contiguous()) 62 | return window 63 | 64 | 65 | def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int = 11, size_average: bool = True) -> torch.Tensor: 66 | """ 67 | Computes the Structural Similarity Index (SSIM) between two images. 68 | 69 | Args: 70 | img1: The first image. 71 | img2: The second image. 72 | window_size: The size of the window to be used in SSIM computation. Defaults to 11. 73 | size_average: If True, averages the SSIM over all pixels. Defaults to True. 74 | 75 | Returns: 76 | The computed SSIM value. 77 | """ 78 | channel = img1.size(-3) 79 | window = create_window(window_size, channel) 80 | 81 | if img1.is_cuda: 82 | window = window.cuda(img1.get_device()) 83 | window = window.type_as(img1) 84 | 85 | return _ssim(img1, img2, window, window_size, channel, size_average) 86 | 87 | 88 | def _ssim(img1: torch.Tensor, img2: torch.Tensor, window: Variable, window_size: int, 89 | channel: int, size_average: bool = True) -> torch.Tensor: 90 | """ 91 | Internal function to compute the Structural Similarity Index (SSIM) between two images. 92 | 93 | Args: 94 | img1: The first image. 95 | img2: The second image. 96 | window: The Gaussian window/kernel for SSIM computation. 97 | window_size: The size of the window to be used in SSIM computation. 98 | channel: The number of channels in the image. 99 | size_average: If True, averages the SSIM over all pixels. 100 | 101 | Returns: 102 | The computed SSIM value. 103 | """ 104 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 105 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 106 | 107 | mu1_sq = mu1.pow(2) 108 | mu2_sq = mu2.pow(2) 109 | mu1_mu2 = mu1 * mu2 110 | 111 | sigma1_sq = F.conv2d(img1 * img1, window, 112 | padding=window_size // 2, groups=channel) - mu1_sq 113 | sigma2_sq = F.conv2d(img2 * img2, window, 114 | padding=window_size // 2, groups=channel) - mu2_sq 115 | sigma12 = F.conv2d(img1 * img2, window, 116 | padding=window_size // 2, groups=channel) - mu1_mu2 117 | 118 | C1 = 0.01 ** 2 119 | C2 = 0.03 ** 2 120 | 121 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 122 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 123 | 124 | if size_average: 125 | return ssim_map.mean() 126 | else: 127 | return ssim_map.mean(1).mean(1).mean(1) 128 | 129 | 130 | def isotropic_loss(scaling: torch.Tensor) -> torch.Tensor: 131 | """ 132 | Computes loss enforcing isotropic scaling for the 3D Gaussians 133 | Args: 134 | scaling: scaling tensor of 3D Gaussians of shape (n, 3) 135 | Returns: 136 | The computed isotropic loss 137 | """ 138 | mean_scaling = scaling.mean(dim=1, keepdim=True) 139 | isotropic_diff = torch.abs(scaling - mean_scaling * torch.ones_like(scaling)) 140 | return isotropic_diff.mean() 141 | -------------------------------------------------------------------------------- /src/entities/mapper.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Mapper class, which is responsible scene mapping: Paragraph 3.2 """ 2 | import time 3 | from argparse import ArgumentParser 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | 9 | from src.entities.arguments import OptimizationParams 10 | from src.entities.datasets import TUM_RGBD, BaseDataset, ScanNet 11 | from src.entities.gaussian_model import GaussianModel 12 | from src.entities.logger import Logger 13 | from src.entities.losses import isotropic_loss, l1_loss, ssim 14 | from src.utils.mapper_utils import (calc_psnr, compute_camera_frustum_corners, 15 | compute_frustum_point_ids, 16 | compute_new_points_ids, 17 | compute_opt_views_distribution, 18 | create_point_cloud, geometric_edge_mask, 19 | sample_pixels_based_on_gradient) 20 | from src.utils.utils import (get_render_settings, np2ptcloud, np2torch, 21 | render_gaussian_model, torch2np) 22 | from src.utils.vis_utils import * # noqa - needed for debugging 23 | 24 | 25 | class Mapper(object): 26 | def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None: 27 | """ Sets up the mapper parameters 28 | Args: 29 | config: configuration of the mapper 30 | dataset: The dataset object used for extracting camera parameters and reading the data 31 | logger: The logger object used for logging the mapping process and saving visualizations 32 | """ 33 | self.config = config 34 | self.logger = logger 35 | self.dataset = dataset 36 | self.iterations = config["iterations"] 37 | self.new_submap_iterations = config["new_submap_iterations"] 38 | self.new_submap_points_num = config["new_submap_points_num"] 39 | self.new_submap_gradient_points_num = config["new_submap_gradient_points_num"] 40 | self.new_frame_sample_size = config["new_frame_sample_size"] 41 | self.new_points_radius = config["new_points_radius"] 42 | self.alpha_thre = config["alpha_thre"] 43 | self.pruning_thre = config["pruning_thre"] 44 | self.current_view_opt_iterations = config["current_view_opt_iterations"] 45 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters")) 46 | self.keyframes = [] 47 | 48 | def compute_seeding_mask(self, gaussian_model: GaussianModel, keyframe: dict, new_submap: bool) -> np.ndarray: 49 | """ 50 | Computes a binary mask to identify regions within a keyframe where new Gaussian models should be seeded 51 | based on alpha masks or color gradient 52 | Args: 53 | gaussian_model: The current submap 54 | keyframe (dict): Keyframe dict containing color, depth, and render settings 55 | new_submap (bool): A boolean indicating whether the seeding is occurring in current submap or a new submap 56 | Returns: 57 | np.ndarray: A binary mask of shpae (H, W) indicates regions suitable for seeding new 3D Gaussian models 58 | """ 59 | seeding_mask = None 60 | if new_submap: 61 | color_for_mask = (torch2np(keyframe["color"].permute(1, 2, 0)) * 255).astype(np.uint8) 62 | seeding_mask = geometric_edge_mask(color_for_mask, RGB=True) 63 | else: 64 | render_dict = render_gaussian_model(gaussian_model, keyframe["render_settings"]) 65 | alpha_mask = (render_dict["alpha"] < self.alpha_thre) 66 | gt_depth_tensor = keyframe["depth"][None] 67 | depth_error = torch.abs(gt_depth_tensor - render_dict["depth"]) * (gt_depth_tensor > 0) 68 | depth_error_mask = (render_dict["depth"] > gt_depth_tensor) * (depth_error > 40 * depth_error.median()) 69 | seeding_mask = alpha_mask | depth_error_mask 70 | seeding_mask = torch2np(seeding_mask[0]) 71 | return seeding_mask 72 | 73 | def seed_new_gaussians(self, gt_color: np.ndarray, gt_depth: np.ndarray, intrinsics: np.ndarray, 74 | estimate_c2w: np.ndarray, seeding_mask: np.ndarray, is_new_submap: bool) -> np.ndarray: 75 | """ 76 | Seeds means for the new 3D Gaussian based on ground truth color and depth, camera intrinsics, 77 | estimated camera-to-world transformation, a seeding mask, and a flag indicating whether this is a new submap. 78 | Args: 79 | gt_color: The ground truth color image as a numpy array with shape (H, W, 3). 80 | gt_depth: The ground truth depth map as a numpy array with shape (H, W). 81 | intrinsics: The camera intrinsics matrix as a numpy array with shape (3, 3). 82 | estimate_c2w: The estimated camera-to-world transformation matrix as a numpy array with shape (4, 4). 83 | seeding_mask: A binary mask indicating where to seed new Gaussians, with shape (H, W). 84 | is_new_submap: Flag indicating whether the seeding is for a new submap (True) or an existing submap (False). 85 | Returns: 86 | np.ndarray: An array of 3D points where new Gaussians will be initialized, with shape (N, 3) 87 | 88 | """ 89 | pts = create_point_cloud(gt_color, 1.005 * gt_depth, intrinsics, estimate_c2w) 90 | flat_gt_depth = gt_depth.flatten() 91 | non_zero_depth_mask = flat_gt_depth > 0. # need filter if zero depth pixels in gt_depth 92 | valid_ids = np.flatnonzero(seeding_mask) 93 | if is_new_submap: 94 | if self.new_submap_points_num < 0: 95 | uniform_ids = np.arange(pts.shape[0]) 96 | else: 97 | uniform_ids = np.random.choice(pts.shape[0], self.new_submap_points_num, replace=False) 98 | gradient_ids = sample_pixels_based_on_gradient(gt_color, self.new_submap_gradient_points_num) 99 | combined_ids = np.concatenate((uniform_ids, gradient_ids)) 100 | combined_ids = np.concatenate((combined_ids, valid_ids)) 101 | sample_ids = np.unique(combined_ids) 102 | else: 103 | if self.new_frame_sample_size < 0 or len(valid_ids) < self.new_frame_sample_size: 104 | sample_ids = valid_ids 105 | else: 106 | sample_ids = np.random.choice(valid_ids, size=self.new_frame_sample_size, replace=False) 107 | sample_ids = sample_ids[non_zero_depth_mask[sample_ids]] 108 | return pts[sample_ids, :].astype(np.float32) 109 | 110 | def optimize_submap(self, keyframes: list, gaussian_model: GaussianModel, iterations: int = 100) -> dict: 111 | """ 112 | Optimizes the submap by refining the parameters of the 3D Gaussian based on the observations 113 | from keyframes observing the submap. 114 | Args: 115 | keyframes: A list of tuples consisting of frame id and keyframe dictionary 116 | gaussian_model: An instance of the GaussianModel class representing the initial state 117 | of the Gaussian model to be optimized. 118 | iterations: The number of iterations to perform the optimization process. Defaults to 100. 119 | Returns: 120 | losses_dict: Dictionary with the optimization statistics 121 | """ 122 | 123 | iteration = 0 124 | losses_dict = {} 125 | 126 | current_frame_iters = self.current_view_opt_iterations * iterations 127 | distribution = compute_opt_views_distribution(len(keyframes), iterations, current_frame_iters) 128 | start_time = time.time() 129 | while iteration < iterations + 1: 130 | gaussian_model.optimizer.zero_grad(set_to_none=True) 131 | keyframe_id = np.random.choice(np.arange(len(keyframes)), p=distribution) 132 | 133 | frame_id, keyframe = keyframes[keyframe_id] 134 | render_pkg = render_gaussian_model(gaussian_model, keyframe["render_settings"]) 135 | 136 | image, depth = render_pkg["color"], render_pkg["depth"] 137 | gt_image = keyframe["color"] 138 | gt_depth = keyframe["depth"] 139 | 140 | mask = (gt_depth > 0) & (~torch.isnan(depth)).squeeze(0) 141 | color_loss = (1.0 - self.opt.lambda_dssim) * l1_loss( 142 | image[:, mask], gt_image[:, mask]) + self.opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 143 | 144 | depth_loss = l1_loss(depth[:, mask], gt_depth[mask]) 145 | reg_loss = isotropic_loss(gaussian_model.get_scaling()) 146 | total_loss = color_loss + depth_loss + reg_loss 147 | total_loss.backward() 148 | 149 | losses_dict[frame_id] = {"color_loss": color_loss.item(), 150 | "depth_loss": depth_loss.item(), 151 | "total_loss": total_loss.item()} 152 | 153 | with torch.no_grad(): 154 | 155 | if iteration == iterations // 2 or iteration == iterations: 156 | prune_mask = (gaussian_model.get_opacity() 157 | < self.pruning_thre).squeeze() 158 | gaussian_model.prune_points(prune_mask) 159 | 160 | # Optimizer step 161 | if iteration < iterations: 162 | gaussian_model.optimizer.step() 163 | gaussian_model.optimizer.zero_grad(set_to_none=True) 164 | 165 | iteration += 1 166 | optimization_time = time.time() - start_time 167 | losses_dict["optimization_time"] = optimization_time 168 | losses_dict["optimization_iter_time"] = optimization_time / iterations 169 | return losses_dict 170 | 171 | def grow_submap(self, gt_depth: np.ndarray, estimate_c2w: np.ndarray, gaussian_model: GaussianModel, 172 | pts: np.ndarray, filter_cloud: bool) -> int: 173 | """ 174 | Expands the submap by integrating new points from the current keyframe 175 | Args: 176 | gt_depth: The ground truth depth map for the current keyframe, as a 2D numpy array. 177 | estimate_c2w: The estimated camera-to-world transformation matrix for the current keyframe of shape (4x4) 178 | gaussian_model (GaussianModel): The Gaussian model representing the current state of the submap. 179 | pts: The current set of 3D points in the keyframe of shape (N, 3) 180 | filter_cloud: A boolean flag indicating whether to apply filtering to the point cloud to remove 181 | outliers or noise before integrating it into the map. 182 | Returns: 183 | int: The number of points added to the submap 184 | """ 185 | gaussian_points = gaussian_model.get_xyz() 186 | camera_frustum_corners = compute_camera_frustum_corners(gt_depth, estimate_c2w, self.dataset.intrinsics) 187 | reused_pts_ids = compute_frustum_point_ids( 188 | gaussian_points, np2torch(camera_frustum_corners), device="cuda") 189 | new_pts_ids = compute_new_points_ids(gaussian_points[reused_pts_ids], np2torch(pts[:, :3]).contiguous(), 190 | radius=self.new_points_radius, device="cuda") 191 | new_pts_ids = torch2np(new_pts_ids) 192 | if new_pts_ids.shape[0] > 0: 193 | cloud_to_add = np2ptcloud(pts[new_pts_ids, :3], pts[new_pts_ids, 3:] / 255.0) 194 | if filter_cloud: 195 | cloud_to_add, _ = cloud_to_add.remove_statistical_outlier(nb_neighbors=40, std_ratio=2.0) 196 | gaussian_model.add_points(cloud_to_add) 197 | gaussian_model._features_dc.requires_grad = False 198 | gaussian_model._features_rest.requires_grad = False 199 | print("Gaussian model size", gaussian_model.get_size()) 200 | return new_pts_ids.shape[0] 201 | 202 | def map(self, frame_id: int, estimate_c2w: np.ndarray, gaussian_model: GaussianModel, is_new_submap: bool) -> dict: 203 | """ Calls out the mapping process described in paragraph 3.2 204 | The process goes as follows: seed new gaussians -> add to the submap -> optimize the submap 205 | Args: 206 | frame_id: current keyframe id 207 | estimate_c2w (np.ndarray): The estimated camera-to-world transformation matrix of shape (4x4) 208 | gaussian_model (GaussianModel): The current Gaussian model of the submap 209 | is_new_submap (bool): A boolean flag indicating whether the current frame initiates a new submap 210 | Returns: 211 | opt_dict: Dictionary with statistics about the optimization process 212 | """ 213 | 214 | _, gt_color, gt_depth, _ = self.dataset[frame_id] 215 | estimate_w2c = np.linalg.inv(estimate_c2w) 216 | 217 | color_transform = torchvision.transforms.ToTensor() 218 | keyframe = { 219 | "color": color_transform(gt_color).cuda(), 220 | "depth": np2torch(gt_depth, device="cuda"), 221 | "render_settings": get_render_settings( 222 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, estimate_w2c)} 223 | 224 | seeding_mask = self.compute_seeding_mask(gaussian_model, keyframe, is_new_submap) 225 | pts = self.seed_new_gaussians( 226 | gt_color, gt_depth, self.dataset.intrinsics, estimate_c2w, seeding_mask, is_new_submap) 227 | 228 | filter_cloud = isinstance(self.dataset, (TUM_RGBD, ScanNet)) and not is_new_submap 229 | 230 | new_pts_num = self.grow_submap(gt_depth, estimate_c2w, gaussian_model, pts, filter_cloud) 231 | 232 | max_iterations = self.iterations 233 | if is_new_submap: 234 | max_iterations = self.new_submap_iterations 235 | start_time = time.time() 236 | opt_dict = self.optimize_submap([(frame_id, keyframe)] + self.keyframes, gaussian_model, max_iterations) 237 | optimization_time = time.time() - start_time 238 | print("Optimization time: ", optimization_time) 239 | 240 | self.keyframes.append((frame_id, keyframe)) 241 | 242 | # Visualise the mapping for the current frame 243 | with torch.no_grad(): 244 | render_pkg_vis = render_gaussian_model(gaussian_model, keyframe["render_settings"]) 245 | image_vis, depth_vis = render_pkg_vis["color"], render_pkg_vis["depth"] 246 | psnr_value = calc_psnr(image_vis, keyframe["color"]).mean().item() 247 | opt_dict["psnr_render"] = psnr_value 248 | print(f"PSNR this frame: {psnr_value}") 249 | self.logger.vis_mapping_iteration( 250 | frame_id, max_iterations, 251 | image_vis.clone().detach().permute(1, 2, 0), 252 | depth_vis.clone().detach().permute(1, 2, 0), 253 | keyframe["color"].permute(1, 2, 0), 254 | keyframe["depth"].unsqueeze(-1), 255 | seeding_mask=seeding_mask) 256 | 257 | # Log the mapping numbers for the current frame 258 | self.logger.log_mapping_iteration(frame_id, new_pts_num, gaussian_model.get_size(), 259 | optimization_time/max_iterations, opt_dict) 260 | return opt_dict 261 | -------------------------------------------------------------------------------- /src/entities/tracker.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Mapper class, which is responsible scene mapping: Paper Section 3.4 """ 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | from src.entities.arguments import OptimizationParams 11 | from src.entities.losses import l1_loss 12 | from src.entities.gaussian_model import GaussianModel 13 | from src.entities.logger import Logger 14 | from src.entities.datasets import BaseDataset 15 | from src.entities.visual_odometer import VisualOdometer 16 | from src.utils.gaussian_model_utils import build_rotation 17 | from src.utils.tracker_utils import (compute_camera_opt_params, 18 | extrapolate_poses, multiply_quaternions, 19 | transformation_to_quaternion) 20 | from src.utils.utils import (get_render_settings, np2torch, 21 | render_gaussian_model, torch2np) 22 | 23 | 24 | class Tracker(object): 25 | def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None: 26 | """ Initializes the Tracker with a given configuration, dataset, and logger. 27 | Args: 28 | config: Configuration dictionary specifying hyperparameters and operational settings. 29 | dataset: The dataset object providing access to the sequence of frames. 30 | logger: Logger object for logging the tracking process. 31 | """ 32 | self.dataset = dataset 33 | self.logger = logger 34 | self.config = config 35 | self.filter_alpha = self.config["filter_alpha"] 36 | self.filter_outlier_depth = self.config["filter_outlier_depth"] 37 | self.alpha_thre = self.config["alpha_thre"] 38 | self.soft_alpha = self.config["soft_alpha"] 39 | self.mask_invalid_depth_in_color_loss = self.config["mask_invalid_depth"] 40 | self.w_color_loss = self.config["w_color_loss"] 41 | self.transform = torchvision.transforms.ToTensor() 42 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters")) 43 | self.frame_depth_loss = [] 44 | self.frame_color_loss = [] 45 | self.odometry_type = self.config["odometry_type"] 46 | self.help_camera_initialization = self.config["help_camera_initialization"] 47 | self.init_err_ratio = self.config["init_err_ratio"] 48 | self.odometer = VisualOdometer(self.dataset.intrinsics, self.config["odometer_method"]) 49 | 50 | def compute_losses(self, gaussian_model: GaussianModel, render_settings: dict, 51 | opt_cam_rot: torch.Tensor, opt_cam_trans: torch.Tensor, 52 | gt_color: torch.Tensor, gt_depth: torch.Tensor, depth_mask: torch.Tensor) -> tuple: 53 | """ Computes the tracking losses with respect to ground truth color and depth. 54 | Args: 55 | gaussian_model: The current state of the Gaussian model of the scene. 56 | render_settings: Dictionary containing rendering settings such as image dimensions and camera intrinsics. 57 | opt_cam_rot: Optimizable tensor representing the camera's rotation. 58 | opt_cam_trans: Optimizable tensor representing the camera's translation. 59 | gt_color: Ground truth color image tensor. 60 | gt_depth: Ground truth depth image tensor. 61 | depth_mask: Binary mask indicating valid depth values in the ground truth depth image. 62 | Returns: 63 | A tuple containing losses and renders 64 | """ 65 | rel_transform = torch.eye(4).cuda().float() 66 | rel_transform[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None]))[0] 67 | rel_transform[:3, 3] = opt_cam_trans 68 | 69 | pts = gaussian_model.get_xyz() 70 | pts_ones = torch.ones(pts.shape[0], 1).cuda().float() 71 | pts4 = torch.cat((pts, pts_ones), dim=1) 72 | transformed_pts = (rel_transform @ pts4.T).T[:, :3] 73 | 74 | quat = F.normalize(opt_cam_rot[None]) 75 | _rotations = multiply_quaternions(gaussian_model.get_rotation(), quat.unsqueeze(0)).squeeze(0) 76 | 77 | render_dict = render_gaussian_model(gaussian_model, render_settings, 78 | override_means_3d=transformed_pts, override_rotations=_rotations) 79 | rendered_color, rendered_depth = render_dict["color"], render_dict["depth"] 80 | alpha_mask = render_dict["alpha"] > self.alpha_thre 81 | 82 | tracking_mask = torch.ones_like(alpha_mask).bool() 83 | tracking_mask &= depth_mask 84 | depth_err = torch.abs(rendered_depth - gt_depth) * depth_mask 85 | 86 | if self.filter_alpha: 87 | tracking_mask &= alpha_mask 88 | if self.filter_outlier_depth and torch.median(depth_err) > 0: 89 | tracking_mask &= depth_err < 50 * torch.median(depth_err) 90 | 91 | color_loss = l1_loss(rendered_color, gt_color, agg="none") 92 | depth_loss = l1_loss(rendered_depth, gt_depth, agg="none") * tracking_mask 93 | 94 | if self.soft_alpha: 95 | alpha = render_dict["alpha"] ** 3 96 | color_loss *= alpha 97 | depth_loss *= alpha 98 | if self.mask_invalid_depth_in_color_loss: 99 | color_loss *= tracking_mask 100 | else: 101 | color_loss *= tracking_mask 102 | 103 | color_loss = color_loss.sum() 104 | depth_loss = depth_loss.sum() 105 | 106 | return color_loss, depth_loss, rendered_color, rendered_depth, alpha_mask 107 | 108 | def track(self, frame_id: int, gaussian_model: GaussianModel, prev_c2ws: np.ndarray) -> np.ndarray: 109 | """ 110 | Updates the camera pose estimation for the current frame based on the provided image and depth, using either ground truth poses, 111 | constant speed assumption, or visual odometry. 112 | Args: 113 | frame_id: Index of the current frame being processed. 114 | gaussian_model: The current Gaussian model of the scene. 115 | prev_c2ws: Array containing the camera-to-world transformation matrices for the frames (0, i - 2, i - 1) 116 | Returns: 117 | The updated camera-to-world transformation matrix for the current frame. 118 | """ 119 | _, image, depth, gt_c2w = self.dataset[frame_id] 120 | 121 | if (self.help_camera_initialization or self.odometry_type == "odometer") and self.odometer.last_rgbd is None: 122 | _, last_image, last_depth, _ = self.dataset[frame_id - 1] 123 | self.odometer.update_last_rgbd(last_image, last_depth) 124 | 125 | if self.odometry_type == "gt": 126 | return gt_c2w 127 | elif self.odometry_type == "const_speed": 128 | init_c2w = extrapolate_poses(prev_c2ws[1:]) 129 | elif self.odometry_type == "odometer": 130 | odometer_rel = self.odometer.estimate_rel_pose(image, depth) 131 | init_c2w = prev_c2ws[-1] @ odometer_rel 132 | 133 | last_c2w = prev_c2ws[-1] 134 | last_w2c = np.linalg.inv(last_c2w) 135 | init_rel = init_c2w @ np.linalg.inv(last_c2w) 136 | init_rel_w2c = np.linalg.inv(init_rel) 137 | reference_w2c = last_w2c 138 | render_settings = get_render_settings( 139 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, reference_w2c) 140 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c) 141 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config) 142 | 143 | gt_color = self.transform(image).cuda() 144 | gt_depth = np2torch(depth, "cuda") 145 | depth_mask = gt_depth > 0.0 146 | gt_trans = np2torch(gt_c2w[:3, 3]) 147 | gt_quat = np2torch(R.from_matrix(gt_c2w[:3, :3]).as_quat(canonical=True)[[3, 0, 1, 2]]) 148 | num_iters = self.config["iterations"] 149 | current_min_loss = float("inf") 150 | 151 | print(f"\nTracking frame {frame_id}") 152 | # Initial loss check 153 | color_loss, depth_loss, _, _, _ = self.compute_losses(gaussian_model, render_settings, opt_cam_rot, 154 | opt_cam_trans, gt_color, gt_depth, depth_mask) 155 | if len(self.frame_color_loss) > 0 and ( 156 | color_loss.item() > self.init_err_ratio * np.median(self.frame_color_loss) 157 | or depth_loss.item() > self.init_err_ratio * np.median(self.frame_depth_loss) 158 | ): 159 | num_iters *= 2 160 | print(f"Higher initial loss, increasing num_iters to {num_iters}") 161 | if self.help_camera_initialization and self.odometry_type != "odometer": 162 | _, last_image, last_depth, _ = self.dataset[frame_id - 1] 163 | self.odometer.update_last_rgbd(last_image, last_depth) 164 | odometer_rel = self.odometer.estimate_rel_pose(image, depth) 165 | init_c2w = last_c2w @ odometer_rel 166 | init_rel = init_c2w @ np.linalg.inv(last_c2w) 167 | init_rel_w2c = np.linalg.inv(init_rel) 168 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c) 169 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config) 170 | render_settings = get_render_settings( 171 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, last_w2c) 172 | print(f"re-init with odometer for frame {frame_id}") 173 | 174 | for iter in range(num_iters): 175 | color_loss, depth_loss, _, _, _, = self.compute_losses( 176 | gaussian_model, render_settings, opt_cam_rot, opt_cam_trans, gt_color, gt_depth, depth_mask) 177 | 178 | total_loss = (self.w_color_loss * color_loss + (1 - self.w_color_loss) * depth_loss) 179 | total_loss.backward() 180 | gaussian_model.optimizer.step() 181 | gaussian_model.optimizer.zero_grad(set_to_none=True) 182 | 183 | with torch.no_grad(): 184 | if total_loss.item() < current_min_loss: 185 | current_min_loss = total_loss.item() 186 | best_w2c = torch.eye(4) 187 | best_w2c[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None].clone().detach().cpu()))[0] 188 | best_w2c[:3, 3] = opt_cam_trans.clone().detach().cpu() 189 | 190 | cur_quat, cur_trans = F.normalize(opt_cam_rot[None].clone().detach()), opt_cam_trans.clone().detach() 191 | cur_rel_w2c = torch.eye(4) 192 | cur_rel_w2c[:3, :3] = build_rotation(cur_quat)[0] 193 | cur_rel_w2c[:3, 3] = cur_trans 194 | if iter == num_iters - 1: 195 | cur_w2c = torch.from_numpy(reference_w2c) @ best_w2c 196 | else: 197 | cur_w2c = torch.from_numpy(reference_w2c) @ cur_rel_w2c 198 | cur_c2w = torch.inverse(cur_w2c) 199 | cur_cam = transformation_to_quaternion(cur_c2w) 200 | if (gt_quat * cur_cam[:4]).sum() < 0: # for logging purpose 201 | gt_quat *= -1 202 | if iter == num_iters - 1: 203 | self.frame_color_loss.append(color_loss.item()) 204 | self.frame_depth_loss.append(depth_loss.item()) 205 | self.logger.log_tracking_iteration( 206 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters, 207 | wandb_output=True, print_output=True) 208 | elif iter % 20 == 0: 209 | self.logger.log_tracking_iteration( 210 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters, 211 | wandb_output=False, print_output=True) 212 | 213 | final_c2w = torch.inverse(torch.from_numpy(reference_w2c) @ best_w2c) 214 | final_c2w[-1, :] = torch.tensor([0., 0., 0., 1.], dtype=final_c2w.dtype, device=final_c2w.device) 215 | return torch2np(final_c2w) 216 | -------------------------------------------------------------------------------- /src/entities/visual_odometer.py: -------------------------------------------------------------------------------- 1 | """ This module includes the Odometer class, which is allows for fast pose estimation from RGBD neighbor frames """ 2 | import numpy as np 3 | import open3d as o3d 4 | import open3d.core as o3c 5 | 6 | 7 | class VisualOdometer(object): 8 | 9 | def __init__(self, intrinsics: np.ndarray, method_name="hybrid", device="cuda"): 10 | """ Initializes the visual odometry system with specified intrinsics, method, and device. 11 | Args: 12 | intrinsics: Camera intrinsic parameters. 13 | method_name: The name of the odometry computation method to use ('hybrid' or 'point_to_plane'). 14 | device: The computation device ('cuda' or 'cpu'). 15 | """ 16 | device = "CUDA:0" if device == "cuda" else "CPU:0" 17 | self.device = o3c.Device(device) 18 | self.intrinsics = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64) 19 | self.last_abs_pose = None 20 | self.last_frame = None 21 | self.criteria_list = [ 22 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500), 23 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500), 24 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500)] 25 | self.setup_method(method_name) 26 | self.max_depth = 10.0 27 | self.depth_scale = 1.0 28 | self.last_rgbd = None 29 | 30 | def setup_method(self, method_name: str) -> None: 31 | """ Sets up the odometry computation method based on the provided method name. 32 | Args: 33 | method_name: The name of the odometry method to use ('hybrid' or 'point_to_plane'). 34 | """ 35 | if method_name == "hybrid": 36 | self.method = o3d.t.pipelines.odometry.Method.Hybrid 37 | elif method_name == "point_to_plane": 38 | self.method = o3d.t.pipelines.odometry.Method.PointToPlane 39 | else: 40 | raise ValueError("Odometry method does not exist!") 41 | 42 | def update_last_rgbd(self, image: np.ndarray, depth: np.ndarray) -> None: 43 | """ Updates the last RGB-D frame stored in the system with a new RGB-D frame constructed from provided image and depth. 44 | Args: 45 | image: The new RGB image as a numpy ndarray. 46 | depth: The new depth image as a numpy ndarray. 47 | """ 48 | self.last_rgbd = o3d.t.geometry.RGBDImage( 49 | o3d.t.geometry.Image(np.ascontiguousarray( 50 | image).astype(np.float32)).to(self.device), 51 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device)) 52 | 53 | def estimate_rel_pose(self, image: np.ndarray, depth: np.ndarray, init_transform=np.eye(4)): 54 | """ Estimates the relative pose of the current frame with respect to the last frame using RGB-D odometry. 55 | Args: 56 | image: The current RGB image as a numpy ndarray. 57 | depth: The current depth image as a numpy ndarray. 58 | init_transform: An initial transformation guess as a numpy ndarray. Defaults to the identity matrix. 59 | Returns: 60 | The relative transformation matrix as a numpy ndarray. 61 | """ 62 | rgbd = o3d.t.geometry.RGBDImage( 63 | o3d.t.geometry.Image(np.ascontiguousarray(image).astype(np.float32)).to(self.device), 64 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device)) 65 | rel_transform = o3d.t.pipelines.odometry.rgbd_odometry_multi_scale( 66 | self.last_rgbd, rgbd, self.intrinsics, o3c.Tensor(init_transform), 67 | self.depth_scale, self.max_depth, self.criteria_list, self.method) 68 | self.last_rgbd = rgbd.clone() 69 | 70 | # Adjust for the coordinate system difference 71 | rel_transform = rel_transform.transformation.cpu().numpy() 72 | rel_transform[0, [1, 2, 3]] *= -1 73 | rel_transform[1, [0, 2, 3]] *= -1 74 | rel_transform[2, [0, 1, 3]] *= -1 75 | 76 | return rel_transform 77 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/evaluate_merged_map.py: -------------------------------------------------------------------------------- 1 | """ This module is responsible for merging submaps. """ 2 | from argparse import ArgumentParser 3 | 4 | import faiss 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | from torch.utils.data import Dataset 9 | from tqdm import tqdm 10 | 11 | from src.entities.arguments import OptimizationParams 12 | from src.entities.gaussian_model import GaussianModel 13 | from src.entities.losses import isotropic_loss, l1_loss, ssim 14 | from src.utils.utils import (batch_search_faiss, get_render_settings, 15 | np2ptcloud, render_gaussian_model, torch2np) 16 | 17 | 18 | class RenderFrames(Dataset): 19 | """A dataset class for loading keyframes along with their estimated camera poses and render settings.""" 20 | def __init__(self, dataset, render_poses: np.ndarray, height: int, width: int, fx: float, fy: float): 21 | self.dataset = dataset 22 | self.render_poses = render_poses 23 | self.height = height 24 | self.width = width 25 | self.fx = fx 26 | self.fy = fy 27 | self.device = "cuda" 28 | self.stride = 1 29 | if len(dataset) > 1000: 30 | self.stride = len(dataset) // 1000 31 | 32 | def __len__(self) -> int: 33 | return len(self.dataset) // self.stride 34 | 35 | def __getitem__(self, idx): 36 | idx = idx * self.stride 37 | color = (torch.from_numpy( 38 | self.dataset[idx][1]) / 255.0).float().to(self.device) 39 | depth = torch.from_numpy(self.dataset[idx][2]).float().to(self.device) 40 | estimate_c2w = self.render_poses[idx] 41 | estimate_w2c = np.linalg.inv(estimate_c2w) 42 | frame = { 43 | "frame_id": idx, 44 | "color": color, 45 | "depth": depth, 46 | "render_settings": get_render_settings( 47 | self.width, self.height, self.dataset.intrinsics, estimate_w2c) 48 | } 49 | return frame 50 | 51 | 52 | def merge_submaps(submaps_paths: list, radius: float = 0.0001, device: str = "cuda") -> o3d.geometry.PointCloud: 53 | """ Merge submaps into a single point cloud, which is then used for global map refinement. 54 | Args: 55 | segments_paths (list): Folder path of the submaps. 56 | radius (float, optional): Nearest neighbor distance threshold for adding a point. Defaults to 0.0001. 57 | device (str, optional): Defaults to "cuda". 58 | 59 | Returns: 60 | o3d.geometry.PointCloud: merged point cloud 61 | """ 62 | pts_index = faiss.IndexFlatL2(3) 63 | if device == "cuda": 64 | pts_index = faiss.index_cpu_to_gpu( 65 | faiss.StandardGpuResources(), 66 | 0, 67 | faiss.IndexIVFFlat(faiss.IndexFlatL2(3), 3, 500, faiss.METRIC_L2)) 68 | pts_index.nprobe = 5 69 | merged_pts = [] 70 | print("Merging segments") 71 | for submap_path in tqdm(submaps_paths): 72 | gaussian_params = torch.load(submap_path)["gaussian_params"] 73 | current_pts = gaussian_params["xyz"].to(device).float() 74 | pts_index.train(current_pts) 75 | distances, _ = batch_search_faiss(pts_index, current_pts, 8) 76 | neighbor_num = (distances < radius).sum(axis=1).int() 77 | ids_to_include = torch.where(neighbor_num == 0)[0] 78 | pts_index.add(current_pts[ids_to_include]) 79 | merged_pts.append(current_pts[ids_to_include]) 80 | pts = torch2np(torch.vstack(merged_pts)) 81 | pt_cloud = np2ptcloud(pts, np.zeros_like(pts)) 82 | 83 | # Downsampling if the total number of points is too large 84 | if len(pt_cloud.points) > 1_000_000: 85 | voxel_size = 0.02 86 | pt_cloud = pt_cloud.voxel_down_sample(voxel_size) 87 | print(f"Downsampled point cloud to {len(pt_cloud.points)} points") 88 | filtered_pt_cloud, _ = pt_cloud.remove_statistical_outlier(nb_neighbors=40, std_ratio=3.0) 89 | del pts_index 90 | return filtered_pt_cloud 91 | 92 | 93 | def refine_global_map(pt_cloud: o3d.geometry.PointCloud, training_frames: list, max_iterations: int) -> GaussianModel: 94 | """Refines a global map based on the merged point cloud and training keyframes frames. 95 | Args: 96 | pt_cloud (o3d.geometry.PointCloud): The merged point cloud used for refinement. 97 | training_frames (list): A list of training frames for map refinement. 98 | max_iterations (int): The maximum number of iterations to perform for refinement. 99 | Returns: 100 | GaussianModel: The refined global map as a Gaussian model. 101 | """ 102 | opt_params = OptimizationParams(ArgumentParser(description="Training script parameters")) 103 | 104 | gaussian_model = GaussianModel(3) 105 | gaussian_model.active_sh_degree = 3 106 | gaussian_model.training_setup(opt_params) 107 | gaussian_model.add_points(pt_cloud) 108 | 109 | iteration = 0 110 | for iteration in tqdm(range(max_iterations), desc="Refinement"): 111 | training_frame = next(training_frames) 112 | gt_color, gt_depth, render_settings = ( 113 | training_frame["color"].squeeze(0), 114 | training_frame["depth"].squeeze(0), 115 | training_frame["render_settings"]) 116 | 117 | render_dict = render_gaussian_model(gaussian_model, render_settings) 118 | rendered_color, rendered_depth = (render_dict["color"].permute(1, 2, 0), render_dict["depth"]) 119 | 120 | reg_loss = isotropic_loss(gaussian_model.get_scaling()) 121 | depth_mask = (gt_depth > 0) 122 | color_loss = (1.0 - opt_params.lambda_dssim) * l1_loss( 123 | rendered_color[depth_mask, :], gt_color[depth_mask, :] 124 | ) + opt_params.lambda_dssim * (1.0 - ssim(rendered_color, gt_color)) 125 | depth_loss = l1_loss( 126 | rendered_depth[:, depth_mask], gt_depth[depth_mask]) 127 | 128 | total_loss = color_loss + depth_loss + reg_loss 129 | total_loss.backward() 130 | 131 | with torch.no_grad(): 132 | if iteration % 500 == 0: 133 | prune_mask = (gaussian_model.get_opacity() < 0.005).squeeze() 134 | gaussian_model.prune_points(prune_mask) 135 | 136 | # Optimizer step 137 | gaussian_model.optimizer.step() 138 | gaussian_model.optimizer.zero_grad(set_to_none=True) 139 | iteration += 1 140 | 141 | return gaussian_model 142 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_reconstruction.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | import trimesh 9 | from evaluate_3d_reconstruction import run_evaluation 10 | from tqdm import tqdm 11 | 12 | 13 | def normalize(x): 14 | return x / np.linalg.norm(x) 15 | 16 | 17 | def get_align_transformation(rec_meshfile, gt_meshfile): 18 | """ 19 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh. 20 | """ 21 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 22 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 23 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices) 24 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices) 25 | trans_init = np.eye(4) 26 | threshold = 0.1 27 | reg_p2p = o3d.pipelines.registration.registration_icp( 28 | o3d_rec_pc, 29 | o3d_gt_pc, 30 | threshold, 31 | trans_init, 32 | o3d.pipelines.registration.TransformationEstimationPointToPoint(), 33 | ) 34 | transformation = reg_p2p.transformation 35 | return transformation 36 | 37 | 38 | def check_proj(points, W, H, fx, fy, cx, cy, c2w): 39 | """ 40 | Check if points can be projected into the camera view. 41 | 42 | Returns: 43 | bool: True if there are points can be projected 44 | 45 | """ 46 | c2w = c2w.copy() 47 | c2w[:3, 1] *= -1.0 48 | c2w[:3, 2] *= -1.0 49 | points = torch.from_numpy(points).cuda().clone() 50 | w2c = np.linalg.inv(c2w) 51 | w2c = torch.from_numpy(w2c).cuda().float() 52 | K = torch.from_numpy( 53 | np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]).reshape(3, 3) 54 | ).cuda() 55 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda() 56 | homo_points = ( 57 | torch.cat([points, ones], dim=1).reshape(-1, 4, 1).cuda().float() 58 | ) # (N, 4) 59 | cam_cord_homo = w2c @ homo_points # (N, 4, 1)=(4,4)*(N, 4, 1) 60 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1) 61 | cam_cord[:, 0] *= -1 62 | uv = K.float() @ cam_cord.float() 63 | z = uv[:, -1:] + 1e-5 64 | uv = uv[:, :2] / z 65 | uv = uv.float().squeeze(-1).cpu().numpy() 66 | edge = 0 67 | mask = ( 68 | (0 <= -z[:, 0, 0].cpu().numpy()) 69 | & (uv[:, 0] < W - edge) 70 | & (uv[:, 0] > edge) 71 | & (uv[:, 1] < H - edge) 72 | & (uv[:, 1] > edge) 73 | ) 74 | return mask.sum() > 0 75 | 76 | 77 | def get_cam_position(gt_meshfile): 78 | mesh_gt = trimesh.load(gt_meshfile) 79 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt) 80 | extents[2] *= 0.7 81 | extents[1] *= 0.7 82 | extents[0] *= 0.3 83 | transform = np.linalg.inv(to_origin) 84 | transform[2, 3] += 0.4 85 | return extents, transform 86 | 87 | 88 | def viewmatrix(z, up, pos): 89 | vec2 = normalize(z) 90 | vec1_avg = up 91 | vec0 = normalize(np.cross(vec1_avg, vec2)) 92 | vec1 = normalize(np.cross(vec2, vec0)) 93 | m = np.stack([vec0, vec1, vec2, pos], 1) 94 | return m 95 | 96 | 97 | def calc_2d_metric( 98 | rec_meshfile, gt_meshfile, unseen_gt_pointcloud_file, align=True, n_imgs=1000 99 | ): 100 | """ 101 | 2D reconstruction metric, depth L1 loss. 102 | 103 | """ 104 | H = 500 105 | W = 500 106 | focal = 300 107 | fx = focal 108 | fy = focal 109 | cx = H / 2.0 - 0.5 110 | cy = W / 2.0 - 0.5 111 | 112 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 113 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 114 | pc_unseen = np.load(unseen_gt_pointcloud_file) 115 | if align: 116 | transformation = get_align_transformation(rec_meshfile, gt_meshfile) 117 | rec_mesh = rec_mesh.transform(transformation) 118 | 119 | # get vacant area inside the room 120 | extents, transform = get_cam_position(gt_meshfile) 121 | 122 | vis = o3d.visualization.Visualizer() 123 | vis.create_window(width=W, height=H, visible=False) 124 | vis.get_render_option().mesh_show_back_face = True 125 | errors = [] 126 | for i in tqdm(range(n_imgs)): 127 | while True: 128 | # sample view, and check if unseen region is not inside the camera view 129 | # if inside, then needs to resample 130 | up = [0, 0, -1] 131 | origin = trimesh.sample.volume_rectangular( 132 | extents, 1, transform=transform) 133 | origin = origin.reshape(-1) 134 | tx = round(random.uniform(-10000, +10000), 2) 135 | ty = round(random.uniform(-10000, +10000), 2) 136 | tz = round(random.uniform(-10000, +10000), 2) 137 | # will be normalized, so sample from range [0.0,1.0] 138 | target = [tx, ty, tz] 139 | target = np.array(target) - np.array(origin) 140 | c2w = viewmatrix(target, up, origin) 141 | tmp = np.eye(4) 142 | tmp[:3, :] = c2w # sample translations 143 | c2w = tmp 144 | # if unseen points are projected into current view (c2w) 145 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w) 146 | if ~seen: 147 | break 148 | 149 | param = o3d.camera.PinholeCameraParameters() 150 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array 151 | 152 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic( 153 | W, H, fx, fy, cx, cy) 154 | 155 | ctr = vis.get_view_control() 156 | ctr.set_constant_z_far(20) 157 | ctr.convert_from_pinhole_camera_parameters(param, True) 158 | 159 | vis.add_geometry( 160 | gt_mesh, 161 | reset_bounding_box=True, 162 | ) 163 | ctr.convert_from_pinhole_camera_parameters(param, True) 164 | vis.poll_events() 165 | vis.update_renderer() 166 | gt_depth = vis.capture_depth_float_buffer(True) 167 | gt_depth = np.asarray(gt_depth) 168 | vis.remove_geometry( 169 | gt_mesh, 170 | reset_bounding_box=True, 171 | ) 172 | 173 | vis.add_geometry( 174 | rec_mesh, 175 | reset_bounding_box=True, 176 | ) 177 | ctr.convert_from_pinhole_camera_parameters(param, True) 178 | vis.poll_events() 179 | vis.update_renderer() 180 | ours_depth = vis.capture_depth_float_buffer(True) 181 | ours_depth = np.asarray(ours_depth) 182 | vis.remove_geometry( 183 | rec_mesh, 184 | reset_bounding_box=True, 185 | ) 186 | 187 | # filter missing surfaces where depth is 0 188 | if (ours_depth > 0).sum() > 0: 189 | errors += [ 190 | np.abs(gt_depth[ours_depth > 0] - 191 | ours_depth[ours_depth > 0]).mean() 192 | ] 193 | else: 194 | continue 195 | 196 | errors = np.array(errors) 197 | return {"depth l1": errors.mean() * 100} 198 | 199 | 200 | def clean_mesh(mesh): 201 | mesh_tri = trimesh.Trimesh( 202 | vertices=np.asarray(mesh.vertices), 203 | faces=np.asarray(mesh.triangles), 204 | vertex_colors=np.asarray(mesh.vertex_colors), 205 | ) 206 | components = trimesh.graph.connected_components( 207 | edges=mesh_tri.edges_sorted) 208 | 209 | min_len = 200 210 | components_to_keep = [c for c in components if len(c) >= min_len] 211 | 212 | new_vertices = [] 213 | new_faces = [] 214 | new_colors = [] 215 | vertex_count = 0 216 | for component in components_to_keep: 217 | vertices = mesh_tri.vertices[component] 218 | colors = mesh_tri.visual.vertex_colors[component] 219 | 220 | # Create a mapping from old vertex indices to new vertex indices 221 | index_mapping = { 222 | old_idx: vertex_count + new_idx for new_idx, old_idx in enumerate(component) 223 | } 224 | vertex_count += len(vertices) 225 | 226 | # Select faces that are part of the current connected component and update vertex indices 227 | faces_in_component = mesh_tri.faces[ 228 | np.any(np.isin(mesh_tri.faces, component), axis=1) 229 | ] 230 | reindexed_faces = np.vectorize(index_mapping.get)(faces_in_component) 231 | 232 | new_vertices.extend(vertices) 233 | new_faces.extend(reindexed_faces) 234 | new_colors.extend(colors) 235 | 236 | cleaned_mesh_tri = trimesh.Trimesh(vertices=new_vertices, faces=new_faces) 237 | cleaned_mesh_tri.visual.vertex_colors = np.array(new_colors) 238 | 239 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.nondegenerate_faces()) 240 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.unique_faces()) 241 | print( 242 | f"Mesh cleaning (before/after), vertices: {len(mesh_tri.vertices)}/{len(cleaned_mesh_tri.vertices)}, faces: {len(mesh_tri.faces)}/{len(cleaned_mesh_tri.faces)}") 243 | 244 | cleaned_mesh = o3d.geometry.TriangleMesh( 245 | o3d.utility.Vector3dVector(cleaned_mesh_tri.vertices), 246 | o3d.utility.Vector3iVector(cleaned_mesh_tri.faces), 247 | ) 248 | vertex_colors = np.asarray(cleaned_mesh_tri.visual.vertex_colors)[ 249 | :, :3] / 255.0 250 | cleaned_mesh.vertex_colors = o3d.utility.Vector3dVector( 251 | vertex_colors.astype(np.float64) 252 | ) 253 | 254 | return cleaned_mesh 255 | 256 | 257 | def evaluate_reconstruction( 258 | mesh_path: Path, 259 | gt_mesh_path: Path, 260 | unseen_pc_path: Path, 261 | output_path: Path, 262 | to_clean=True, 263 | ): 264 | if to_clean: 265 | mesh = o3d.io.read_triangle_mesh(str(mesh_path)) 266 | print(mesh) 267 | cleaned_mesh = clean_mesh(mesh) 268 | cleaned_mesh_path = output_path / "mesh" / "cleaned_mesh.ply" 269 | o3d.io.write_triangle_mesh(str(cleaned_mesh_path), cleaned_mesh) 270 | mesh_path = cleaned_mesh_path 271 | 272 | result_3d = run_evaluation( 273 | str(mesh_path.parts[-1]), 274 | str(mesh_path.parent), 275 | str(gt_mesh_path).split("/")[-1].split(".")[0], 276 | distance_thresh=0.01, 277 | full_path_to_gt_ply=gt_mesh_path, 278 | icp_align=True, 279 | ) 280 | 281 | try: 282 | result_2d = calc_2d_metric(str(mesh_path), str(gt_mesh_path), str(unseen_pc_path), align=True, n_imgs=1000) 283 | except Exception as e: 284 | print(e) 285 | result_2d = {"depth l1": None} 286 | 287 | result = {**result_3d, **result_2d} 288 | with open(str(output_path / "reconstruction_metrics.json"), "w") as f: 289 | json.dump(result, f) 290 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_trajectory.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | class NumpyFloatValuesEncoder(json.JSONEncoder): 9 | def default(self, obj): 10 | if isinstance(obj, np.float32): 11 | return float(obj) 12 | return JSONEncoder.default(self, obj) 13 | 14 | 15 | def align(model, data): 16 | """Align two trajectories using the method of Horn (closed-form). 17 | 18 | Input: 19 | model -- first trajectory (3xn) 20 | data -- second trajectory (3xn) 21 | 22 | Output: 23 | rot -- rotation matrix (3x3) 24 | trans -- translation vector (3x1) 25 | trans_error -- translational error per point (1xn) 26 | 27 | """ 28 | np.set_printoptions(precision=3, suppress=True) 29 | model_zerocentered = model - model.mean(1) 30 | data_zerocentered = data - data.mean(1) 31 | 32 | W = np.zeros((3, 3)) 33 | for column in range(model.shape[1]): 34 | W += np.outer(model_zerocentered[:, 35 | column], data_zerocentered[:, column]) 36 | U, d, Vh = np.linalg.linalg.svd(W.transpose()) 37 | S = np.matrix(np.identity(3)) 38 | if (np.linalg.det(U) * np.linalg.det(Vh) < 0): 39 | S[2, 2] = -1 40 | rot = U * S * Vh 41 | trans = data.mean(1) - rot * model.mean(1) 42 | 43 | model_aligned = rot * model + trans 44 | alignment_error = model_aligned - data 45 | 46 | trans_error = np.sqrt( 47 | np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0] 48 | 49 | return rot, trans, trans_error 50 | 51 | 52 | def align_trajectories(t_pred: np.ndarray, t_gt: np.ndarray): 53 | """ 54 | Args: 55 | t_pred: (n, 3) translations 56 | t_gt: (n, 3) translations 57 | Returns: 58 | t_align: (n, 3) aligned translations 59 | """ 60 | t_align = np.matrix(t_pred).transpose() 61 | R, t, _ = align(t_align, np.matrix(t_gt).transpose()) 62 | t_align = R * t_align + t 63 | t_align = np.asarray(t_align).T 64 | return t_align 65 | 66 | 67 | def pose_error(t_pred: np.ndarray, t_gt: np.ndarray, align=False): 68 | """ 69 | Args: 70 | t_pred: (n, 3) translations 71 | t_gt: (n, 3) translations 72 | Returns: 73 | dict: error dict 74 | """ 75 | n = t_pred.shape[0] 76 | trans_error = np.linalg.norm(t_pred - t_gt, axis=1) 77 | return { 78 | "compared_pose_pairs": n, 79 | "rmse": np.sqrt(np.dot(trans_error, trans_error) / n), 80 | "mean": np.mean(trans_error), 81 | "median": np.median(trans_error), 82 | "std": np.std(trans_error), 83 | "min": np.min(trans_error), 84 | "max": np.max(trans_error) 85 | } 86 | 87 | 88 | def plot_2d(pts, ax=None, color="green", label="None", title="3D Trajectory in 2D"): 89 | if ax is None: 90 | _, ax = plt.subplots() 91 | ax.scatter(pts[:, 0], pts[:, 1], color=color, label=label, s=0.7) 92 | ax.set_xlabel('X') 93 | ax.set_ylabel('Y') 94 | ax.set_title(title) 95 | return ax 96 | 97 | 98 | def evaluate_trajectory(estimated_poses: np.ndarray, gt_poses: np.ndarray, output_path: Path): 99 | output_path.mkdir(exist_ok=True, parents=True) 100 | # Truncate the ground truth trajectory if needed 101 | if gt_poses.shape[0] > estimated_poses.shape[0]: 102 | gt_poses = gt_poses[:estimated_poses.shape[0]] 103 | valid = ~np.any(np.isnan(gt_poses) | 104 | np.isinf(gt_poses), axis=(1, 2)) 105 | gt_poses = gt_poses[valid] 106 | estimated_poses = estimated_poses[valid] 107 | 108 | gt_t = gt_poses[:, :3, 3] 109 | estimated_t = estimated_poses[:, :3, 3] 110 | estimated_t_aligned = align_trajectories(estimated_t, gt_t) 111 | ate = pose_error(estimated_t, gt_t) 112 | ate_aligned = pose_error(estimated_t_aligned, gt_t) 113 | 114 | with open(str(output_path / "ate.json"), "w") as f: 115 | f.write(json.dumps(ate, cls=NumpyFloatValuesEncoder)) 116 | 117 | with open(str(output_path / "ate_aligned.json"), "w") as f: 118 | f.write(json.dumps(ate_aligned, cls=NumpyFloatValuesEncoder)) 119 | 120 | ate_rmse, ate_rmse_aligned = ate["rmse"], ate_aligned["rmse"] 121 | ax = plot_2d( 122 | estimated_t, label=f"ate-rmse: {round(ate_rmse * 100, 2)} cm", color="orange") 123 | ax = plot_2d(estimated_t_aligned, ax, 124 | label=f"ate-rsme (aligned): {round(ate_rmse_aligned * 100, 2)} cm", color="lightskyblue") 125 | ax = plot_2d(gt_t, ax, label="GT", color="green") 126 | ax.legend() 127 | plt.savefig(str(output_path / "eval_trajectory.png"), dpi=300) 128 | plt.close() 129 | print( 130 | f"ATE-RMSE: {ate_rmse * 100:.2f} cm, ATE-RMSE (aligned): {ate_rmse_aligned * 100:.2f} cm") 131 | -------------------------------------------------------------------------------- /src/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | """ This module is responsible for evaluating rendering, trajectory and reconstruction metrics""" 2 | import traceback 3 | from argparse import ArgumentParser 4 | from copy import deepcopy 5 | from itertools import cycle 6 | from pathlib import Path 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import open3d as o3d 11 | import torch 12 | import torchvision 13 | from pytorch_msssim import ms_ssim 14 | from scipy.ndimage import median_filter 15 | from torch.utils.data import DataLoader 16 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 17 | from torchvision.utils import save_image 18 | from tqdm import tqdm 19 | 20 | from src.entities.arguments import OptimizationParams 21 | from src.entities.datasets import get_dataset 22 | from src.entities.gaussian_model import GaussianModel 23 | from src.evaluation.evaluate_merged_map import (RenderFrames, merge_submaps, 24 | refine_global_map) 25 | from src.evaluation.evaluate_reconstruction import evaluate_reconstruction 26 | from src.evaluation.evaluate_trajectory import evaluate_trajectory 27 | from src.utils.io_utils import load_config, save_dict_to_json 28 | from src.utils.mapper_utils import calc_psnr 29 | from src.utils.utils import (get_render_settings, np2torch, 30 | render_gaussian_model, setup_seed, torch2np) 31 | 32 | 33 | def filter_depth_outliers(depth_map, kernel_size=3, threshold=1.0): 34 | median_filtered = median_filter(depth_map, size=kernel_size) 35 | abs_diff = np.abs(depth_map - median_filtered) 36 | outlier_mask = abs_diff > threshold 37 | depth_map_filtered = np.where(outlier_mask, median_filtered, depth_map) 38 | return depth_map_filtered 39 | 40 | 41 | class Evaluator(object): 42 | 43 | def __init__(self, checkpoint_path, config_path, config=None, save_render=False) -> None: 44 | if config is None: 45 | self.config = load_config(config_path) 46 | else: 47 | self.config = config 48 | setup_seed(self.config["seed"]) 49 | 50 | self.checkpoint_path = Path(checkpoint_path) 51 | self.device = "cuda" 52 | self.dataset = get_dataset(self.config["dataset_name"])({**self.config["data"], **self.config["cam"]}) 53 | self.scene_name = self.config["data"]["scene_name"] 54 | self.dataset_name = self.config["dataset_name"] 55 | self.gt_poses = np.array(self.dataset.poses) 56 | self.fx, self.fy = self.dataset.intrinsics[0, 0], self.dataset.intrinsics[1, 1] 57 | self.cx, self.cy = self.dataset.intrinsics[0, 58 | 2], self.dataset.intrinsics[1, 2] 59 | self.width, self.height = self.dataset.width, self.dataset.height 60 | self.save_render = save_render 61 | if self.save_render: 62 | self.render_path = self.checkpoint_path / "rendered_imgs" 63 | self.render_path.mkdir(exist_ok=True, parents=True) 64 | 65 | self.estimated_c2w = torch2np(torch.load(self.checkpoint_path / "estimated_c2w.ckpt", map_location=self.device)) 66 | self.submaps_paths = sorted(list((self.checkpoint_path / "submaps").glob('*'))) 67 | 68 | def run_trajectory_eval(self): 69 | """ Evaluates the estimated trajectory """ 70 | print("Running trajectory evaluation...") 71 | evaluate_trajectory(self.estimated_c2w, self.gt_poses, self.checkpoint_path) 72 | 73 | def run_rendering_eval(self): 74 | """ Renderes the submaps and evaluates the PSNR, LPIPS, SSIM and depth L1 metrics.""" 75 | print("Running rendering evaluation...") 76 | psnr, lpips, ssim, depth_l1 = [], [], [], [] 77 | color_transform = torchvision.transforms.ToTensor() 78 | lpips_model = LearnedPerceptualImagePatchSimilarity( 79 | net_type='alex', normalize=True).to(self.device) 80 | opt_settings = OptimizationParams(ArgumentParser( 81 | description="Training script parameters")) 82 | 83 | submaps_paths = sorted( 84 | list((self.checkpoint_path / "submaps").glob('*.ckpt'))) 85 | for submap_path in tqdm(submaps_paths): 86 | submap = torch.load(submap_path, map_location=self.device) 87 | gaussian_model = GaussianModel() 88 | gaussian_model.training_setup(opt_settings) 89 | gaussian_model.restore_from_params( 90 | submap["gaussian_params"], opt_settings) 91 | 92 | for keyframe_id in submap["submap_keyframes"]: 93 | 94 | _, gt_color, gt_depth, _ = self.dataset[keyframe_id] 95 | gt_color = color_transform(gt_color).to(self.device) 96 | gt_depth = np2torch(gt_depth).to(self.device) 97 | 98 | estimate_c2w = self.estimated_c2w[keyframe_id] 99 | estimate_w2c = np.linalg.inv(estimate_c2w) 100 | render_dict = render_gaussian_model( 101 | gaussian_model, get_render_settings(self.width, self.height, self.dataset.intrinsics, estimate_w2c)) 102 | rendered_color, rendered_depth = render_dict["color"].detach( 103 | ), render_dict["depth"][0].detach() 104 | rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0) 105 | if self.save_render: 106 | torchvision.utils.save_image( 107 | rendered_color, self.render_path / f"{keyframe_id:05d}.png") 108 | 109 | mse_loss = torch.nn.functional.mse_loss( 110 | rendered_color, gt_color) 111 | psnr_value = (-10. * torch.log10(mse_loss)).item() 112 | lpips_value = lpips_model( 113 | rendered_color[None], gt_color[None]).item() 114 | ssim_value = ms_ssim( 115 | rendered_color[None], gt_color[None], data_range=1.0, size_average=True).item() 116 | depth_l1_value = torch.abs( 117 | (rendered_depth - gt_depth)).mean().item() 118 | 119 | psnr.append(psnr_value) 120 | lpips.append(lpips_value) 121 | ssim.append(ssim_value) 122 | depth_l1.append(depth_l1_value) 123 | 124 | num_frames = len(psnr) 125 | metrics = { 126 | "psnr": sum(psnr) / num_frames, 127 | "lpips": sum(lpips) / num_frames, 128 | "ssim": sum(ssim) / num_frames, 129 | "depth_l1_train_view": sum(depth_l1) / num_frames, 130 | "num_renders": num_frames 131 | } 132 | save_dict_to_json(metrics, "rendering_metrics.json", 133 | directory=self.checkpoint_path) 134 | 135 | x = list(range(len(psnr))) 136 | fig, axs = plt.subplots(1, 3, figsize=(12, 4)) 137 | axs[0].plot(x, psnr, label="PSNR") 138 | axs[0].legend() 139 | axs[0].set_title("PSNR") 140 | axs[1].plot(x, ssim, label="SSIM") 141 | axs[1].legend() 142 | axs[1].set_title("SSIM") 143 | axs[2].plot(x, depth_l1, label="Depth L1 (Train view)") 144 | axs[2].legend() 145 | axs[2].set_title("Depth L1 Render") 146 | plt.tight_layout() 147 | plt.savefig(str(self.checkpoint_path / 148 | "rendering_metrics.png"), dpi=300) 149 | print(metrics) 150 | 151 | def run_reconstruction_eval(self): 152 | """ Reconstructs the mesh, evaluates it, render novel view depth maps from it, and evaluates them as well """ 153 | print("Running reconstruction evaluation...") 154 | if self.config["dataset_name"] != "replica": 155 | print("dataset is not supported, skipping reconstruction eval") 156 | return 157 | (self.checkpoint_path / "mesh").mkdir(exist_ok=True, parents=True) 158 | opt_settings = OptimizationParams(ArgumentParser( 159 | description="Training script parameters")) 160 | intrinsic = o3d.camera.PinholeCameraIntrinsic( 161 | self.width, self.height, self.fx, self.fy, self.cx, self.cy) 162 | scale = 1.0 163 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 164 | voxel_length=5.0 * scale / 512.0, 165 | sdf_trunc=0.04 * scale, 166 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8) 167 | 168 | submaps_paths = sorted(list((self.checkpoint_path / "submaps").glob('*.ckpt'))) 169 | for submap_path in tqdm(submaps_paths): 170 | submap = torch.load(submap_path, map_location=self.device) 171 | gaussian_model = GaussianModel() 172 | gaussian_model.training_setup(opt_settings) 173 | gaussian_model.restore_from_params( 174 | submap["gaussian_params"], opt_settings) 175 | 176 | for keyframe_id in submap["submap_keyframes"]: 177 | estimate_c2w = self.estimated_c2w[keyframe_id] 178 | estimate_w2c = np.linalg.inv(estimate_c2w) 179 | render_dict = render_gaussian_model( 180 | gaussian_model, get_render_settings(self.width, self.height, self.dataset.intrinsics, estimate_w2c)) 181 | rendered_color, rendered_depth = render_dict["color"].detach( 182 | ), render_dict["depth"][0].detach() 183 | rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0) 184 | 185 | rendered_color = ( 186 | torch2np(rendered_color.permute(1, 2, 0)) * 255).astype(np.uint8) 187 | rendered_depth = torch2np(rendered_depth) 188 | rendered_depth = filter_depth_outliers( 189 | rendered_depth, kernel_size=20, threshold=0.1) 190 | 191 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 192 | o3d.geometry.Image(np.ascontiguousarray(rendered_color)), 193 | o3d.geometry.Image(rendered_depth), 194 | depth_scale=scale, 195 | depth_trunc=30, 196 | convert_rgb_to_intensity=False) 197 | volume.integrate(rgbd, intrinsic, estimate_w2c) 198 | 199 | o3d_mesh = volume.extract_triangle_mesh() 200 | compensate_vector = (-0.0 * scale / 512.0, 2.5 * 201 | scale / 512.0, -2.5 * scale / 512.0) 202 | o3d_mesh = o3d_mesh.translate(compensate_vector) 203 | file_name = self.checkpoint_path / "mesh" / "final_mesh.ply" 204 | o3d.io.write_triangle_mesh(str(file_name), o3d_mesh) 205 | evaluate_reconstruction(file_name, 206 | f"data/Replica-SLAM/cull_replica/{self.scene_name}.ply", 207 | f"data/Replica-SLAM/cull_replica/{self.scene_name}_pc_unseen.npy", 208 | self.checkpoint_path) 209 | 210 | def run_global_map_eval(self): 211 | """ Merges the map, evaluates it over training and novel views """ 212 | print("Running global map evaluation...") 213 | 214 | training_frames = RenderFrames(self.dataset, self.estimated_c2w, self.height, self.width, self.fx, self.fy) 215 | training_frames = DataLoader(training_frames, batch_size=1, shuffle=True) 216 | training_frames = cycle(training_frames) 217 | merged_cloud = merge_submaps(self.submaps_paths) 218 | refined_merged_gaussian_model = refine_global_map(merged_cloud, training_frames, 10000) 219 | ply_path = self.checkpoint_path / f"{self.config['data']['scene_name']}_global_map.ply" 220 | refined_merged_gaussian_model.save_ply(ply_path) 221 | print(f"Refined global map saved to {ply_path}") 222 | 223 | if self.config["dataset_name"] != "scannetpp": 224 | return # "NVS evaluation only supported for scannetpp" 225 | 226 | eval_config = deepcopy(self.config) 227 | print(f"✨ Eval NVS for scene {self.config['data']['scene_name']}...") 228 | (self.checkpoint_path / "nvs_eval").mkdir(exist_ok=True, parents=True) 229 | eval_config["data"]["use_train_split"] = False 230 | test_set = get_dataset(eval_config["dataset_name"])({**eval_config["data"], **eval_config["cam"]}) 231 | test_poses = torch.stack([torch.from_numpy(test_set[i][3]) for i in range(len(test_set))], dim=0) 232 | test_frames = RenderFrames(test_set, test_poses, self.height, self.width, self.fx, self.fy) 233 | 234 | psnr_list = [] 235 | for i in tqdm(range(len(test_set))): 236 | gt_color, _, render_settings = ( 237 | test_frames[i]["color"], 238 | test_frames[i]["depth"], 239 | test_frames[i]["render_settings"]) 240 | render_dict = render_gaussian_model(refined_merged_gaussian_model, render_settings) 241 | rendered_color, _ = (render_dict["color"].permute(1, 2, 0), render_dict["depth"],) 242 | rendered_color = torch.clip(rendered_color, 0, 1) 243 | save_image(rendered_color.permute(2, 0, 1), self.checkpoint_path / f"nvs_eval/{i:04d}.jpg") 244 | psnr = calc_psnr(gt_color, rendered_color).mean() 245 | psnr_list.append(psnr.item()) 246 | print(f"PSNR List: {psnr_list}") 247 | print(f"Avg. NVS PSNR: {np.array(psnr_list).mean()}") 248 | 249 | def run(self): 250 | """ Runs the general evaluation flow """ 251 | 252 | print("Starting evaluation...🍺") 253 | 254 | try: 255 | self.run_trajectory_eval() 256 | except Exception: 257 | print("Could not run trajectory eval") 258 | traceback.print_exc() 259 | 260 | try: 261 | self.run_rendering_eval() 262 | except Exception: 263 | print("Could not run rendering eval") 264 | traceback.print_exc() 265 | 266 | try: 267 | self.run_reconstruction_eval() 268 | except Exception: 269 | print("Could not run reconstruction eval") 270 | traceback.print_exc() 271 | 272 | try: 273 | self.run_global_map_eval() 274 | except Exception: 275 | print("Could not run global map eval") 276 | traceback.print_exc() 277 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/gaussian_model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | import numpy as np 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | C0 = 0.28209479177387814 28 | C1 = 0.4886025119029199 29 | C2 = [ 30 | 1.0925484305920792, 31 | -1.0925484305920792, 32 | 0.31539156525252005, 33 | -1.0925484305920792, 34 | 0.5462742152960396 35 | ] 36 | C3 = [ 37 | -0.5900435899266435, 38 | 2.890611442640554, 39 | -0.4570457994644658, 40 | 0.3731763325901154, 41 | -0.4570457994644658, 42 | 1.445305721320277, 43 | -0.5900435899266435 44 | ] 45 | C4 = [ 46 | 2.5033429417967046, 47 | -1.7701307697799304, 48 | 0.9461746957575601, 49 | -0.6690465435572892, 50 | 0.10578554691520431, 51 | -0.6690465435572892, 52 | 0.47308734787878004, 53 | -1.7701307697799304, 54 | 0.6258357354491761, 55 | ] 56 | 57 | 58 | def eval_sh(deg, sh, dirs): 59 | """ 60 | Evaluate spherical harmonics at unit directions 61 | using hardcoded SH polynomials. 62 | Works with torch/np/jnp. 63 | ... Can be 0 or more batch dimensions. 64 | Args: 65 | deg: int SH deg. Currently, 0-3 supported 66 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 67 | dirs: jnp.ndarray unit directions [..., 3] 68 | Returns: 69 | [..., C] 70 | """ 71 | assert deg <= 4 and deg >= 0 72 | coeff = (deg + 1) ** 2 73 | assert sh.shape[-1] >= coeff 74 | 75 | result = C0 * sh[..., 0] 76 | if deg > 0: 77 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 78 | result = (result - 79 | C1 * y * sh[..., 1] + 80 | C1 * z * sh[..., 2] - 81 | C1 * x * sh[..., 3]) 82 | 83 | if deg > 1: 84 | xx, yy, zz = x * x, y * y, z * z 85 | xy, yz, xz = x * y, y * z, x * z 86 | result = (result + 87 | C2[0] * xy * sh[..., 4] + 88 | C2[1] * yz * sh[..., 5] + 89 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 90 | C2[3] * xz * sh[..., 7] + 91 | C2[4] * (xx - yy) * sh[..., 8]) 92 | 93 | if deg > 2: 94 | result = (result + 95 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 96 | C3[1] * xy * z * sh[..., 10] + 97 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + 98 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 99 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 100 | C3[5] * z * (xx - yy) * sh[..., 14] + 101 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 102 | 103 | if deg > 3: 104 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 105 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 106 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 107 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 108 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 109 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 110 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 111 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 112 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 113 | return result 114 | 115 | 116 | def RGB2SH(rgb): 117 | return (rgb - 0.5) / C0 118 | 119 | 120 | def SH2RGB(sh): 121 | return sh * C0 + 0.5 122 | 123 | 124 | def inverse_sigmoid(x): 125 | return torch.log(x/(1-x)) 126 | 127 | 128 | def get_expon_lr_func( 129 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 130 | ): 131 | """ 132 | Copied from Plenoxels 133 | 134 | Continuous learning rate decay function. Adapted from JaxNeRF 135 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 136 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 137 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 138 | function of lr_delay_mult, such that the initial learning rate is 139 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 140 | to the normal learning rate when steps>lr_delay_steps. 141 | :param conf: config subtree 'lr' or similar 142 | :param max_steps: int, the number of steps during optimization. 143 | :return HoF which takes step as input 144 | """ 145 | 146 | def helper(step): 147 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 148 | # Disable this parameter 149 | return 0.0 150 | if lr_delay_steps > 0: 151 | # A kind of reverse cosine decay. 152 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 153 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 154 | ) 155 | else: 156 | delay_rate = 1.0 157 | t = np.clip(step / max_steps, 0, 1) 158 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 159 | return delay_rate * log_lerp 160 | 161 | return helper 162 | 163 | 164 | def strip_lowerdiag(L): 165 | uncertainty = torch.zeros( 166 | (L.shape[0], 6), dtype=torch.float, device="cuda") 167 | 168 | uncertainty[:, 0] = L[:, 0, 0] 169 | uncertainty[:, 1] = L[:, 0, 1] 170 | uncertainty[:, 2] = L[:, 0, 2] 171 | uncertainty[:, 3] = L[:, 1, 1] 172 | uncertainty[:, 4] = L[:, 1, 2] 173 | uncertainty[:, 5] = L[:, 2, 2] 174 | return uncertainty 175 | 176 | 177 | def strip_symmetric(sym): 178 | return strip_lowerdiag(sym) 179 | 180 | 181 | def build_rotation(r): 182 | 183 | q = F.normalize(r, p=2, dim=1) 184 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 185 | 186 | r = q[:, 0] 187 | x = q[:, 1] 188 | y = q[:, 2] 189 | z = q[:, 3] 190 | 191 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 192 | R[:, 0, 1] = 2 * (x*y - r*z) 193 | R[:, 0, 2] = 2 * (x*z + r*y) 194 | R[:, 1, 0] = 2 * (x*y + r*z) 195 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 196 | R[:, 1, 2] = 2 * (y*z - r*x) 197 | R[:, 2, 0] = 2 * (x*z - r*y) 198 | R[:, 2, 1] = 2 * (y*z + r*x) 199 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 200 | return R 201 | 202 | 203 | def build_scaling_rotation(s, r): 204 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 205 | R = build_rotation(r) 206 | 207 | L[:, 0, 0] = s[:, 0] 208 | L[:, 1, 1] = s[:, 1] 209 | L[:, 2, 2] = s[:, 2] 210 | 211 | L = R @ L 212 | return L 213 | -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import open3d as o3d 7 | import torch 8 | import wandb 9 | import yaml 10 | 11 | 12 | def mkdir_decorator(func): 13 | """A decorator that creates the directory specified in the function's 'directory' keyword 14 | argument before calling the function. 15 | Args: 16 | func: The function to be decorated. 17 | Returns: 18 | The wrapper function. 19 | """ 20 | def wrapper(*args, **kwargs): 21 | output_path = Path(kwargs["directory"]) 22 | output_path.mkdir(parents=True, exist_ok=True) 23 | return func(*args, **kwargs) 24 | return wrapper 25 | 26 | 27 | @mkdir_decorator 28 | def save_clouds(clouds: list, cloud_names: list, *, directory: Union[str, Path]) -> None: 29 | """ Saves a list of point clouds to the specified directory, creating the directory if it does not exist. 30 | Args: 31 | clouds: A list of point cloud objects to be saved. 32 | cloud_names: A list of filenames for the point clouds, corresponding by index to the clouds. 33 | directory: The directory where the point clouds will be saved. 34 | """ 35 | for cld_name, cloud in zip(cloud_names, clouds): 36 | o3d.io.write_point_cloud(str(directory / cld_name), cloud) 37 | 38 | 39 | @mkdir_decorator 40 | def save_dict_to_ckpt(dictionary, file_name: str, *, directory: Union[str, Path]) -> None: 41 | """ Saves a dictionary to a checkpoint file in the specified directory, creating the directory if it does not exist. 42 | Args: 43 | dictionary: The dictionary to be saved. 44 | file_name: The name of the checkpoint file. 45 | directory: The directory where the checkpoint file will be saved. 46 | """ 47 | torch.save(dictionary, directory / file_name, 48 | _use_new_zipfile_serialization=False) 49 | 50 | 51 | @mkdir_decorator 52 | def save_dict_to_yaml(dictionary, file_name: str, *, directory: Union[str, Path]) -> None: 53 | """ Saves a dictionary to a YAML file in the specified directory, creating the directory if it does not exist. 54 | Args: 55 | dictionary: The dictionary to be saved. 56 | file_name: The name of the YAML file. 57 | directory: The directory where the YAML file will be saved. 58 | """ 59 | with open(directory / file_name, "w") as f: 60 | yaml.dump(dictionary, f) 61 | 62 | 63 | @mkdir_decorator 64 | def save_dict_to_json(dictionary, file_name: str, *, directory: Union[str, Path]) -> None: 65 | """ Saves a dictionary to a JSON file in the specified directory, creating the directory if it does not exist. 66 | Args: 67 | dictionary: The dictionary to be saved. 68 | file_name: The name of the JSON file. 69 | directory: The directory where the JSON file will be saved. 70 | """ 71 | with open(directory / file_name, "w") as f: 72 | json.dump(dictionary, f) 73 | 74 | 75 | def load_config(path: str, default_path: str = None) -> dict: 76 | """ 77 | Loads a configuration file and optionally merges it with a default configuration file. 78 | 79 | This function loads a configuration from the given path. If the configuration specifies an inheritance 80 | path (`inherit_from`), or if a `default_path` is provided, it loads the base configuration and updates it 81 | with the specific configuration. 82 | 83 | Args: 84 | path: The path to the specific configuration file. 85 | default_path: An optional path to a default configuration file that is loaded if the specific configuration 86 | does not specify an inheritance or as a base for the inheritance. 87 | 88 | Returns: 89 | A dictionary containing the merged configuration. 90 | """ 91 | # load configuration from per scene/dataset cfg. 92 | with open(path, 'r') as f: 93 | cfg_special = yaml.full_load(f) 94 | inherit_from = cfg_special.get('inherit_from') 95 | cfg = dict() 96 | if inherit_from is not None: 97 | cfg = load_config(inherit_from, default_path) 98 | elif default_path is not None: 99 | with open(default_path, 'r') as f: 100 | cfg = yaml.full_load(f) 101 | update_recursive(cfg, cfg_special) 102 | return cfg 103 | 104 | 105 | def update_recursive(dict1: dict, dict2: dict) -> None: 106 | """ Recursively updates the first dictionary with the contents of the second dictionary. 107 | 108 | This function iterates through `dict2` and updates `dict1` with its contents. If a key from `dict2` 109 | exists in `dict1` and its value is also a dictionary, the function updates the value recursively. 110 | Otherwise, it overwrites the value in `dict1` with the value from `dict2`. 111 | 112 | Args: 113 | dict1: The dictionary to be updated. 114 | dict2: The dictionary whose entries are used to update `dict1`. 115 | 116 | Returns: 117 | None: The function modifies `dict1` in place. 118 | """ 119 | for k, v in dict2.items(): 120 | if k not in dict1: 121 | dict1[k] = dict() 122 | if isinstance(v, dict): 123 | update_recursive(dict1[k], v) 124 | else: 125 | dict1[k] = v 126 | 127 | 128 | def log_metrics_to_wandb(json_files: list, output_path: str, section: str = "Evaluation") -> None: 129 | """ Logs metrics from JSON files to Weights & Biases under a specified section. 130 | 131 | This function reads metrics from a list of JSON files and logs them to Weights & Biases (wandb). 132 | Each metric is prefixed with a specified section name for organized logging. 133 | 134 | Args: 135 | json_files: A list of filenames for JSON files containing metrics to be logged. 136 | output_path: The directory path where the JSON files are located. 137 | section: The section under which to log the metrics in wandb. Defaults to "Evaluation". 138 | 139 | Returns: 140 | None: Metrics are logged to wandb and the function does not return a value. 141 | """ 142 | for json_file in json_files: 143 | file_path = os.path.join(output_path, json_file) 144 | if os.path.exists(file_path): 145 | with open(file_path, 'r') as file: 146 | metrics = json.load(file) 147 | prefixed_metrics = { 148 | f"{section}/{key}": value for key, value in metrics.items()} 149 | wandb.log(prefixed_metrics) 150 | -------------------------------------------------------------------------------- /src/utils/mapper_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import faiss 4 | import faiss.contrib.torch_utils 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def compute_opt_views_distribution(keyframes_num, iterations_num, current_frame_iter) -> np.ndarray: 10 | """ Computes the probability distribution for selecting views based on the current iteration. 11 | Args: 12 | keyframes_num: The total number of keyframes. 13 | iterations_num: The total number of iterations planned. 14 | current_frame_iter: The current iteration number. 15 | Returns: 16 | An array representing the probability distribution of keyframes. 17 | """ 18 | if keyframes_num == 1: 19 | return np.array([1.0]) 20 | prob = np.full(keyframes_num, (iterations_num - current_frame_iter) / (keyframes_num - 1)) 21 | prob[0] = current_frame_iter 22 | prob /= prob.sum() 23 | return prob 24 | 25 | 26 | def compute_camera_frustum_corners(depth_map: np.ndarray, pose: np.ndarray, intrinsics: np.ndarray) -> np.ndarray: 27 | """ Computes the 3D coordinates of the camera frustum corners based on the depth map, pose, and intrinsics. 28 | Args: 29 | depth_map: The depth map of the scene. 30 | pose: The camera pose matrix. 31 | intrinsics: The camera intrinsic matrix. 32 | Returns: 33 | An array of 3D coordinates for the frustum corners. 34 | """ 35 | height, width = depth_map.shape 36 | depth_map = depth_map[depth_map > 0] 37 | min_depth, max_depth = depth_map.min(), depth_map.max() 38 | corners = np.array( 39 | [ 40 | [0, 0, min_depth], 41 | [width, 0, min_depth], 42 | [0, height, min_depth], 43 | [width, height, min_depth], 44 | [0, 0, max_depth], 45 | [width, 0, max_depth], 46 | [0, height, max_depth], 47 | [width, height, max_depth], 48 | ] 49 | ) 50 | x = (corners[:, 0] - intrinsics[0, 2]) * corners[:, 2] / intrinsics[0, 0] 51 | y = (corners[:, 1] - intrinsics[1, 2]) * corners[:, 2] / intrinsics[1, 1] 52 | z = corners[:, 2] 53 | corners_3d = np.vstack((x, y, z, np.ones(x.shape[0]))).T 54 | corners_3d = pose @ corners_3d.T 55 | return corners_3d.T[:, :3] 56 | 57 | 58 | def compute_camera_frustum_planes(frustum_corners: np.ndarray) -> torch.Tensor: 59 | """ Computes the planes of the camera frustum from its corners. 60 | Args: 61 | frustum_corners: An array of 3D coordinates representing the corners of the frustum. 62 | 63 | Returns: 64 | A tensor of frustum planes. 65 | """ 66 | # near, far, left, right, top, bottom 67 | planes = torch.stack( 68 | [ 69 | torch.cross( 70 | frustum_corners[2] - frustum_corners[0], 71 | frustum_corners[1] - frustum_corners[0] 72 | ), 73 | torch.cross( 74 | frustum_corners[6] - frustum_corners[4], 75 | frustum_corners[5] - frustum_corners[4] 76 | ), 77 | torch.cross( 78 | frustum_corners[4] - frustum_corners[0], 79 | frustum_corners[2] - frustum_corners[0] 80 | ), 81 | torch.cross( 82 | frustum_corners[7] - frustum_corners[3], 83 | frustum_corners[1] - frustum_corners[3] 84 | ), 85 | torch.cross( 86 | frustum_corners[5] - frustum_corners[1], 87 | frustum_corners[0] - frustum_corners[1] 88 | ), 89 | torch.cross( 90 | frustum_corners[6] - frustum_corners[2], 91 | frustum_corners[3] - frustum_corners[2] 92 | ), 93 | ] 94 | ) 95 | D = torch.stack([-torch.dot(plane, frustum_corners[i]) for i, plane in enumerate(planes)]) 96 | return torch.cat([planes, D[:, None]], dim=1).float() 97 | 98 | 99 | def compute_frustum_aabb(frustum_corners: torch.Tensor): 100 | """ Computes a mask indicating which points lie inside a given axis-aligned bounding box (AABB). 101 | Args: 102 | points: An array of 3D points. 103 | min_corner: The minimum corner of the AABB. 104 | max_corner: The maximum corner of the AABB. 105 | Returns: 106 | A boolean array indicating whether each point lies inside the AABB. 107 | """ 108 | return torch.min(frustum_corners, axis=0).values, torch.max(frustum_corners, axis=0).values 109 | 110 | 111 | def points_inside_aabb_mask(points: np.ndarray, min_corner: np.ndarray, max_corner: np.ndarray) -> np.ndarray: 112 | """ Computes a mask indicating which points lie inside the camera frustum. 113 | Args: 114 | points: A tensor of 3D points. 115 | frustum_planes: A tensor representing the planes of the frustum. 116 | Returns: 117 | A boolean tensor indicating whether each point lies inside the frustum. 118 | """ 119 | return ( 120 | (points[:, 0] >= min_corner[0]) 121 | & (points[:, 0] <= max_corner[0]) 122 | & (points[:, 1] >= min_corner[1]) 123 | & (points[:, 1] <= max_corner[1]) 124 | & (points[:, 2] >= min_corner[2]) 125 | & (points[:, 2] <= max_corner[2])) 126 | 127 | 128 | def points_inside_frustum_mask(points: torch.Tensor, frustum_planes: torch.Tensor) -> torch.Tensor: 129 | """ Computes a mask indicating which points lie inside the camera frustum. 130 | Args: 131 | points: A tensor of 3D points. 132 | frustum_planes: A tensor representing the planes of the frustum. 133 | Returns: 134 | A boolean tensor indicating whether each point lies inside the frustum. 135 | """ 136 | num_pts = points.shape[0] 137 | ones = torch.ones(num_pts, 1).to(points.device) 138 | plane_product = torch.cat([points, ones], axis=1) @ frustum_planes.T 139 | return torch.all(plane_product <= 0, axis=1) 140 | 141 | 142 | def compute_frustum_point_ids(pts: torch.Tensor, frustum_corners: torch.Tensor, device: str = "cuda"): 143 | """ Identifies points within the camera frustum, optimizing for computation on a specified device. 144 | Args: 145 | pts: A tensor of 3D points. 146 | frustum_corners: A tensor of 3D coordinates representing the corners of the frustum. 147 | device: The computation device ("cuda" or "cpu"). 148 | Returns: 149 | Indices of points lying inside the frustum. 150 | """ 151 | if pts.shape[0] == 0: 152 | return torch.tensor([], dtype=torch.int64, device=device) 153 | # Broad phase 154 | pts = pts.to(device) 155 | frustum_corners = frustum_corners.to(device) 156 | 157 | min_corner, max_corner = compute_frustum_aabb(frustum_corners) 158 | inside_aabb_mask = points_inside_aabb_mask(pts, min_corner, max_corner) 159 | 160 | # Narrow phase 161 | frustum_planes = compute_camera_frustum_planes(frustum_corners) 162 | frustum_planes = frustum_planes.to(device) 163 | inside_frustum_mask = points_inside_frustum_mask(pts[inside_aabb_mask], frustum_planes) 164 | 165 | inside_aabb_mask[inside_aabb_mask == 1] = inside_frustum_mask 166 | return torch.where(inside_aabb_mask)[0] 167 | 168 | 169 | def sample_pixels_based_on_gradient(image: np.ndarray, num_samples: int) -> np.ndarray: 170 | """ Samples pixel indices based on the gradient magnitude of an image. 171 | Args: 172 | image: The image from which to sample pixels. 173 | num_samples: The number of pixels to sample. 174 | Returns: 175 | Indices of the sampled pixels. 176 | """ 177 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 178 | grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) 179 | grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) 180 | grad_magnitude = cv2.magnitude(grad_x, grad_y) 181 | 182 | # Normalize the gradient magnitude to create a probability map 183 | prob_map = grad_magnitude / np.sum(grad_magnitude) 184 | 185 | # Flatten the probability map 186 | prob_map_flat = prob_map.flatten() 187 | 188 | # Sample pixel indices based on the probability map 189 | sampled_indices = np.random.choice(prob_map_flat.size, size=num_samples, p=prob_map_flat) 190 | return sampled_indices.T 191 | 192 | 193 | def compute_new_points_ids(frustum_points: torch.Tensor, new_pts: torch.Tensor, 194 | radius: float = 0.03, device: str = "cpu") -> torch.Tensor: 195 | """ Having newly initialized points, decides which of them should be added to the submap. 196 | For every new point, if there are no neighbors within the radius in the frustum points, 197 | it is added to the submap. 198 | Args: 199 | frustum_points: Point within a current frustum of the active submap of shape (N, 3) 200 | new_pts: New 3D Gaussian means which are about to be added to the submap of shape (N, 3) 201 | radius: Radius whithin which the points are considered to be neighbors 202 | device: Execution device 203 | Returns: 204 | Indicies of the new points that should be added to the submap of shape (N) 205 | """ 206 | if frustum_points.shape[0] == 0: 207 | return torch.arange(new_pts.shape[0]) 208 | if device == "cpu": 209 | pts_index = faiss.IndexFlatL2(3) 210 | else: 211 | pts_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss.IndexFlatL2(3)) 212 | frustum_points = frustum_points.to(device) 213 | new_pts = new_pts.to(device) 214 | pts_index.add(frustum_points) 215 | 216 | split_pos = torch.split(new_pts, 65535, dim=0) 217 | distances, ids = [], [] 218 | for split_p in split_pos: 219 | distance, id = pts_index.search(split_p.float(), 8) 220 | distances.append(distance) 221 | ids.append(id) 222 | distances = torch.cat(distances, dim=0) 223 | ids = torch.cat(ids, dim=0) 224 | neighbor_num = (distances < radius).sum(axis=1).int() 225 | pts_index.reset() 226 | return torch.where(neighbor_num == 0)[0] 227 | 228 | 229 | def rotation_to_euler(R: torch.Tensor) -> torch.Tensor: 230 | """ 231 | Converts a rotation matrix to Euler angles. 232 | Args: 233 | R: A rotation matrix. 234 | Returns: 235 | Euler angles corresponding to the rotation matrix. 236 | """ 237 | sy = torch.sqrt(R[0, 0] ** 2 + R[1, 0] ** 2) 238 | singular = sy < 1e-6 239 | 240 | if not singular: 241 | x = torch.atan2(R[2, 1], R[2, 2]) 242 | y = torch.atan2(-R[2, 0], sy) 243 | z = torch.atan2(R[1, 0], R[0, 0]) 244 | else: 245 | x = torch.atan2(-R[1, 2], R[1, 1]) 246 | y = torch.atan2(-R[2, 0], sy) 247 | z = 0 248 | 249 | return torch.tensor([x, y, z]) * (180 / np.pi) 250 | 251 | 252 | def exceeds_motion_thresholds(current_c2w: torch.Tensor, last_submap_c2w: torch.Tensor, 253 | rot_thre: float = 50, trans_thre: float = 0.5) -> bool: 254 | """ Checks if a camera motion exceeds certain rotation and translation thresholds 255 | Args: 256 | current_c2w: The current camera-to-world transformation matrix. 257 | last_submap_c2w: The last submap's camera-to-world transformation matrix. 258 | rot_thre: The rotation threshold for triggering a new submap. 259 | trans_thre: The translation threshold for triggering a new submap. 260 | 261 | Returns: 262 | A boolean indicating whether a new submap is required. 263 | """ 264 | delta_pose = torch.matmul(torch.linalg.inv(last_submap_c2w).float(), current_c2w.float()) 265 | translation_diff = torch.norm(delta_pose[:3, 3]) 266 | rot_euler_diff_deg = torch.abs(rotation_to_euler(delta_pose[:3, :3])) 267 | exceeds_thresholds = (translation_diff > trans_thre) or torch.any(rot_euler_diff_deg > rot_thre) 268 | return exceeds_thresholds.item() 269 | 270 | 271 | def geometric_edge_mask(rgb_image: np.ndarray, dilate: bool = True, RGB: bool = False) -> np.ndarray: 272 | """ Computes an edge mask for an RGB image using geometric edges. 273 | Args: 274 | rgb_image: The RGB image. 275 | dilate: Whether to dilate the edges. 276 | RGB: Indicates if the image format is RGB (True) or BGR (False). 277 | Returns: 278 | An edge mask of the input image. 279 | """ 280 | # Convert the image to grayscale as Canny edge detection requires a single channel image 281 | gray_image = cv2.cvtColor( 282 | rgb_image, cv2.COLOR_BGR2GRAY if not RGB else cv2.COLOR_RGB2GRAY) 283 | if gray_image.dtype != np.uint8: 284 | gray_image = gray_image.astype(np.uint8) 285 | edges = cv2.Canny(gray_image, threshold1=100, threshold2=200, apertureSize=3, L2gradient=True) 286 | # Define the structuring element for dilation, you can change the size for a thicker/thinner mask 287 | if dilate: 288 | kernel = np.ones((2, 2), np.uint8) 289 | edges = cv2.dilate(edges, kernel, iterations=1) 290 | return edges 291 | 292 | 293 | def calc_psnr(img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: 294 | """ Calculates the Peak Signal-to-Noise Ratio (PSNR) between two images. 295 | Args: 296 | img1: The first image. 297 | img2: The second image. 298 | Returns: 299 | The PSNR value. 300 | """ 301 | mse = ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 302 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 303 | 304 | 305 | def create_point_cloud(image: np.ndarray, depth: np.ndarray, intrinsics: np.ndarray, pose: np.ndarray) -> np.ndarray: 306 | """ 307 | Creates a point cloud from an image, depth map, camera intrinsics, and pose. 308 | 309 | Args: 310 | image: The RGB image of shape (H, W, 3) 311 | depth: The depth map of shape (H, W) 312 | intrinsics: The camera intrinsic parameters of shape (3, 3) 313 | pose: The camera pose of shape (4, 4) 314 | Returns: 315 | A point cloud of shape (N, 6) with last dimension representing (x, y, z, r, g, b) 316 | """ 317 | height, width = depth.shape 318 | # Create a mesh grid of pixel coordinates 319 | u, v = np.meshgrid(np.arange(width), np.arange(height)) 320 | # Convert pixel coordinates to camera coordinates 321 | x = (u - intrinsics[0, 2]) * depth / intrinsics[0, 0] 322 | y = (v - intrinsics[1, 2]) * depth / intrinsics[1, 1] 323 | z = depth 324 | # Stack the coordinates together 325 | points = np.stack((x, y, z, np.ones_like(z)), axis=-1) 326 | # Reshape the coordinates for matrix multiplication 327 | points = points.reshape(-1, 4) 328 | # Transform points to world coordinates 329 | posed_points = pose @ points.T 330 | posed_points = posed_points.T[:, :3] 331 | # Flatten the image to get colors for each point 332 | colors = image.reshape(-1, 3) 333 | # Concatenate posed points with their corresponding color 334 | point_cloud = np.concatenate((posed_points, colors), axis=-1) 335 | 336 | return point_cloud 337 | -------------------------------------------------------------------------------- /src/utils/tracker_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.spatial.transform import Rotation 4 | from typing import Union 5 | from src.utils.utils import np2torch 6 | 7 | 8 | def multiply_quaternions(q: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 9 | """Performs batch-wise quaternion multiplication. 10 | 11 | Given two quaternions, this function computes their product. The operation is 12 | vectorized and can be performed on batches of quaternions. 13 | 14 | Args: 15 | q: A tensor representing the first quaternion or a batch of quaternions. 16 | Expected shape is (... , 4), where the last dimension contains quaternion components (w, x, y, z). 17 | r: A tensor representing the second quaternion or a batch of quaternions with the same shape as q. 18 | Returns: 19 | A tensor of the same shape as the input tensors, representing the product of the input quaternions. 20 | """ 21 | w0, x0, y0, z0 = q[..., 0], q[..., 1], q[..., 2], q[..., 3] 22 | w1, x1, y1, z1 = r[..., 0], r[..., 1], r[..., 2], r[..., 3] 23 | 24 | w = -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0 25 | x = x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0 26 | y = -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0 27 | z = x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0 28 | return torch.stack((w, x, y, z), dim=-1) 29 | 30 | 31 | def transformation_to_quaternion(RT: Union[torch.Tensor, np.ndarray]): 32 | """ Converts a rotation-translation matrix to a tensor representing quaternion and translation. 33 | 34 | This function takes a 3x4 transformation matrix (rotation and translation) and converts it 35 | into a tensor that combines the quaternion representation of the rotation and the translation vector. 36 | 37 | Args: 38 | RT: A 3x4 matrix representing the rotation and translation. This can be a NumPy array 39 | or a torch.Tensor. If it's a torch.Tensor and resides on a GPU, it will be moved to CPU. 40 | 41 | Returns: 42 | A tensor combining the quaternion (in w, x, y, z order) and translation vector. The tensor 43 | will be moved to the original device if the input was a GPU tensor. 44 | """ 45 | gpu_id = -1 46 | if isinstance(RT, torch.Tensor): 47 | if RT.get_device() != -1: 48 | RT = RT.detach().cpu() 49 | gpu_id = RT.get_device() 50 | RT = RT.numpy() 51 | R, T = RT[:3, :3], RT[:3, 3] 52 | 53 | rot = Rotation.from_matrix(R) 54 | quad = rot.as_quat(canonical=True) 55 | quad = np.roll(quad, 1) 56 | tensor = np.concatenate([quad, T], 0) 57 | tensor = torch.from_numpy(tensor).float() 58 | if gpu_id != -1: 59 | tensor = tensor.to(gpu_id) 60 | return tensor 61 | 62 | 63 | def extrapolate_poses(poses: np.ndarray) -> np.ndarray: 64 | """ Generates an interpolated pose based on the first two poses in the given array. 65 | Args: 66 | poses: An array of poses, where each pose is represented by a 4x4 transformation matrix. 67 | Returns: 68 | A 4x4 numpy ndarray representing the interpolated transformation matrix. 69 | """ 70 | return poses[1, :] @ np.linalg.inv(poses[0, :]) @ poses[1, :] 71 | 72 | 73 | def compute_camera_opt_params(estimate_rel_w2c: np.ndarray) -> tuple: 74 | """ Computes the camera's rotation and translation parameters from an world-to-camera transformation matrix. 75 | This function extracts the rotation component of the transformation matrix, converts it to a quaternion, 76 | and reorders it to match a specific convention. Both rotation and translation parameters are converted 77 | to torch Parameters and intended to be optimized in a PyTorch model. 78 | Args: 79 | estimate_rel_w2c: A 4x4 numpy ndarray representing the estimated world-to-camera transformation matrix. 80 | Returns: 81 | A tuple containing two torch.nn.Parameters: camera's rotation and camera's translation. 82 | """ 83 | quaternion = Rotation.from_matrix(estimate_rel_w2c[:3, :3]).as_quat(canonical=True) 84 | quaternion = quaternion[[3, 0, 1, 2]] 85 | opt_cam_rot = torch.nn.Parameter(np2torch(quaternion, "cuda")) 86 | opt_cam_trans = torch.nn.Parameter(np2torch(estimate_rel_w2c[:3, 3], "cuda")) 87 | return opt_cam_rot, opt_cam_trans 88 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | import torch 7 | from gaussian_rasterizer import GaussianRasterizationSettings, GaussianRasterizer 8 | 9 | 10 | def setup_seed(seed: int) -> None: 11 | """ Sets the seed for generating random numbers to ensure reproducibility across multiple runs. 12 | Args: 13 | seed: The seed value to set for random number generators in torch, numpy, and random. 14 | """ 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | os.environ["PYTHONHASHSEED"] = str(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | 23 | 24 | def torch2np(tensor: torch.Tensor) -> np.ndarray: 25 | """ Converts a PyTorch tensor to a NumPy ndarray. 26 | Args: 27 | tensor: The PyTorch tensor to convert. 28 | Returns: 29 | A NumPy ndarray with the same data and dtype as the input tensor. 30 | """ 31 | return tensor.detach().cpu().numpy() 32 | 33 | 34 | def np2torch(array: np.ndarray, device: str = "cpu") -> torch.Tensor: 35 | """Converts a NumPy ndarray to a PyTorch tensor. 36 | Args: 37 | array: The NumPy ndarray to convert. 38 | device: The device to which the tensor is sent. Defaults to 'cpu'. 39 | 40 | Returns: 41 | A PyTorch tensor with the same data as the input array. 42 | """ 43 | return torch.from_numpy(array).float().to(device) 44 | 45 | 46 | def np2ptcloud(pts: np.ndarray, rgb=None) -> o3d.geometry.PointCloud: 47 | """converts numpy array to point cloud 48 | Args: 49 | pts (ndarray): point cloud 50 | Returns: 51 | (PointCloud): resulting point cloud 52 | """ 53 | cloud = o3d.geometry.PointCloud() 54 | cloud.points = o3d.utility.Vector3dVector(pts) 55 | if rgb is not None: 56 | cloud.colors = o3d.utility.Vector3dVector(rgb) 57 | return cloud 58 | 59 | 60 | def dict2device(dict: dict, device: str = "cpu") -> dict: 61 | """Sends all tensors in a dictionary to a specified device. 62 | Args: 63 | dict: The dictionary containing tensors. 64 | device: The device to send the tensors to. Defaults to 'cpu'. 65 | Returns: 66 | The dictionary with all tensors sent to the specified device. 67 | """ 68 | for k, v in dict.items(): 69 | if isinstance(v, torch.Tensor): 70 | dict[k] = v.to(device) 71 | return dict 72 | 73 | 74 | def get_render_settings(w, h, intrinsics, w2c, near=0.01, far=100, sh_degree=0): 75 | """ 76 | Constructs and returns a GaussianRasterizationSettings object for rendering, 77 | configured with given camera parameters. 78 | 79 | Args: 80 | width (int): The width of the image. 81 | height (int): The height of the image. 82 | intrinsic (array): 3*3, Intrinsic camera matrix. 83 | w2c (array): World to camera transformation matrix. 84 | near (float, optional): The near plane for the camera. Defaults to 0.01. 85 | far (float, optional): The far plane for the camera. Defaults to 100. 86 | 87 | Returns: 88 | GaussianRasterizationSettings: Configured settings for Gaussian rasterization. 89 | """ 90 | fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 91 | 1], intrinsics[0, 2], intrinsics[1, 2] 92 | w2c = torch.tensor(w2c).cuda().float() 93 | cam_center = torch.inverse(w2c)[:3, 3] 94 | viewmatrix = w2c.transpose(0, 1) 95 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0], 96 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0], 97 | [0.0, 0.0, far / 98 | (far - near), -(far * near) / (far - near)], 99 | [0.0, 0.0, 1.0, 0.0]], device='cuda').float().transpose(0, 1) 100 | full_proj_matrix = viewmatrix.unsqueeze( 101 | 0).bmm(opengl_proj.unsqueeze(0)).squeeze(0) 102 | return GaussianRasterizationSettings( 103 | image_height=h, 104 | image_width=w, 105 | tanfovx=w / (2 * fx), 106 | tanfovy=h / (2 * fy), 107 | bg=torch.tensor([0, 0, 0], device='cuda').float(), 108 | scale_modifier=1.0, 109 | viewmatrix=viewmatrix, 110 | projmatrix=full_proj_matrix, 111 | sh_degree=sh_degree, 112 | campos=cam_center, 113 | prefiltered=False, 114 | debug=False) 115 | 116 | 117 | def render_gaussian_model(gaussian_model, render_settings, 118 | override_means_3d=None, override_means_2d=None, 119 | override_scales=None, override_rotations=None, 120 | override_opacities=None, override_colors=None): 121 | """ 122 | Renders a Gaussian model with specified rendering settings, allowing for 123 | optional overrides of various model parameters. 124 | 125 | Args: 126 | gaussian_model: A Gaussian model object that provides methods to get 127 | various properties like xyz coordinates, opacity, features, etc. 128 | render_settings: Configuration settings for the GaussianRasterizer. 129 | override_means_3d (Optional): If provided, these values will override 130 | the 3D mean values from the Gaussian model. 131 | override_means_2d (Optional): If provided, these values will override 132 | the 2D mean values. Defaults to zeros if not provided. 133 | override_scales (Optional): If provided, these values will override the 134 | scale values from the Gaussian model. 135 | override_rotations (Optional): If provided, these values will override 136 | the rotation values from the Gaussian model. 137 | override_opacities (Optional): If provided, these values will override 138 | the opacity values from the Gaussian model. 139 | override_colors (Optional): If provided, these values will override the 140 | color values from the Gaussian model. 141 | Returns: 142 | A dictionary containing the rendered color, depth, radii, and 2D means 143 | of the Gaussian model. The keys of this dictionary are 'color', 'depth', 144 | 'radii', and 'means2D', each mapping to their respective rendered values. 145 | """ 146 | renderer = GaussianRasterizer(raster_settings=render_settings) 147 | 148 | if override_means_3d is None: 149 | means3D = gaussian_model.get_xyz() 150 | else: 151 | means3D = override_means_3d 152 | 153 | if override_means_2d is None: 154 | means2D = torch.zeros_like( 155 | means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") 156 | means2D.retain_grad() 157 | else: 158 | means2D = override_means_2d 159 | 160 | if override_opacities is None: 161 | opacities = gaussian_model.get_opacity() 162 | else: 163 | opacities = override_opacities 164 | 165 | shs, colors_precomp = None, None 166 | if override_colors is not None: 167 | colors_precomp = override_colors 168 | else: 169 | shs = gaussian_model.get_features() 170 | 171 | render_args = { 172 | "means3D": means3D, 173 | "means2D": means2D, 174 | "opacities": opacities, 175 | "colors_precomp": colors_precomp, 176 | "shs": shs, 177 | "scales": gaussian_model.get_scaling() if override_scales is None else override_scales, 178 | "rotations": gaussian_model.get_rotation() if override_rotations is None else override_rotations, 179 | "cov3D_precomp": None 180 | } 181 | color, depth, alpha, radii = renderer(**render_args) 182 | 183 | return {"color": color, "depth": depth, "radii": radii, "means2D": means2D, "alpha": alpha} 184 | 185 | 186 | def batch_search_faiss(indexer, query_points, k): 187 | """ 188 | Perform a batch search on a IndexIVFFlat indexer to circumvent the search size limit of 65535. 189 | 190 | Args: 191 | indexer: The FAISS indexer object. 192 | query_points: A tensor of query points. 193 | k (int): The number of nearest neighbors to find. 194 | 195 | Returns: 196 | distances (torch.Tensor): The distances of the nearest neighbors. 197 | ids (torch.Tensor): The indices of the nearest neighbors. 198 | """ 199 | split_pos = torch.split(query_points, 65535, dim=0) 200 | distances_list, ids_list = [], [] 201 | 202 | for split_p in split_pos: 203 | distance, id = indexer.search(split_p.float(), k) 204 | distances_list.append(distance.clone()) 205 | ids_list.append(id.clone()) 206 | distances = torch.cat(distances_list, dim=0) 207 | ids = torch.cat(ids_list, dim=0) 208 | 209 | return distances, ids 210 | -------------------------------------------------------------------------------- /src/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | from matplotlib import colors 8 | 9 | COLORS_ANSI = OrderedDict({ 10 | "blue": "\033[94m", 11 | "orange": "\033[93m", 12 | "green": "\033[92m", 13 | "red": "\033[91m", 14 | "purple": "\033[95m", 15 | "brown": "\033[93m", # No exact match, using yellow 16 | "pink": "\033[95m", 17 | "gray": "\033[90m", 18 | "olive": "\033[93m", # No exact match, using yellow 19 | "cyan": "\033[96m", 20 | "end": "\033[0m", # Reset color 21 | }) 22 | 23 | 24 | COLORS_MATPLOTLIB = OrderedDict({ 25 | 'blue': '#1f77b4', 26 | 'orange': '#ff7f0e', 27 | 'green': '#2ca02c', 28 | 'red': '#d62728', 29 | 'purple': '#9467bd', 30 | 'brown': '#8c564b', 31 | 'pink': '#e377c2', 32 | 'gray': '#7f7f7f', 33 | 'yellow-green': '#bcbd22', 34 | 'cyan': '#17becf' 35 | }) 36 | 37 | 38 | COLORS_MATPLOTLIB_RGB = OrderedDict({ 39 | 'blue': np.array([31, 119, 180]) / 255.0, 40 | 'orange': np.array([255, 127, 14]) / 255.0, 41 | 'green': np.array([44, 160, 44]) / 255.0, 42 | 'red': np.array([214, 39, 40]) / 255.0, 43 | 'purple': np.array([148, 103, 189]) / 255.0, 44 | 'brown': np.array([140, 86, 75]) / 255.0, 45 | 'pink': np.array([227, 119, 194]) / 255.0, 46 | 'gray': np.array([127, 127, 127]) / 255.0, 47 | 'yellow-green': np.array([188, 189, 34]) / 255.0, 48 | 'cyan': np.array([23, 190, 207]) / 255.0 49 | }) 50 | 51 | 52 | def get_color(color_name: str): 53 | """ Returns the RGB values of a given color name as a normalized numpy array. 54 | Args: 55 | color_name: The name of the color. Can be any color name from CSS4_COLORS. 56 | Returns: 57 | A numpy array representing the RGB values of the specified color, normalized to the range [0, 1]. 58 | """ 59 | if color_name == "custom_yellow": 60 | return np.asarray([255.0, 204.0, 102.0]) / 255.0 61 | if color_name == "custom_blue": 62 | return np.asarray([102.0, 153.0, 255.0]) / 255.0 63 | assert color_name in colors.CSS4_COLORS 64 | return np.asarray(colors.to_rgb(colors.CSS4_COLORS[color_name])) 65 | 66 | 67 | def plot_ptcloud(point_clouds: Union[List, o3d.geometry.PointCloud], show_frame: bool = True): 68 | """ Visualizes one or more point clouds, optionally showing the coordinate frame. 69 | Args: 70 | point_clouds: A single point cloud or a list of point clouds to be visualized. 71 | show_frame: If True, displays the coordinate frame in the visualization. Defaults to True. 72 | """ 73 | # rotate down up 74 | if not isinstance(point_clouds, list): 75 | point_clouds = [point_clouds] 76 | if show_frame: 77 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0]) 78 | point_clouds = point_clouds + [mesh_frame] 79 | o3d.visualization.draw_geometries(point_clouds) 80 | 81 | 82 | def draw_registration_result_original_color(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud, 83 | transformation: np.ndarray): 84 | """ Visualizes the result of a point cloud registration, keeping the original color of the source point cloud. 85 | Args: 86 | source: The source point cloud. 87 | target: The target point cloud. 88 | transformation: The transformation matrix applied to the source point cloud. 89 | """ 90 | source_temp = deepcopy(source) 91 | source_temp.transform(transformation) 92 | o3d.visualization.draw_geometries([source_temp, target]) 93 | 94 | 95 | def draw_registration_result(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud, 96 | transformation: np.ndarray, source_color: str = "blue", target_color: str = "orange"): 97 | """ Visualizes the result of a point cloud registration, coloring the source and target point clouds. 98 | Args: 99 | source: The source point cloud. 100 | target: The target point cloud. 101 | transformation: The transformation matrix applied to the source point cloud. 102 | source_color: The color to apply to the source point cloud. Defaults to "blue". 103 | target_color: The color to apply to the target point cloud. Defaults to "orange". 104 | """ 105 | source_temp = deepcopy(source) 106 | source_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[source_color]) 107 | 108 | target_temp = deepcopy(target) 109 | target_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[target_color]) 110 | 111 | source_temp.transform(transformation) 112 | o3d.visualization.draw_geometries([source_temp, target_temp]) 113 | --------------------------------------------------------------------------------