├── .gitignore
├── LICENSE
├── README.md
├── assets
├── gaussian_slam.gif
└── gaussian_slam.mp4
├── configs
├── Replica
│ ├── office0.yaml
│ ├── office1.yaml
│ ├── office2.yaml
│ ├── office3.yaml
│ ├── office4.yaml
│ ├── replica.yaml
│ ├── room0.yaml
│ ├── room1.yaml
│ └── room2.yaml
├── ScanNet
│ ├── scannet.yaml
│ ├── scene0000_00.yaml
│ ├── scene0059_00.yaml
│ ├── scene0106_00.yaml
│ ├── scene0169_00.yaml
│ ├── scene0181_00.yaml
│ └── scene0207_00.yaml
├── TUM_RGBD
│ ├── rgbd_dataset_freiburg1_desk.yaml
│ ├── rgbd_dataset_freiburg2_xyz.yaml
│ ├── rgbd_dataset_freiburg3_long_office_household.yaml
│ └── tum_rgbd.yaml
└── scannetpp
│ ├── 281bc17764.yaml
│ ├── 2e74812d00.yaml
│ ├── 8b5caf3398.yaml
│ ├── b20a261fdf.yaml
│ ├── fb05e13ad1.yaml
│ └── scannetpp.yaml
├── environment.yml
├── run_evaluation.py
├── run_slam.py
├── scripts
├── download_replica.sh
├── download_tum.sh
└── reproduce_sbatch.sh
└── src
├── entities
├── __init__.py
├── arguments.py
├── datasets.py
├── gaussian_model.py
├── gaussian_slam.py
├── logger.py
├── losses.py
├── mapper.py
├── tracker.py
└── visual_odometer.py
├── evaluation
├── __init__.py
├── evaluate_merged_map.py
├── evaluate_reconstruction.py
├── evaluate_trajectory.py
└── evaluator.py
└── utils
├── __init__.py
├── gaussian_model_utils.py
├── io_utils.py
├── mapper_utils.py
├── tracker_utils.py
├── utils.py
└── vis_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | output
4 | build
5 | diff_rasterization/diff_rast.egg-info
6 | diff_rasterization/dist
7 | tensorboard_3d
8 | screenshots
9 | debug
10 | wandb
11 | data
12 | *.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Gaussian-SLAM: Photo-realistic Dense SLAM with Gaussian Splatting
4 |
5 | Vladimir Yugay
6 | ·
7 | Yue Li*
8 | ·
9 | Theo Gevers
10 | ·
11 | Martin Oswald
12 |
13 | *Significant contribution
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## ⚙️ Setting Things Up
25 |
26 | Clone the repo:
27 |
28 | ```
29 | git clone https://github.com/VladimirYugay/Gaussian-SLAM
30 | ```
31 |
32 | Make sure that gcc and g++ paths on your system are exported:
33 |
34 | ```
35 | export CC=
36 | export CXX=
37 | ```
38 |
39 | To find the gcc path and g++ path on your machine you can use which gcc.
40 |
41 |
42 | Then setup environment from the provided conda environment file,
43 |
44 | ```
45 | conda env create -f environment.yml
46 | conda activate gslam
47 | ```
48 | We tested our code on RTX3090 and RTX A6000 GPUs respectively and Ubuntu22 and CentOS7.5.
49 |
50 | ## 🔨 Running Gaussian-SLAM
51 |
52 | Here we elaborate on how to load the necessary data, configure Gaussian-SLAM for your use-case, debug it, and how to reproduce the results mentioned in the paper.
53 |
54 |
55 | Downloading the Data
56 | We tested our code on Replica, TUM_RGBD, ScanNet, and ScanNet++ datasets. We also provide scripts for downloading Replica nad TUM_RGBD. Install git lfs before using the scripts by running "git lfs install".
57 | For downloading ScanNet, follow the procedure described on here.
58 | For downloading ScanNet++, follow the procedure described on here.
59 | The config files are named after the sequences that we used for our method.
60 |
61 |
62 |
63 | Running the code
64 | Start the system with the command:
65 |
66 | ```
67 | python run_slam.py configs// --input_path --output_path
68 | ```
69 | For example:
70 | ```
71 | python run_slam.py configs/Replica/room0.yaml --input_path /home/datasets/Replica/room0 --output_path output/Replica/room0
72 | ```
73 | You can also configure input and output paths in the config yaml file.
74 |
75 |
76 |
77 | Reproducing Results
78 | While we made all parts of our code deterministic, differential rasterizer of Gaussian Splatting is not. The metrics can be slightly different from run to run. In the paper we report average metrics that were computed over three seeds: 0, 1, and 2.
79 |
80 | You can reproduce the results for a single scene by running:
81 |
82 | ```
83 | python run_slam.py configs// --input_path --output_path
84 | ```
85 |
86 | If you are running on a SLURM cluster, you can reproduce the results for all scenes in a dataset by running the script:
87 | ```
88 | ./scripts/reproduce_sbatch.sh
89 | ```
90 | Please note the evaluation of ```depth_L1``` metric requires reconstruction of the mesh, which in turns requires headless installation of open3d if you are running on a cluster.
91 |
92 |
93 |
94 | Demo
95 | We used the camera path tool in gaussian-splatting-lightning repo to help make the fly-through video based on the reconstructed scenes. We thank its author for the great work.
96 |
97 |
98 | ## 📌 Citation
99 |
100 | If you find our paper and code useful, please cite us:
101 |
102 | ```bib
103 | @misc{yugay2023gaussianslam,
104 | title={Gaussian-SLAM: Photo-realistic Dense SLAM with Gaussian Splatting},
105 | author={Vladimir Yugay and Yue Li and Theo Gevers and Martin R. Oswald},
106 | year={2023},
107 | eprint={2312.10070},
108 | archivePrefix={arXiv},
109 | primaryClass={cs.CV}
110 | }
111 |
--------------------------------------------------------------------------------
/assets/gaussian_slam.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/assets/gaussian_slam.gif
--------------------------------------------------------------------------------
/assets/gaussian_slam.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/assets/gaussian_slam.mp4
--------------------------------------------------------------------------------
/configs/Replica/office0.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office0
4 | input_path: data/Replica-SLAM/Replica/office0/
5 | output_path: output/Replica/office0/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office1
4 | input_path: data/Replica-SLAM/Replica/office1/
5 | output_path: output/Replica/office1/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office2
4 | input_path: data/Replica-SLAM/Replica/office2/
5 | output_path: output/Replica/office2/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office3.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office3
4 | input_path: data/Replica-SLAM/Replica/office3/
5 | output_path: output/Replica/office3/
6 |
--------------------------------------------------------------------------------
/configs/Replica/office4.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: office4
4 | input_path: data/Replica-SLAM/Replica/office4/
5 | output_path: output/Replica/office4/
6 |
--------------------------------------------------------------------------------
/configs/Replica/replica.yaml:
--------------------------------------------------------------------------------
1 | project_name: "Gaussian_SLAM_replica"
2 | dataset_name: "replica"
3 | checkpoint_path: null
4 | use_wandb: False
5 | frame_limit: -1 # for debugging, set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 50
9 | map_every: 5
10 | iterations: 100
11 | new_submap_iterations: 1000
12 | new_submap_points_num: 600000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: -1
15 | new_points_radius: 0.0000001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.1
19 | submap_using_motion_heuristic: True
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.95
23 | iterations: 60
24 | cam_rot_lr: 0.0002
25 | cam_trans_lr: 0.002
26 | odometry_type: "odometer" # gt, const_speed, odometer
27 | help_camera_initialization: False # temp option to help const_init
28 | init_err_ratio: 5
29 | odometer_method: "point_to_plane" # hybrid or point_to_plane
30 | filter_alpha: False
31 | filter_outlier_depth: True
32 | alpha_thre: 0.98
33 | soft_alpha: True
34 | mask_invalid_depth: False
35 | cam:
36 | H: 680
37 | W: 1200
38 | fx: 600.0
39 | fy: 600.0
40 | cx: 599.5
41 | cy: 339.5
42 | depth_scale: 6553.5
43 |
--------------------------------------------------------------------------------
/configs/Replica/room0.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: room0
4 | input_path: data/Replica-SLAM/room0/
5 | output_path: output/Replica/room0/
6 |
--------------------------------------------------------------------------------
/configs/Replica/room1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: room1
4 | input_path: data/Replica-SLAM/Replica/room1/
5 | output_path: output/Replica/room1/
6 |
--------------------------------------------------------------------------------
/configs/Replica/room2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | data:
3 | scene_name: room2
4 | input_path: data/Replica-SLAM/Replica/room2/
5 | output_path: output/Replica/room2/
6 |
--------------------------------------------------------------------------------
/configs/ScanNet/scannet.yaml:
--------------------------------------------------------------------------------
1 | project_name: "Gaussian_SLAM_scannet"
2 | dataset_name: "scan_net"
3 | checkpoint_path: null
4 | use_wandb: False
5 | frame_limit: -1 # for debugging, set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 50
9 | map_every: 1
10 | iterations: 100
11 | new_submap_iterations: 100
12 | new_submap_points_num: 100000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: 30000
15 | new_points_radius: 0.0001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.5
19 | submap_using_motion_heuristic: False
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.6
23 | iterations: 200
24 | cam_rot_lr: 0.002
25 | cam_trans_lr: 0.01
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: False # temp option to help const_init
28 | init_err_ratio: 5
29 | odometer_method: "hybrid" # hybrid or point_to_plane
30 | filter_alpha: True
31 | filter_outlier_depth: True
32 | alpha_thre: 0.98
33 | soft_alpha: True
34 | mask_invalid_depth: True
35 | cam:
36 | H: 480
37 | W: 640
38 | fx: 577.590698
39 | fy: 578.729797
40 | cx: 318.905426
41 | cy: 242.683609
42 | depth_scale: 1000.
43 | crop_edge: 12
--------------------------------------------------------------------------------
/configs/ScanNet/scene0000_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: data/scannet/scans/scene0000_00
4 | output_path: output/ScanNet/scene0000
5 | scene_name: scene0000_00
--------------------------------------------------------------------------------
/configs/ScanNet/scene0059_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: data/scannet/scans/scene0059_00
4 | output_path: output/ScanNet/scene0059
5 | scene_name: scene0059_00
--------------------------------------------------------------------------------
/configs/ScanNet/scene0106_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: data/scannet/scans/scene0106_00
4 | output_path: output/ScanNet/scene0106
5 | scene_name: scene0106_00
--------------------------------------------------------------------------------
/configs/ScanNet/scene0169_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: data/scannet/scans/scene0169_00
4 | output_path: output/ScanNet/scene0169
5 | scene_name: scene0169_00
6 | cam:
7 | fx: 574.540771
8 | fy: 577.583740
9 | cx: 322.522827
10 | cy: 238.558853
--------------------------------------------------------------------------------
/configs/ScanNet/scene0181_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: data/scannet/scans/scene0181_00
4 | output_path: output/ScanNet/scene0181
5 | scene_name: scene0181_00
6 | cam:
7 | fx: 575.547668
8 | fy: 577.459778
9 | cx: 323.171967
10 | cy: 236.417465
11 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0207_00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | data:
3 | input_path: data/scannet/scans/scene0207_00
4 | output_path: output/ScanNet/scene0207
5 | scene_name: scene0207_00
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg1_desk.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg1_desk
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg1_desk/
5 | scene_name: rgbd_dataset_freiburg1_desk
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 517.3
10 | fy: 516.5
11 | cx: 318.6
12 | cy: 255.3
13 | crop_edge: 50
14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg2_xyz.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg2_xyz
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg2_xyz/
5 | scene_name: rgbd_dataset_freiburg2_xyz
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 520.9
10 | fy: 521.0
11 | cx: 325.1
12 | cy: 249.7
13 | crop_edge: 8
14 | distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/TUM_RGBD/tum_rgbd.yaml
2 | data:
3 | input_path: data/TUM_RGBD-SLAM/rgbd_dataset_freiburg3_long_office_household/
4 | output_path: output/TUM_RGBD/rgbd_dataset_freiburg3_long_office_household/
5 | scene_name: rgbd_dataset_freiburg3_long_office_household
6 | cam: #intrinsic is different per scene in TUM
7 | H: 480
8 | W: 640
9 | fx: 517.3
10 | fy: 516.5
11 | cx: 318.6
12 | cy: 255.3
13 | crop_edge: 50
14 | distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633]
--------------------------------------------------------------------------------
/configs/TUM_RGBD/tum_rgbd.yaml:
--------------------------------------------------------------------------------
1 | project_name: "Gaussian_SLAM_tumrgbd"
2 | dataset_name: "tum_rgbd"
3 | checkpoint_path: null
4 | use_wandb: False
5 | frame_limit: -1 # for debugging, set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 50
9 | map_every: 1
10 | iterations: 100
11 | new_submap_iterations: 100
12 | new_submap_points_num: 100000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: 30000
15 | new_points_radius: 0.0001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.5
19 | submap_using_motion_heuristic: True
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.6
23 | iterations: 200
24 | cam_rot_lr: 0.002
25 | cam_trans_lr: 0.01
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: False # temp option to help const_init
28 | init_err_ratio: 5
29 | odometer_method: "hybrid" # hybrid or point_to_plane
30 | filter_alpha: False
31 | filter_outlier_depth: False
32 | alpha_thre: 0.98
33 | soft_alpha: True
34 | mask_invalid_depth: True
35 | cam:
36 | crop_edge: 16
37 | depth_scale: 5000.0
--------------------------------------------------------------------------------
/configs/scannetpp/281bc17764.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/281bc17764
4 | output_path: output/ScanNetPP/281bc17764
5 | scene_name: "281bc17764"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 312.79197434640764
12 | fy: 313.48022477591036
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/2e74812d00.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/2e74812d00
4 | output_path: output/ScanNetPP/2e74812d00
5 | scene_name: "2e74812d00"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 312.0984049779051
12 | fy: 312.4823067146056
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/8b5caf3398.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/8b5caf3398
4 | output_path: output/ScanNetPP/8b5caf3398
5 | scene_name: "8b5caf3398"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 316.3837659917395
12 | fy: 319.18649362678593
13 | cx: 438.0
14 | cy: 292.0
15 |
--------------------------------------------------------------------------------
/configs/scannetpp/b20a261fdf.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/b20a261fdf
4 | output_path: output/ScanNetPP/b20a261fdf
5 | scene_name: "b20a261fdf"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 312.7099188244687
12 | fy: 313.5121746848229
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/fb05e13ad1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/scannetpp/scannetpp.yaml
2 | data:
3 | input_path: data/scannetpp/data/fb05e13ad1
4 | output_path: output/ScanNetPP/fb05e13ad1
5 | scene_name: "fb05e13ad1"
6 | use_train_split: True
7 | frame_limit: 250
8 | cam:
9 | H: 584
10 | W: 876
11 | fx: 231.8197441948914
12 | fy: 231.9980523882361
13 | cx: 438.0
14 | cy: 292.0
--------------------------------------------------------------------------------
/configs/scannetpp/scannetpp.yaml:
--------------------------------------------------------------------------------
1 | project_name: "Gaussian_SLAM_scannetpp"
2 | dataset_name: "scannetpp"
3 | checkpoint_path: null
4 | use_wandb: False
5 | frame_limit: -1 # set to -1 to disable
6 | seed: 0
7 | mapping:
8 | new_submap_every: 100
9 | map_every: 2
10 | iterations: 500
11 | new_submap_iterations: 500
12 | new_submap_points_num: 400000
13 | new_submap_gradient_points_num: 50000
14 | new_frame_sample_size: 100000
15 | new_points_radius: 0.00000001
16 | current_view_opt_iterations: 0.4 # What portion of iterations to spend on the current view
17 | alpha_thre: 0.6
18 | pruning_thre: 0.5
19 | submap_using_motion_heuristic: False
20 | tracking:
21 | gt_camera: False
22 | w_color_loss: 0.5
23 | iterations: 300
24 | cam_rot_lr: 0.002
25 | cam_trans_lr: 0.01
26 | odometry_type: "const_speed" # gt, const_speed, odometer
27 | help_camera_initialization: True
28 | init_err_ratio: 50
29 | odometer_method: "point_to_plane" # hybrid or point_to_plane
30 | filter_alpha: True
31 | filter_outlier_depth: True
32 | alpha_thre: 0.98
33 | soft_alpha: False
34 | mask_invalid_depth: True
35 | cam:
36 | crop_edge: 0
37 | depth_scale: 1000.0
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gslam
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - nvidia/label/cuda-12.1.0
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - python=3.10
10 | - faiss-gpu=1.8.0
11 | - cuda-toolkit=12.1
12 | - pytorch=2.1.2
13 | - pytorch-cuda=12.1
14 | - torchvision=0.16.2
15 | - pip
16 | - pip:
17 | - open3d==0.18.0
18 | - wandb
19 | - trimesh
20 | - pytorch_msssim
21 | - torchmetrics
22 | - tqdm
23 | - imageio
24 | - opencv-python
25 | - plyfile
26 | - git+https://github.com/eriksandstroem/evaluate_3d_reconstruction_lib.git@9b3cc08be5440db9c375cc21e3bd65bb4a337db7
27 | - git+https://github.com/VladimirYugay/simple-knn.git@c7e51a06a4cd84c25e769fee29ab391fe5d5ff8d
28 | - git+https://github.com/VladimirYugay/gaussian_rasterizer.git@9c40173fcc8d9b16778a1a8040295bc2f9ebf129
--------------------------------------------------------------------------------
/run_evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from src.evaluation.evaluator import Evaluator
4 |
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser(description='Arguments to compute the mesh')
8 | parser.add_argument('--checkpoint_path', type=str, help='SLAM checkpoint path', default="output/slam/full_experiment/")
9 | parser.add_argument('--config_path', type=str, help='Config path', default="")
10 | return parser.parse_args()
11 |
12 |
13 | if __name__ == "__main__":
14 | args = get_args()
15 | if args.config_path == "":
16 | args.config_path = Path(args.checkpoint_path) / "config.yaml"
17 |
18 | evaluator = Evaluator(Path(args.checkpoint_path), Path(args.config_path))
19 | evaluator.run()
20 |
--------------------------------------------------------------------------------
/run_slam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import uuid
5 |
6 | import wandb
7 |
8 | from src.entities.gaussian_slam import GaussianSLAM
9 | from src.evaluation.evaluator import Evaluator
10 | from src.utils.io_utils import load_config, log_metrics_to_wandb
11 | from src.utils.utils import setup_seed
12 |
13 |
14 | def get_args():
15 | parser = argparse.ArgumentParser(
16 | description='Arguments to compute the mesh')
17 | parser.add_argument('config_path', type=str,
18 | help='Path to the configuration yaml file')
19 | parser.add_argument('--input_path', default="")
20 | parser.add_argument('--output_path', default="")
21 | parser.add_argument('--track_w_color_loss', type=float)
22 | parser.add_argument('--track_alpha_thre', type=float)
23 | parser.add_argument('--track_iters', type=int)
24 | parser.add_argument('--track_filter_alpha', action='store_true')
25 | parser.add_argument('--track_filter_outlier', action='store_true')
26 | parser.add_argument('--track_wo_filter_alpha', action='store_true')
27 | parser.add_argument("--track_wo_filter_outlier", action="store_true")
28 | parser.add_argument("--track_cam_trans_lr", type=float)
29 | parser.add_argument('--alpha_seeding_thre', type=float)
30 | parser.add_argument('--map_every', type=int)
31 | parser.add_argument("--map_iters", type=int)
32 | parser.add_argument('--new_submap_every', type=int)
33 | parser.add_argument('--project_name', type=str)
34 | parser.add_argument('--group_name', type=str)
35 | parser.add_argument('--gt_camera', action='store_true')
36 | parser.add_argument('--help_camera_initialization', action='store_true')
37 | parser.add_argument('--soft_alpha', action='store_true')
38 | parser.add_argument('--seed', type=int)
39 | parser.add_argument('--submap_using_motion_heuristic', action='store_true')
40 | parser.add_argument('--new_submap_points_num', type=int)
41 | return parser.parse_args()
42 |
43 |
44 | def update_config_with_args(config, args):
45 | if args.input_path:
46 | config["data"]["input_path"] = args.input_path
47 | if args.output_path:
48 | config["data"]["output_path"] = args.output_path
49 | if args.track_w_color_loss is not None:
50 | config["tracking"]["w_color_loss"] = args.track_w_color_loss
51 | if args.track_iters is not None:
52 | config["tracking"]["iterations"] = args.track_iters
53 | if args.track_filter_alpha:
54 | config["tracking"]["filter_alpha"] = True
55 | if args.track_wo_filter_alpha:
56 | config["tracking"]["filter_alpha"] = False
57 | if args.track_filter_outlier:
58 | config["tracking"]["filter_outlier_depth"] = True
59 | if args.track_wo_filter_outlier:
60 | config["tracking"]["filter_outlier_depth"] = False
61 | if args.track_alpha_thre is not None:
62 | config["tracking"]["alpha_thre"] = args.track_alpha_thre
63 | if args.map_every:
64 | config["mapping"]["map_every"] = args.map_every
65 | if args.map_iters:
66 | config["mapping"]["iterations"] = args.map_iters
67 | if args.new_submap_every:
68 | config["mapping"]["new_submap_every"] = args.new_submap_every
69 | if args.project_name:
70 | config["project_name"] = args.project_name
71 | if args.alpha_seeding_thre is not None:
72 | config["mapping"]["alpha_thre"] = args.alpha_seeding_thre
73 | if args.seed:
74 | config["seed"] = args.seed
75 | if args.help_camera_initialization:
76 | config["tracking"]["help_camera_initialization"] = True
77 | if args.soft_alpha:
78 | config["tracking"]["soft_alpha"] = True
79 | if args.submap_using_motion_heuristic:
80 | config["mapping"]["submap_using_motion_heuristic"] = True
81 | if args.new_submap_points_num:
82 | config["mapping"]["new_submap_points_num"] = args.new_submap_points_num
83 | if args.track_cam_trans_lr:
84 | config["tracking"]["cam_trans_lr"] = args.track_cam_trans_lr
85 | return config
86 |
87 |
88 | if __name__ == "__main__":
89 | args = get_args()
90 | config = load_config(args.config_path)
91 | config = update_config_with_args(config, args)
92 |
93 | if os.getenv('DISABLE_WANDB') == 'true':
94 | config["use_wandb"] = False
95 | if config["use_wandb"]:
96 | wandb.init(
97 | project=config["project_name"],
98 | config=config,
99 | dir="/home/yli3/scratch/outputs/slam/wandb",
100 | group=config["data"]["scene_name"]
101 | if not args.group_name
102 | else args.group_name,
103 | name=f'{config["data"]["scene_name"]}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}_{str(uuid.uuid4())[:5]}',
104 | )
105 | wandb.run.log_code(".", include_fn=lambda path: path.endswith(".py"))
106 |
107 | setup_seed(config["seed"])
108 | gslam = GaussianSLAM(config)
109 | gslam.run()
110 |
111 | evaluator = Evaluator(gslam.output_path, gslam.output_path / "config.yaml")
112 | evaluator.run()
113 | if config["use_wandb"]:
114 | evals = ["rendering_metrics.json",
115 | "reconstruction_metrics.json", "ate_aligned.json"]
116 | log_metrics_to_wandb(evals, gslam.output_path, "Evaluation")
117 | wandb.finish()
118 | print("All done.✨")
119 |
--------------------------------------------------------------------------------
/scripts/download_replica.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data
2 | cd data
3 | git clone https://huggingface.co/datasets/voviktyl/Replica-SLAM
--------------------------------------------------------------------------------
/scripts/download_tum.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data
2 | cd data
3 | git clone https://huggingface.co/datasets/voviktyl/TUM_RGBD-SLAM
--------------------------------------------------------------------------------
/scripts/reproduce_sbatch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=output/logs/%A_%a.log # please change accordingly
3 | #SBATCH --error=output/logs/%A_%a.log # please change accordingly
4 | #SBATCH -N 1
5 | #SBATCH -n 1
6 | #SBATCH --gpus-per-node=1
7 | #SBATCH --partition=gpu
8 | #SBATCH --cpus-per-task=12
9 | #SBATCH --time=24:00:00
10 | #SBATCH --array=0-4 # number of scenes, 0-7 for Replica, 0-2 for TUM_RGBD, 0-5 for ScanNet, 0-4 for ScanNet++
11 |
12 | dataset="Replica" # set dataset
13 | if [ "$dataset" == "Replica" ]; then
14 | scenes=("room0" "room1" "room2" "office0" "office1" "office2" "office3" "office4")
15 | INPUT_PATH="data/Replica-SLAM"
16 | elif [ "$dataset" == "TUM_RGBD" ]; then
17 | scenes=("rgbd_dataset_freiburg1_desk" "rgbd_dataset_freiburg2_xyz" "rgbd_dataset_freiburg3_long_office_household")
18 | INPUT_PATH="data/TUM_RGBD-SLAM"
19 | elif [ "$dataset" == "ScanNet" ]; then
20 | scenes=("scene0000_00" "scene0059_00" "scene0106_00" "scene0169_00" "scene0181_00" "scene0207_00")
21 | INPUT_PATH="data/scannet/scans"
22 | elif [ "$dataset" == "ScanNetPP" ]; then
23 | scenes=("b20a261fdf" "8b5caf3398" "fb05e13ad1" "2e74812d00" "281bc17764")
24 | INPUT_PATH="data/scannetpp/data"
25 | else
26 | echo "Dataset not recognized!"
27 | exit 1
28 | fi
29 |
30 | OUTPUT_PATH="output"
31 | CONFIG_PATH="configs/${dataset}"
32 | EXPERIMENT_NAME="reproduce"
33 | SCENE_NAME=${scenes[$SLURM_ARRAY_TASK_ID]}
34 |
35 | source # please change accordingly
36 | conda activate gslam
37 |
38 | echo "Job for dataset: $dataset, scene: $SCENE_NAME"
39 | echo "Starting on: $(date)"
40 | echo "Running on node: $(hostname)"
41 |
42 | # Your command to run the experiment
43 | python run_slam.py "${CONFIG_PATH}/${SCENE_NAME}.yaml" \
44 | --input_path "${INPUT_PATH}/${SCENE_NAME}" \
45 | --output_path "${OUTPUT_PATH}/${dataset}/${EXPERIMENT_NAME}/${SCENE_NAME}" \
46 | --group_name "${EXPERIMENT_NAME}" \
47 |
48 | echo "Job for scene $SCENE_NAME completed."
49 | echo "Started at: $START_TIME"
50 | echo "Finished at: $(date)"
--------------------------------------------------------------------------------
/src/entities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/src/entities/__init__.py
--------------------------------------------------------------------------------
/src/entities/arguments.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from argparse import ArgumentParser, Namespace
15 |
16 |
17 | class GroupParams:
18 | pass
19 |
20 |
21 | class ParamGroup:
22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
23 | group = parser.add_argument_group(name)
24 | for key, value in vars(self).items():
25 | shorthand = False
26 | if key.startswith("_"):
27 | shorthand = True
28 | key = key[1:]
29 | t = type(value)
30 | value = value if not fill_none else None
31 | if shorthand:
32 | if t == bool:
33 | group.add_argument(
34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true")
35 | else:
36 | group.add_argument(
37 | "--" + key, ("-" + key[0:1]), default=value, type=t)
38 | else:
39 | if t == bool:
40 | group.add_argument(
41 | "--" + key, default=value, action="store_true")
42 | else:
43 | group.add_argument("--" + key, default=value, type=t)
44 |
45 | def extract(self, args):
46 | group = GroupParams()
47 | for arg in vars(args).items():
48 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
49 | setattr(group, arg[0], arg[1])
50 | return group
51 |
52 |
53 | class OptimizationParams(ParamGroup):
54 | def __init__(self, parser):
55 | self.iterations = 30_000
56 | self.position_lr_init = 0.0001
57 | self.position_lr_final = 0.0000016
58 | self.position_lr_delay_mult = 0.01
59 | self.position_lr_max_steps = 30_000
60 | self.feature_lr = 0.0025
61 | self.opacity_lr = 0.05
62 | self.scaling_lr = 0.005 # before 0.005
63 | self.rotation_lr = 0.001
64 | self.percent_dense = 0.01
65 | self.lambda_dssim = 0.2
66 | self.densification_interval = 100
67 | self.opacity_reset_interval = 3000
68 | self.densify_from_iter = 500
69 | self.densify_until_iter = 15_000
70 | self.densify_grad_threshold = 0.0002
71 | super().__init__(parser, "Optimization Parameters")
72 |
73 |
74 | def get_combined_args(parser: ArgumentParser):
75 | cmdlne_string = sys.argv[1:]
76 | cfgfile_string = "Namespace()"
77 | args_cmdline = parser.parse_args(cmdlne_string)
78 |
79 | try:
80 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
81 | print("Looking for config file in", cfgfilepath)
82 | with open(cfgfilepath) as cfg_file:
83 | print("Config file found: {}".format(cfgfilepath))
84 | cfgfile_string = cfg_file.read()
85 | except TypeError:
86 | print("Config file not found at")
87 | pass
88 | args_cfgfile = eval(cfgfile_string)
89 |
90 | merged_dict = vars(args_cfgfile).copy()
91 | for k, v in vars(args_cmdline).items():
92 | if v is not None:
93 | merged_dict[k] = v
94 | return Namespace(**merged_dict)
95 |
--------------------------------------------------------------------------------
/src/entities/datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from pathlib import Path
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import json
9 | import imageio
10 |
11 |
12 | class BaseDataset(torch.utils.data.Dataset):
13 |
14 | def __init__(self, dataset_config: dict):
15 | self.dataset_path = Path(dataset_config["input_path"])
16 | self.frame_limit = dataset_config.get("frame_limit", -1)
17 | self.dataset_config = dataset_config
18 | self.height = dataset_config["H"]
19 | self.width = dataset_config["W"]
20 | self.fx = dataset_config["fx"]
21 | self.fy = dataset_config["fy"]
22 | self.cx = dataset_config["cx"]
23 | self.cy = dataset_config["cy"]
24 |
25 | self.depth_scale = dataset_config["depth_scale"]
26 | self.distortion = np.array(
27 | dataset_config['distortion']) if 'distortion' in dataset_config else None
28 | self.crop_edge = dataset_config['crop_edge'] if 'crop_edge' in dataset_config else 0
29 | if self.crop_edge:
30 | self.height -= 2 * self.crop_edge
31 | self.width -= 2 * self.crop_edge
32 | self.cx -= self.crop_edge
33 | self.cy -= self.crop_edge
34 |
35 | self.fovx = 2 * math.atan(self.width / (2 * self.fx))
36 | self.fovy = 2 * math.atan(self.height / (2 * self.fy))
37 | self.intrinsics = np.array(
38 | [[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1]])
39 |
40 | self.color_paths = []
41 | self.depth_paths = []
42 |
43 | def __len__(self):
44 | return len(self.color_paths) if self.frame_limit < 0 else int(self.frame_limit)
45 |
46 |
47 | class Replica(BaseDataset):
48 |
49 | def __init__(self, dataset_config: dict):
50 | super().__init__(dataset_config)
51 | self.color_paths = sorted(
52 | list((self.dataset_path / "results").glob("frame*.jpg")))
53 | self.depth_paths = sorted(
54 | list((self.dataset_path / "results").glob("depth*.png")))
55 | self.load_poses(self.dataset_path / "traj.txt")
56 | print(f"Loaded {len(self.color_paths)} frames")
57 |
58 | def load_poses(self, path):
59 | self.poses = []
60 | with open(path, "r") as f:
61 | lines = f.readlines()
62 | for line in lines:
63 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
64 | self.poses.append(c2w.astype(np.float32))
65 |
66 | def __getitem__(self, index):
67 | color_data = cv2.imread(str(self.color_paths[index]))
68 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
69 | depth_data = cv2.imread(
70 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
71 | depth_data = depth_data.astype(np.float32) / self.depth_scale
72 | return index, color_data, depth_data, self.poses[index]
73 |
74 |
75 | class TUM_RGBD(BaseDataset):
76 | def __init__(self, dataset_config: dict):
77 | super().__init__(dataset_config)
78 | self.color_paths, self.depth_paths, self.poses = self.loadtum(
79 | self.dataset_path, frame_rate=32)
80 |
81 | def parse_list(self, filepath, skiprows=0):
82 | """ read list data """
83 | return np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
84 |
85 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08):
86 | """ pair images, depths, and poses """
87 | associations = []
88 | for i, t in enumerate(tstamp_image):
89 | if tstamp_pose is None:
90 | j = np.argmin(np.abs(tstamp_depth - t))
91 | if (np.abs(tstamp_depth[j] - t) < max_dt):
92 | associations.append((i, j))
93 | else:
94 | j = np.argmin(np.abs(tstamp_depth - t))
95 | k = np.argmin(np.abs(tstamp_pose - t))
96 | if (np.abs(tstamp_depth[j] - t) < max_dt) and (np.abs(tstamp_pose[k] - t) < max_dt):
97 | associations.append((i, j, k))
98 | return associations
99 |
100 | def loadtum(self, datapath, frame_rate=-1):
101 | """ read video data in tum-rgbd format """
102 | if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')):
103 | pose_list = os.path.join(datapath, 'groundtruth.txt')
104 | elif os.path.isfile(os.path.join(datapath, 'pose.txt')):
105 | pose_list = os.path.join(datapath, 'pose.txt')
106 |
107 | image_list = os.path.join(datapath, 'rgb.txt')
108 | depth_list = os.path.join(datapath, 'depth.txt')
109 |
110 | image_data = self.parse_list(image_list)
111 | depth_data = self.parse_list(depth_list)
112 | pose_data = self.parse_list(pose_list, skiprows=1)
113 | pose_vecs = pose_data[:, 1:].astype(np.float64)
114 |
115 | tstamp_image = image_data[:, 0].astype(np.float64)
116 | tstamp_depth = depth_data[:, 0].astype(np.float64)
117 | tstamp_pose = pose_data[:, 0].astype(np.float64)
118 | associations = self.associate_frames(
119 | tstamp_image, tstamp_depth, tstamp_pose)
120 |
121 | indicies = [0]
122 | for i in range(1, len(associations)):
123 | t0 = tstamp_image[associations[indicies[-1]][0]]
124 | t1 = tstamp_image[associations[i][0]]
125 | if t1 - t0 > 1.0 / frame_rate:
126 | indicies += [i]
127 |
128 | images, poses, depths = [], [], []
129 | inv_pose = None
130 | for ix in indicies:
131 | (i, j, k) = associations[ix]
132 | images += [os.path.join(datapath, image_data[i, 1])]
133 | depths += [os.path.join(datapath, depth_data[j, 1])]
134 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k])
135 | if inv_pose is None:
136 | inv_pose = np.linalg.inv(c2w)
137 | c2w = np.eye(4)
138 | else:
139 | c2w = inv_pose@c2w
140 | poses += [c2w.astype(np.float32)]
141 |
142 | return images, depths, poses
143 |
144 | def pose_matrix_from_quaternion(self, pvec):
145 | """ convert 4x4 pose matrix to (t, q) """
146 | from scipy.spatial.transform import Rotation
147 |
148 | pose = np.eye(4)
149 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix()
150 | pose[:3, 3] = pvec[:3]
151 | return pose
152 |
153 | def __getitem__(self, index):
154 | color_data = cv2.imread(str(self.color_paths[index]))
155 | if self.distortion is not None:
156 | color_data = cv2.undistort(
157 | color_data, self.intrinsics, self.distortion)
158 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
159 |
160 | depth_data = cv2.imread(
161 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
162 | depth_data = depth_data.astype(np.float32) / self.depth_scale
163 | edge = self.crop_edge
164 | if edge > 0:
165 | color_data = color_data[edge:-edge, edge:-edge]
166 | depth_data = depth_data[edge:-edge, edge:-edge]
167 | # Interpolate depth values for splatting
168 | return index, color_data, depth_data, self.poses[index]
169 |
170 |
171 | class ScanNet(BaseDataset):
172 | def __init__(self, dataset_config: dict):
173 | super().__init__(dataset_config)
174 | self.color_paths = sorted(list(
175 | (self.dataset_path / "color").glob("*.jpg")), key=lambda x: int(os.path.basename(x)[:-4]))
176 | self.depth_paths = sorted(list(
177 | (self.dataset_path / "depth").glob("*.png")), key=lambda x: int(os.path.basename(x)[:-4]))
178 | self.load_poses(self.dataset_path / "pose")
179 |
180 | def load_poses(self, path):
181 | self.poses = []
182 | pose_paths = sorted(path.glob('*.txt'),
183 | key=lambda x: int(os.path.basename(x)[:-4]))
184 | for pose_path in pose_paths:
185 | with open(pose_path, "r") as f:
186 | lines = f.readlines()
187 | ls = []
188 | for line in lines:
189 | ls.append(list(map(float, line.split(' '))))
190 | c2w = np.array(ls).reshape(4, 4).astype(np.float32)
191 | self.poses.append(c2w)
192 |
193 | def __getitem__(self, index):
194 | color_data = cv2.imread(str(self.color_paths[index]))
195 | if self.distortion is not None:
196 | color_data = cv2.undistort(
197 | color_data, self.intrinsics, self.distortion)
198 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
199 | color_data = cv2.resize(color_data, (self.dataset_config["W"], self.dataset_config["H"]))
200 |
201 | depth_data = cv2.imread(
202 | str(self.depth_paths[index]), cv2.IMREAD_UNCHANGED)
203 | depth_data = depth_data.astype(np.float32) / self.depth_scale
204 | edge = self.crop_edge
205 | if edge > 0:
206 | color_data = color_data[edge:-edge, edge:-edge]
207 | depth_data = depth_data[edge:-edge, edge:-edge]
208 | # Interpolate depth values for splatting
209 | return index, color_data, depth_data, self.poses[index]
210 |
211 |
212 | class ScanNetPP(BaseDataset):
213 | def __init__(self, dataset_config: dict):
214 | super().__init__(dataset_config)
215 | self.use_train_split = dataset_config["use_train_split"]
216 | self.train_test_split = json.load(open(f"{self.dataset_path}/dslr/train_test_lists.json", "r"))
217 | if self.use_train_split:
218 | self.image_names = self.train_test_split["train"]
219 | else:
220 | self.image_names = self.train_test_split["test"]
221 | self.load_data()
222 |
223 | def load_data(self):
224 | self.poses = []
225 | cams_path = self.dataset_path / "dslr" / "nerfstudio" / "transforms_undistorted.json"
226 | cams_metadata = json.load(open(str(cams_path), "r"))
227 | frames_key = "frames" if self.use_train_split else "test_frames"
228 | frames_metadata = cams_metadata[frames_key]
229 | frame2idx = {frame["file_path"]: index for index, frame in enumerate(frames_metadata)}
230 | P = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).astype(np.float32)
231 | for image_name in self.image_names:
232 | frame_metadata = frames_metadata[frame2idx[image_name]]
233 | # if self.ignore_bad and frame_metadata['is_bad']:
234 | # continue
235 | color_path = str(self.dataset_path / "dslr" / "undistorted_images" / image_name)
236 | depth_path = str(self.dataset_path / "dslr" / "undistorted_depths" / image_name.replace('.JPG', '.png'))
237 | self.color_paths.append(color_path)
238 | self.depth_paths.append(depth_path)
239 | c2w = np.array(frame_metadata["transform_matrix"]).astype(np.float32)
240 | c2w = P @ c2w @ P.T
241 | self.poses.append(c2w)
242 |
243 | def __len__(self):
244 | if self.use_train_split:
245 | return len(self.image_names) if self.frame_limit < 0 else int(self.frame_limit)
246 | else:
247 | return len(self.image_names)
248 |
249 | def __getitem__(self, index):
250 |
251 | color_data = np.asarray(imageio.imread(self.color_paths[index]), dtype=float)
252 | color_data = cv2.resize(color_data, (self.width, self.height), interpolation=cv2.INTER_LINEAR)
253 | color_data = color_data.astype(np.uint8)
254 |
255 | depth_data = np.asarray(imageio.imread(self.depth_paths[index]), dtype=np.int64)
256 | depth_data = cv2.resize(depth_data.astype(float), (self.width, self.height), interpolation=cv2.INTER_NEAREST)
257 | depth_data = depth_data.astype(np.float32) / self.depth_scale
258 | return index, color_data, depth_data, self.poses[index]
259 |
260 |
261 | def get_dataset(dataset_name: str):
262 | if dataset_name == "replica":
263 | return Replica
264 | elif dataset_name == "tum_rgbd":
265 | return TUM_RGBD
266 | elif dataset_name == "scan_net":
267 | return ScanNet
268 | elif dataset_name == "scannetpp":
269 | return ScanNetPP
270 | raise NotImplementedError(f"Dataset {dataset_name} not implemented")
271 |
--------------------------------------------------------------------------------
/src/entities/gaussian_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | from pathlib import Path
12 |
13 | import numpy as np
14 | import open3d as o3d
15 | import torch
16 | from plyfile import PlyData, PlyElement
17 | from simple_knn._C import distCUDA2
18 | from torch import nn
19 |
20 | from src.utils.gaussian_model_utils import (RGB2SH, build_scaling_rotation,
21 | get_expon_lr_func, inverse_sigmoid,
22 | strip_symmetric)
23 |
24 |
25 | class GaussianModel:
26 | def __init__(self, sh_degree: int = 3, isotropic=False):
27 | self.gaussian_param_names = [
28 | "active_sh_degree",
29 | "xyz",
30 | "features_dc",
31 | "features_rest",
32 | "scaling",
33 | "rotation",
34 | "opacity",
35 | "max_radii2D",
36 | "xyz_gradient_accum",
37 | "denom",
38 | "spatial_lr_scale",
39 | "optimizer",
40 | ]
41 | self.max_sh_degree = sh_degree
42 | self.active_sh_degree = sh_degree # temp
43 | self._xyz = torch.empty(0).cuda()
44 | self._features_dc = torch.empty(0).cuda()
45 | self._features_rest = torch.empty(0).cuda()
46 | self._scaling = torch.empty(0).cuda()
47 | self._rotation = torch.empty(0, 4).cuda()
48 | self._opacity = torch.empty(0).cuda()
49 | self.max_radii2D = torch.empty(0)
50 | self.xyz_gradient_accum = torch.empty(0)
51 | self.denom = torch.empty(0)
52 | self.optimizer = None
53 | self.percent_dense = 0
54 | self.spatial_lr_scale = 1
55 | self.setup_functions()
56 | self.isotropic = isotropic
57 |
58 | def restore_from_params(self, params_dict, training_args):
59 | self.training_setup(training_args)
60 | self.densification_postfix(
61 | params_dict["xyz"],
62 | params_dict["features_dc"],
63 | params_dict["features_rest"],
64 | params_dict["opacity"],
65 | params_dict["scaling"],
66 | params_dict["rotation"])
67 |
68 | def build_covariance_from_scaling_rotation(self, scaling, scaling_modifier, rotation):
69 | L = build_scaling_rotation(scaling_modifier * scaling, rotation)
70 | actual_covariance = L @ L.transpose(1, 2)
71 | symm = strip_symmetric(actual_covariance)
72 | return symm
73 |
74 | def setup_functions(self):
75 | self.scaling_activation = torch.exp
76 | self.scaling_inverse_activation = torch.log
77 | self.opacity_activation = torch.sigmoid
78 | self.inverse_opacity_activation = inverse_sigmoid
79 | self.rotation_activation = torch.nn.functional.normalize
80 |
81 | def capture_dict(self):
82 | return {
83 | "active_sh_degree": self.active_sh_degree,
84 | "xyz": self._xyz.clone().detach().cpu(),
85 | "features_dc": self._features_dc.clone().detach().cpu(),
86 | "features_rest": self._features_rest.clone().detach().cpu(),
87 | "scaling": self._scaling.clone().detach().cpu(),
88 | "rotation": self._rotation.clone().detach().cpu(),
89 | "opacity": self._opacity.clone().detach().cpu(),
90 | "max_radii2D": self.max_radii2D.clone().detach().cpu(),
91 | "xyz_gradient_accum": self.xyz_gradient_accum.clone().detach().cpu(),
92 | "denom": self.denom.clone().detach().cpu(),
93 | "spatial_lr_scale": self.spatial_lr_scale,
94 | "optimizer": self.optimizer.state_dict(),
95 | }
96 |
97 | def get_size(self):
98 | return self._xyz.shape[0]
99 |
100 | def get_scaling(self):
101 | if self.isotropic:
102 | scale = self.scaling_activation(self._scaling)[:, 0:1] # Extract the first column
103 | scales = scale.repeat(1, 3) # Replicate this column three times
104 | return scales
105 | return self.scaling_activation(self._scaling)
106 |
107 | def get_rotation(self):
108 | return self.rotation_activation(self._rotation)
109 |
110 | def get_xyz(self):
111 | return self._xyz
112 |
113 | def get_features(self):
114 | features_dc = self._features_dc
115 | features_rest = self._features_rest
116 | return torch.cat((features_dc, features_rest), dim=1)
117 |
118 | def get_opacity(self):
119 | return self.opacity_activation(self._opacity)
120 |
121 | def get_active_sh_degree(self):
122 | return self.active_sh_degree
123 |
124 | def get_covariance(self, scaling_modifier=1):
125 | return self.build_covariance_from_scaling_rotation(self.get_scaling(), scaling_modifier, self._rotation)
126 |
127 | def add_points(self, pcd: o3d.geometry.PointCloud, global_scale_init=True):
128 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
129 | fused_color = RGB2SH(torch.tensor(
130 | np.asarray(pcd.colors)).float().cuda())
131 | features = (torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda())
132 | features[:, :3, 0] = fused_color
133 | features[:, 3:, 1:] = 0.0
134 | print("Number of added points: ", fused_point_cloud.shape[0])
135 |
136 | if global_scale_init:
137 | global_points = torch.cat((self.get_xyz(),torch.from_numpy(np.asarray(pcd.points)).float().cuda()))
138 | dist2 = torch.clamp_min(distCUDA2(global_points), 0.0000001)
139 | dist2 = dist2[self.get_size():]
140 | else:
141 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
142 | scales = torch.log(1.0 * torch.sqrt(dist2))[..., None].repeat(1, 3)
143 | # scales = torch.log(0.001 * torch.ones_like(dist2))[..., None].repeat(1, 3)
144 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
145 | rots[:, 0] = 1
146 | opacities = inverse_sigmoid(0.5 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
147 | new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
148 | new_features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True))
149 | new_features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True))
150 | new_scaling = nn.Parameter(scales.requires_grad_(True))
151 | new_rotation = nn.Parameter(rots.requires_grad_(True))
152 | new_opacities = nn.Parameter(opacities.requires_grad_(True))
153 | self.densification_postfix(
154 | new_xyz,
155 | new_features_dc,
156 | new_features_rest,
157 | new_opacities,
158 | new_scaling,
159 | new_rotation,
160 | )
161 |
162 | def training_setup(self, training_args):
163 | self.percent_dense = training_args.percent_dense
164 | self.xyz_gradient_accum = torch.zeros(
165 | (self.get_xyz().shape[0], 1), device="cuda"
166 | )
167 | self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
168 |
169 | params = [
170 | {"params": [self._xyz], "lr": training_args.position_lr_init, "name": "xyz"},
171 | {"params": [self._features_dc], "lr": training_args.feature_lr, "name": "f_dc"},
172 | {"params": [self._features_rest], "lr": training_args.feature_lr / 20.0, "name": "f_rest"},
173 | {"params": [self._opacity], "lr": training_args.opacity_lr, "name": "opacity"},
174 | {"params": [self._scaling], "lr": training_args.scaling_lr, "name": "scaling"},
175 | {"params": [self._rotation], "lr": training_args.rotation_lr, "name": "rotation"},
176 | ]
177 |
178 | self.optimizer = torch.optim.Adam(params, lr=0.0, eps=1e-15)
179 | self.xyz_scheduler_args = get_expon_lr_func(
180 | lr_init=training_args.position_lr_init * self.spatial_lr_scale,
181 | lr_final=training_args.position_lr_final * self.spatial_lr_scale,
182 | lr_delay_mult=training_args.position_lr_delay_mult,
183 | max_steps=training_args.position_lr_max_steps,
184 | )
185 |
186 | def training_setup_camera(self, cam_rot, cam_trans, cfg):
187 | self.xyz_gradient_accum = torch.zeros(
188 | (self.get_xyz().shape[0], 1), device="cuda"
189 | )
190 | self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
191 | params = [
192 | {"params": [self._xyz], "lr": 0.0, "name": "xyz"},
193 | {"params": [self._features_dc], "lr": 0.0, "name": "f_dc"},
194 | {"params": [self._features_rest], "lr": 0.0, "name": "f_rest"},
195 | {"params": [self._opacity], "lr": 0.0, "name": "opacity"},
196 | {"params": [self._scaling], "lr": 0.0, "name": "scaling"},
197 | {"params": [self._rotation], "lr": 0.0, "name": "rotation"},
198 | {"params": [cam_rot], "lr": cfg["cam_rot_lr"],
199 | "name": "cam_unnorm_rot"},
200 | {"params": [cam_trans], "lr": cfg["cam_trans_lr"],
201 | "name": "cam_trans"},
202 | ]
203 | self.optimizer = torch.optim.Adam(params, amsgrad=True)
204 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
205 | self.optimizer, "min", factor=0.98, patience=10, verbose=False)
206 |
207 | def construct_list_of_attributes(self):
208 | l = ["x", "y", "z", "nx", "ny", "nz"]
209 | # All channels except the 3 DC
210 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
211 | l.append("f_dc_{}".format(i))
212 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
213 | l.append("f_rest_{}".format(i))
214 | l.append("opacity")
215 | for i in range(self._scaling.shape[1]):
216 | l.append("scale_{}".format(i))
217 | for i in range(self._rotation.shape[1]):
218 | l.append("rot_{}".format(i))
219 | return l
220 |
221 | def save_ply(self, path):
222 | Path(path).parent.mkdir(parents=True, exist_ok=True)
223 |
224 | xyz = self._xyz.detach().cpu().numpy()
225 | normals = np.zeros_like(xyz)
226 | f_dc = (
227 | self._features_dc.detach()
228 | .transpose(1, 2)
229 | .flatten(start_dim=1)
230 | .contiguous()
231 | .cpu()
232 | .numpy())
233 | f_rest = (
234 | self._features_rest.detach()
235 | .transpose(1, 2)
236 | .flatten(start_dim=1)
237 | .contiguous()
238 | .cpu()
239 | .numpy())
240 | opacities = self._opacity.detach().cpu().numpy()
241 | if self.isotropic:
242 | # tile into shape (P, 3)
243 | scale = np.tile(self._scaling.detach().cpu().numpy()[:, 0].reshape(-1, 1), (1, 3))
244 | else:
245 | scale = self._scaling.detach().cpu().numpy()
246 | rotation = self._rotation.detach().cpu().numpy()
247 |
248 | dtype_full = [(attribute, "f4") for attribute in self.construct_list_of_attributes()]
249 |
250 | elements = np.empty(xyz.shape[0], dtype=dtype_full)
251 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
252 | elements[:] = list(map(tuple, attributes))
253 | el = PlyElement.describe(elements, "vertex")
254 | PlyData([el]).write(path)
255 |
256 | def load_ply(self, path):
257 | plydata = PlyData.read(path)
258 |
259 | xyz = np.stack((
260 | np.asarray(plydata.elements[0]["x"]),
261 | np.asarray(plydata.elements[0]["y"]),
262 | np.asarray(plydata.elements[0]["z"])),
263 | axis=1)
264 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
265 |
266 | features_dc = np.zeros((xyz.shape[0], 3, 1))
267 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
268 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
269 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
270 |
271 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
272 | extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
273 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
274 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
275 | for idx, attr_name in enumerate(extra_f_names):
276 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
277 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
278 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
279 |
280 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
281 | scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
282 | scales = np.zeros((xyz.shape[0], len(scale_names)))
283 | for idx, attr_name in enumerate(scale_names):
284 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
285 |
286 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
287 | rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
288 | rots = np.zeros((xyz.shape[0], len(rot_names)))
289 | for idx, attr_name in enumerate(rot_names):
290 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
291 |
292 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
293 | self._features_dc = nn.Parameter(
294 | torch.tensor(features_dc, dtype=torch.float, device="cuda")
295 | .transpose(1, 2).contiguous().requires_grad_(True))
296 | self._features_rest = nn.Parameter(
297 | torch.tensor(features_extra, dtype=torch.float, device="cuda")
298 | .transpose(1, 2)
299 | .contiguous()
300 | .requires_grad_(True)
301 | )
302 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
303 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
304 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
305 |
306 | self.active_sh_degree = self.max_sh_degree
307 |
308 | def replace_tensor_to_optimizer(self, tensor, name):
309 | optimizable_tensors = {}
310 | for group in self.optimizer.param_groups:
311 | if group["name"] == name:
312 | stored_state = self.optimizer.state.get(group["params"][0], None)
313 | stored_state["exp_avg"] = torch.zeros_like(tensor)
314 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
315 |
316 | del self.optimizer.state[group["params"][0]]
317 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
318 | self.optimizer.state[group["params"][0]] = stored_state
319 |
320 | optimizable_tensors[group["name"]] = group["params"][0]
321 | return optimizable_tensors
322 |
323 | def _prune_optimizer(self, mask):
324 | optimizable_tensors = {}
325 | for group in self.optimizer.param_groups:
326 | stored_state = self.optimizer.state.get(group["params"][0], None)
327 | if stored_state is not None:
328 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
329 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
330 |
331 | del self.optimizer.state[group["params"][0]]
332 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
333 | self.optimizer.state[group["params"][0]] = stored_state
334 | optimizable_tensors[group["name"]] = group["params"][0]
335 | else:
336 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
337 | optimizable_tensors[group["name"]] = group["params"][0]
338 | return optimizable_tensors
339 |
340 | def prune_points(self, mask):
341 | valid_points_mask = ~mask
342 | optimizable_tensors = self._prune_optimizer(valid_points_mask)
343 |
344 | self._xyz = optimizable_tensors["xyz"]
345 | self._features_dc = optimizable_tensors["f_dc"]
346 | self._features_rest = optimizable_tensors["f_rest"]
347 | self._opacity = optimizable_tensors["opacity"]
348 | self._scaling = optimizable_tensors["scaling"]
349 | self._rotation = optimizable_tensors["rotation"]
350 |
351 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
352 |
353 | self.denom = self.denom[valid_points_mask]
354 | self.max_radii2D = self.max_radii2D[valid_points_mask]
355 |
356 | def cat_tensors_to_optimizer(self, tensors_dict):
357 | optimizable_tensors = {}
358 | for group in self.optimizer.param_groups:
359 | assert len(group["params"]) == 1
360 | extension_tensor = tensors_dict[group["name"]]
361 | stored_state = self.optimizer.state.get(group["params"][0], None)
362 | if stored_state is not None:
363 | stored_state["exp_avg"] = torch.cat(
364 | (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
365 | stored_state["exp_avg_sq"] = torch.cat(
366 | (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
367 |
368 | del self.optimizer.state[group["params"][0]]
369 | group["params"][0] = nn.Parameter(
370 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
371 | self.optimizer.state[group["params"][0]] = stored_state
372 |
373 | optimizable_tensors[group["name"]] = group["params"][0]
374 | else:
375 | group["params"][0] = nn.Parameter(
376 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
377 | optimizable_tensors[group["name"]] = group["params"][0]
378 |
379 | return optimizable_tensors
380 |
381 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest,
382 | new_opacities, new_scaling, new_rotation):
383 | d = {
384 | "xyz": new_xyz,
385 | "f_dc": new_features_dc,
386 | "f_rest": new_features_rest,
387 | "opacity": new_opacities,
388 | "scaling": new_scaling,
389 | "rotation": new_rotation,
390 | }
391 |
392 | optimizable_tensors = self.cat_tensors_to_optimizer(d)
393 | self._xyz = optimizable_tensors["xyz"]
394 | self._features_dc = optimizable_tensors["f_dc"]
395 | self._features_rest = optimizable_tensors["f_rest"]
396 | self._opacity = optimizable_tensors["opacity"]
397 | self._scaling = optimizable_tensors["scaling"]
398 | self._rotation = optimizable_tensors["rotation"]
399 |
400 | self.xyz_gradient_accum = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
401 | self.denom = torch.zeros((self.get_xyz().shape[0], 1), device="cuda")
402 | self.max_radii2D = torch.zeros(
403 | (self.get_xyz().shape[0]), device="cuda")
404 |
405 | def add_densification_stats(self, viewspace_point_tensor, update_filter):
406 | self.xyz_gradient_accum[update_filter] += torch.norm(
407 | viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True)
408 | self.denom[update_filter] += 1
409 |
--------------------------------------------------------------------------------
/src/entities/gaussian_slam.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Gaussian-SLAM class, which is responsible for controlling Mapper and Tracker
2 | It also decides when to start a new submap and when to update the estimated camera poses.
3 | """
4 | import os
5 | import pprint
6 | from argparse import ArgumentParser
7 | from datetime import datetime
8 | from pathlib import Path
9 |
10 | import numpy as np
11 | import torch
12 |
13 | from src.entities.arguments import OptimizationParams
14 | from src.entities.datasets import get_dataset
15 | from src.entities.gaussian_model import GaussianModel
16 | from src.entities.mapper import Mapper
17 | from src.entities.tracker import Tracker
18 | from src.entities.logger import Logger
19 | from src.utils.io_utils import save_dict_to_ckpt, save_dict_to_yaml
20 | from src.utils.mapper_utils import exceeds_motion_thresholds
21 | from src.utils.utils import np2torch, setup_seed, torch2np
22 | from src.utils.vis_utils import * # noqa - needed for debugging
23 |
24 |
25 | class GaussianSLAM(object):
26 |
27 | def __init__(self, config: dict) -> None:
28 |
29 | self._setup_output_path(config)
30 | self.device = "cuda"
31 | self.config = config
32 |
33 | self.scene_name = config["data"]["scene_name"]
34 | self.dataset_name = config["dataset_name"]
35 | self.dataset = get_dataset(config["dataset_name"])({**config["data"], **config["cam"]})
36 |
37 | n_frames = len(self.dataset)
38 | frame_ids = list(range(n_frames))
39 | self.mapping_frame_ids = frame_ids[::config["mapping"]["map_every"]] + [n_frames - 1]
40 |
41 | self.estimated_c2ws = torch.empty(len(self.dataset), 4, 4)
42 | self.estimated_c2ws[0] = torch.from_numpy(self.dataset[0][3])
43 |
44 | save_dict_to_yaml(config, "config.yaml", directory=self.output_path)
45 |
46 | self.submap_using_motion_heuristic = config["mapping"]["submap_using_motion_heuristic"]
47 |
48 | self.keyframes_info = {}
49 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
50 |
51 | if self.submap_using_motion_heuristic:
52 | self.new_submap_frame_ids = [0]
53 | else:
54 | self.new_submap_frame_ids = frame_ids[::config["mapping"]["new_submap_every"]] + [n_frames - 1]
55 | self.new_submap_frame_ids.pop(0)
56 |
57 | self.logger = Logger(self.output_path, config["use_wandb"])
58 | self.mapper = Mapper(config["mapping"], self.dataset, self.logger)
59 | self.tracker = Tracker(config["tracking"], self.dataset, self.logger)
60 |
61 | print('Tracking config')
62 | pprint.PrettyPrinter().pprint(config["tracking"])
63 | print('Mapping config')
64 | pprint.PrettyPrinter().pprint(config["mapping"])
65 |
66 | def _setup_output_path(self, config: dict) -> None:
67 | """ Sets up the output path for saving results based on the provided configuration. If the output path is not
68 | specified in the configuration, it creates a new directory with a timestamp.
69 | Args:
70 | config: A dictionary containing the experiment configuration including data and output path information.
71 | """
72 | if "output_path" not in config["data"]:
73 | output_path = Path(config["data"]["output_path"])
74 | self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
75 | self.output_path = output_path / self.timestamp
76 | else:
77 | self.output_path = Path(config["data"]["output_path"])
78 | self.output_path.mkdir(exist_ok=True, parents=True)
79 | os.makedirs(self.output_path / "mapping_vis", exist_ok=True)
80 | os.makedirs(self.output_path / "tracking_vis", exist_ok=True)
81 |
82 | def should_start_new_submap(self, frame_id: int) -> bool:
83 | """ Determines whether a new submap should be started based on the motion heuristic or specific frame IDs.
84 | Args:
85 | frame_id: The ID of the current frame being processed.
86 | Returns:
87 | A boolean indicating whether to start a new submap.
88 | """
89 | if self.submap_using_motion_heuristic:
90 | if exceeds_motion_thresholds(
91 | self.estimated_c2ws[frame_id], self.estimated_c2ws[self.new_submap_frame_ids[-1]],
92 | rot_thre=50, trans_thre=0.5):
93 | return True
94 | elif frame_id in self.new_submap_frame_ids:
95 | return True
96 | return False
97 |
98 | def start_new_submap(self, frame_id: int, gaussian_model: GaussianModel) -> None:
99 | """ Initializes a new submap, saving the current submap's checkpoint and resetting the Gaussian model.
100 | This function updates the submap count and optionally marks the current frame ID for new submap initiation.
101 | Args:
102 | frame_id: The ID of the current frame at which the new submap is started.
103 | gaussian_model: The current GaussianModel instance to capture and reset for the new submap.
104 | Returns:
105 | A new, reset GaussianModel instance for the new submap.
106 | """
107 | gaussian_params = gaussian_model.capture_dict()
108 | submap_ckpt_name = str(self.submap_id).zfill(6)
109 | submap_ckpt = {
110 | "gaussian_params": gaussian_params,
111 | "submap_keyframes": sorted(list(self.keyframes_info.keys()))
112 | }
113 | save_dict_to_ckpt(
114 | submap_ckpt, f"{submap_ckpt_name}.ckpt", directory=self.output_path / "submaps")
115 | gaussian_model = GaussianModel(0)
116 | gaussian_model.training_setup(self.opt)
117 | self.mapper.keyframes = []
118 | self.keyframes_info = {}
119 | if self.submap_using_motion_heuristic:
120 | self.new_submap_frame_ids.append(frame_id)
121 | self.mapping_frame_ids.append(frame_id)
122 | self.submap_id += 1
123 | return gaussian_model
124 |
125 | def run(self) -> None:
126 | """ Starts the main program flow for Gaussian-SLAM, including tracking and mapping. """
127 | setup_seed(self.config["seed"])
128 | gaussian_model = GaussianModel(0)
129 | gaussian_model.training_setup(self.opt)
130 | self.submap_id = 0
131 |
132 | for frame_id in range(len(self.dataset)):
133 |
134 | if frame_id in [0, 1]:
135 | estimated_c2w = self.dataset[frame_id][-1]
136 | else:
137 | estimated_c2w = self.tracker.track(
138 | frame_id, gaussian_model,
139 | torch2np(self.estimated_c2ws[torch.tensor([0, frame_id - 2, frame_id - 1])]))
140 | self.estimated_c2ws[frame_id] = np2torch(estimated_c2w)
141 |
142 | # Reinitialize gaussian model for new segment
143 | if self.should_start_new_submap(frame_id):
144 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path)
145 | gaussian_model = self.start_new_submap(frame_id, gaussian_model)
146 |
147 | if frame_id in self.mapping_frame_ids:
148 | print("\nMapping frame", frame_id)
149 | gaussian_model.training_setup(self.opt)
150 | estimate_c2w = torch2np(self.estimated_c2ws[frame_id])
151 | new_submap = not bool(self.keyframes_info)
152 | opt_dict = self.mapper.map(frame_id, estimate_c2w, gaussian_model, new_submap)
153 |
154 | # Keyframes info update
155 | self.keyframes_info[frame_id] = {
156 | "keyframe_id": len(self.keyframes_info.keys()),
157 | "opt_dict": opt_dict
158 | }
159 | save_dict_to_ckpt(self.estimated_c2ws[:frame_id + 1], "estimated_c2w.ckpt", directory=self.output_path)
160 |
--------------------------------------------------------------------------------
/src/entities/logger.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Logger class, which is responsible for logging for both Mapper and the Tracker """
2 | from pathlib import Path
3 | from typing import Union
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import wandb
9 |
10 |
11 | class Logger(object):
12 |
13 | def __init__(self, output_path: Union[Path, str], use_wandb=False) -> None:
14 | self.output_path = Path(output_path)
15 | (self.output_path / "mapping_vis").mkdir(exist_ok=True, parents=True)
16 | self.use_wandb = use_wandb
17 |
18 | def log_tracking_iteration(self, frame_id, cur_pose, gt_quat, gt_trans, total_loss,
19 | color_loss, depth_loss, iter, num_iters,
20 | wandb_output=False, print_output=False) -> None:
21 | """ Logs tracking iteration metrics including pose error, losses, and optionally reports to Weights & Biases.
22 | Logs the error between the current pose estimate and ground truth quaternion and translation,
23 | as well as various loss metrics. Can output to wandb if enabled and specified, and print to console.
24 | Args:
25 | frame_id: Identifier for the current frame.
26 | cur_pose: The current estimated pose as a tensor (quaternion + translation).
27 | gt_quat: Ground truth quaternion.
28 | gt_trans: Ground truth translation.
29 | total_loss: Total computed loss for the current iteration.
30 | color_loss: Computed color loss for the current iteration.
31 | depth_loss: Computed depth loss for the current iteration.
32 | iter: The current iteration number.
33 | num_iters: The total number of iterations planned.
34 | wandb_output: Whether to output the log to wandb.
35 | print_output: Whether to print the log output.
36 | """
37 |
38 | quad_err = torch.abs(cur_pose[:4] - gt_quat).mean().item()
39 | trans_err = torch.abs(cur_pose[4:] - gt_trans).mean().item()
40 | if self.use_wandb and wandb_output:
41 | wandb.log(
42 | {
43 | "Tracking/idx": frame_id,
44 | "Tracking/cam_quad_err": quad_err,
45 | "Tracking/cam_position_err": trans_err,
46 | "Tracking/total_loss": total_loss.item(),
47 | "Tracking/color_loss": color_loss.item(),
48 | "Tracking/depth_loss": depth_loss.item(),
49 | "Tracking/num_iters": num_iters,
50 | })
51 | if iter == num_iters - 1:
52 | msg = f"frame_id: {frame_id}, cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f} "
53 | else:
54 | msg = f"iter: {iter}, color_loss: {color_loss.item():.5f}, depth_loss: {depth_loss.item():.5f} "
55 | msg = msg + f", cam_quad_err: {quad_err:.5f}, cam_trans_err: {trans_err:.5f}"
56 | if print_output:
57 | print(msg, flush=True)
58 |
59 | def log_mapping_iteration(self, frame_id, new_pts_num, model_size, iter_opt_time, opt_dict: dict) -> None:
60 | """ Logs mapping iteration metrics including the number of new points, model size, and optimization times,
61 | and optionally reports to Weights & Biases (wandb).
62 | Args:
63 | frame_id: Identifier for the current frame.
64 | new_pts_num: The number of new points added in the current mapping iteration.
65 | model_size: The total size of the model after the current mapping iteration.
66 | iter_opt_time: Time taken per optimization iteration.
67 | opt_dict: A dictionary containing optimization metrics such as PSNR, color loss, and depth loss.
68 | """
69 | if self.use_wandb:
70 | wandb.log({"Mapping/idx": frame_id,
71 | "Mapping/num_total_gs": model_size,
72 | "Mapping/num_new_gs": new_pts_num,
73 | "Mapping/per_iteration_time": iter_opt_time,
74 | "Mapping/psnr_render": opt_dict["psnr_render"],
75 | "Mapping/color_loss": opt_dict[frame_id]["color_loss"],
76 | "Mapping/depth_loss": opt_dict[frame_id]["depth_loss"]})
77 |
78 | def vis_mapping_iteration(self, frame_id, iter, color, depth, gt_color, gt_depth, seeding_mask=None) -> None:
79 | """
80 | Visualization of depth, color images and save to file.
81 |
82 | Args:
83 | frame_id (int): current frame index.
84 | iter (int): the iteration number.
85 | save_rendered_image (bool): whether to save the rgb image in separate folder
86 | img_dir (str): the directory to save the visualization.
87 | seeding_mask: used in mapper when adding gaussians, if not none.
88 | """
89 | gt_depth_np = gt_depth.cpu().numpy()
90 | gt_color_np = gt_color.cpu().numpy()
91 |
92 | depth_np = depth.detach().cpu().numpy()
93 | color = torch.round(color * 255.0) / 255.0
94 | color_np = color.detach().cpu().numpy()
95 | depth_residual = np.abs(gt_depth_np - depth_np)
96 | depth_residual[gt_depth_np == 0.0] = 0.0
97 | # make errors >=5cm noticeable
98 | depth_residual = np.clip(depth_residual, 0.0, 0.05)
99 |
100 | color_residual = np.abs(gt_color_np - color_np)
101 | color_residual[np.squeeze(gt_depth_np == 0.0)] = 0.0
102 |
103 | # Determine Aspect Ratio and Figure Size
104 | aspect_ratio = color.shape[1] / color.shape[0]
105 | fig_height = 8
106 | # Adjust the multiplier as needed for better spacing
107 | fig_width = fig_height * aspect_ratio * 1.2
108 |
109 | fig, axs = plt.subplots(2, 3, figsize=(fig_width, fig_height))
110 | axs[0, 0].imshow(gt_depth_np, cmap="jet", vmin=0, vmax=6)
111 | axs[0, 0].set_title('Input Depth', fontsize=16)
112 | axs[0, 0].set_xticks([])
113 | axs[0, 0].set_yticks([])
114 | axs[0, 1].imshow(depth_np, cmap="jet", vmin=0, vmax=6)
115 | axs[0, 1].set_title('Rendered Depth', fontsize=16)
116 | axs[0, 1].set_xticks([])
117 | axs[0, 1].set_yticks([])
118 | axs[0, 2].imshow(depth_residual, cmap="plasma")
119 | axs[0, 2].set_title('Depth Residual', fontsize=16)
120 | axs[0, 2].set_xticks([])
121 | axs[0, 2].set_yticks([])
122 | gt_color_np = np.clip(gt_color_np, 0, 1)
123 | color_np = np.clip(color_np, 0, 1)
124 | color_residual = np.clip(color_residual, 0, 1)
125 | axs[1, 0].imshow(gt_color_np, cmap="plasma")
126 | axs[1, 0].set_title('Input RGB', fontsize=16)
127 | axs[1, 0].set_xticks([])
128 | axs[1, 0].set_yticks([])
129 | axs[1, 1].imshow(color_np, cmap="plasma")
130 | axs[1, 1].set_title('Rendered RGB', fontsize=16)
131 | axs[1, 1].set_xticks([])
132 | axs[1, 1].set_yticks([])
133 | if seeding_mask is not None:
134 | axs[1, 2].imshow(seeding_mask, cmap="gray")
135 | axs[1, 2].set_title('Densification Mask', fontsize=16)
136 | axs[1, 2].set_xticks([])
137 | axs[1, 2].set_yticks([])
138 | else:
139 | axs[1, 2].imshow(color_residual, cmap="plasma")
140 | axs[1, 2].set_title('RGB Residual', fontsize=16)
141 | axs[1, 2].set_xticks([])
142 | axs[1, 2].set_yticks([])
143 |
144 | for ax in axs.flatten():
145 | ax.axis('off')
146 | fig.tight_layout()
147 | plt.subplots_adjust(top=0.90) # Adjust top margin
148 | fig_name = str(self.output_path / "mapping_vis" / f'{frame_id:04d}_{iter:04d}.jpg')
149 | fig_title = f"Mapper Color/Depth at frame {frame_id:04d} iters {iter:04d}"
150 | plt.suptitle(fig_title, y=0.98, fontsize=20)
151 | plt.savefig(fig_name, dpi=250, bbox_inches='tight')
152 | plt.clf()
153 | plt.close()
154 | if self.use_wandb:
155 | log_title = "Mapping_vis/" + f'{frame_id:04d}_{iter:04d}'
156 | wandb.log({log_title: [wandb.Image(fig_name)]})
157 | print(f"Saved rendering vis of color/depth at {frame_id:04d}_{iter:04d}.jpg")
158 |
--------------------------------------------------------------------------------
/src/entities/losses.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 |
8 | def l1_loss(network_output: torch.Tensor, gt: torch.Tensor, agg="mean") -> torch.Tensor:
9 | """
10 | Computes the L1 loss, which is the mean absolute error between the network output and the ground truth.
11 |
12 | Args:
13 | network_output: The output from the network.
14 | gt: The ground truth tensor.
15 | agg: The aggregation method to be used. Defaults to "mean".
16 | Returns:
17 | The computed L1 loss.
18 | """
19 | l1_loss = torch.abs(network_output - gt)
20 | if agg == "mean":
21 | return l1_loss.mean()
22 | elif agg == "sum":
23 | return l1_loss.sum()
24 | elif agg == "none":
25 | return l1_loss
26 | else:
27 | raise ValueError("Invalid aggregation method.")
28 |
29 |
30 | def gaussian(window_size: int, sigma: float) -> torch.Tensor:
31 | """
32 | Creates a 1D Gaussian kernel.
33 |
34 | Args:
35 | window_size: The size of the window for the Gaussian kernel.
36 | sigma: The standard deviation of the Gaussian kernel.
37 |
38 | Returns:
39 | The 1D Gaussian kernel.
40 | """
41 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 /
42 | float(2 * sigma ** 2)) for x in range(window_size)])
43 | return gauss / gauss.sum()
44 |
45 |
46 | def create_window(window_size: int, channel: int) -> Variable:
47 | """
48 | Creates a 2D Gaussian window/kernel for SSIM computation.
49 |
50 | Args:
51 | window_size: The size of the window to be created.
52 | channel: The number of channels in the image.
53 |
54 | Returns:
55 | A 2D Gaussian window expanded to match the number of channels.
56 | """
57 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
58 | _2D_window = _1D_window.mm(
59 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
60 | window = Variable(_2D_window.expand(
61 | channel, 1, window_size, window_size).contiguous())
62 | return window
63 |
64 |
65 | def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int = 11, size_average: bool = True) -> torch.Tensor:
66 | """
67 | Computes the Structural Similarity Index (SSIM) between two images.
68 |
69 | Args:
70 | img1: The first image.
71 | img2: The second image.
72 | window_size: The size of the window to be used in SSIM computation. Defaults to 11.
73 | size_average: If True, averages the SSIM over all pixels. Defaults to True.
74 |
75 | Returns:
76 | The computed SSIM value.
77 | """
78 | channel = img1.size(-3)
79 | window = create_window(window_size, channel)
80 |
81 | if img1.is_cuda:
82 | window = window.cuda(img1.get_device())
83 | window = window.type_as(img1)
84 |
85 | return _ssim(img1, img2, window, window_size, channel, size_average)
86 |
87 |
88 | def _ssim(img1: torch.Tensor, img2: torch.Tensor, window: Variable, window_size: int,
89 | channel: int, size_average: bool = True) -> torch.Tensor:
90 | """
91 | Internal function to compute the Structural Similarity Index (SSIM) between two images.
92 |
93 | Args:
94 | img1: The first image.
95 | img2: The second image.
96 | window: The Gaussian window/kernel for SSIM computation.
97 | window_size: The size of the window to be used in SSIM computation.
98 | channel: The number of channels in the image.
99 | size_average: If True, averages the SSIM over all pixels.
100 |
101 | Returns:
102 | The computed SSIM value.
103 | """
104 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
105 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
106 |
107 | mu1_sq = mu1.pow(2)
108 | mu2_sq = mu2.pow(2)
109 | mu1_mu2 = mu1 * mu2
110 |
111 | sigma1_sq = F.conv2d(img1 * img1, window,
112 | padding=window_size // 2, groups=channel) - mu1_sq
113 | sigma2_sq = F.conv2d(img2 * img2, window,
114 | padding=window_size // 2, groups=channel) - mu2_sq
115 | sigma12 = F.conv2d(img1 * img2, window,
116 | padding=window_size // 2, groups=channel) - mu1_mu2
117 |
118 | C1 = 0.01 ** 2
119 | C2 = 0.03 ** 2
120 |
121 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
122 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
123 |
124 | if size_average:
125 | return ssim_map.mean()
126 | else:
127 | return ssim_map.mean(1).mean(1).mean(1)
128 |
129 |
130 | def isotropic_loss(scaling: torch.Tensor) -> torch.Tensor:
131 | """
132 | Computes loss enforcing isotropic scaling for the 3D Gaussians
133 | Args:
134 | scaling: scaling tensor of 3D Gaussians of shape (n, 3)
135 | Returns:
136 | The computed isotropic loss
137 | """
138 | mean_scaling = scaling.mean(dim=1, keepdim=True)
139 | isotropic_diff = torch.abs(scaling - mean_scaling * torch.ones_like(scaling))
140 | return isotropic_diff.mean()
141 |
--------------------------------------------------------------------------------
/src/entities/mapper.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Mapper class, which is responsible scene mapping: Paragraph 3.2 """
2 | import time
3 | from argparse import ArgumentParser
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 |
9 | from src.entities.arguments import OptimizationParams
10 | from src.entities.datasets import TUM_RGBD, BaseDataset, ScanNet
11 | from src.entities.gaussian_model import GaussianModel
12 | from src.entities.logger import Logger
13 | from src.entities.losses import isotropic_loss, l1_loss, ssim
14 | from src.utils.mapper_utils import (calc_psnr, compute_camera_frustum_corners,
15 | compute_frustum_point_ids,
16 | compute_new_points_ids,
17 | compute_opt_views_distribution,
18 | create_point_cloud, geometric_edge_mask,
19 | sample_pixels_based_on_gradient)
20 | from src.utils.utils import (get_render_settings, np2ptcloud, np2torch,
21 | render_gaussian_model, torch2np)
22 | from src.utils.vis_utils import * # noqa - needed for debugging
23 |
24 |
25 | class Mapper(object):
26 | def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None:
27 | """ Sets up the mapper parameters
28 | Args:
29 | config: configuration of the mapper
30 | dataset: The dataset object used for extracting camera parameters and reading the data
31 | logger: The logger object used for logging the mapping process and saving visualizations
32 | """
33 | self.config = config
34 | self.logger = logger
35 | self.dataset = dataset
36 | self.iterations = config["iterations"]
37 | self.new_submap_iterations = config["new_submap_iterations"]
38 | self.new_submap_points_num = config["new_submap_points_num"]
39 | self.new_submap_gradient_points_num = config["new_submap_gradient_points_num"]
40 | self.new_frame_sample_size = config["new_frame_sample_size"]
41 | self.new_points_radius = config["new_points_radius"]
42 | self.alpha_thre = config["alpha_thre"]
43 | self.pruning_thre = config["pruning_thre"]
44 | self.current_view_opt_iterations = config["current_view_opt_iterations"]
45 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
46 | self.keyframes = []
47 |
48 | def compute_seeding_mask(self, gaussian_model: GaussianModel, keyframe: dict, new_submap: bool) -> np.ndarray:
49 | """
50 | Computes a binary mask to identify regions within a keyframe where new Gaussian models should be seeded
51 | based on alpha masks or color gradient
52 | Args:
53 | gaussian_model: The current submap
54 | keyframe (dict): Keyframe dict containing color, depth, and render settings
55 | new_submap (bool): A boolean indicating whether the seeding is occurring in current submap or a new submap
56 | Returns:
57 | np.ndarray: A binary mask of shpae (H, W) indicates regions suitable for seeding new 3D Gaussian models
58 | """
59 | seeding_mask = None
60 | if new_submap:
61 | color_for_mask = (torch2np(keyframe["color"].permute(1, 2, 0)) * 255).astype(np.uint8)
62 | seeding_mask = geometric_edge_mask(color_for_mask, RGB=True)
63 | else:
64 | render_dict = render_gaussian_model(gaussian_model, keyframe["render_settings"])
65 | alpha_mask = (render_dict["alpha"] < self.alpha_thre)
66 | gt_depth_tensor = keyframe["depth"][None]
67 | depth_error = torch.abs(gt_depth_tensor - render_dict["depth"]) * (gt_depth_tensor > 0)
68 | depth_error_mask = (render_dict["depth"] > gt_depth_tensor) * (depth_error > 40 * depth_error.median())
69 | seeding_mask = alpha_mask | depth_error_mask
70 | seeding_mask = torch2np(seeding_mask[0])
71 | return seeding_mask
72 |
73 | def seed_new_gaussians(self, gt_color: np.ndarray, gt_depth: np.ndarray, intrinsics: np.ndarray,
74 | estimate_c2w: np.ndarray, seeding_mask: np.ndarray, is_new_submap: bool) -> np.ndarray:
75 | """
76 | Seeds means for the new 3D Gaussian based on ground truth color and depth, camera intrinsics,
77 | estimated camera-to-world transformation, a seeding mask, and a flag indicating whether this is a new submap.
78 | Args:
79 | gt_color: The ground truth color image as a numpy array with shape (H, W, 3).
80 | gt_depth: The ground truth depth map as a numpy array with shape (H, W).
81 | intrinsics: The camera intrinsics matrix as a numpy array with shape (3, 3).
82 | estimate_c2w: The estimated camera-to-world transformation matrix as a numpy array with shape (4, 4).
83 | seeding_mask: A binary mask indicating where to seed new Gaussians, with shape (H, W).
84 | is_new_submap: Flag indicating whether the seeding is for a new submap (True) or an existing submap (False).
85 | Returns:
86 | np.ndarray: An array of 3D points where new Gaussians will be initialized, with shape (N, 3)
87 |
88 | """
89 | pts = create_point_cloud(gt_color, 1.005 * gt_depth, intrinsics, estimate_c2w)
90 | flat_gt_depth = gt_depth.flatten()
91 | non_zero_depth_mask = flat_gt_depth > 0. # need filter if zero depth pixels in gt_depth
92 | valid_ids = np.flatnonzero(seeding_mask)
93 | if is_new_submap:
94 | if self.new_submap_points_num < 0:
95 | uniform_ids = np.arange(pts.shape[0])
96 | else:
97 | uniform_ids = np.random.choice(pts.shape[0], self.new_submap_points_num, replace=False)
98 | gradient_ids = sample_pixels_based_on_gradient(gt_color, self.new_submap_gradient_points_num)
99 | combined_ids = np.concatenate((uniform_ids, gradient_ids))
100 | combined_ids = np.concatenate((combined_ids, valid_ids))
101 | sample_ids = np.unique(combined_ids)
102 | else:
103 | if self.new_frame_sample_size < 0 or len(valid_ids) < self.new_frame_sample_size:
104 | sample_ids = valid_ids
105 | else:
106 | sample_ids = np.random.choice(valid_ids, size=self.new_frame_sample_size, replace=False)
107 | sample_ids = sample_ids[non_zero_depth_mask[sample_ids]]
108 | return pts[sample_ids, :].astype(np.float32)
109 |
110 | def optimize_submap(self, keyframes: list, gaussian_model: GaussianModel, iterations: int = 100) -> dict:
111 | """
112 | Optimizes the submap by refining the parameters of the 3D Gaussian based on the observations
113 | from keyframes observing the submap.
114 | Args:
115 | keyframes: A list of tuples consisting of frame id and keyframe dictionary
116 | gaussian_model: An instance of the GaussianModel class representing the initial state
117 | of the Gaussian model to be optimized.
118 | iterations: The number of iterations to perform the optimization process. Defaults to 100.
119 | Returns:
120 | losses_dict: Dictionary with the optimization statistics
121 | """
122 |
123 | iteration = 0
124 | losses_dict = {}
125 |
126 | current_frame_iters = self.current_view_opt_iterations * iterations
127 | distribution = compute_opt_views_distribution(len(keyframes), iterations, current_frame_iters)
128 | start_time = time.time()
129 | while iteration < iterations + 1:
130 | gaussian_model.optimizer.zero_grad(set_to_none=True)
131 | keyframe_id = np.random.choice(np.arange(len(keyframes)), p=distribution)
132 |
133 | frame_id, keyframe = keyframes[keyframe_id]
134 | render_pkg = render_gaussian_model(gaussian_model, keyframe["render_settings"])
135 |
136 | image, depth = render_pkg["color"], render_pkg["depth"]
137 | gt_image = keyframe["color"]
138 | gt_depth = keyframe["depth"]
139 |
140 | mask = (gt_depth > 0) & (~torch.isnan(depth)).squeeze(0)
141 | color_loss = (1.0 - self.opt.lambda_dssim) * l1_loss(
142 | image[:, mask], gt_image[:, mask]) + self.opt.lambda_dssim * (1.0 - ssim(image, gt_image))
143 |
144 | depth_loss = l1_loss(depth[:, mask], gt_depth[mask])
145 | reg_loss = isotropic_loss(gaussian_model.get_scaling())
146 | total_loss = color_loss + depth_loss + reg_loss
147 | total_loss.backward()
148 |
149 | losses_dict[frame_id] = {"color_loss": color_loss.item(),
150 | "depth_loss": depth_loss.item(),
151 | "total_loss": total_loss.item()}
152 |
153 | with torch.no_grad():
154 |
155 | if iteration == iterations // 2 or iteration == iterations:
156 | prune_mask = (gaussian_model.get_opacity()
157 | < self.pruning_thre).squeeze()
158 | gaussian_model.prune_points(prune_mask)
159 |
160 | # Optimizer step
161 | if iteration < iterations:
162 | gaussian_model.optimizer.step()
163 | gaussian_model.optimizer.zero_grad(set_to_none=True)
164 |
165 | iteration += 1
166 | optimization_time = time.time() - start_time
167 | losses_dict["optimization_time"] = optimization_time
168 | losses_dict["optimization_iter_time"] = optimization_time / iterations
169 | return losses_dict
170 |
171 | def grow_submap(self, gt_depth: np.ndarray, estimate_c2w: np.ndarray, gaussian_model: GaussianModel,
172 | pts: np.ndarray, filter_cloud: bool) -> int:
173 | """
174 | Expands the submap by integrating new points from the current keyframe
175 | Args:
176 | gt_depth: The ground truth depth map for the current keyframe, as a 2D numpy array.
177 | estimate_c2w: The estimated camera-to-world transformation matrix for the current keyframe of shape (4x4)
178 | gaussian_model (GaussianModel): The Gaussian model representing the current state of the submap.
179 | pts: The current set of 3D points in the keyframe of shape (N, 3)
180 | filter_cloud: A boolean flag indicating whether to apply filtering to the point cloud to remove
181 | outliers or noise before integrating it into the map.
182 | Returns:
183 | int: The number of points added to the submap
184 | """
185 | gaussian_points = gaussian_model.get_xyz()
186 | camera_frustum_corners = compute_camera_frustum_corners(gt_depth, estimate_c2w, self.dataset.intrinsics)
187 | reused_pts_ids = compute_frustum_point_ids(
188 | gaussian_points, np2torch(camera_frustum_corners), device="cuda")
189 | new_pts_ids = compute_new_points_ids(gaussian_points[reused_pts_ids], np2torch(pts[:, :3]).contiguous(),
190 | radius=self.new_points_radius, device="cuda")
191 | new_pts_ids = torch2np(new_pts_ids)
192 | if new_pts_ids.shape[0] > 0:
193 | cloud_to_add = np2ptcloud(pts[new_pts_ids, :3], pts[new_pts_ids, 3:] / 255.0)
194 | if filter_cloud:
195 | cloud_to_add, _ = cloud_to_add.remove_statistical_outlier(nb_neighbors=40, std_ratio=2.0)
196 | gaussian_model.add_points(cloud_to_add)
197 | gaussian_model._features_dc.requires_grad = False
198 | gaussian_model._features_rest.requires_grad = False
199 | print("Gaussian model size", gaussian_model.get_size())
200 | return new_pts_ids.shape[0]
201 |
202 | def map(self, frame_id: int, estimate_c2w: np.ndarray, gaussian_model: GaussianModel, is_new_submap: bool) -> dict:
203 | """ Calls out the mapping process described in paragraph 3.2
204 | The process goes as follows: seed new gaussians -> add to the submap -> optimize the submap
205 | Args:
206 | frame_id: current keyframe id
207 | estimate_c2w (np.ndarray): The estimated camera-to-world transformation matrix of shape (4x4)
208 | gaussian_model (GaussianModel): The current Gaussian model of the submap
209 | is_new_submap (bool): A boolean flag indicating whether the current frame initiates a new submap
210 | Returns:
211 | opt_dict: Dictionary with statistics about the optimization process
212 | """
213 |
214 | _, gt_color, gt_depth, _ = self.dataset[frame_id]
215 | estimate_w2c = np.linalg.inv(estimate_c2w)
216 |
217 | color_transform = torchvision.transforms.ToTensor()
218 | keyframe = {
219 | "color": color_transform(gt_color).cuda(),
220 | "depth": np2torch(gt_depth, device="cuda"),
221 | "render_settings": get_render_settings(
222 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, estimate_w2c)}
223 |
224 | seeding_mask = self.compute_seeding_mask(gaussian_model, keyframe, is_new_submap)
225 | pts = self.seed_new_gaussians(
226 | gt_color, gt_depth, self.dataset.intrinsics, estimate_c2w, seeding_mask, is_new_submap)
227 |
228 | filter_cloud = isinstance(self.dataset, (TUM_RGBD, ScanNet)) and not is_new_submap
229 |
230 | new_pts_num = self.grow_submap(gt_depth, estimate_c2w, gaussian_model, pts, filter_cloud)
231 |
232 | max_iterations = self.iterations
233 | if is_new_submap:
234 | max_iterations = self.new_submap_iterations
235 | start_time = time.time()
236 | opt_dict = self.optimize_submap([(frame_id, keyframe)] + self.keyframes, gaussian_model, max_iterations)
237 | optimization_time = time.time() - start_time
238 | print("Optimization time: ", optimization_time)
239 |
240 | self.keyframes.append((frame_id, keyframe))
241 |
242 | # Visualise the mapping for the current frame
243 | with torch.no_grad():
244 | render_pkg_vis = render_gaussian_model(gaussian_model, keyframe["render_settings"])
245 | image_vis, depth_vis = render_pkg_vis["color"], render_pkg_vis["depth"]
246 | psnr_value = calc_psnr(image_vis, keyframe["color"]).mean().item()
247 | opt_dict["psnr_render"] = psnr_value
248 | print(f"PSNR this frame: {psnr_value}")
249 | self.logger.vis_mapping_iteration(
250 | frame_id, max_iterations,
251 | image_vis.clone().detach().permute(1, 2, 0),
252 | depth_vis.clone().detach().permute(1, 2, 0),
253 | keyframe["color"].permute(1, 2, 0),
254 | keyframe["depth"].unsqueeze(-1),
255 | seeding_mask=seeding_mask)
256 |
257 | # Log the mapping numbers for the current frame
258 | self.logger.log_mapping_iteration(frame_id, new_pts_num, gaussian_model.get_size(),
259 | optimization_time/max_iterations, opt_dict)
260 | return opt_dict
261 |
--------------------------------------------------------------------------------
/src/entities/tracker.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Mapper class, which is responsible scene mapping: Paper Section 3.4 """
2 | from argparse import ArgumentParser
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 | from scipy.spatial.transform import Rotation as R
9 |
10 | from src.entities.arguments import OptimizationParams
11 | from src.entities.losses import l1_loss
12 | from src.entities.gaussian_model import GaussianModel
13 | from src.entities.logger import Logger
14 | from src.entities.datasets import BaseDataset
15 | from src.entities.visual_odometer import VisualOdometer
16 | from src.utils.gaussian_model_utils import build_rotation
17 | from src.utils.tracker_utils import (compute_camera_opt_params,
18 | extrapolate_poses, multiply_quaternions,
19 | transformation_to_quaternion)
20 | from src.utils.utils import (get_render_settings, np2torch,
21 | render_gaussian_model, torch2np)
22 |
23 |
24 | class Tracker(object):
25 | def __init__(self, config: dict, dataset: BaseDataset, logger: Logger) -> None:
26 | """ Initializes the Tracker with a given configuration, dataset, and logger.
27 | Args:
28 | config: Configuration dictionary specifying hyperparameters and operational settings.
29 | dataset: The dataset object providing access to the sequence of frames.
30 | logger: Logger object for logging the tracking process.
31 | """
32 | self.dataset = dataset
33 | self.logger = logger
34 | self.config = config
35 | self.filter_alpha = self.config["filter_alpha"]
36 | self.filter_outlier_depth = self.config["filter_outlier_depth"]
37 | self.alpha_thre = self.config["alpha_thre"]
38 | self.soft_alpha = self.config["soft_alpha"]
39 | self.mask_invalid_depth_in_color_loss = self.config["mask_invalid_depth"]
40 | self.w_color_loss = self.config["w_color_loss"]
41 | self.transform = torchvision.transforms.ToTensor()
42 | self.opt = OptimizationParams(ArgumentParser(description="Training script parameters"))
43 | self.frame_depth_loss = []
44 | self.frame_color_loss = []
45 | self.odometry_type = self.config["odometry_type"]
46 | self.help_camera_initialization = self.config["help_camera_initialization"]
47 | self.init_err_ratio = self.config["init_err_ratio"]
48 | self.odometer = VisualOdometer(self.dataset.intrinsics, self.config["odometer_method"])
49 |
50 | def compute_losses(self, gaussian_model: GaussianModel, render_settings: dict,
51 | opt_cam_rot: torch.Tensor, opt_cam_trans: torch.Tensor,
52 | gt_color: torch.Tensor, gt_depth: torch.Tensor, depth_mask: torch.Tensor) -> tuple:
53 | """ Computes the tracking losses with respect to ground truth color and depth.
54 | Args:
55 | gaussian_model: The current state of the Gaussian model of the scene.
56 | render_settings: Dictionary containing rendering settings such as image dimensions and camera intrinsics.
57 | opt_cam_rot: Optimizable tensor representing the camera's rotation.
58 | opt_cam_trans: Optimizable tensor representing the camera's translation.
59 | gt_color: Ground truth color image tensor.
60 | gt_depth: Ground truth depth image tensor.
61 | depth_mask: Binary mask indicating valid depth values in the ground truth depth image.
62 | Returns:
63 | A tuple containing losses and renders
64 | """
65 | rel_transform = torch.eye(4).cuda().float()
66 | rel_transform[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None]))[0]
67 | rel_transform[:3, 3] = opt_cam_trans
68 |
69 | pts = gaussian_model.get_xyz()
70 | pts_ones = torch.ones(pts.shape[0], 1).cuda().float()
71 | pts4 = torch.cat((pts, pts_ones), dim=1)
72 | transformed_pts = (rel_transform @ pts4.T).T[:, :3]
73 |
74 | quat = F.normalize(opt_cam_rot[None])
75 | _rotations = multiply_quaternions(gaussian_model.get_rotation(), quat.unsqueeze(0)).squeeze(0)
76 |
77 | render_dict = render_gaussian_model(gaussian_model, render_settings,
78 | override_means_3d=transformed_pts, override_rotations=_rotations)
79 | rendered_color, rendered_depth = render_dict["color"], render_dict["depth"]
80 | alpha_mask = render_dict["alpha"] > self.alpha_thre
81 |
82 | tracking_mask = torch.ones_like(alpha_mask).bool()
83 | tracking_mask &= depth_mask
84 | depth_err = torch.abs(rendered_depth - gt_depth) * depth_mask
85 |
86 | if self.filter_alpha:
87 | tracking_mask &= alpha_mask
88 | if self.filter_outlier_depth and torch.median(depth_err) > 0:
89 | tracking_mask &= depth_err < 50 * torch.median(depth_err)
90 |
91 | color_loss = l1_loss(rendered_color, gt_color, agg="none")
92 | depth_loss = l1_loss(rendered_depth, gt_depth, agg="none") * tracking_mask
93 |
94 | if self.soft_alpha:
95 | alpha = render_dict["alpha"] ** 3
96 | color_loss *= alpha
97 | depth_loss *= alpha
98 | if self.mask_invalid_depth_in_color_loss:
99 | color_loss *= tracking_mask
100 | else:
101 | color_loss *= tracking_mask
102 |
103 | color_loss = color_loss.sum()
104 | depth_loss = depth_loss.sum()
105 |
106 | return color_loss, depth_loss, rendered_color, rendered_depth, alpha_mask
107 |
108 | def track(self, frame_id: int, gaussian_model: GaussianModel, prev_c2ws: np.ndarray) -> np.ndarray:
109 | """
110 | Updates the camera pose estimation for the current frame based on the provided image and depth, using either ground truth poses,
111 | constant speed assumption, or visual odometry.
112 | Args:
113 | frame_id: Index of the current frame being processed.
114 | gaussian_model: The current Gaussian model of the scene.
115 | prev_c2ws: Array containing the camera-to-world transformation matrices for the frames (0, i - 2, i - 1)
116 | Returns:
117 | The updated camera-to-world transformation matrix for the current frame.
118 | """
119 | _, image, depth, gt_c2w = self.dataset[frame_id]
120 |
121 | if (self.help_camera_initialization or self.odometry_type == "odometer") and self.odometer.last_rgbd is None:
122 | _, last_image, last_depth, _ = self.dataset[frame_id - 1]
123 | self.odometer.update_last_rgbd(last_image, last_depth)
124 |
125 | if self.odometry_type == "gt":
126 | return gt_c2w
127 | elif self.odometry_type == "const_speed":
128 | init_c2w = extrapolate_poses(prev_c2ws[1:])
129 | elif self.odometry_type == "odometer":
130 | odometer_rel = self.odometer.estimate_rel_pose(image, depth)
131 | init_c2w = prev_c2ws[-1] @ odometer_rel
132 |
133 | last_c2w = prev_c2ws[-1]
134 | last_w2c = np.linalg.inv(last_c2w)
135 | init_rel = init_c2w @ np.linalg.inv(last_c2w)
136 | init_rel_w2c = np.linalg.inv(init_rel)
137 | reference_w2c = last_w2c
138 | render_settings = get_render_settings(
139 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, reference_w2c)
140 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c)
141 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config)
142 |
143 | gt_color = self.transform(image).cuda()
144 | gt_depth = np2torch(depth, "cuda")
145 | depth_mask = gt_depth > 0.0
146 | gt_trans = np2torch(gt_c2w[:3, 3])
147 | gt_quat = np2torch(R.from_matrix(gt_c2w[:3, :3]).as_quat(canonical=True)[[3, 0, 1, 2]])
148 | num_iters = self.config["iterations"]
149 | current_min_loss = float("inf")
150 |
151 | print(f"\nTracking frame {frame_id}")
152 | # Initial loss check
153 | color_loss, depth_loss, _, _, _ = self.compute_losses(gaussian_model, render_settings, opt_cam_rot,
154 | opt_cam_trans, gt_color, gt_depth, depth_mask)
155 | if len(self.frame_color_loss) > 0 and (
156 | color_loss.item() > self.init_err_ratio * np.median(self.frame_color_loss)
157 | or depth_loss.item() > self.init_err_ratio * np.median(self.frame_depth_loss)
158 | ):
159 | num_iters *= 2
160 | print(f"Higher initial loss, increasing num_iters to {num_iters}")
161 | if self.help_camera_initialization and self.odometry_type != "odometer":
162 | _, last_image, last_depth, _ = self.dataset[frame_id - 1]
163 | self.odometer.update_last_rgbd(last_image, last_depth)
164 | odometer_rel = self.odometer.estimate_rel_pose(image, depth)
165 | init_c2w = last_c2w @ odometer_rel
166 | init_rel = init_c2w @ np.linalg.inv(last_c2w)
167 | init_rel_w2c = np.linalg.inv(init_rel)
168 | opt_cam_rot, opt_cam_trans = compute_camera_opt_params(init_rel_w2c)
169 | gaussian_model.training_setup_camera(opt_cam_rot, opt_cam_trans, self.config)
170 | render_settings = get_render_settings(
171 | self.dataset.width, self.dataset.height, self.dataset.intrinsics, last_w2c)
172 | print(f"re-init with odometer for frame {frame_id}")
173 |
174 | for iter in range(num_iters):
175 | color_loss, depth_loss, _, _, _, = self.compute_losses(
176 | gaussian_model, render_settings, opt_cam_rot, opt_cam_trans, gt_color, gt_depth, depth_mask)
177 |
178 | total_loss = (self.w_color_loss * color_loss + (1 - self.w_color_loss) * depth_loss)
179 | total_loss.backward()
180 | gaussian_model.optimizer.step()
181 | gaussian_model.optimizer.zero_grad(set_to_none=True)
182 |
183 | with torch.no_grad():
184 | if total_loss.item() < current_min_loss:
185 | current_min_loss = total_loss.item()
186 | best_w2c = torch.eye(4)
187 | best_w2c[:3, :3] = build_rotation(F.normalize(opt_cam_rot[None].clone().detach().cpu()))[0]
188 | best_w2c[:3, 3] = opt_cam_trans.clone().detach().cpu()
189 |
190 | cur_quat, cur_trans = F.normalize(opt_cam_rot[None].clone().detach()), opt_cam_trans.clone().detach()
191 | cur_rel_w2c = torch.eye(4)
192 | cur_rel_w2c[:3, :3] = build_rotation(cur_quat)[0]
193 | cur_rel_w2c[:3, 3] = cur_trans
194 | if iter == num_iters - 1:
195 | cur_w2c = torch.from_numpy(reference_w2c) @ best_w2c
196 | else:
197 | cur_w2c = torch.from_numpy(reference_w2c) @ cur_rel_w2c
198 | cur_c2w = torch.inverse(cur_w2c)
199 | cur_cam = transformation_to_quaternion(cur_c2w)
200 | if (gt_quat * cur_cam[:4]).sum() < 0: # for logging purpose
201 | gt_quat *= -1
202 | if iter == num_iters - 1:
203 | self.frame_color_loss.append(color_loss.item())
204 | self.frame_depth_loss.append(depth_loss.item())
205 | self.logger.log_tracking_iteration(
206 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters,
207 | wandb_output=True, print_output=True)
208 | elif iter % 20 == 0:
209 | self.logger.log_tracking_iteration(
210 | frame_id, cur_cam, gt_quat, gt_trans, total_loss, color_loss, depth_loss, iter, num_iters,
211 | wandb_output=False, print_output=True)
212 |
213 | final_c2w = torch.inverse(torch.from_numpy(reference_w2c) @ best_w2c)
214 | final_c2w[-1, :] = torch.tensor([0., 0., 0., 1.], dtype=final_c2w.dtype, device=final_c2w.device)
215 | return torch2np(final_c2w)
216 |
--------------------------------------------------------------------------------
/src/entities/visual_odometer.py:
--------------------------------------------------------------------------------
1 | """ This module includes the Odometer class, which is allows for fast pose estimation from RGBD neighbor frames """
2 | import numpy as np
3 | import open3d as o3d
4 | import open3d.core as o3c
5 |
6 |
7 | class VisualOdometer(object):
8 |
9 | def __init__(self, intrinsics: np.ndarray, method_name="hybrid", device="cuda"):
10 | """ Initializes the visual odometry system with specified intrinsics, method, and device.
11 | Args:
12 | intrinsics: Camera intrinsic parameters.
13 | method_name: The name of the odometry computation method to use ('hybrid' or 'point_to_plane').
14 | device: The computation device ('cuda' or 'cpu').
15 | """
16 | device = "CUDA:0" if device == "cuda" else "CPU:0"
17 | self.device = o3c.Device(device)
18 | self.intrinsics = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64)
19 | self.last_abs_pose = None
20 | self.last_frame = None
21 | self.criteria_list = [
22 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500),
23 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500),
24 | o3d.t.pipelines.odometry.OdometryConvergenceCriteria(500)]
25 | self.setup_method(method_name)
26 | self.max_depth = 10.0
27 | self.depth_scale = 1.0
28 | self.last_rgbd = None
29 |
30 | def setup_method(self, method_name: str) -> None:
31 | """ Sets up the odometry computation method based on the provided method name.
32 | Args:
33 | method_name: The name of the odometry method to use ('hybrid' or 'point_to_plane').
34 | """
35 | if method_name == "hybrid":
36 | self.method = o3d.t.pipelines.odometry.Method.Hybrid
37 | elif method_name == "point_to_plane":
38 | self.method = o3d.t.pipelines.odometry.Method.PointToPlane
39 | else:
40 | raise ValueError("Odometry method does not exist!")
41 |
42 | def update_last_rgbd(self, image: np.ndarray, depth: np.ndarray) -> None:
43 | """ Updates the last RGB-D frame stored in the system with a new RGB-D frame constructed from provided image and depth.
44 | Args:
45 | image: The new RGB image as a numpy ndarray.
46 | depth: The new depth image as a numpy ndarray.
47 | """
48 | self.last_rgbd = o3d.t.geometry.RGBDImage(
49 | o3d.t.geometry.Image(np.ascontiguousarray(
50 | image).astype(np.float32)).to(self.device),
51 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device))
52 |
53 | def estimate_rel_pose(self, image: np.ndarray, depth: np.ndarray, init_transform=np.eye(4)):
54 | """ Estimates the relative pose of the current frame with respect to the last frame using RGB-D odometry.
55 | Args:
56 | image: The current RGB image as a numpy ndarray.
57 | depth: The current depth image as a numpy ndarray.
58 | init_transform: An initial transformation guess as a numpy ndarray. Defaults to the identity matrix.
59 | Returns:
60 | The relative transformation matrix as a numpy ndarray.
61 | """
62 | rgbd = o3d.t.geometry.RGBDImage(
63 | o3d.t.geometry.Image(np.ascontiguousarray(image).astype(np.float32)).to(self.device),
64 | o3d.t.geometry.Image(np.ascontiguousarray(depth).astype(np.float32)).to(self.device))
65 | rel_transform = o3d.t.pipelines.odometry.rgbd_odometry_multi_scale(
66 | self.last_rgbd, rgbd, self.intrinsics, o3c.Tensor(init_transform),
67 | self.depth_scale, self.max_depth, self.criteria_list, self.method)
68 | self.last_rgbd = rgbd.clone()
69 |
70 | # Adjust for the coordinate system difference
71 | rel_transform = rel_transform.transformation.cpu().numpy()
72 | rel_transform[0, [1, 2, 3]] *= -1
73 | rel_transform[1, [0, 2, 3]] *= -1
74 | rel_transform[2, [0, 1, 3]] *= -1
75 |
76 | return rel_transform
77 |
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/src/evaluation/__init__.py
--------------------------------------------------------------------------------
/src/evaluation/evaluate_merged_map.py:
--------------------------------------------------------------------------------
1 | """ This module is responsible for merging submaps. """
2 | from argparse import ArgumentParser
3 |
4 | import faiss
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | from torch.utils.data import Dataset
9 | from tqdm import tqdm
10 |
11 | from src.entities.arguments import OptimizationParams
12 | from src.entities.gaussian_model import GaussianModel
13 | from src.entities.losses import isotropic_loss, l1_loss, ssim
14 | from src.utils.utils import (batch_search_faiss, get_render_settings,
15 | np2ptcloud, render_gaussian_model, torch2np)
16 |
17 |
18 | class RenderFrames(Dataset):
19 | """A dataset class for loading keyframes along with their estimated camera poses and render settings."""
20 | def __init__(self, dataset, render_poses: np.ndarray, height: int, width: int, fx: float, fy: float):
21 | self.dataset = dataset
22 | self.render_poses = render_poses
23 | self.height = height
24 | self.width = width
25 | self.fx = fx
26 | self.fy = fy
27 | self.device = "cuda"
28 | self.stride = 1
29 | if len(dataset) > 1000:
30 | self.stride = len(dataset) // 1000
31 |
32 | def __len__(self) -> int:
33 | return len(self.dataset) // self.stride
34 |
35 | def __getitem__(self, idx):
36 | idx = idx * self.stride
37 | color = (torch.from_numpy(
38 | self.dataset[idx][1]) / 255.0).float().to(self.device)
39 | depth = torch.from_numpy(self.dataset[idx][2]).float().to(self.device)
40 | estimate_c2w = self.render_poses[idx]
41 | estimate_w2c = np.linalg.inv(estimate_c2w)
42 | frame = {
43 | "frame_id": idx,
44 | "color": color,
45 | "depth": depth,
46 | "render_settings": get_render_settings(
47 | self.width, self.height, self.dataset.intrinsics, estimate_w2c)
48 | }
49 | return frame
50 |
51 |
52 | def merge_submaps(submaps_paths: list, radius: float = 0.0001, device: str = "cuda") -> o3d.geometry.PointCloud:
53 | """ Merge submaps into a single point cloud, which is then used for global map refinement.
54 | Args:
55 | segments_paths (list): Folder path of the submaps.
56 | radius (float, optional): Nearest neighbor distance threshold for adding a point. Defaults to 0.0001.
57 | device (str, optional): Defaults to "cuda".
58 |
59 | Returns:
60 | o3d.geometry.PointCloud: merged point cloud
61 | """
62 | pts_index = faiss.IndexFlatL2(3)
63 | if device == "cuda":
64 | pts_index = faiss.index_cpu_to_gpu(
65 | faiss.StandardGpuResources(),
66 | 0,
67 | faiss.IndexIVFFlat(faiss.IndexFlatL2(3), 3, 500, faiss.METRIC_L2))
68 | pts_index.nprobe = 5
69 | merged_pts = []
70 | print("Merging segments")
71 | for submap_path in tqdm(submaps_paths):
72 | gaussian_params = torch.load(submap_path)["gaussian_params"]
73 | current_pts = gaussian_params["xyz"].to(device).float()
74 | pts_index.train(current_pts)
75 | distances, _ = batch_search_faiss(pts_index, current_pts, 8)
76 | neighbor_num = (distances < radius).sum(axis=1).int()
77 | ids_to_include = torch.where(neighbor_num == 0)[0]
78 | pts_index.add(current_pts[ids_to_include])
79 | merged_pts.append(current_pts[ids_to_include])
80 | pts = torch2np(torch.vstack(merged_pts))
81 | pt_cloud = np2ptcloud(pts, np.zeros_like(pts))
82 |
83 | # Downsampling if the total number of points is too large
84 | if len(pt_cloud.points) > 1_000_000:
85 | voxel_size = 0.02
86 | pt_cloud = pt_cloud.voxel_down_sample(voxel_size)
87 | print(f"Downsampled point cloud to {len(pt_cloud.points)} points")
88 | filtered_pt_cloud, _ = pt_cloud.remove_statistical_outlier(nb_neighbors=40, std_ratio=3.0)
89 | del pts_index
90 | return filtered_pt_cloud
91 |
92 |
93 | def refine_global_map(pt_cloud: o3d.geometry.PointCloud, training_frames: list, max_iterations: int) -> GaussianModel:
94 | """Refines a global map based on the merged point cloud and training keyframes frames.
95 | Args:
96 | pt_cloud (o3d.geometry.PointCloud): The merged point cloud used for refinement.
97 | training_frames (list): A list of training frames for map refinement.
98 | max_iterations (int): The maximum number of iterations to perform for refinement.
99 | Returns:
100 | GaussianModel: The refined global map as a Gaussian model.
101 | """
102 | opt_params = OptimizationParams(ArgumentParser(description="Training script parameters"))
103 |
104 | gaussian_model = GaussianModel(3)
105 | gaussian_model.active_sh_degree = 3
106 | gaussian_model.training_setup(opt_params)
107 | gaussian_model.add_points(pt_cloud)
108 |
109 | iteration = 0
110 | for iteration in tqdm(range(max_iterations), desc="Refinement"):
111 | training_frame = next(training_frames)
112 | gt_color, gt_depth, render_settings = (
113 | training_frame["color"].squeeze(0),
114 | training_frame["depth"].squeeze(0),
115 | training_frame["render_settings"])
116 |
117 | render_dict = render_gaussian_model(gaussian_model, render_settings)
118 | rendered_color, rendered_depth = (render_dict["color"].permute(1, 2, 0), render_dict["depth"])
119 |
120 | reg_loss = isotropic_loss(gaussian_model.get_scaling())
121 | depth_mask = (gt_depth > 0)
122 | color_loss = (1.0 - opt_params.lambda_dssim) * l1_loss(
123 | rendered_color[depth_mask, :], gt_color[depth_mask, :]
124 | ) + opt_params.lambda_dssim * (1.0 - ssim(rendered_color, gt_color))
125 | depth_loss = l1_loss(
126 | rendered_depth[:, depth_mask], gt_depth[depth_mask])
127 |
128 | total_loss = color_loss + depth_loss + reg_loss
129 | total_loss.backward()
130 |
131 | with torch.no_grad():
132 | if iteration % 500 == 0:
133 | prune_mask = (gaussian_model.get_opacity() < 0.005).squeeze()
134 | gaussian_model.prune_points(prune_mask)
135 |
136 | # Optimizer step
137 | gaussian_model.optimizer.step()
138 | gaussian_model.optimizer.zero_grad(set_to_none=True)
139 | iteration += 1
140 |
141 | return gaussian_model
142 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_reconstruction.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | import trimesh
9 | from evaluate_3d_reconstruction import run_evaluation
10 | from tqdm import tqdm
11 |
12 |
13 | def normalize(x):
14 | return x / np.linalg.norm(x)
15 |
16 |
17 | def get_align_transformation(rec_meshfile, gt_meshfile):
18 | """
19 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh.
20 | """
21 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
22 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
23 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices)
24 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices)
25 | trans_init = np.eye(4)
26 | threshold = 0.1
27 | reg_p2p = o3d.pipelines.registration.registration_icp(
28 | o3d_rec_pc,
29 | o3d_gt_pc,
30 | threshold,
31 | trans_init,
32 | o3d.pipelines.registration.TransformationEstimationPointToPoint(),
33 | )
34 | transformation = reg_p2p.transformation
35 | return transformation
36 |
37 |
38 | def check_proj(points, W, H, fx, fy, cx, cy, c2w):
39 | """
40 | Check if points can be projected into the camera view.
41 |
42 | Returns:
43 | bool: True if there are points can be projected
44 |
45 | """
46 | c2w = c2w.copy()
47 | c2w[:3, 1] *= -1.0
48 | c2w[:3, 2] *= -1.0
49 | points = torch.from_numpy(points).cuda().clone()
50 | w2c = np.linalg.inv(c2w)
51 | w2c = torch.from_numpy(w2c).cuda().float()
52 | K = torch.from_numpy(
53 | np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]).reshape(3, 3)
54 | ).cuda()
55 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
56 | homo_points = (
57 | torch.cat([points, ones], dim=1).reshape(-1, 4, 1).cuda().float()
58 | ) # (N, 4)
59 | cam_cord_homo = w2c @ homo_points # (N, 4, 1)=(4,4)*(N, 4, 1)
60 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1)
61 | cam_cord[:, 0] *= -1
62 | uv = K.float() @ cam_cord.float()
63 | z = uv[:, -1:] + 1e-5
64 | uv = uv[:, :2] / z
65 | uv = uv.float().squeeze(-1).cpu().numpy()
66 | edge = 0
67 | mask = (
68 | (0 <= -z[:, 0, 0].cpu().numpy())
69 | & (uv[:, 0] < W - edge)
70 | & (uv[:, 0] > edge)
71 | & (uv[:, 1] < H - edge)
72 | & (uv[:, 1] > edge)
73 | )
74 | return mask.sum() > 0
75 |
76 |
77 | def get_cam_position(gt_meshfile):
78 | mesh_gt = trimesh.load(gt_meshfile)
79 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt)
80 | extents[2] *= 0.7
81 | extents[1] *= 0.7
82 | extents[0] *= 0.3
83 | transform = np.linalg.inv(to_origin)
84 | transform[2, 3] += 0.4
85 | return extents, transform
86 |
87 |
88 | def viewmatrix(z, up, pos):
89 | vec2 = normalize(z)
90 | vec1_avg = up
91 | vec0 = normalize(np.cross(vec1_avg, vec2))
92 | vec1 = normalize(np.cross(vec2, vec0))
93 | m = np.stack([vec0, vec1, vec2, pos], 1)
94 | return m
95 |
96 |
97 | def calc_2d_metric(
98 | rec_meshfile, gt_meshfile, unseen_gt_pointcloud_file, align=True, n_imgs=1000
99 | ):
100 | """
101 | 2D reconstruction metric, depth L1 loss.
102 |
103 | """
104 | H = 500
105 | W = 500
106 | focal = 300
107 | fx = focal
108 | fy = focal
109 | cx = H / 2.0 - 0.5
110 | cy = W / 2.0 - 0.5
111 |
112 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
113 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
114 | pc_unseen = np.load(unseen_gt_pointcloud_file)
115 | if align:
116 | transformation = get_align_transformation(rec_meshfile, gt_meshfile)
117 | rec_mesh = rec_mesh.transform(transformation)
118 |
119 | # get vacant area inside the room
120 | extents, transform = get_cam_position(gt_meshfile)
121 |
122 | vis = o3d.visualization.Visualizer()
123 | vis.create_window(width=W, height=H, visible=False)
124 | vis.get_render_option().mesh_show_back_face = True
125 | errors = []
126 | for i in tqdm(range(n_imgs)):
127 | while True:
128 | # sample view, and check if unseen region is not inside the camera view
129 | # if inside, then needs to resample
130 | up = [0, 0, -1]
131 | origin = trimesh.sample.volume_rectangular(
132 | extents, 1, transform=transform)
133 | origin = origin.reshape(-1)
134 | tx = round(random.uniform(-10000, +10000), 2)
135 | ty = round(random.uniform(-10000, +10000), 2)
136 | tz = round(random.uniform(-10000, +10000), 2)
137 | # will be normalized, so sample from range [0.0,1.0]
138 | target = [tx, ty, tz]
139 | target = np.array(target) - np.array(origin)
140 | c2w = viewmatrix(target, up, origin)
141 | tmp = np.eye(4)
142 | tmp[:3, :] = c2w # sample translations
143 | c2w = tmp
144 | # if unseen points are projected into current view (c2w)
145 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w)
146 | if ~seen:
147 | break
148 |
149 | param = o3d.camera.PinholeCameraParameters()
150 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array
151 |
152 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic(
153 | W, H, fx, fy, cx, cy)
154 |
155 | ctr = vis.get_view_control()
156 | ctr.set_constant_z_far(20)
157 | ctr.convert_from_pinhole_camera_parameters(param, True)
158 |
159 | vis.add_geometry(
160 | gt_mesh,
161 | reset_bounding_box=True,
162 | )
163 | ctr.convert_from_pinhole_camera_parameters(param, True)
164 | vis.poll_events()
165 | vis.update_renderer()
166 | gt_depth = vis.capture_depth_float_buffer(True)
167 | gt_depth = np.asarray(gt_depth)
168 | vis.remove_geometry(
169 | gt_mesh,
170 | reset_bounding_box=True,
171 | )
172 |
173 | vis.add_geometry(
174 | rec_mesh,
175 | reset_bounding_box=True,
176 | )
177 | ctr.convert_from_pinhole_camera_parameters(param, True)
178 | vis.poll_events()
179 | vis.update_renderer()
180 | ours_depth = vis.capture_depth_float_buffer(True)
181 | ours_depth = np.asarray(ours_depth)
182 | vis.remove_geometry(
183 | rec_mesh,
184 | reset_bounding_box=True,
185 | )
186 |
187 | # filter missing surfaces where depth is 0
188 | if (ours_depth > 0).sum() > 0:
189 | errors += [
190 | np.abs(gt_depth[ours_depth > 0] -
191 | ours_depth[ours_depth > 0]).mean()
192 | ]
193 | else:
194 | continue
195 |
196 | errors = np.array(errors)
197 | return {"depth l1": errors.mean() * 100}
198 |
199 |
200 | def clean_mesh(mesh):
201 | mesh_tri = trimesh.Trimesh(
202 | vertices=np.asarray(mesh.vertices),
203 | faces=np.asarray(mesh.triangles),
204 | vertex_colors=np.asarray(mesh.vertex_colors),
205 | )
206 | components = trimesh.graph.connected_components(
207 | edges=mesh_tri.edges_sorted)
208 |
209 | min_len = 200
210 | components_to_keep = [c for c in components if len(c) >= min_len]
211 |
212 | new_vertices = []
213 | new_faces = []
214 | new_colors = []
215 | vertex_count = 0
216 | for component in components_to_keep:
217 | vertices = mesh_tri.vertices[component]
218 | colors = mesh_tri.visual.vertex_colors[component]
219 |
220 | # Create a mapping from old vertex indices to new vertex indices
221 | index_mapping = {
222 | old_idx: vertex_count + new_idx for new_idx, old_idx in enumerate(component)
223 | }
224 | vertex_count += len(vertices)
225 |
226 | # Select faces that are part of the current connected component and update vertex indices
227 | faces_in_component = mesh_tri.faces[
228 | np.any(np.isin(mesh_tri.faces, component), axis=1)
229 | ]
230 | reindexed_faces = np.vectorize(index_mapping.get)(faces_in_component)
231 |
232 | new_vertices.extend(vertices)
233 | new_faces.extend(reindexed_faces)
234 | new_colors.extend(colors)
235 |
236 | cleaned_mesh_tri = trimesh.Trimesh(vertices=new_vertices, faces=new_faces)
237 | cleaned_mesh_tri.visual.vertex_colors = np.array(new_colors)
238 |
239 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.nondegenerate_faces())
240 | cleaned_mesh_tri.update_faces(cleaned_mesh_tri.unique_faces())
241 | print(
242 | f"Mesh cleaning (before/after), vertices: {len(mesh_tri.vertices)}/{len(cleaned_mesh_tri.vertices)}, faces: {len(mesh_tri.faces)}/{len(cleaned_mesh_tri.faces)}")
243 |
244 | cleaned_mesh = o3d.geometry.TriangleMesh(
245 | o3d.utility.Vector3dVector(cleaned_mesh_tri.vertices),
246 | o3d.utility.Vector3iVector(cleaned_mesh_tri.faces),
247 | )
248 | vertex_colors = np.asarray(cleaned_mesh_tri.visual.vertex_colors)[
249 | :, :3] / 255.0
250 | cleaned_mesh.vertex_colors = o3d.utility.Vector3dVector(
251 | vertex_colors.astype(np.float64)
252 | )
253 |
254 | return cleaned_mesh
255 |
256 |
257 | def evaluate_reconstruction(
258 | mesh_path: Path,
259 | gt_mesh_path: Path,
260 | unseen_pc_path: Path,
261 | output_path: Path,
262 | to_clean=True,
263 | ):
264 | if to_clean:
265 | mesh = o3d.io.read_triangle_mesh(str(mesh_path))
266 | print(mesh)
267 | cleaned_mesh = clean_mesh(mesh)
268 | cleaned_mesh_path = output_path / "mesh" / "cleaned_mesh.ply"
269 | o3d.io.write_triangle_mesh(str(cleaned_mesh_path), cleaned_mesh)
270 | mesh_path = cleaned_mesh_path
271 |
272 | result_3d = run_evaluation(
273 | str(mesh_path.parts[-1]),
274 | str(mesh_path.parent),
275 | str(gt_mesh_path).split("/")[-1].split(".")[0],
276 | distance_thresh=0.01,
277 | full_path_to_gt_ply=gt_mesh_path,
278 | icp_align=True,
279 | )
280 |
281 | try:
282 | result_2d = calc_2d_metric(str(mesh_path), str(gt_mesh_path), str(unseen_pc_path), align=True, n_imgs=1000)
283 | except Exception as e:
284 | print(e)
285 | result_2d = {"depth l1": None}
286 |
287 | result = {**result_3d, **result_2d}
288 | with open(str(output_path / "reconstruction_metrics.json"), "w") as f:
289 | json.dump(result, f)
290 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_trajectory.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 |
8 | class NumpyFloatValuesEncoder(json.JSONEncoder):
9 | def default(self, obj):
10 | if isinstance(obj, np.float32):
11 | return float(obj)
12 | return JSONEncoder.default(self, obj)
13 |
14 |
15 | def align(model, data):
16 | """Align two trajectories using the method of Horn (closed-form).
17 |
18 | Input:
19 | model -- first trajectory (3xn)
20 | data -- second trajectory (3xn)
21 |
22 | Output:
23 | rot -- rotation matrix (3x3)
24 | trans -- translation vector (3x1)
25 | trans_error -- translational error per point (1xn)
26 |
27 | """
28 | np.set_printoptions(precision=3, suppress=True)
29 | model_zerocentered = model - model.mean(1)
30 | data_zerocentered = data - data.mean(1)
31 |
32 | W = np.zeros((3, 3))
33 | for column in range(model.shape[1]):
34 | W += np.outer(model_zerocentered[:,
35 | column], data_zerocentered[:, column])
36 | U, d, Vh = np.linalg.linalg.svd(W.transpose())
37 | S = np.matrix(np.identity(3))
38 | if (np.linalg.det(U) * np.linalg.det(Vh) < 0):
39 | S[2, 2] = -1
40 | rot = U * S * Vh
41 | trans = data.mean(1) - rot * model.mean(1)
42 |
43 | model_aligned = rot * model + trans
44 | alignment_error = model_aligned - data
45 |
46 | trans_error = np.sqrt(
47 | np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0]
48 |
49 | return rot, trans, trans_error
50 |
51 |
52 | def align_trajectories(t_pred: np.ndarray, t_gt: np.ndarray):
53 | """
54 | Args:
55 | t_pred: (n, 3) translations
56 | t_gt: (n, 3) translations
57 | Returns:
58 | t_align: (n, 3) aligned translations
59 | """
60 | t_align = np.matrix(t_pred).transpose()
61 | R, t, _ = align(t_align, np.matrix(t_gt).transpose())
62 | t_align = R * t_align + t
63 | t_align = np.asarray(t_align).T
64 | return t_align
65 |
66 |
67 | def pose_error(t_pred: np.ndarray, t_gt: np.ndarray, align=False):
68 | """
69 | Args:
70 | t_pred: (n, 3) translations
71 | t_gt: (n, 3) translations
72 | Returns:
73 | dict: error dict
74 | """
75 | n = t_pred.shape[0]
76 | trans_error = np.linalg.norm(t_pred - t_gt, axis=1)
77 | return {
78 | "compared_pose_pairs": n,
79 | "rmse": np.sqrt(np.dot(trans_error, trans_error) / n),
80 | "mean": np.mean(trans_error),
81 | "median": np.median(trans_error),
82 | "std": np.std(trans_error),
83 | "min": np.min(trans_error),
84 | "max": np.max(trans_error)
85 | }
86 |
87 |
88 | def plot_2d(pts, ax=None, color="green", label="None", title="3D Trajectory in 2D"):
89 | if ax is None:
90 | _, ax = plt.subplots()
91 | ax.scatter(pts[:, 0], pts[:, 1], color=color, label=label, s=0.7)
92 | ax.set_xlabel('X')
93 | ax.set_ylabel('Y')
94 | ax.set_title(title)
95 | return ax
96 |
97 |
98 | def evaluate_trajectory(estimated_poses: np.ndarray, gt_poses: np.ndarray, output_path: Path):
99 | output_path.mkdir(exist_ok=True, parents=True)
100 | # Truncate the ground truth trajectory if needed
101 | if gt_poses.shape[0] > estimated_poses.shape[0]:
102 | gt_poses = gt_poses[:estimated_poses.shape[0]]
103 | valid = ~np.any(np.isnan(gt_poses) |
104 | np.isinf(gt_poses), axis=(1, 2))
105 | gt_poses = gt_poses[valid]
106 | estimated_poses = estimated_poses[valid]
107 |
108 | gt_t = gt_poses[:, :3, 3]
109 | estimated_t = estimated_poses[:, :3, 3]
110 | estimated_t_aligned = align_trajectories(estimated_t, gt_t)
111 | ate = pose_error(estimated_t, gt_t)
112 | ate_aligned = pose_error(estimated_t_aligned, gt_t)
113 |
114 | with open(str(output_path / "ate.json"), "w") as f:
115 | f.write(json.dumps(ate, cls=NumpyFloatValuesEncoder))
116 |
117 | with open(str(output_path / "ate_aligned.json"), "w") as f:
118 | f.write(json.dumps(ate_aligned, cls=NumpyFloatValuesEncoder))
119 |
120 | ate_rmse, ate_rmse_aligned = ate["rmse"], ate_aligned["rmse"]
121 | ax = plot_2d(
122 | estimated_t, label=f"ate-rmse: {round(ate_rmse * 100, 2)} cm", color="orange")
123 | ax = plot_2d(estimated_t_aligned, ax,
124 | label=f"ate-rsme (aligned): {round(ate_rmse_aligned * 100, 2)} cm", color="lightskyblue")
125 | ax = plot_2d(gt_t, ax, label="GT", color="green")
126 | ax.legend()
127 | plt.savefig(str(output_path / "eval_trajectory.png"), dpi=300)
128 | plt.close()
129 | print(
130 | f"ATE-RMSE: {ate_rmse * 100:.2f} cm, ATE-RMSE (aligned): {ate_rmse_aligned * 100:.2f} cm")
131 |
--------------------------------------------------------------------------------
/src/evaluation/evaluator.py:
--------------------------------------------------------------------------------
1 | """ This module is responsible for evaluating rendering, trajectory and reconstruction metrics"""
2 | import traceback
3 | from argparse import ArgumentParser
4 | from copy import deepcopy
5 | from itertools import cycle
6 | from pathlib import Path
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import open3d as o3d
11 | import torch
12 | import torchvision
13 | from pytorch_msssim import ms_ssim
14 | from scipy.ndimage import median_filter
15 | from torch.utils.data import DataLoader
16 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
17 | from torchvision.utils import save_image
18 | from tqdm import tqdm
19 |
20 | from src.entities.arguments import OptimizationParams
21 | from src.entities.datasets import get_dataset
22 | from src.entities.gaussian_model import GaussianModel
23 | from src.evaluation.evaluate_merged_map import (RenderFrames, merge_submaps,
24 | refine_global_map)
25 | from src.evaluation.evaluate_reconstruction import evaluate_reconstruction
26 | from src.evaluation.evaluate_trajectory import evaluate_trajectory
27 | from src.utils.io_utils import load_config, save_dict_to_json
28 | from src.utils.mapper_utils import calc_psnr
29 | from src.utils.utils import (get_render_settings, np2torch,
30 | render_gaussian_model, setup_seed, torch2np)
31 |
32 |
33 | def filter_depth_outliers(depth_map, kernel_size=3, threshold=1.0):
34 | median_filtered = median_filter(depth_map, size=kernel_size)
35 | abs_diff = np.abs(depth_map - median_filtered)
36 | outlier_mask = abs_diff > threshold
37 | depth_map_filtered = np.where(outlier_mask, median_filtered, depth_map)
38 | return depth_map_filtered
39 |
40 |
41 | class Evaluator(object):
42 |
43 | def __init__(self, checkpoint_path, config_path, config=None, save_render=False) -> None:
44 | if config is None:
45 | self.config = load_config(config_path)
46 | else:
47 | self.config = config
48 | setup_seed(self.config["seed"])
49 |
50 | self.checkpoint_path = Path(checkpoint_path)
51 | self.device = "cuda"
52 | self.dataset = get_dataset(self.config["dataset_name"])({**self.config["data"], **self.config["cam"]})
53 | self.scene_name = self.config["data"]["scene_name"]
54 | self.dataset_name = self.config["dataset_name"]
55 | self.gt_poses = np.array(self.dataset.poses)
56 | self.fx, self.fy = self.dataset.intrinsics[0, 0], self.dataset.intrinsics[1, 1]
57 | self.cx, self.cy = self.dataset.intrinsics[0,
58 | 2], self.dataset.intrinsics[1, 2]
59 | self.width, self.height = self.dataset.width, self.dataset.height
60 | self.save_render = save_render
61 | if self.save_render:
62 | self.render_path = self.checkpoint_path / "rendered_imgs"
63 | self.render_path.mkdir(exist_ok=True, parents=True)
64 |
65 | self.estimated_c2w = torch2np(torch.load(self.checkpoint_path / "estimated_c2w.ckpt", map_location=self.device))
66 | self.submaps_paths = sorted(list((self.checkpoint_path / "submaps").glob('*')))
67 |
68 | def run_trajectory_eval(self):
69 | """ Evaluates the estimated trajectory """
70 | print("Running trajectory evaluation...")
71 | evaluate_trajectory(self.estimated_c2w, self.gt_poses, self.checkpoint_path)
72 |
73 | def run_rendering_eval(self):
74 | """ Renderes the submaps and evaluates the PSNR, LPIPS, SSIM and depth L1 metrics."""
75 | print("Running rendering evaluation...")
76 | psnr, lpips, ssim, depth_l1 = [], [], [], []
77 | color_transform = torchvision.transforms.ToTensor()
78 | lpips_model = LearnedPerceptualImagePatchSimilarity(
79 | net_type='alex', normalize=True).to(self.device)
80 | opt_settings = OptimizationParams(ArgumentParser(
81 | description="Training script parameters"))
82 |
83 | submaps_paths = sorted(
84 | list((self.checkpoint_path / "submaps").glob('*.ckpt')))
85 | for submap_path in tqdm(submaps_paths):
86 | submap = torch.load(submap_path, map_location=self.device)
87 | gaussian_model = GaussianModel()
88 | gaussian_model.training_setup(opt_settings)
89 | gaussian_model.restore_from_params(
90 | submap["gaussian_params"], opt_settings)
91 |
92 | for keyframe_id in submap["submap_keyframes"]:
93 |
94 | _, gt_color, gt_depth, _ = self.dataset[keyframe_id]
95 | gt_color = color_transform(gt_color).to(self.device)
96 | gt_depth = np2torch(gt_depth).to(self.device)
97 |
98 | estimate_c2w = self.estimated_c2w[keyframe_id]
99 | estimate_w2c = np.linalg.inv(estimate_c2w)
100 | render_dict = render_gaussian_model(
101 | gaussian_model, get_render_settings(self.width, self.height, self.dataset.intrinsics, estimate_w2c))
102 | rendered_color, rendered_depth = render_dict["color"].detach(
103 | ), render_dict["depth"][0].detach()
104 | rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0)
105 | if self.save_render:
106 | torchvision.utils.save_image(
107 | rendered_color, self.render_path / f"{keyframe_id:05d}.png")
108 |
109 | mse_loss = torch.nn.functional.mse_loss(
110 | rendered_color, gt_color)
111 | psnr_value = (-10. * torch.log10(mse_loss)).item()
112 | lpips_value = lpips_model(
113 | rendered_color[None], gt_color[None]).item()
114 | ssim_value = ms_ssim(
115 | rendered_color[None], gt_color[None], data_range=1.0, size_average=True).item()
116 | depth_l1_value = torch.abs(
117 | (rendered_depth - gt_depth)).mean().item()
118 |
119 | psnr.append(psnr_value)
120 | lpips.append(lpips_value)
121 | ssim.append(ssim_value)
122 | depth_l1.append(depth_l1_value)
123 |
124 | num_frames = len(psnr)
125 | metrics = {
126 | "psnr": sum(psnr) / num_frames,
127 | "lpips": sum(lpips) / num_frames,
128 | "ssim": sum(ssim) / num_frames,
129 | "depth_l1_train_view": sum(depth_l1) / num_frames,
130 | "num_renders": num_frames
131 | }
132 | save_dict_to_json(metrics, "rendering_metrics.json",
133 | directory=self.checkpoint_path)
134 |
135 | x = list(range(len(psnr)))
136 | fig, axs = plt.subplots(1, 3, figsize=(12, 4))
137 | axs[0].plot(x, psnr, label="PSNR")
138 | axs[0].legend()
139 | axs[0].set_title("PSNR")
140 | axs[1].plot(x, ssim, label="SSIM")
141 | axs[1].legend()
142 | axs[1].set_title("SSIM")
143 | axs[2].plot(x, depth_l1, label="Depth L1 (Train view)")
144 | axs[2].legend()
145 | axs[2].set_title("Depth L1 Render")
146 | plt.tight_layout()
147 | plt.savefig(str(self.checkpoint_path /
148 | "rendering_metrics.png"), dpi=300)
149 | print(metrics)
150 |
151 | def run_reconstruction_eval(self):
152 | """ Reconstructs the mesh, evaluates it, render novel view depth maps from it, and evaluates them as well """
153 | print("Running reconstruction evaluation...")
154 | if self.config["dataset_name"] != "replica":
155 | print("dataset is not supported, skipping reconstruction eval")
156 | return
157 | (self.checkpoint_path / "mesh").mkdir(exist_ok=True, parents=True)
158 | opt_settings = OptimizationParams(ArgumentParser(
159 | description="Training script parameters"))
160 | intrinsic = o3d.camera.PinholeCameraIntrinsic(
161 | self.width, self.height, self.fx, self.fy, self.cx, self.cy)
162 | scale = 1.0
163 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
164 | voxel_length=5.0 * scale / 512.0,
165 | sdf_trunc=0.04 * scale,
166 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8)
167 |
168 | submaps_paths = sorted(list((self.checkpoint_path / "submaps").glob('*.ckpt')))
169 | for submap_path in tqdm(submaps_paths):
170 | submap = torch.load(submap_path, map_location=self.device)
171 | gaussian_model = GaussianModel()
172 | gaussian_model.training_setup(opt_settings)
173 | gaussian_model.restore_from_params(
174 | submap["gaussian_params"], opt_settings)
175 |
176 | for keyframe_id in submap["submap_keyframes"]:
177 | estimate_c2w = self.estimated_c2w[keyframe_id]
178 | estimate_w2c = np.linalg.inv(estimate_c2w)
179 | render_dict = render_gaussian_model(
180 | gaussian_model, get_render_settings(self.width, self.height, self.dataset.intrinsics, estimate_w2c))
181 | rendered_color, rendered_depth = render_dict["color"].detach(
182 | ), render_dict["depth"][0].detach()
183 | rendered_color = torch.clamp(rendered_color, min=0.0, max=1.0)
184 |
185 | rendered_color = (
186 | torch2np(rendered_color.permute(1, 2, 0)) * 255).astype(np.uint8)
187 | rendered_depth = torch2np(rendered_depth)
188 | rendered_depth = filter_depth_outliers(
189 | rendered_depth, kernel_size=20, threshold=0.1)
190 |
191 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
192 | o3d.geometry.Image(np.ascontiguousarray(rendered_color)),
193 | o3d.geometry.Image(rendered_depth),
194 | depth_scale=scale,
195 | depth_trunc=30,
196 | convert_rgb_to_intensity=False)
197 | volume.integrate(rgbd, intrinsic, estimate_w2c)
198 |
199 | o3d_mesh = volume.extract_triangle_mesh()
200 | compensate_vector = (-0.0 * scale / 512.0, 2.5 *
201 | scale / 512.0, -2.5 * scale / 512.0)
202 | o3d_mesh = o3d_mesh.translate(compensate_vector)
203 | file_name = self.checkpoint_path / "mesh" / "final_mesh.ply"
204 | o3d.io.write_triangle_mesh(str(file_name), o3d_mesh)
205 | evaluate_reconstruction(file_name,
206 | f"data/Replica-SLAM/cull_replica/{self.scene_name}.ply",
207 | f"data/Replica-SLAM/cull_replica/{self.scene_name}_pc_unseen.npy",
208 | self.checkpoint_path)
209 |
210 | def run_global_map_eval(self):
211 | """ Merges the map, evaluates it over training and novel views """
212 | print("Running global map evaluation...")
213 |
214 | training_frames = RenderFrames(self.dataset, self.estimated_c2w, self.height, self.width, self.fx, self.fy)
215 | training_frames = DataLoader(training_frames, batch_size=1, shuffle=True)
216 | training_frames = cycle(training_frames)
217 | merged_cloud = merge_submaps(self.submaps_paths)
218 | refined_merged_gaussian_model = refine_global_map(merged_cloud, training_frames, 10000)
219 | ply_path = self.checkpoint_path / f"{self.config['data']['scene_name']}_global_map.ply"
220 | refined_merged_gaussian_model.save_ply(ply_path)
221 | print(f"Refined global map saved to {ply_path}")
222 |
223 | if self.config["dataset_name"] != "scannetpp":
224 | return # "NVS evaluation only supported for scannetpp"
225 |
226 | eval_config = deepcopy(self.config)
227 | print(f"✨ Eval NVS for scene {self.config['data']['scene_name']}...")
228 | (self.checkpoint_path / "nvs_eval").mkdir(exist_ok=True, parents=True)
229 | eval_config["data"]["use_train_split"] = False
230 | test_set = get_dataset(eval_config["dataset_name"])({**eval_config["data"], **eval_config["cam"]})
231 | test_poses = torch.stack([torch.from_numpy(test_set[i][3]) for i in range(len(test_set))], dim=0)
232 | test_frames = RenderFrames(test_set, test_poses, self.height, self.width, self.fx, self.fy)
233 |
234 | psnr_list = []
235 | for i in tqdm(range(len(test_set))):
236 | gt_color, _, render_settings = (
237 | test_frames[i]["color"],
238 | test_frames[i]["depth"],
239 | test_frames[i]["render_settings"])
240 | render_dict = render_gaussian_model(refined_merged_gaussian_model, render_settings)
241 | rendered_color, _ = (render_dict["color"].permute(1, 2, 0), render_dict["depth"],)
242 | rendered_color = torch.clip(rendered_color, 0, 1)
243 | save_image(rendered_color.permute(2, 0, 1), self.checkpoint_path / f"nvs_eval/{i:04d}.jpg")
244 | psnr = calc_psnr(gt_color, rendered_color).mean()
245 | psnr_list.append(psnr.item())
246 | print(f"PSNR List: {psnr_list}")
247 | print(f"Avg. NVS PSNR: {np.array(psnr_list).mean()}")
248 |
249 | def run(self):
250 | """ Runs the general evaluation flow """
251 |
252 | print("Starting evaluation...🍺")
253 |
254 | try:
255 | self.run_trajectory_eval()
256 | except Exception:
257 | print("Could not run trajectory eval")
258 | traceback.print_exc()
259 |
260 | try:
261 | self.run_rendering_eval()
262 | except Exception:
263 | print("Could not run rendering eval")
264 | traceback.print_exc()
265 |
266 | try:
267 | self.run_reconstruction_eval()
268 | except Exception:
269 | print("Could not run reconstruction eval")
270 | traceback.print_exc()
271 |
272 | try:
273 | self.run_global_map_eval()
274 | except Exception:
275 | print("Could not run global map eval")
276 | traceback.print_exc()
277 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VladimirYugay/Gaussian-SLAM/eaec10d73ce7511563882b8856896e06d1f804e3/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/gaussian_model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | import numpy as np
24 | import torch
25 | import torch.nn.functional as F
26 |
27 | C0 = 0.28209479177387814
28 | C1 = 0.4886025119029199
29 | C2 = [
30 | 1.0925484305920792,
31 | -1.0925484305920792,
32 | 0.31539156525252005,
33 | -1.0925484305920792,
34 | 0.5462742152960396
35 | ]
36 | C3 = [
37 | -0.5900435899266435,
38 | 2.890611442640554,
39 | -0.4570457994644658,
40 | 0.3731763325901154,
41 | -0.4570457994644658,
42 | 1.445305721320277,
43 | -0.5900435899266435
44 | ]
45 | C4 = [
46 | 2.5033429417967046,
47 | -1.7701307697799304,
48 | 0.9461746957575601,
49 | -0.6690465435572892,
50 | 0.10578554691520431,
51 | -0.6690465435572892,
52 | 0.47308734787878004,
53 | -1.7701307697799304,
54 | 0.6258357354491761,
55 | ]
56 |
57 |
58 | def eval_sh(deg, sh, dirs):
59 | """
60 | Evaluate spherical harmonics at unit directions
61 | using hardcoded SH polynomials.
62 | Works with torch/np/jnp.
63 | ... Can be 0 or more batch dimensions.
64 | Args:
65 | deg: int SH deg. Currently, 0-3 supported
66 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
67 | dirs: jnp.ndarray unit directions [..., 3]
68 | Returns:
69 | [..., C]
70 | """
71 | assert deg <= 4 and deg >= 0
72 | coeff = (deg + 1) ** 2
73 | assert sh.shape[-1] >= coeff
74 |
75 | result = C0 * sh[..., 0]
76 | if deg > 0:
77 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
78 | result = (result -
79 | C1 * y * sh[..., 1] +
80 | C1 * z * sh[..., 2] -
81 | C1 * x * sh[..., 3])
82 |
83 | if deg > 1:
84 | xx, yy, zz = x * x, y * y, z * z
85 | xy, yz, xz = x * y, y * z, x * z
86 | result = (result +
87 | C2[0] * xy * sh[..., 4] +
88 | C2[1] * yz * sh[..., 5] +
89 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
90 | C2[3] * xz * sh[..., 7] +
91 | C2[4] * (xx - yy) * sh[..., 8])
92 |
93 | if deg > 2:
94 | result = (result +
95 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
96 | C3[1] * xy * z * sh[..., 10] +
97 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] +
98 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
99 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
100 | C3[5] * z * (xx - yy) * sh[..., 14] +
101 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
102 |
103 | if deg > 3:
104 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
105 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
106 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
107 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
108 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
109 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
110 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
111 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
112 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
113 | return result
114 |
115 |
116 | def RGB2SH(rgb):
117 | return (rgb - 0.5) / C0
118 |
119 |
120 | def SH2RGB(sh):
121 | return sh * C0 + 0.5
122 |
123 |
124 | def inverse_sigmoid(x):
125 | return torch.log(x/(1-x))
126 |
127 |
128 | def get_expon_lr_func(
129 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
130 | ):
131 | """
132 | Copied from Plenoxels
133 |
134 | Continuous learning rate decay function. Adapted from JaxNeRF
135 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
136 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
137 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
138 | function of lr_delay_mult, such that the initial learning rate is
139 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
140 | to the normal learning rate when steps>lr_delay_steps.
141 | :param conf: config subtree 'lr' or similar
142 | :param max_steps: int, the number of steps during optimization.
143 | :return HoF which takes step as input
144 | """
145 |
146 | def helper(step):
147 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
148 | # Disable this parameter
149 | return 0.0
150 | if lr_delay_steps > 0:
151 | # A kind of reverse cosine decay.
152 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
153 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
154 | )
155 | else:
156 | delay_rate = 1.0
157 | t = np.clip(step / max_steps, 0, 1)
158 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
159 | return delay_rate * log_lerp
160 |
161 | return helper
162 |
163 |
164 | def strip_lowerdiag(L):
165 | uncertainty = torch.zeros(
166 | (L.shape[0], 6), dtype=torch.float, device="cuda")
167 |
168 | uncertainty[:, 0] = L[:, 0, 0]
169 | uncertainty[:, 1] = L[:, 0, 1]
170 | uncertainty[:, 2] = L[:, 0, 2]
171 | uncertainty[:, 3] = L[:, 1, 1]
172 | uncertainty[:, 4] = L[:, 1, 2]
173 | uncertainty[:, 5] = L[:, 2, 2]
174 | return uncertainty
175 |
176 |
177 | def strip_symmetric(sym):
178 | return strip_lowerdiag(sym)
179 |
180 |
181 | def build_rotation(r):
182 |
183 | q = F.normalize(r, p=2, dim=1)
184 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
185 |
186 | r = q[:, 0]
187 | x = q[:, 1]
188 | y = q[:, 2]
189 | z = q[:, 3]
190 |
191 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
192 | R[:, 0, 1] = 2 * (x*y - r*z)
193 | R[:, 0, 2] = 2 * (x*z + r*y)
194 | R[:, 1, 0] = 2 * (x*y + r*z)
195 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
196 | R[:, 1, 2] = 2 * (y*z - r*x)
197 | R[:, 2, 0] = 2 * (x*z - r*y)
198 | R[:, 2, 1] = 2 * (y*z + r*x)
199 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
200 | return R
201 |
202 |
203 | def build_scaling_rotation(s, r):
204 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
205 | R = build_rotation(r)
206 |
207 | L[:, 0, 0] = s[:, 0]
208 | L[:, 1, 1] = s[:, 1]
209 | L[:, 2, 2] = s[:, 2]
210 |
211 | L = R @ L
212 | return L
213 |
--------------------------------------------------------------------------------
/src/utils/io_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | import open3d as o3d
7 | import torch
8 | import wandb
9 | import yaml
10 |
11 |
12 | def mkdir_decorator(func):
13 | """A decorator that creates the directory specified in the function's 'directory' keyword
14 | argument before calling the function.
15 | Args:
16 | func: The function to be decorated.
17 | Returns:
18 | The wrapper function.
19 | """
20 | def wrapper(*args, **kwargs):
21 | output_path = Path(kwargs["directory"])
22 | output_path.mkdir(parents=True, exist_ok=True)
23 | return func(*args, **kwargs)
24 | return wrapper
25 |
26 |
27 | @mkdir_decorator
28 | def save_clouds(clouds: list, cloud_names: list, *, directory: Union[str, Path]) -> None:
29 | """ Saves a list of point clouds to the specified directory, creating the directory if it does not exist.
30 | Args:
31 | clouds: A list of point cloud objects to be saved.
32 | cloud_names: A list of filenames for the point clouds, corresponding by index to the clouds.
33 | directory: The directory where the point clouds will be saved.
34 | """
35 | for cld_name, cloud in zip(cloud_names, clouds):
36 | o3d.io.write_point_cloud(str(directory / cld_name), cloud)
37 |
38 |
39 | @mkdir_decorator
40 | def save_dict_to_ckpt(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
41 | """ Saves a dictionary to a checkpoint file in the specified directory, creating the directory if it does not exist.
42 | Args:
43 | dictionary: The dictionary to be saved.
44 | file_name: The name of the checkpoint file.
45 | directory: The directory where the checkpoint file will be saved.
46 | """
47 | torch.save(dictionary, directory / file_name,
48 | _use_new_zipfile_serialization=False)
49 |
50 |
51 | @mkdir_decorator
52 | def save_dict_to_yaml(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
53 | """ Saves a dictionary to a YAML file in the specified directory, creating the directory if it does not exist.
54 | Args:
55 | dictionary: The dictionary to be saved.
56 | file_name: The name of the YAML file.
57 | directory: The directory where the YAML file will be saved.
58 | """
59 | with open(directory / file_name, "w") as f:
60 | yaml.dump(dictionary, f)
61 |
62 |
63 | @mkdir_decorator
64 | def save_dict_to_json(dictionary, file_name: str, *, directory: Union[str, Path]) -> None:
65 | """ Saves a dictionary to a JSON file in the specified directory, creating the directory if it does not exist.
66 | Args:
67 | dictionary: The dictionary to be saved.
68 | file_name: The name of the JSON file.
69 | directory: The directory where the JSON file will be saved.
70 | """
71 | with open(directory / file_name, "w") as f:
72 | json.dump(dictionary, f)
73 |
74 |
75 | def load_config(path: str, default_path: str = None) -> dict:
76 | """
77 | Loads a configuration file and optionally merges it with a default configuration file.
78 |
79 | This function loads a configuration from the given path. If the configuration specifies an inheritance
80 | path (`inherit_from`), or if a `default_path` is provided, it loads the base configuration and updates it
81 | with the specific configuration.
82 |
83 | Args:
84 | path: The path to the specific configuration file.
85 | default_path: An optional path to a default configuration file that is loaded if the specific configuration
86 | does not specify an inheritance or as a base for the inheritance.
87 |
88 | Returns:
89 | A dictionary containing the merged configuration.
90 | """
91 | # load configuration from per scene/dataset cfg.
92 | with open(path, 'r') as f:
93 | cfg_special = yaml.full_load(f)
94 | inherit_from = cfg_special.get('inherit_from')
95 | cfg = dict()
96 | if inherit_from is not None:
97 | cfg = load_config(inherit_from, default_path)
98 | elif default_path is not None:
99 | with open(default_path, 'r') as f:
100 | cfg = yaml.full_load(f)
101 | update_recursive(cfg, cfg_special)
102 | return cfg
103 |
104 |
105 | def update_recursive(dict1: dict, dict2: dict) -> None:
106 | """ Recursively updates the first dictionary with the contents of the second dictionary.
107 |
108 | This function iterates through `dict2` and updates `dict1` with its contents. If a key from `dict2`
109 | exists in `dict1` and its value is also a dictionary, the function updates the value recursively.
110 | Otherwise, it overwrites the value in `dict1` with the value from `dict2`.
111 |
112 | Args:
113 | dict1: The dictionary to be updated.
114 | dict2: The dictionary whose entries are used to update `dict1`.
115 |
116 | Returns:
117 | None: The function modifies `dict1` in place.
118 | """
119 | for k, v in dict2.items():
120 | if k not in dict1:
121 | dict1[k] = dict()
122 | if isinstance(v, dict):
123 | update_recursive(dict1[k], v)
124 | else:
125 | dict1[k] = v
126 |
127 |
128 | def log_metrics_to_wandb(json_files: list, output_path: str, section: str = "Evaluation") -> None:
129 | """ Logs metrics from JSON files to Weights & Biases under a specified section.
130 |
131 | This function reads metrics from a list of JSON files and logs them to Weights & Biases (wandb).
132 | Each metric is prefixed with a specified section name for organized logging.
133 |
134 | Args:
135 | json_files: A list of filenames for JSON files containing metrics to be logged.
136 | output_path: The directory path where the JSON files are located.
137 | section: The section under which to log the metrics in wandb. Defaults to "Evaluation".
138 |
139 | Returns:
140 | None: Metrics are logged to wandb and the function does not return a value.
141 | """
142 | for json_file in json_files:
143 | file_path = os.path.join(output_path, json_file)
144 | if os.path.exists(file_path):
145 | with open(file_path, 'r') as file:
146 | metrics = json.load(file)
147 | prefixed_metrics = {
148 | f"{section}/{key}": value for key, value in metrics.items()}
149 | wandb.log(prefixed_metrics)
150 |
--------------------------------------------------------------------------------
/src/utils/mapper_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import faiss
4 | import faiss.contrib.torch_utils
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def compute_opt_views_distribution(keyframes_num, iterations_num, current_frame_iter) -> np.ndarray:
10 | """ Computes the probability distribution for selecting views based on the current iteration.
11 | Args:
12 | keyframes_num: The total number of keyframes.
13 | iterations_num: The total number of iterations planned.
14 | current_frame_iter: The current iteration number.
15 | Returns:
16 | An array representing the probability distribution of keyframes.
17 | """
18 | if keyframes_num == 1:
19 | return np.array([1.0])
20 | prob = np.full(keyframes_num, (iterations_num - current_frame_iter) / (keyframes_num - 1))
21 | prob[0] = current_frame_iter
22 | prob /= prob.sum()
23 | return prob
24 |
25 |
26 | def compute_camera_frustum_corners(depth_map: np.ndarray, pose: np.ndarray, intrinsics: np.ndarray) -> np.ndarray:
27 | """ Computes the 3D coordinates of the camera frustum corners based on the depth map, pose, and intrinsics.
28 | Args:
29 | depth_map: The depth map of the scene.
30 | pose: The camera pose matrix.
31 | intrinsics: The camera intrinsic matrix.
32 | Returns:
33 | An array of 3D coordinates for the frustum corners.
34 | """
35 | height, width = depth_map.shape
36 | depth_map = depth_map[depth_map > 0]
37 | min_depth, max_depth = depth_map.min(), depth_map.max()
38 | corners = np.array(
39 | [
40 | [0, 0, min_depth],
41 | [width, 0, min_depth],
42 | [0, height, min_depth],
43 | [width, height, min_depth],
44 | [0, 0, max_depth],
45 | [width, 0, max_depth],
46 | [0, height, max_depth],
47 | [width, height, max_depth],
48 | ]
49 | )
50 | x = (corners[:, 0] - intrinsics[0, 2]) * corners[:, 2] / intrinsics[0, 0]
51 | y = (corners[:, 1] - intrinsics[1, 2]) * corners[:, 2] / intrinsics[1, 1]
52 | z = corners[:, 2]
53 | corners_3d = np.vstack((x, y, z, np.ones(x.shape[0]))).T
54 | corners_3d = pose @ corners_3d.T
55 | return corners_3d.T[:, :3]
56 |
57 |
58 | def compute_camera_frustum_planes(frustum_corners: np.ndarray) -> torch.Tensor:
59 | """ Computes the planes of the camera frustum from its corners.
60 | Args:
61 | frustum_corners: An array of 3D coordinates representing the corners of the frustum.
62 |
63 | Returns:
64 | A tensor of frustum planes.
65 | """
66 | # near, far, left, right, top, bottom
67 | planes = torch.stack(
68 | [
69 | torch.cross(
70 | frustum_corners[2] - frustum_corners[0],
71 | frustum_corners[1] - frustum_corners[0]
72 | ),
73 | torch.cross(
74 | frustum_corners[6] - frustum_corners[4],
75 | frustum_corners[5] - frustum_corners[4]
76 | ),
77 | torch.cross(
78 | frustum_corners[4] - frustum_corners[0],
79 | frustum_corners[2] - frustum_corners[0]
80 | ),
81 | torch.cross(
82 | frustum_corners[7] - frustum_corners[3],
83 | frustum_corners[1] - frustum_corners[3]
84 | ),
85 | torch.cross(
86 | frustum_corners[5] - frustum_corners[1],
87 | frustum_corners[0] - frustum_corners[1]
88 | ),
89 | torch.cross(
90 | frustum_corners[6] - frustum_corners[2],
91 | frustum_corners[3] - frustum_corners[2]
92 | ),
93 | ]
94 | )
95 | D = torch.stack([-torch.dot(plane, frustum_corners[i]) for i, plane in enumerate(planes)])
96 | return torch.cat([planes, D[:, None]], dim=1).float()
97 |
98 |
99 | def compute_frustum_aabb(frustum_corners: torch.Tensor):
100 | """ Computes a mask indicating which points lie inside a given axis-aligned bounding box (AABB).
101 | Args:
102 | points: An array of 3D points.
103 | min_corner: The minimum corner of the AABB.
104 | max_corner: The maximum corner of the AABB.
105 | Returns:
106 | A boolean array indicating whether each point lies inside the AABB.
107 | """
108 | return torch.min(frustum_corners, axis=0).values, torch.max(frustum_corners, axis=0).values
109 |
110 |
111 | def points_inside_aabb_mask(points: np.ndarray, min_corner: np.ndarray, max_corner: np.ndarray) -> np.ndarray:
112 | """ Computes a mask indicating which points lie inside the camera frustum.
113 | Args:
114 | points: A tensor of 3D points.
115 | frustum_planes: A tensor representing the planes of the frustum.
116 | Returns:
117 | A boolean tensor indicating whether each point lies inside the frustum.
118 | """
119 | return (
120 | (points[:, 0] >= min_corner[0])
121 | & (points[:, 0] <= max_corner[0])
122 | & (points[:, 1] >= min_corner[1])
123 | & (points[:, 1] <= max_corner[1])
124 | & (points[:, 2] >= min_corner[2])
125 | & (points[:, 2] <= max_corner[2]))
126 |
127 |
128 | def points_inside_frustum_mask(points: torch.Tensor, frustum_planes: torch.Tensor) -> torch.Tensor:
129 | """ Computes a mask indicating which points lie inside the camera frustum.
130 | Args:
131 | points: A tensor of 3D points.
132 | frustum_planes: A tensor representing the planes of the frustum.
133 | Returns:
134 | A boolean tensor indicating whether each point lies inside the frustum.
135 | """
136 | num_pts = points.shape[0]
137 | ones = torch.ones(num_pts, 1).to(points.device)
138 | plane_product = torch.cat([points, ones], axis=1) @ frustum_planes.T
139 | return torch.all(plane_product <= 0, axis=1)
140 |
141 |
142 | def compute_frustum_point_ids(pts: torch.Tensor, frustum_corners: torch.Tensor, device: str = "cuda"):
143 | """ Identifies points within the camera frustum, optimizing for computation on a specified device.
144 | Args:
145 | pts: A tensor of 3D points.
146 | frustum_corners: A tensor of 3D coordinates representing the corners of the frustum.
147 | device: The computation device ("cuda" or "cpu").
148 | Returns:
149 | Indices of points lying inside the frustum.
150 | """
151 | if pts.shape[0] == 0:
152 | return torch.tensor([], dtype=torch.int64, device=device)
153 | # Broad phase
154 | pts = pts.to(device)
155 | frustum_corners = frustum_corners.to(device)
156 |
157 | min_corner, max_corner = compute_frustum_aabb(frustum_corners)
158 | inside_aabb_mask = points_inside_aabb_mask(pts, min_corner, max_corner)
159 |
160 | # Narrow phase
161 | frustum_planes = compute_camera_frustum_planes(frustum_corners)
162 | frustum_planes = frustum_planes.to(device)
163 | inside_frustum_mask = points_inside_frustum_mask(pts[inside_aabb_mask], frustum_planes)
164 |
165 | inside_aabb_mask[inside_aabb_mask == 1] = inside_frustum_mask
166 | return torch.where(inside_aabb_mask)[0]
167 |
168 |
169 | def sample_pixels_based_on_gradient(image: np.ndarray, num_samples: int) -> np.ndarray:
170 | """ Samples pixel indices based on the gradient magnitude of an image.
171 | Args:
172 | image: The image from which to sample pixels.
173 | num_samples: The number of pixels to sample.
174 | Returns:
175 | Indices of the sampled pixels.
176 | """
177 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
178 | grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
179 | grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
180 | grad_magnitude = cv2.magnitude(grad_x, grad_y)
181 |
182 | # Normalize the gradient magnitude to create a probability map
183 | prob_map = grad_magnitude / np.sum(grad_magnitude)
184 |
185 | # Flatten the probability map
186 | prob_map_flat = prob_map.flatten()
187 |
188 | # Sample pixel indices based on the probability map
189 | sampled_indices = np.random.choice(prob_map_flat.size, size=num_samples, p=prob_map_flat)
190 | return sampled_indices.T
191 |
192 |
193 | def compute_new_points_ids(frustum_points: torch.Tensor, new_pts: torch.Tensor,
194 | radius: float = 0.03, device: str = "cpu") -> torch.Tensor:
195 | """ Having newly initialized points, decides which of them should be added to the submap.
196 | For every new point, if there are no neighbors within the radius in the frustum points,
197 | it is added to the submap.
198 | Args:
199 | frustum_points: Point within a current frustum of the active submap of shape (N, 3)
200 | new_pts: New 3D Gaussian means which are about to be added to the submap of shape (N, 3)
201 | radius: Radius whithin which the points are considered to be neighbors
202 | device: Execution device
203 | Returns:
204 | Indicies of the new points that should be added to the submap of shape (N)
205 | """
206 | if frustum_points.shape[0] == 0:
207 | return torch.arange(new_pts.shape[0])
208 | if device == "cpu":
209 | pts_index = faiss.IndexFlatL2(3)
210 | else:
211 | pts_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss.IndexFlatL2(3))
212 | frustum_points = frustum_points.to(device)
213 | new_pts = new_pts.to(device)
214 | pts_index.add(frustum_points)
215 |
216 | split_pos = torch.split(new_pts, 65535, dim=0)
217 | distances, ids = [], []
218 | for split_p in split_pos:
219 | distance, id = pts_index.search(split_p.float(), 8)
220 | distances.append(distance)
221 | ids.append(id)
222 | distances = torch.cat(distances, dim=0)
223 | ids = torch.cat(ids, dim=0)
224 | neighbor_num = (distances < radius).sum(axis=1).int()
225 | pts_index.reset()
226 | return torch.where(neighbor_num == 0)[0]
227 |
228 |
229 | def rotation_to_euler(R: torch.Tensor) -> torch.Tensor:
230 | """
231 | Converts a rotation matrix to Euler angles.
232 | Args:
233 | R: A rotation matrix.
234 | Returns:
235 | Euler angles corresponding to the rotation matrix.
236 | """
237 | sy = torch.sqrt(R[0, 0] ** 2 + R[1, 0] ** 2)
238 | singular = sy < 1e-6
239 |
240 | if not singular:
241 | x = torch.atan2(R[2, 1], R[2, 2])
242 | y = torch.atan2(-R[2, 0], sy)
243 | z = torch.atan2(R[1, 0], R[0, 0])
244 | else:
245 | x = torch.atan2(-R[1, 2], R[1, 1])
246 | y = torch.atan2(-R[2, 0], sy)
247 | z = 0
248 |
249 | return torch.tensor([x, y, z]) * (180 / np.pi)
250 |
251 |
252 | def exceeds_motion_thresholds(current_c2w: torch.Tensor, last_submap_c2w: torch.Tensor,
253 | rot_thre: float = 50, trans_thre: float = 0.5) -> bool:
254 | """ Checks if a camera motion exceeds certain rotation and translation thresholds
255 | Args:
256 | current_c2w: The current camera-to-world transformation matrix.
257 | last_submap_c2w: The last submap's camera-to-world transformation matrix.
258 | rot_thre: The rotation threshold for triggering a new submap.
259 | trans_thre: The translation threshold for triggering a new submap.
260 |
261 | Returns:
262 | A boolean indicating whether a new submap is required.
263 | """
264 | delta_pose = torch.matmul(torch.linalg.inv(last_submap_c2w).float(), current_c2w.float())
265 | translation_diff = torch.norm(delta_pose[:3, 3])
266 | rot_euler_diff_deg = torch.abs(rotation_to_euler(delta_pose[:3, :3]))
267 | exceeds_thresholds = (translation_diff > trans_thre) or torch.any(rot_euler_diff_deg > rot_thre)
268 | return exceeds_thresholds.item()
269 |
270 |
271 | def geometric_edge_mask(rgb_image: np.ndarray, dilate: bool = True, RGB: bool = False) -> np.ndarray:
272 | """ Computes an edge mask for an RGB image using geometric edges.
273 | Args:
274 | rgb_image: The RGB image.
275 | dilate: Whether to dilate the edges.
276 | RGB: Indicates if the image format is RGB (True) or BGR (False).
277 | Returns:
278 | An edge mask of the input image.
279 | """
280 | # Convert the image to grayscale as Canny edge detection requires a single channel image
281 | gray_image = cv2.cvtColor(
282 | rgb_image, cv2.COLOR_BGR2GRAY if not RGB else cv2.COLOR_RGB2GRAY)
283 | if gray_image.dtype != np.uint8:
284 | gray_image = gray_image.astype(np.uint8)
285 | edges = cv2.Canny(gray_image, threshold1=100, threshold2=200, apertureSize=3, L2gradient=True)
286 | # Define the structuring element for dilation, you can change the size for a thicker/thinner mask
287 | if dilate:
288 | kernel = np.ones((2, 2), np.uint8)
289 | edges = cv2.dilate(edges, kernel, iterations=1)
290 | return edges
291 |
292 |
293 | def calc_psnr(img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
294 | """ Calculates the Peak Signal-to-Noise Ratio (PSNR) between two images.
295 | Args:
296 | img1: The first image.
297 | img2: The second image.
298 | Returns:
299 | The PSNR value.
300 | """
301 | mse = ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
302 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
303 |
304 |
305 | def create_point_cloud(image: np.ndarray, depth: np.ndarray, intrinsics: np.ndarray, pose: np.ndarray) -> np.ndarray:
306 | """
307 | Creates a point cloud from an image, depth map, camera intrinsics, and pose.
308 |
309 | Args:
310 | image: The RGB image of shape (H, W, 3)
311 | depth: The depth map of shape (H, W)
312 | intrinsics: The camera intrinsic parameters of shape (3, 3)
313 | pose: The camera pose of shape (4, 4)
314 | Returns:
315 | A point cloud of shape (N, 6) with last dimension representing (x, y, z, r, g, b)
316 | """
317 | height, width = depth.shape
318 | # Create a mesh grid of pixel coordinates
319 | u, v = np.meshgrid(np.arange(width), np.arange(height))
320 | # Convert pixel coordinates to camera coordinates
321 | x = (u - intrinsics[0, 2]) * depth / intrinsics[0, 0]
322 | y = (v - intrinsics[1, 2]) * depth / intrinsics[1, 1]
323 | z = depth
324 | # Stack the coordinates together
325 | points = np.stack((x, y, z, np.ones_like(z)), axis=-1)
326 | # Reshape the coordinates for matrix multiplication
327 | points = points.reshape(-1, 4)
328 | # Transform points to world coordinates
329 | posed_points = pose @ points.T
330 | posed_points = posed_points.T[:, :3]
331 | # Flatten the image to get colors for each point
332 | colors = image.reshape(-1, 3)
333 | # Concatenate posed points with their corresponding color
334 | point_cloud = np.concatenate((posed_points, colors), axis=-1)
335 |
336 | return point_cloud
337 |
--------------------------------------------------------------------------------
/src/utils/tracker_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy.spatial.transform import Rotation
4 | from typing import Union
5 | from src.utils.utils import np2torch
6 |
7 |
8 | def multiply_quaternions(q: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
9 | """Performs batch-wise quaternion multiplication.
10 |
11 | Given two quaternions, this function computes their product. The operation is
12 | vectorized and can be performed on batches of quaternions.
13 |
14 | Args:
15 | q: A tensor representing the first quaternion or a batch of quaternions.
16 | Expected shape is (... , 4), where the last dimension contains quaternion components (w, x, y, z).
17 | r: A tensor representing the second quaternion or a batch of quaternions with the same shape as q.
18 | Returns:
19 | A tensor of the same shape as the input tensors, representing the product of the input quaternions.
20 | """
21 | w0, x0, y0, z0 = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
22 | w1, x1, y1, z1 = r[..., 0], r[..., 1], r[..., 2], r[..., 3]
23 |
24 | w = -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0
25 | x = x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0
26 | y = -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0
27 | z = x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0
28 | return torch.stack((w, x, y, z), dim=-1)
29 |
30 |
31 | def transformation_to_quaternion(RT: Union[torch.Tensor, np.ndarray]):
32 | """ Converts a rotation-translation matrix to a tensor representing quaternion and translation.
33 |
34 | This function takes a 3x4 transformation matrix (rotation and translation) and converts it
35 | into a tensor that combines the quaternion representation of the rotation and the translation vector.
36 |
37 | Args:
38 | RT: A 3x4 matrix representing the rotation and translation. This can be a NumPy array
39 | or a torch.Tensor. If it's a torch.Tensor and resides on a GPU, it will be moved to CPU.
40 |
41 | Returns:
42 | A tensor combining the quaternion (in w, x, y, z order) and translation vector. The tensor
43 | will be moved to the original device if the input was a GPU tensor.
44 | """
45 | gpu_id = -1
46 | if isinstance(RT, torch.Tensor):
47 | if RT.get_device() != -1:
48 | RT = RT.detach().cpu()
49 | gpu_id = RT.get_device()
50 | RT = RT.numpy()
51 | R, T = RT[:3, :3], RT[:3, 3]
52 |
53 | rot = Rotation.from_matrix(R)
54 | quad = rot.as_quat(canonical=True)
55 | quad = np.roll(quad, 1)
56 | tensor = np.concatenate([quad, T], 0)
57 | tensor = torch.from_numpy(tensor).float()
58 | if gpu_id != -1:
59 | tensor = tensor.to(gpu_id)
60 | return tensor
61 |
62 |
63 | def extrapolate_poses(poses: np.ndarray) -> np.ndarray:
64 | """ Generates an interpolated pose based on the first two poses in the given array.
65 | Args:
66 | poses: An array of poses, where each pose is represented by a 4x4 transformation matrix.
67 | Returns:
68 | A 4x4 numpy ndarray representing the interpolated transformation matrix.
69 | """
70 | return poses[1, :] @ np.linalg.inv(poses[0, :]) @ poses[1, :]
71 |
72 |
73 | def compute_camera_opt_params(estimate_rel_w2c: np.ndarray) -> tuple:
74 | """ Computes the camera's rotation and translation parameters from an world-to-camera transformation matrix.
75 | This function extracts the rotation component of the transformation matrix, converts it to a quaternion,
76 | and reorders it to match a specific convention. Both rotation and translation parameters are converted
77 | to torch Parameters and intended to be optimized in a PyTorch model.
78 | Args:
79 | estimate_rel_w2c: A 4x4 numpy ndarray representing the estimated world-to-camera transformation matrix.
80 | Returns:
81 | A tuple containing two torch.nn.Parameters: camera's rotation and camera's translation.
82 | """
83 | quaternion = Rotation.from_matrix(estimate_rel_w2c[:3, :3]).as_quat(canonical=True)
84 | quaternion = quaternion[[3, 0, 1, 2]]
85 | opt_cam_rot = torch.nn.Parameter(np2torch(quaternion, "cuda"))
86 | opt_cam_trans = torch.nn.Parameter(np2torch(estimate_rel_w2c[:3, 3], "cuda"))
87 | return opt_cam_rot, opt_cam_trans
88 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import open3d as o3d
6 | import torch
7 | from gaussian_rasterizer import GaussianRasterizationSettings, GaussianRasterizer
8 |
9 |
10 | def setup_seed(seed: int) -> None:
11 | """ Sets the seed for generating random numbers to ensure reproducibility across multiple runs.
12 | Args:
13 | seed: The seed value to set for random number generators in torch, numpy, and random.
14 | """
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | os.environ["PYTHONHASHSEED"] = str(seed)
18 | np.random.seed(seed)
19 | random.seed(seed)
20 | torch.backends.cudnn.deterministic = True
21 | torch.backends.cudnn.benchmark = False
22 |
23 |
24 | def torch2np(tensor: torch.Tensor) -> np.ndarray:
25 | """ Converts a PyTorch tensor to a NumPy ndarray.
26 | Args:
27 | tensor: The PyTorch tensor to convert.
28 | Returns:
29 | A NumPy ndarray with the same data and dtype as the input tensor.
30 | """
31 | return tensor.detach().cpu().numpy()
32 |
33 |
34 | def np2torch(array: np.ndarray, device: str = "cpu") -> torch.Tensor:
35 | """Converts a NumPy ndarray to a PyTorch tensor.
36 | Args:
37 | array: The NumPy ndarray to convert.
38 | device: The device to which the tensor is sent. Defaults to 'cpu'.
39 |
40 | Returns:
41 | A PyTorch tensor with the same data as the input array.
42 | """
43 | return torch.from_numpy(array).float().to(device)
44 |
45 |
46 | def np2ptcloud(pts: np.ndarray, rgb=None) -> o3d.geometry.PointCloud:
47 | """converts numpy array to point cloud
48 | Args:
49 | pts (ndarray): point cloud
50 | Returns:
51 | (PointCloud): resulting point cloud
52 | """
53 | cloud = o3d.geometry.PointCloud()
54 | cloud.points = o3d.utility.Vector3dVector(pts)
55 | if rgb is not None:
56 | cloud.colors = o3d.utility.Vector3dVector(rgb)
57 | return cloud
58 |
59 |
60 | def dict2device(dict: dict, device: str = "cpu") -> dict:
61 | """Sends all tensors in a dictionary to a specified device.
62 | Args:
63 | dict: The dictionary containing tensors.
64 | device: The device to send the tensors to. Defaults to 'cpu'.
65 | Returns:
66 | The dictionary with all tensors sent to the specified device.
67 | """
68 | for k, v in dict.items():
69 | if isinstance(v, torch.Tensor):
70 | dict[k] = v.to(device)
71 | return dict
72 |
73 |
74 | def get_render_settings(w, h, intrinsics, w2c, near=0.01, far=100, sh_degree=0):
75 | """
76 | Constructs and returns a GaussianRasterizationSettings object for rendering,
77 | configured with given camera parameters.
78 |
79 | Args:
80 | width (int): The width of the image.
81 | height (int): The height of the image.
82 | intrinsic (array): 3*3, Intrinsic camera matrix.
83 | w2c (array): World to camera transformation matrix.
84 | near (float, optional): The near plane for the camera. Defaults to 0.01.
85 | far (float, optional): The far plane for the camera. Defaults to 100.
86 |
87 | Returns:
88 | GaussianRasterizationSettings: Configured settings for Gaussian rasterization.
89 | """
90 | fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1,
91 | 1], intrinsics[0, 2], intrinsics[1, 2]
92 | w2c = torch.tensor(w2c).cuda().float()
93 | cam_center = torch.inverse(w2c)[:3, 3]
94 | viewmatrix = w2c.transpose(0, 1)
95 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0],
96 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0],
97 | [0.0, 0.0, far /
98 | (far - near), -(far * near) / (far - near)],
99 | [0.0, 0.0, 1.0, 0.0]], device='cuda').float().transpose(0, 1)
100 | full_proj_matrix = viewmatrix.unsqueeze(
101 | 0).bmm(opengl_proj.unsqueeze(0)).squeeze(0)
102 | return GaussianRasterizationSettings(
103 | image_height=h,
104 | image_width=w,
105 | tanfovx=w / (2 * fx),
106 | tanfovy=h / (2 * fy),
107 | bg=torch.tensor([0, 0, 0], device='cuda').float(),
108 | scale_modifier=1.0,
109 | viewmatrix=viewmatrix,
110 | projmatrix=full_proj_matrix,
111 | sh_degree=sh_degree,
112 | campos=cam_center,
113 | prefiltered=False,
114 | debug=False)
115 |
116 |
117 | def render_gaussian_model(gaussian_model, render_settings,
118 | override_means_3d=None, override_means_2d=None,
119 | override_scales=None, override_rotations=None,
120 | override_opacities=None, override_colors=None):
121 | """
122 | Renders a Gaussian model with specified rendering settings, allowing for
123 | optional overrides of various model parameters.
124 |
125 | Args:
126 | gaussian_model: A Gaussian model object that provides methods to get
127 | various properties like xyz coordinates, opacity, features, etc.
128 | render_settings: Configuration settings for the GaussianRasterizer.
129 | override_means_3d (Optional): If provided, these values will override
130 | the 3D mean values from the Gaussian model.
131 | override_means_2d (Optional): If provided, these values will override
132 | the 2D mean values. Defaults to zeros if not provided.
133 | override_scales (Optional): If provided, these values will override the
134 | scale values from the Gaussian model.
135 | override_rotations (Optional): If provided, these values will override
136 | the rotation values from the Gaussian model.
137 | override_opacities (Optional): If provided, these values will override
138 | the opacity values from the Gaussian model.
139 | override_colors (Optional): If provided, these values will override the
140 | color values from the Gaussian model.
141 | Returns:
142 | A dictionary containing the rendered color, depth, radii, and 2D means
143 | of the Gaussian model. The keys of this dictionary are 'color', 'depth',
144 | 'radii', and 'means2D', each mapping to their respective rendered values.
145 | """
146 | renderer = GaussianRasterizer(raster_settings=render_settings)
147 |
148 | if override_means_3d is None:
149 | means3D = gaussian_model.get_xyz()
150 | else:
151 | means3D = override_means_3d
152 |
153 | if override_means_2d is None:
154 | means2D = torch.zeros_like(
155 | means3D, dtype=means3D.dtype, requires_grad=True, device="cuda")
156 | means2D.retain_grad()
157 | else:
158 | means2D = override_means_2d
159 |
160 | if override_opacities is None:
161 | opacities = gaussian_model.get_opacity()
162 | else:
163 | opacities = override_opacities
164 |
165 | shs, colors_precomp = None, None
166 | if override_colors is not None:
167 | colors_precomp = override_colors
168 | else:
169 | shs = gaussian_model.get_features()
170 |
171 | render_args = {
172 | "means3D": means3D,
173 | "means2D": means2D,
174 | "opacities": opacities,
175 | "colors_precomp": colors_precomp,
176 | "shs": shs,
177 | "scales": gaussian_model.get_scaling() if override_scales is None else override_scales,
178 | "rotations": gaussian_model.get_rotation() if override_rotations is None else override_rotations,
179 | "cov3D_precomp": None
180 | }
181 | color, depth, alpha, radii = renderer(**render_args)
182 |
183 | return {"color": color, "depth": depth, "radii": radii, "means2D": means2D, "alpha": alpha}
184 |
185 |
186 | def batch_search_faiss(indexer, query_points, k):
187 | """
188 | Perform a batch search on a IndexIVFFlat indexer to circumvent the search size limit of 65535.
189 |
190 | Args:
191 | indexer: The FAISS indexer object.
192 | query_points: A tensor of query points.
193 | k (int): The number of nearest neighbors to find.
194 |
195 | Returns:
196 | distances (torch.Tensor): The distances of the nearest neighbors.
197 | ids (torch.Tensor): The indices of the nearest neighbors.
198 | """
199 | split_pos = torch.split(query_points, 65535, dim=0)
200 | distances_list, ids_list = [], []
201 |
202 | for split_p in split_pos:
203 | distance, id = indexer.search(split_p.float(), k)
204 | distances_list.append(distance.clone())
205 | ids_list.append(id.clone())
206 | distances = torch.cat(distances_list, dim=0)
207 | ids = torch.cat(ids_list, dim=0)
208 |
209 | return distances, ids
210 |
--------------------------------------------------------------------------------
/src/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from copy import deepcopy
3 | from typing import List, Union
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | from matplotlib import colors
8 |
9 | COLORS_ANSI = OrderedDict({
10 | "blue": "\033[94m",
11 | "orange": "\033[93m",
12 | "green": "\033[92m",
13 | "red": "\033[91m",
14 | "purple": "\033[95m",
15 | "brown": "\033[93m", # No exact match, using yellow
16 | "pink": "\033[95m",
17 | "gray": "\033[90m",
18 | "olive": "\033[93m", # No exact match, using yellow
19 | "cyan": "\033[96m",
20 | "end": "\033[0m", # Reset color
21 | })
22 |
23 |
24 | COLORS_MATPLOTLIB = OrderedDict({
25 | 'blue': '#1f77b4',
26 | 'orange': '#ff7f0e',
27 | 'green': '#2ca02c',
28 | 'red': '#d62728',
29 | 'purple': '#9467bd',
30 | 'brown': '#8c564b',
31 | 'pink': '#e377c2',
32 | 'gray': '#7f7f7f',
33 | 'yellow-green': '#bcbd22',
34 | 'cyan': '#17becf'
35 | })
36 |
37 |
38 | COLORS_MATPLOTLIB_RGB = OrderedDict({
39 | 'blue': np.array([31, 119, 180]) / 255.0,
40 | 'orange': np.array([255, 127, 14]) / 255.0,
41 | 'green': np.array([44, 160, 44]) / 255.0,
42 | 'red': np.array([214, 39, 40]) / 255.0,
43 | 'purple': np.array([148, 103, 189]) / 255.0,
44 | 'brown': np.array([140, 86, 75]) / 255.0,
45 | 'pink': np.array([227, 119, 194]) / 255.0,
46 | 'gray': np.array([127, 127, 127]) / 255.0,
47 | 'yellow-green': np.array([188, 189, 34]) / 255.0,
48 | 'cyan': np.array([23, 190, 207]) / 255.0
49 | })
50 |
51 |
52 | def get_color(color_name: str):
53 | """ Returns the RGB values of a given color name as a normalized numpy array.
54 | Args:
55 | color_name: The name of the color. Can be any color name from CSS4_COLORS.
56 | Returns:
57 | A numpy array representing the RGB values of the specified color, normalized to the range [0, 1].
58 | """
59 | if color_name == "custom_yellow":
60 | return np.asarray([255.0, 204.0, 102.0]) / 255.0
61 | if color_name == "custom_blue":
62 | return np.asarray([102.0, 153.0, 255.0]) / 255.0
63 | assert color_name in colors.CSS4_COLORS
64 | return np.asarray(colors.to_rgb(colors.CSS4_COLORS[color_name]))
65 |
66 |
67 | def plot_ptcloud(point_clouds: Union[List, o3d.geometry.PointCloud], show_frame: bool = True):
68 | """ Visualizes one or more point clouds, optionally showing the coordinate frame.
69 | Args:
70 | point_clouds: A single point cloud or a list of point clouds to be visualized.
71 | show_frame: If True, displays the coordinate frame in the visualization. Defaults to True.
72 | """
73 | # rotate down up
74 | if not isinstance(point_clouds, list):
75 | point_clouds = [point_clouds]
76 | if show_frame:
77 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1, origin=[0, 0, 0])
78 | point_clouds = point_clouds + [mesh_frame]
79 | o3d.visualization.draw_geometries(point_clouds)
80 |
81 |
82 | def draw_registration_result_original_color(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud,
83 | transformation: np.ndarray):
84 | """ Visualizes the result of a point cloud registration, keeping the original color of the source point cloud.
85 | Args:
86 | source: The source point cloud.
87 | target: The target point cloud.
88 | transformation: The transformation matrix applied to the source point cloud.
89 | """
90 | source_temp = deepcopy(source)
91 | source_temp.transform(transformation)
92 | o3d.visualization.draw_geometries([source_temp, target])
93 |
94 |
95 | def draw_registration_result(source: o3d.geometry.PointCloud, target: o3d.geometry.PointCloud,
96 | transformation: np.ndarray, source_color: str = "blue", target_color: str = "orange"):
97 | """ Visualizes the result of a point cloud registration, coloring the source and target point clouds.
98 | Args:
99 | source: The source point cloud.
100 | target: The target point cloud.
101 | transformation: The transformation matrix applied to the source point cloud.
102 | source_color: The color to apply to the source point cloud. Defaults to "blue".
103 | target_color: The color to apply to the target point cloud. Defaults to "orange".
104 | """
105 | source_temp = deepcopy(source)
106 | source_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[source_color])
107 |
108 | target_temp = deepcopy(target)
109 | target_temp.paint_uniform_color(COLORS_MATPLOTLIB_RGB[target_color])
110 |
111 | source_temp.transform(transformation)
112 | o3d.visualization.draw_geometries([source_temp, target_temp])
113 |
--------------------------------------------------------------------------------