├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── 1.gif ├── 4.gif ├── collage.gif └── image.png ├── bash_scripts ├── download_replica.sh ├── download_replicav2.sh ├── download_tum.sh ├── nerfcapture.bash ├── nerfcapture2dataset.bash ├── online_demo.bash └── start_docker.bash ├── configs ├── _init_.py ├── data │ ├── TUM │ │ ├── freiburg1_desk.yaml │ │ ├── freiburg1_desk2.yaml │ │ ├── freiburg1_room.yaml │ │ ├── freiburg2_xyz.yaml │ │ └── freiburg3_long_office_household.yaml │ ├── replica.yaml │ ├── replica_v2.yaml │ └── scannet.yaml ├── iphone │ ├── dataset.py │ ├── gaussian_splatting.py │ ├── nerfcapture.py │ ├── online_demo.py │ ├── post_splatam_opt.py │ ├── splatam.py │ └── splatam_viz.py ├── replica │ ├── gaussian_splatting.py │ ├── post_splatam_opt.py │ ├── replica.bash │ ├── replica_eval.py │ ├── splatam.py │ └── splatam_s.py ├── replica_v2 │ ├── eval_novel_view.py │ └── splatam.py ├── scannet │ ├── scannet.bash │ ├── scannet_eval.py │ └── splatam.py ├── scannetpp │ ├── eval_novel_view.bash │ ├── eval_novel_view.py │ ├── gaussian_splatting.py │ ├── post_splatam_opt.py │ ├── scannetpp.bash │ ├── scannetpp_eval.py │ └── splatam.py └── tum │ ├── splatam.py │ ├── tum.bash │ └── tum_eval.py ├── datasets ├── _init_.py └── gradslam_datasets │ ├── README.md │ ├── __init__.py │ ├── ai2thor.py │ ├── azure.py │ ├── basedataset.py │ ├── dataconfig.py │ ├── datautils.py │ ├── geometryutils.py │ ├── icl.py │ ├── nerfcapture.py │ ├── realsense.py │ ├── record3d.py │ ├── replica.py │ ├── scannet.py │ ├── scannetpp.py │ └── tum.py ├── environment.yml ├── requirements.txt ├── scripts ├── _init_.py ├── eval_novel_view.py ├── export_ply.py ├── gaussian_splatting.py ├── iphone_demo.py ├── nerfcapture2dataset.py ├── post_splatam_opt.py └── splatam.py ├── utils ├── _init_.py ├── common_utils.py ├── eval_helpers.py ├── graphics_utils.py ├── gs_external.py ├── gs_helpers.py ├── keyframe_selection.py ├── neighbor_search.py ├── recon_helpers.py ├── slam_external.py └── slam_helpers.py ├── venv_requirements.txt └── viz_scripts ├── final_recon.py └── online_recon.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Sub-Directories 2 | configs/local 3 | data/ 4 | experiments/ 5 | results/ 6 | wandb/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "diff-gaussian-rasterization-w-depth.git"] 2 | path = diff-gaussian-rasterization-w-depth.git 3 | url = git@github.com:JonathonLuiten/diff-gaussian-rasterization-w-depth.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Nikhil Varma Keetha 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /assets/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/assets/1.gif -------------------------------------------------------------------------------- /assets/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/assets/4.gif -------------------------------------------------------------------------------- /assets/collage.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/assets/collage.gif -------------------------------------------------------------------------------- /assets/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/assets/image.png -------------------------------------------------------------------------------- /bash_scripts/download_replica.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | # you can also download the Replica.zip manually through 4 | # link: https://caiyun.139.com/m/i?1A5Ch5C3abNiL password: v3fY (the zip is split into smaller zips because of the size limitation of caiyun) 5 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip 6 | unzip Replica.zip -------------------------------------------------------------------------------- /bash_scripts/download_replicav2.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | 4 | wget https://huggingface.co/datasets/kxic/vMAP/resolve/main/vmap.zip 5 | 6 | unzip vmap.zip 7 | 8 | mv -r vmap/* replica_v2/* 9 | -------------------------------------------------------------------------------- /bash_scripts/download_tum.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/TUM_RGBD 2 | cd data/TUM_RGBD 3 | wget https://vision.in.tum.de/rgbd/dataset/freiburg1/rgbd_dataset_freiburg1_desk.tgz 4 | tar -xvzf rgbd_dataset_freiburg1_desk.tgz 5 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg1/rgbd_dataset_freiburg1_desk2.tgz 6 | tar -xvzf rgbd_dataset_freiburg1_desk2.tgz 7 | wget https://cvg.cit.tum.de/rgbd/dataset/freiburg1/rgbd_dataset_freiburg1_room.tgz 8 | tar -xvzf rgbd_dataset_freiburg1_room.tgz 9 | wget https://vision.in.tum.de/rgbd/dataset/freiburg2/rgbd_dataset_freiburg2_xyz.tgz 10 | tar -xvzf rgbd_dataset_freiburg2_xyz.tgz 11 | wget https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_long_office_household.tgz 12 | tar -xvzf rgbd_dataset_freiburg3_long_office_household.tgz -------------------------------------------------------------------------------- /bash_scripts/nerfcapture.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check rmem_max and wmem_max, and increase size if necessary 4 | if [ "$#" -ne 1 ]; then 5 | echo "Usage: bash_scripts/nerfcapture.bash " 6 | exit 7 | fi 8 | 9 | if [ ! -f $1 ]; then 10 | echo "Config file not found!" 11 | exit 12 | fi 13 | 14 | if sysctl -a | grep -q "net.core.rmem_max = 2147483647"; then 15 | echo "rmem_max already set to 2147483647" 16 | else 17 | echo "Setting rmem_max to 2147483647" 18 | sudo sysctl -w net.core.rmem_max=2147483647 19 | fi 20 | 21 | if sysctl -a | grep -q "net.core.wmem_max = 2147483647"; then 22 | echo "wmem_max already set to 2147483647" 23 | else 24 | echo "Setting wmem_max to 2147483647" 25 | sudo sysctl -w net.core.wmem_max=2147483647 26 | fi 27 | 28 | # Capture Dataset 29 | python3 scripts/nerfcapture2dataset.py --config $1 30 | 31 | # Run SplaTAM 32 | python3 scripts/splatam.py $1 33 | 34 | # Visualize SplaTAM Output 35 | python3 viz_scripts/final_recon.py $1 -------------------------------------------------------------------------------- /bash_scripts/nerfcapture2dataset.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check rmem_max and wmem_max, and increase size if necessary 4 | if [ "$#" -ne 1 ]; then 5 | echo "Usage: bash_scripts/nerfcapture2dataset.bash " 6 | exit 7 | fi 8 | 9 | if [ ! -f $1 ]; then 10 | echo "Config file not found!" 11 | exit 12 | fi 13 | 14 | if sysctl -a | grep -q "net.core.rmem_max = 2147483647"; then 15 | echo "rmem_max already set to 2147483647" 16 | else 17 | echo "Setting rmem_max to 2147483647" 18 | sudo sysctl -w net.core.rmem_max=2147483647 19 | fi 20 | 21 | if sysctl -a | grep -q "net.core.wmem_max = 2147483647"; then 22 | echo "wmem_max already set to 2147483647" 23 | else 24 | echo "Setting wmem_max to 2147483647" 25 | sudo sysctl -w net.core.wmem_max=2147483647 26 | fi 27 | 28 | # Capture Dataset 29 | python3 scripts/nerfcapture2dataset.py --config $1 -------------------------------------------------------------------------------- /bash_scripts/online_demo.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check rmem_max and wmem_max, and increase size if necessary 4 | if [ "$#" -ne 1 ]; then 5 | echo "Usage: bash_scripts/online_demo.bash " 6 | exit 7 | fi 8 | 9 | if [ ! -f $1 ]; then 10 | echo "Config file not found!" 11 | exit 12 | fi 13 | 14 | if sysctl -a | grep -q "net.core.rmem_max = 2147483647"; then 15 | echo "rmem_max already set to 2147483647" 16 | else 17 | echo "Setting rmem_max to 2147483647" 18 | sudo sysctl -w net.core.rmem_max=2147483647 19 | fi 20 | 21 | if sysctl -a | grep -q "net.core.wmem_max = 2147483647"; then 22 | echo "wmem_max already set to 2147483647" 23 | else 24 | echo "Setting wmem_max to 2147483647" 25 | sudo sysctl -w net.core.wmem_max=2147483647 26 | fi 27 | 28 | # Online Dataset Capture & SplaTAM 29 | python3 scripts/iphone_demo.py --config $1 30 | 31 | # Visualize SplaTAM Output 32 | python3 viz_scripts/final_recon.py $1 33 | -------------------------------------------------------------------------------- /bash_scripts/start_docker.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker run -it \ 4 | --volume="./:/SplaTAM/" \ 5 | --env="NVIDIA_VISIBLE_DEVICES=all" \ 6 | --env="NVIDIA_DRIVER_CAPABILITIES=all" \ 7 | --net=host \ 8 | --privileged \ 9 | --group-add audio \ 10 | --group-add video \ 11 | --ulimit memlock=-1 \ 12 | --ulimit stack=67108864 \ 13 | --name splatam \ 14 | --gpus all \ 15 | nkeetha/splatam:v1 \ 16 | /bin/bash 17 | 18 | -------------------------------------------------------------------------------- /configs/_init_.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/configs/_init_.py -------------------------------------------------------------------------------- /configs/data/TUM/freiburg1_desk.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'tum' 2 | camera_params: 3 | image_height: 480 4 | image_width: 640 5 | fx: 517.3 6 | fy: 516.5 7 | cx: 318.6 8 | cy: 255.3 9 | crop_edge: 8 10 | png_depth_scale: 5000.0 11 | # distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/data/TUM/freiburg1_desk2.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'tum' 2 | camera_params: 3 | image_height: 480 4 | image_width: 640 5 | fx: 517.3 6 | fy: 516.5 7 | cx: 318.6 8 | cy: 255.3 9 | crop_edge: 8 10 | png_depth_scale: 5000.0 11 | # distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/data/TUM/freiburg1_room.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'tum' 2 | camera_params: 3 | image_height: 480 4 | image_width: 640 5 | fx: 517.3 6 | fy: 516.5 7 | cx: 318.6 8 | cy: 255.3 9 | crop_edge: 8 10 | png_depth_scale: 5000.0 11 | # distortion: [0.2624, -0.9531, -0.0054, 0.0026, 1.1633] -------------------------------------------------------------------------------- /configs/data/TUM/freiburg2_xyz.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'tum' 2 | camera_params: 3 | image_height: 480 4 | image_width: 640 5 | fx: 520.9 6 | fy: 521.0 7 | cx: 325.1 8 | cy: 249.7 9 | crop_edge: 8 10 | png_depth_scale: 5000.0 11 | # distortion: [0.2312, -0.7849, -0.0033, -0.0001, 0.9172] -------------------------------------------------------------------------------- /configs/data/TUM/freiburg3_long_office_household.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'tum' 2 | camera_params: 3 | image_height: 480 4 | image_width: 640 5 | fx: 535.4 6 | fy: 539.2 7 | cx: 320.1 8 | cy: 247.6 9 | crop_edge: 8 10 | png_depth_scale: 5000.0 -------------------------------------------------------------------------------- /configs/data/replica.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'replica' 2 | camera_params: 3 | image_height: 680 4 | image_width: 1200 5 | fx: 600.0 6 | fy: 600.0 7 | cx: 599.5 8 | cy: 339.5 9 | png_depth_scale: 6553.5 10 | crop_edge: 0 -------------------------------------------------------------------------------- /configs/data/replica_v2.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'replicav2' 2 | camera_params: 3 | image_height: 680 4 | image_width: 1200 5 | fx: 600.0 6 | fy: 600.0 7 | cx: 599.5 8 | cy: 339.5 9 | png_depth_scale: 1000.0 10 | crop_edge: 0 -------------------------------------------------------------------------------- /configs/data/scannet.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'scannet' 2 | camera_params: 3 | image_height: 968 4 | image_width: 1296 5 | fx: 1169.621094 6 | fy: 1167.105103 7 | cx: 646.295044 8 | cy: 489.927032 9 | png_depth_scale: 1000.0 -------------------------------------------------------------------------------- /configs/iphone/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | seed = 0 6 | 7 | base_dir = "./experiments/iPhone_Captures" # Root Directory to Save iPhone Dataset 8 | scene_name = "dataset_demo" # Scan Name 9 | num_frames = 10 # Desired number of frames to capture 10 | depth_scale = 10.0 # Depth Scale used when saving depth 11 | overwrite = False # Rewrite over dataset if it exists 12 | 13 | config = dict( 14 | workdir=f"./{base_dir}/{scene_name}", 15 | overwrite=overwrite, 16 | depth_scale=depth_scale, 17 | num_frames=num_frames, 18 | ) -------------------------------------------------------------------------------- /configs/iphone/gaussian_splatting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | seed = 0 6 | 7 | base_dir = "./experiments/iPhone_Captures" # Root Directory to Save iPhone Dataset 8 | scene_name = "offline_demo" # Scan Name 9 | num_frames = 10 # Desired number of frames to capture 10 | depth_scale = 10.0 # Depth Scale used when saving depth 11 | overwrite = False # Rewrite over dataset if it exists 12 | 13 | full_res_width = 1920 14 | full_res_height = 1440 15 | downscale_factor = 2.0 16 | densify_downscale_factor = 4.0 17 | 18 | map_every = 1 19 | if num_frames < 25: 20 | keyframe_every = int(num_frames//5) 21 | else: 22 | keyframe_every = 5 23 | mapping_window_size = 32 24 | tracking_iters = 60 25 | mapping_iters = 60 26 | 27 | config = dict( 28 | workdir=f"./{base_dir}/{scene_name}", 29 | run_name="SplaTAM_iPhone", 30 | overwrite=overwrite, 31 | depth_scale=depth_scale, 32 | num_frames=num_frames, 33 | seed=seed, 34 | primary_device=primary_device, 35 | map_every=map_every, # Mapping every nth frame 36 | keyframe_every=keyframe_every, # Keyframe every nth frame 37 | mapping_window_size=mapping_window_size, # Mapping window size 38 | report_global_progress_every=100, # Report Global Progress every nth frame 39 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 40 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 41 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 42 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 43 | report_iter_progress=False, 44 | load_checkpoint=False, 45 | checkpoint_time_idx=130, 46 | save_checkpoints=False, # Save Checkpoints 47 | checkpoint_interval=5, # Checkpoint Interval 48 | use_wandb=False, 49 | data=dict( 50 | dataset_name="nerfcapture", 51 | basedir=base_dir, 52 | sequence=scene_name, 53 | desired_image_height=int(full_res_height//downscale_factor), 54 | desired_image_width=int(full_res_width//downscale_factor), 55 | densification_image_height=int(full_res_height//densify_downscale_factor), 56 | densification_image_width=int(full_res_width//densify_downscale_factor), 57 | start=0, 58 | end=-1, 59 | stride=1, 60 | num_frames=num_frames, 61 | ), 62 | tracking=dict( 63 | use_gt_poses=False, # Use GT Poses for Tracking 64 | forward_prop=True, # Forward Propagate Poses 65 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 66 | num_iters=tracking_iters, 67 | use_sil_for_loss=True, 68 | sil_thres=0.99, 69 | use_l1=True, 70 | use_depth_loss_thres=True, 71 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 72 | ignore_outlier_depth_loss=False, 73 | use_uncertainty_for_loss_mask=False, 74 | use_uncertainty_for_loss=False, 75 | use_chamfer=False, 76 | loss_weights=dict( 77 | im=0.5, 78 | depth=1.0, 79 | ), 80 | lrs=dict( 81 | means3D=0.0, 82 | rgb_colors=0.0, 83 | unnorm_rotations=0.0, 84 | logit_opacities=0.0, 85 | log_scales=0.0, 86 | cam_unnorm_rots=0.001, 87 | cam_trans=0.004, 88 | ), 89 | ), 90 | mapping=dict( 91 | num_iters=mapping_iters, 92 | add_new_gaussians=True, 93 | sil_thres=0.5, # For Addition of new Gaussians 94 | use_l1=True, 95 | ignore_outlier_depth_loss=False, 96 | use_sil_for_loss=False, 97 | use_uncertainty_for_loss_mask=False, 98 | use_uncertainty_for_loss=False, 99 | use_chamfer=False, 100 | loss_weights=dict( 101 | im=0.5, 102 | depth=1.0, 103 | ), 104 | lrs=dict( 105 | means3D=0.0001, 106 | rgb_colors=0.0025, 107 | unnorm_rotations=0.001, 108 | logit_opacities=0.05, 109 | log_scales=0.001, 110 | cam_unnorm_rots=0.0000, 111 | cam_trans=0.0000, 112 | ), 113 | prune_gaussians=True, # Prune Gaussians during Mapping 114 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 115 | start_after=0, 116 | remove_big_after=0, 117 | stop_after=20, 118 | prune_every=20, 119 | removal_opacity_threshold=0.005, 120 | final_removal_opacity_threshold=0.005, 121 | reset_opacities=False, 122 | reset_opacities_every=500, # Doesn't consider iter 0 123 | ), 124 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 125 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 126 | start_after=500, 127 | remove_big_after=3000, 128 | stop_after=5000, 129 | densify_every=100, 130 | grad_thresh=0.0002, 131 | num_to_split_into=2, 132 | removal_opacity_threshold=0.005, 133 | final_removal_opacity_threshold=0.005, 134 | reset_opacities_every=3000, # Doesn't consider iter 0 135 | ), 136 | ), 137 | viz=dict( 138 | render_mode='color', # ['color', 'depth' or 'centers'] 139 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 140 | show_sil=False, # Show Silhouette instead of RGB 141 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 142 | viz_w=600, viz_h=340, 143 | viz_near=0.01, viz_far=100.0, 144 | view_scale=2, 145 | viz_fps=5, # FPS for Online Recon Viz 146 | enter_interactive_post_online=False, # Enter Interactive Mode after Online Recon Viz 147 | ), 148 | ) -------------------------------------------------------------------------------- /configs/iphone/nerfcapture.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | seed = 0 6 | 7 | base_dir = "./experiments/iPhone_Captures" # Root Directory to Save iPhone Dataset 8 | scene_name = "offline_demo" # Scan Name 9 | num_frames = 10 # Desired number of frames to capture 10 | depth_scale = 10.0 # Depth Scale used when saving depth 11 | overwrite = False # Rewrite over dataset if it exists 12 | 13 | full_res_width = 1920 14 | full_res_height = 1440 15 | downscale_factor = 2.0 16 | densify_downscale_factor = 4.0 17 | 18 | map_every = 1 19 | if num_frames < 25: 20 | keyframe_every = int(num_frames//5) 21 | else: 22 | keyframe_every = 5 23 | mapping_window_size = 32 24 | tracking_iters = 60 25 | mapping_iters = 60 26 | 27 | config = dict( 28 | workdir=f"./{base_dir}/{scene_name}", 29 | run_name="SplaTAM_iPhone", 30 | overwrite=overwrite, 31 | depth_scale=depth_scale, 32 | num_frames=num_frames, 33 | seed=seed, 34 | primary_device=primary_device, 35 | map_every=map_every, # Mapping every nth frame 36 | keyframe_every=keyframe_every, # Keyframe every nth frame 37 | mapping_window_size=mapping_window_size, # Mapping window size 38 | report_global_progress_every=100, # Report Global Progress every nth frame 39 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 40 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 41 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 42 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 43 | report_iter_progress=False, 44 | load_checkpoint=False, 45 | checkpoint_time_idx=130, 46 | save_checkpoints=False, # Save Checkpoints 47 | checkpoint_interval=5, # Checkpoint Interval 48 | use_wandb=False, 49 | data=dict( 50 | dataset_name="nerfcapture", 51 | basedir=base_dir, 52 | sequence=scene_name, 53 | desired_image_height=int(full_res_height//downscale_factor), 54 | desired_image_width=int(full_res_width//downscale_factor), 55 | densification_image_height=int(full_res_height//densify_downscale_factor), 56 | densification_image_width=int(full_res_width//densify_downscale_factor), 57 | start=0, 58 | end=-1, 59 | stride=1, 60 | num_frames=num_frames, 61 | ), 62 | tracking=dict( 63 | use_gt_poses=False, # Use GT Poses for Tracking 64 | forward_prop=True, # Forward Propagate Poses 65 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 66 | num_iters=tracking_iters, 67 | use_sil_for_loss=True, 68 | sil_thres=0.99, 69 | use_l1=True, 70 | use_depth_loss_thres=True, 71 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 72 | ignore_outlier_depth_loss=False, 73 | use_uncertainty_for_loss_mask=False, 74 | use_uncertainty_for_loss=False, 75 | use_chamfer=False, 76 | loss_weights=dict( 77 | im=0.5, 78 | depth=1.0, 79 | ), 80 | lrs=dict( 81 | means3D=0.0, 82 | rgb_colors=0.0, 83 | unnorm_rotations=0.0, 84 | logit_opacities=0.0, 85 | log_scales=0.0, 86 | cam_unnorm_rots=0.001, 87 | cam_trans=0.004, 88 | ), 89 | ), 90 | mapping=dict( 91 | num_iters=mapping_iters, 92 | add_new_gaussians=True, 93 | sil_thres=0.5, # For Addition of new Gaussians 94 | use_l1=True, 95 | ignore_outlier_depth_loss=False, 96 | use_sil_for_loss=False, 97 | use_uncertainty_for_loss_mask=False, 98 | use_uncertainty_for_loss=False, 99 | use_chamfer=False, 100 | loss_weights=dict( 101 | im=0.5, 102 | depth=1.0, 103 | ), 104 | lrs=dict( 105 | means3D=0.0001, 106 | rgb_colors=0.0025, 107 | unnorm_rotations=0.001, 108 | logit_opacities=0.05, 109 | log_scales=0.001, 110 | cam_unnorm_rots=0.0000, 111 | cam_trans=0.0000, 112 | ), 113 | prune_gaussians=True, # Prune Gaussians during Mapping 114 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 115 | start_after=0, 116 | remove_big_after=0, 117 | stop_after=20, 118 | prune_every=20, 119 | removal_opacity_threshold=0.005, 120 | final_removal_opacity_threshold=0.005, 121 | reset_opacities=False, 122 | reset_opacities_every=500, # Doesn't consider iter 0 123 | ), 124 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 125 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 126 | start_after=500, 127 | remove_big_after=3000, 128 | stop_after=5000, 129 | densify_every=100, 130 | grad_thresh=0.0002, 131 | num_to_split_into=2, 132 | removal_opacity_threshold=0.005, 133 | final_removal_opacity_threshold=0.005, 134 | reset_opacities_every=3000, # Doesn't consider iter 0 135 | ), 136 | ), 137 | viz=dict( 138 | render_mode='color', # ['color', 'depth' or 'centers'] 139 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 140 | show_sil=False, # Show Silhouette instead of RGB 141 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 142 | viz_w=600, viz_h=340, 143 | viz_near=0.01, viz_far=100.0, 144 | view_scale=2, 145 | viz_fps=5, # FPS for Online Recon Viz 146 | enter_interactive_post_online=False, # Enter Interactive Mode after Online Recon Viz 147 | ), 148 | ) -------------------------------------------------------------------------------- /configs/iphone/online_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | seed = 0 6 | 7 | base_dir = "./experiments/iPhone_Captures" # Root Directory to Save iPhone Dataset 8 | scene_name = "splatam_demo" # Scan Name 9 | num_frames = 10 # Desired number of frames to capture 10 | depth_scale = 10.0 # Depth Scale used when saving depth 11 | overwrite = True # Rewrite over dataset if it exists 12 | 13 | full_res_width = 1920 14 | full_res_height = 1440 15 | downscale_factor = 2.0 16 | densify_downscale_factor = 4.0 17 | 18 | map_every = 1 19 | if num_frames < 25: 20 | keyframe_every = int(num_frames//5) 21 | else: 22 | keyframe_every = 5 23 | mapping_window_size = 32 24 | tracking_iters = 60 25 | mapping_iters = 60 26 | 27 | config = dict( 28 | workdir=f"./{base_dir}/{scene_name}", 29 | run_name="SplaTAM_iPhone", 30 | overwrite=overwrite, 31 | depth_scale=depth_scale, 32 | num_frames=num_frames, 33 | seed=seed, 34 | primary_device=primary_device, 35 | map_every=map_every, # Mapping every nth frame 36 | keyframe_every=keyframe_every, # Keyframe every nth frame 37 | mapping_window_size=mapping_window_size, # Mapping window size 38 | report_global_progress_every=100, # Report Global Progress every nth frame 39 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 40 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 41 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 42 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 43 | report_iter_progress=False, 44 | load_checkpoint=False, 45 | checkpoint_time_idx=130, 46 | save_checkpoints=False, # Save Checkpoints 47 | checkpoint_interval=5, # Checkpoint Interval 48 | use_wandb=False, 49 | data=dict( 50 | dataset_name="nerfcapture", 51 | basedir=base_dir, 52 | sequence=scene_name, 53 | downscale_factor=downscale_factor, 54 | densify_downscale_factor=densify_downscale_factor, 55 | desired_image_height=int(full_res_height//downscale_factor), 56 | desired_image_width=int(full_res_width//downscale_factor), 57 | densification_image_height=int(full_res_height//densify_downscale_factor), 58 | densification_image_width=int(full_res_width//densify_downscale_factor), 59 | start=0, 60 | end=-1, 61 | stride=1, 62 | num_frames=num_frames, 63 | ), 64 | tracking=dict( 65 | use_gt_poses=False, # Use GT Poses for Tracking 66 | forward_prop=True, # Forward Propagate Poses 67 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 68 | num_iters=tracking_iters, 69 | use_sil_for_loss=True, 70 | sil_thres=0.99, 71 | use_l1=True, 72 | use_depth_loss_thres=True, 73 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 74 | ignore_outlier_depth_loss=False, 75 | use_uncertainty_for_loss_mask=False, 76 | use_uncertainty_for_loss=False, 77 | use_chamfer=False, 78 | loss_weights=dict( 79 | im=0.5, 80 | depth=1.0, 81 | ), 82 | lrs=dict( 83 | means3D=0.0, 84 | rgb_colors=0.0, 85 | unnorm_rotations=0.0, 86 | logit_opacities=0.0, 87 | log_scales=0.0, 88 | cam_unnorm_rots=0.001, 89 | cam_trans=0.004, 90 | ), 91 | ), 92 | mapping=dict( 93 | num_iters=mapping_iters, 94 | add_new_gaussians=True, 95 | sil_thres=0.5, # For Addition of new Gaussians 96 | use_l1=True, 97 | ignore_outlier_depth_loss=False, 98 | use_sil_for_loss=False, 99 | use_uncertainty_for_loss_mask=False, 100 | use_uncertainty_for_loss=False, 101 | use_chamfer=False, 102 | loss_weights=dict( 103 | im=0.5, 104 | depth=1.0, 105 | ), 106 | lrs=dict( 107 | means3D=0.0001, 108 | rgb_colors=0.0025, 109 | unnorm_rotations=0.001, 110 | logit_opacities=0.05, 111 | log_scales=0.001, 112 | cam_unnorm_rots=0.0000, 113 | cam_trans=0.0000, 114 | ), 115 | prune_gaussians=True, # Prune Gaussians during Mapping 116 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 117 | start_after=0, 118 | remove_big_after=0, 119 | stop_after=20, 120 | prune_every=20, 121 | removal_opacity_threshold=0.005, 122 | final_removal_opacity_threshold=0.005, 123 | reset_opacities=False, 124 | reset_opacities_every=500, # Doesn't consider iter 0 125 | ), 126 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 127 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 128 | start_after=500, 129 | remove_big_after=3000, 130 | stop_after=5000, 131 | densify_every=100, 132 | grad_thresh=0.0002, 133 | num_to_split_into=2, 134 | removal_opacity_threshold=0.005, 135 | final_removal_opacity_threshold=0.005, 136 | reset_opacities_every=3000, # Doesn't consider iter 0 137 | ), 138 | ), 139 | viz=dict( 140 | render_mode='color', # ['color', 'depth' or 'centers'] 141 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 142 | show_sil=False, # Show Silhouette instead of RGB 143 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 144 | viz_w=600, viz_h=340, 145 | viz_near=0.01, viz_far=100.0, 146 | view_scale=2, 147 | viz_fps=5, # FPS for Online Recon Viz 148 | enter_interactive_post_online=False, # Enter Interactive Mode after Online Recon Viz 149 | ), 150 | ) -------------------------------------------------------------------------------- /configs/iphone/post_splatam_opt.py: -------------------------------------------------------------------------------- 1 | from os.path import join as p_join 2 | 3 | primary_device = "cuda:0" 4 | 5 | base_dir = "./experiments/iPhone_Captures" 6 | scene_name = "splatam_demo" 7 | params_path = f"{base_dir}/{scene_name}/params.npz" 8 | 9 | group_name = "iPhone_Captures" 10 | run_name = f"{scene_name}_post_splatam_opt" 11 | 12 | full_res_width = 1920 13 | full_res_height = 1440 14 | downscale_factor = 2.0 15 | densify_downscale_factor = 4.0 16 | 17 | config = dict( 18 | workdir=f"./experiments/{group_name}", 19 | run_name=run_name, 20 | seed=0, 21 | primary_device=primary_device, 22 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 23 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 24 | report_iter_progress=False, 25 | use_wandb=False, 26 | wandb=dict( 27 | entity="theairlab", 28 | project="SplaTAM", 29 | group=group_name, 30 | name=run_name, 31 | save_qual=False, 32 | eval_save_qual=True, 33 | ), 34 | data=dict( 35 | dataset_name="nerfcapture", 36 | basedir=base_dir, 37 | sequence=scene_name, 38 | downscale_factor=downscale_factor, 39 | densify_downscale_factor=densify_downscale_factor, 40 | desired_image_height=int(full_res_height//downscale_factor), 41 | desired_image_width=int(full_res_width//downscale_factor), 42 | densification_image_height=int(full_res_height//densify_downscale_factor), 43 | densification_image_width=int(full_res_width//densify_downscale_factor), 44 | start=0, 45 | end=-1, 46 | stride=1, 47 | num_frames=-1, 48 | eval_stride=1, 49 | eval_num_frames=-1, 50 | param_ckpt_path=params_path, 51 | ), 52 | train=dict( 53 | num_iters_mapping=15000, 54 | sil_thres=0.5, # For Addition of new Gaussians & Visualization 55 | use_sil_for_loss=True, # Use Silhouette for Loss during Tracking 56 | loss_weights=dict( 57 | im=0.5, 58 | depth=1.0, 59 | ), 60 | lrs_mapping=dict( 61 | means3D=0.00032, 62 | rgb_colors=0.0025, 63 | unnorm_rotations=0.001, 64 | logit_opacities=0.05, 65 | log_scales=0.005, 66 | cam_unnorm_rots=0.0000, 67 | cam_trans=0.0000, 68 | ), 69 | lrs_mapping_means3D_final=0.0000032, 70 | lr_delay_mult=0.01, 71 | use_gaussian_splatting_densification=True, # Use Gaussian Splatting-based Densification during Mapping 72 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 73 | start_after=500, 74 | remove_big_after=3000, 75 | stop_after=15000, 76 | densify_every=100, 77 | grad_thresh=0.0002, 78 | num_to_split_into=2, 79 | removal_opacity_threshold=0.005, 80 | final_removal_opacity_threshold=0.005, 81 | reset_opacities=True, 82 | reset_opacities_every=3000, # Doesn't consider iter 0 83 | ), 84 | ), 85 | viz=dict( 86 | render_mode='color', # ['color', 'depth' or 'centers'] 87 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 88 | show_sil=False, # Show Silhouette instead of RGB 89 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 90 | viz_w=600, viz_h=340, 91 | viz_near=0.01, viz_far=100.0, 92 | view_scale=2, 93 | viz_fps=5, # FPS for Online Recon Viz 94 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 95 | ), 96 | ) -------------------------------------------------------------------------------- /configs/iphone/splatam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | seed = 0 6 | 7 | base_dir = "./experiments/iPhone_Captures" # Root Directory to Save iPhone Dataset 8 | scene_name = "offline_demo" # Scan Name 9 | num_frames = 10 # Desired number of frames to capture 10 | depth_scale = 10.0 # Depth Scale used when saving depth 11 | overwrite = False # Rewrite over dataset if it exists 12 | 13 | full_res_width = 1920 14 | full_res_height = 1440 15 | downscale_factor = 2.0 16 | densify_downscale_factor = 4.0 17 | 18 | map_every = 1 19 | if num_frames < 25: 20 | keyframe_every = int(num_frames//5) 21 | else: 22 | keyframe_every = 5 23 | mapping_window_size = 32 24 | tracking_iters = 60 25 | mapping_iters = 60 26 | 27 | config = dict( 28 | workdir=f"./{base_dir}/{scene_name}", 29 | run_name="SplaTAM_iPhone", 30 | overwrite=overwrite, 31 | depth_scale=depth_scale, 32 | num_frames=num_frames, 33 | seed=seed, 34 | primary_device=primary_device, 35 | map_every=map_every, # Mapping every nth frame 36 | keyframe_every=keyframe_every, # Keyframe every nth frame 37 | mapping_window_size=mapping_window_size, # Mapping window size 38 | report_global_progress_every=100, # Report Global Progress every nth frame 39 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 40 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 41 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 42 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 43 | report_iter_progress=False, 44 | load_checkpoint=False, 45 | checkpoint_time_idx=130, 46 | save_checkpoints=False, # Save Checkpoints 47 | checkpoint_interval=5, # Checkpoint Interval 48 | use_wandb=False, 49 | data=dict( 50 | dataset_name="nerfcapture", 51 | basedir=base_dir, 52 | sequence=scene_name, 53 | desired_image_height=int(full_res_height//downscale_factor), 54 | desired_image_width=int(full_res_width//downscale_factor), 55 | densification_image_height=int(full_res_height//densify_downscale_factor), 56 | densification_image_width=int(full_res_width//densify_downscale_factor), 57 | start=0, 58 | end=-1, 59 | stride=1, 60 | num_frames=num_frames, 61 | ), 62 | tracking=dict( 63 | use_gt_poses=False, # Use GT Poses for Tracking 64 | forward_prop=True, # Forward Propagate Poses 65 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 66 | num_iters=tracking_iters, 67 | use_sil_for_loss=True, 68 | sil_thres=0.99, 69 | use_l1=True, 70 | use_depth_loss_thres=True, 71 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 72 | ignore_outlier_depth_loss=False, 73 | use_uncertainty_for_loss_mask=False, 74 | use_uncertainty_for_loss=False, 75 | use_chamfer=False, 76 | loss_weights=dict( 77 | im=0.5, 78 | depth=1.0, 79 | ), 80 | lrs=dict( 81 | means3D=0.0, 82 | rgb_colors=0.0, 83 | unnorm_rotations=0.0, 84 | logit_opacities=0.0, 85 | log_scales=0.0, 86 | cam_unnorm_rots=0.001, 87 | cam_trans=0.004, 88 | ), 89 | ), 90 | mapping=dict( 91 | num_iters=mapping_iters, 92 | add_new_gaussians=True, 93 | sil_thres=0.5, # For Addition of new Gaussians 94 | use_l1=True, 95 | ignore_outlier_depth_loss=False, 96 | use_sil_for_loss=False, 97 | use_uncertainty_for_loss_mask=False, 98 | use_uncertainty_for_loss=False, 99 | use_chamfer=False, 100 | loss_weights=dict( 101 | im=0.5, 102 | depth=1.0, 103 | ), 104 | lrs=dict( 105 | means3D=0.0001, 106 | rgb_colors=0.0025, 107 | unnorm_rotations=0.001, 108 | logit_opacities=0.05, 109 | log_scales=0.001, 110 | cam_unnorm_rots=0.0000, 111 | cam_trans=0.0000, 112 | ), 113 | prune_gaussians=True, # Prune Gaussians during Mapping 114 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 115 | start_after=0, 116 | remove_big_after=0, 117 | stop_after=20, 118 | prune_every=20, 119 | removal_opacity_threshold=0.005, 120 | final_removal_opacity_threshold=0.005, 121 | reset_opacities=False, 122 | reset_opacities_every=500, # Doesn't consider iter 0 123 | ), 124 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 125 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 126 | start_after=500, 127 | remove_big_after=3000, 128 | stop_after=5000, 129 | densify_every=100, 130 | grad_thresh=0.0002, 131 | num_to_split_into=2, 132 | removal_opacity_threshold=0.005, 133 | final_removal_opacity_threshold=0.005, 134 | reset_opacities_every=3000, # Doesn't consider iter 0 135 | ), 136 | ), 137 | viz=dict( 138 | render_mode='color', # ['color', 'depth' or 'centers'] 139 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 140 | show_sil=False, # Show Silhouette instead of RGB 141 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 142 | viz_w=600, viz_h=340, 143 | viz_near=0.01, viz_far=100.0, 144 | view_scale=2, 145 | viz_fps=5, # FPS for Online Recon Viz 146 | enter_interactive_post_online=False, # Enter Interactive Mode after Online Recon Viz 147 | ), 148 | ) -------------------------------------------------------------------------------- /configs/iphone/splatam_viz.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | seed = 0 5 | 6 | config = dict( 7 | scene_path='./experiments/iPhone_Captures/splatam_demo/params.npz', 8 | seed=seed, 9 | viz=dict( 10 | render_mode='color', # ['color', 'depth' or 'centers'] 11 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 12 | show_sil=False, # Show Silhouette instead of RGB 13 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 14 | viz_w=600, viz_h=340, 15 | viz_near=0.01, viz_far=100.0, 16 | view_scale=2, 17 | viz_fps=5, # FPS for Online Recon Viz 18 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 19 | ), 20 | ) -------------------------------------------------------------------------------- /configs/replica/gaussian_splatting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | scenes = ["room0", "room1", "room2", 5 | "office0", "office1", "office2", 6 | "office_", "office4"] 7 | 8 | primary_device = "cuda:0" 9 | seed = 0 10 | scene_name = scenes[0] 11 | 12 | map_every = 1 13 | keyframe_every = 5 14 | mapping_window_size = 24 15 | tracking_iters = 40 16 | mapping_iters = 60 17 | 18 | group_name = "Replica_3DGS" 19 | run_name = f"{scene_name}_{seed}" 20 | 21 | config = dict( 22 | workdir=f"./experiments/{group_name}", 23 | run_name=run_name, 24 | seed=seed, 25 | primary_device=primary_device, 26 | map_every=map_every, # Mapping every nth frame 27 | keyframe_every=keyframe_every, # Keyframe every nth frame 28 | mapping_window_size=mapping_window_size, # Mapping window size 29 | report_global_progress_every=5, # Report Global Progress every nth frame 30 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 31 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 32 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 33 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 34 | report_iter_progress=False, 35 | load_checkpoint=False, 36 | checkpoint_time_idx=0, 37 | save_checkpoints=False, # Save Checkpoints 38 | checkpoint_interval=5, # Checkpoint Interval 39 | use_wandb=True, 40 | wandb=dict( 41 | entity="theairlab", 42 | project="SplaTAM", 43 | group=group_name, 44 | name=run_name, 45 | save_qual=False, 46 | eval_save_qual=True, 47 | ), 48 | data=dict( 49 | basedir="./data/Replica", 50 | gradslam_data_cfg="./configs/data/replica.yaml", 51 | sequence="room0", 52 | desired_image_height_init=170, 53 | desired_image_width_init=300, 54 | desired_image_height=340, 55 | desired_image_width=600, 56 | start=0, 57 | end=-1, 58 | stride=1, 59 | num_frames=2000, 60 | eval_stride=10, 61 | eval_num_frames=200, 62 | ), 63 | train=dict( 64 | num_iters_mapping=30000, 65 | sil_thres=0.5, # For Addition of new Gaussians & Visualization 66 | use_sil_for_loss=True, # Use Silhouette for Loss during Tracking 67 | loss_weights=dict( 68 | im=0.5, 69 | depth=1.0, 70 | ), 71 | lrs_mapping=dict( 72 | means3D=0.00032, 73 | rgb_colors=0.0025, 74 | unnorm_rotations=0.001, 75 | logit_opacities=0.05, 76 | log_scales=0.005, 77 | cam_unnorm_rots=0.0000, 78 | cam_trans=0.0000, 79 | ), 80 | lrs_mapping_means3D_final=0.0000032, 81 | lr_delay_mult=0.01, 82 | use_gaussian_splatting_densification=True, # Use Gaussian Splatting-based Densification during Mapping 83 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 84 | start_after=500, 85 | remove_big_after=3000, 86 | stop_after=15000, 87 | densify_every=100, 88 | grad_thresh=0.0002, 89 | num_to_split_into=2, 90 | removal_opacity_threshold=0.005, 91 | final_removal_opacity_threshold=0.005, 92 | reset_opacities=True, 93 | reset_opacities_every=3000, # Doesn't consider iter 0 94 | ), 95 | ), 96 | viz=dict( 97 | render_mode='color', # ['color', 'depth' or 'centers'] 98 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 99 | show_sil=False, # Show Silhouette instead of RGB 100 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 101 | viz_w=600, viz_h=340, 102 | viz_near=0.01, viz_far=100.0, 103 | view_scale=2, 104 | viz_fps=5, # FPS for Online Recon Viz 105 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 106 | ), 107 | ) -------------------------------------------------------------------------------- /configs/replica/post_splatam_opt.py: -------------------------------------------------------------------------------- 1 | from os.path import join as p_join 2 | 3 | primary_device = "cuda:0" 4 | 5 | group_name = "Replica" 6 | run_name = "Post_SplaTAM_Opt" 7 | 8 | config = dict( 9 | workdir=f"./experiments/{group_name}", 10 | run_name=run_name, 11 | seed=0, 12 | primary_device=primary_device, 13 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 14 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 15 | report_iter_progress=False, 16 | use_wandb=False, 17 | wandb=dict( 18 | entity="theairlab", 19 | project="SplaTAM", 20 | group=group_name, 21 | name=run_name, 22 | save_qual=False, 23 | eval_save_qual=True, 24 | ), 25 | data=dict( 26 | basedir="./data/Replica", 27 | gradslam_data_cfg="./data/replica.yaml", 28 | sequence="room0", 29 | desired_image_height=680, 30 | desired_image_width=1200, 31 | start=0, 32 | end=-1, 33 | stride=20, 34 | num_frames=100, 35 | eval_stride=5, 36 | eval_num_frames=400, 37 | param_ckpt_path='./experiments/Replica/room0_seed0/params.npz' 38 | ), 39 | train=dict( 40 | num_iters_mapping=15000, 41 | sil_thres=0.5, # For Addition of new Gaussians & Visualization 42 | use_sil_for_loss=True, # Use Silhouette for Loss during Tracking 43 | loss_weights=dict( 44 | im=0.5, 45 | depth=1.0, 46 | ), 47 | lrs_mapping=dict( 48 | means3D=0.00032, 49 | rgb_colors=0.0025, 50 | unnorm_rotations=0.001, 51 | logit_opacities=0.05, 52 | log_scales=0.005, 53 | cam_unnorm_rots=0.0000, 54 | cam_trans=0.0000, 55 | ), 56 | lrs_mapping_means3D_final=0.0000032, 57 | lr_delay_mult=0.01, 58 | use_gaussian_splatting_densification=True, # Use Gaussian Splatting-based Densification during Mapping 59 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 60 | start_after=500, 61 | remove_big_after=3000, 62 | stop_after=15000, 63 | densify_every=100, 64 | grad_thresh=0.0002, 65 | num_to_split_into=2, 66 | removal_opacity_threshold=0.005, 67 | final_removal_opacity_threshold=0.005, 68 | reset_opacities=True, 69 | reset_opacities_every=3000, # Doesn't consider iter 0 70 | ), 71 | ), 72 | viz=dict( 73 | render_mode='color', # ['color', 'depth' or 'centers'] 74 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 75 | show_sil=False, # Show Silhouette instead of RGB 76 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 77 | viz_w=600, viz_h=340, 78 | viz_near=0.01, viz_far=100.0, 79 | view_scale=2, 80 | viz_fps=5, # FPS for Online Recon Viz 81 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 82 | ), 83 | ) -------------------------------------------------------------------------------- /configs/replica/replica.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for seed in 0 1 2 4 | do 5 | SEED=${seed} 6 | export SEED 7 | for scene in 0 1 2 3 4 5 6 7 8 | do 9 | SCENE_NUM=${scene} 10 | export SCENE_NUM 11 | echo "Running scene number ${SCENE_NUM} with seed ${SEED}" 12 | python3 -u scripts/splatam.py configs/replica/replica_eval.py 13 | done 14 | done -------------------------------------------------------------------------------- /configs/replica/replica_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | scenes = ["room0", "room1", "room2", 5 | "office0", "office1", "office2", 6 | "office_", "office4"] 7 | 8 | primary_device="cuda:0" 9 | seed = int(os.environ["SEED"]) 10 | scene_name = scenes[int(os.environ["SCENE_NUM"])] 11 | 12 | map_every = 1 13 | keyframe_every = 5 14 | mapping_window_size = 24 15 | tracking_iters = 40 16 | mapping_iters = 60 17 | 18 | group_name = "Replica" 19 | run_name = f"{scene_name}_{seed}" 20 | 21 | config = dict( 22 | workdir=f"./experiments/{group_name}", 23 | run_name=run_name, 24 | seed=seed, 25 | primary_device=primary_device, 26 | map_every=map_every, # Mapping every nth frame 27 | keyframe_every=keyframe_every, # Keyframe every nth frame 28 | mapping_window_size=mapping_window_size, # Mapping window size 29 | report_global_progress_every=500, # Report Global Progress every nth frame 30 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 31 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 32 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 33 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 34 | report_iter_progress=False, 35 | load_checkpoint=False, 36 | checkpoint_time_idx=0, 37 | save_checkpoints=False, # Save Checkpoints 38 | checkpoint_interval=100, # Checkpoint Interval 39 | use_wandb=True, 40 | wandb=dict( 41 | entity="theairlab", 42 | project="SplaTAM", 43 | group=group_name, 44 | name=run_name, 45 | save_qual=False, 46 | eval_save_qual=True, 47 | ), 48 | data=dict( 49 | basedir="./data/Replica", 50 | gradslam_data_cfg="./configs/data/replica.yaml", 51 | sequence=scene_name, 52 | desired_image_height=680, 53 | desired_image_width=1200, 54 | start=0, 55 | end=-1, 56 | stride=1, 57 | num_frames=-1, 58 | ), 59 | tracking=dict( 60 | use_gt_poses=False, # Use GT Poses for Tracking 61 | forward_prop=True, # Forward Propagate Poses 62 | num_iters=tracking_iters, 63 | use_sil_for_loss=True, 64 | sil_thres=0.99, 65 | use_l1=True, 66 | ignore_outlier_depth_loss=False, 67 | loss_weights=dict( 68 | im=0.5, 69 | depth=1.0, 70 | ), 71 | lrs=dict( 72 | means3D=0.0, 73 | rgb_colors=0.0, 74 | unnorm_rotations=0.0, 75 | logit_opacities=0.0, 76 | log_scales=0.0, 77 | cam_unnorm_rots=0.0004, 78 | cam_trans=0.002, 79 | ), 80 | ), 81 | mapping=dict( 82 | num_iters=mapping_iters, 83 | add_new_gaussians=True, 84 | sil_thres=0.5, # For Addition of new Gaussians 85 | use_l1=True, 86 | use_sil_for_loss=False, 87 | ignore_outlier_depth_loss=False, 88 | loss_weights=dict( 89 | im=0.5, 90 | depth=1.0, 91 | ), 92 | lrs=dict( 93 | means3D=0.0001, 94 | rgb_colors=0.0025, 95 | unnorm_rotations=0.001, 96 | logit_opacities=0.05, 97 | log_scales=0.001, 98 | cam_unnorm_rots=0.0000, 99 | cam_trans=0.0000, 100 | ), 101 | prune_gaussians=True, # Prune Gaussians during Mapping 102 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 103 | start_after=0, 104 | remove_big_after=0, 105 | stop_after=20, 106 | prune_every=20, 107 | removal_opacity_threshold=0.005, 108 | final_removal_opacity_threshold=0.005, 109 | reset_opacities=False, 110 | reset_opacities_every=500, # Doesn't consider iter 0 111 | ), 112 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 113 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 114 | start_after=500, 115 | remove_big_after=3000, 116 | stop_after=5000, 117 | densify_every=100, 118 | grad_thresh=0.0002, 119 | num_to_split_into=2, 120 | removal_opacity_threshold=0.005, 121 | final_removal_opacity_threshold=0.005, 122 | reset_opacities_every=3000, # Doesn't consider iter 0 123 | ), 124 | ), 125 | viz=dict( 126 | render_mode='color', # ['color', 'depth' or 'centers'] 127 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 128 | show_sil=False, # Show Silhouette instead of RGB 129 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 130 | viz_w=600, viz_h=340, 131 | viz_near=0.01, viz_far=100.0, 132 | view_scale=2, 133 | viz_fps=5, # FPS for Online Recon Viz 134 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 135 | ), 136 | ) -------------------------------------------------------------------------------- /configs/replica/splatam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | scenes = ["room0", "room1", "room2", 5 | "office0", "office1", "office2", 6 | "office_", "office4"] 7 | 8 | primary_device="cuda:0" 9 | seed = 0 10 | scene_name = scenes[0] 11 | 12 | map_every = 1 13 | keyframe_every = 5 14 | mapping_window_size = 24 15 | tracking_iters = 40 16 | mapping_iters = 60 17 | 18 | group_name = "Replica" 19 | run_name = f"{scene_name}_{seed}" 20 | 21 | config = dict( 22 | workdir=f"./experiments/{group_name}", 23 | run_name=run_name, 24 | seed=seed, 25 | primary_device=primary_device, 26 | map_every=map_every, # Mapping every nth frame 27 | keyframe_every=keyframe_every, # Keyframe every nth frame 28 | mapping_window_size=mapping_window_size, # Mapping window size 29 | report_global_progress_every=500, # Report Global Progress every nth frame 30 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 31 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 32 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 33 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 34 | report_iter_progress=False, 35 | load_checkpoint=False, 36 | checkpoint_time_idx=0, 37 | save_checkpoints=False, # Save Checkpoints 38 | checkpoint_interval=100, # Checkpoint Interval 39 | use_wandb=True, 40 | wandb=dict( 41 | entity="theairlab", 42 | project="SplaTAM", 43 | group=group_name, 44 | name=run_name, 45 | save_qual=False, 46 | eval_save_qual=True, 47 | ), 48 | data=dict( 49 | basedir="./data/Replica", 50 | gradslam_data_cfg="./configs/data/replica.yaml", 51 | sequence=scene_name, 52 | desired_image_height=680, 53 | desired_image_width=1200, 54 | start=0, 55 | end=-1, 56 | stride=1, 57 | num_frames=-1, 58 | ), 59 | tracking=dict( 60 | use_gt_poses=False, # Use GT Poses for Tracking 61 | forward_prop=True, # Forward Propagate Poses 62 | num_iters=tracking_iters, 63 | use_sil_for_loss=True, 64 | sil_thres=0.99, 65 | use_l1=True, 66 | ignore_outlier_depth_loss=False, 67 | loss_weights=dict( 68 | im=0.5, 69 | depth=1.0, 70 | ), 71 | lrs=dict( 72 | means3D=0.0, 73 | rgb_colors=0.0, 74 | unnorm_rotations=0.0, 75 | logit_opacities=0.0, 76 | log_scales=0.0, 77 | cam_unnorm_rots=0.0004, 78 | cam_trans=0.002, 79 | ), 80 | ), 81 | mapping=dict( 82 | num_iters=mapping_iters, 83 | add_new_gaussians=True, 84 | sil_thres=0.5, # For Addition of new Gaussians 85 | use_l1=True, 86 | use_sil_for_loss=False, 87 | ignore_outlier_depth_loss=False, 88 | loss_weights=dict( 89 | im=0.5, 90 | depth=1.0, 91 | ), 92 | lrs=dict( 93 | means3D=0.0001, 94 | rgb_colors=0.0025, 95 | unnorm_rotations=0.001, 96 | logit_opacities=0.05, 97 | log_scales=0.001, 98 | cam_unnorm_rots=0.0000, 99 | cam_trans=0.0000, 100 | ), 101 | prune_gaussians=True, # Prune Gaussians during Mapping 102 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 103 | start_after=0, 104 | remove_big_after=0, 105 | stop_after=20, 106 | prune_every=20, 107 | removal_opacity_threshold=0.005, 108 | final_removal_opacity_threshold=0.005, 109 | reset_opacities=False, 110 | reset_opacities_every=500, # Doesn't consider iter 0 111 | ), 112 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 113 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 114 | start_after=500, 115 | remove_big_after=3000, 116 | stop_after=5000, 117 | densify_every=100, 118 | grad_thresh=0.0002, 119 | num_to_split_into=2, 120 | removal_opacity_threshold=0.005, 121 | final_removal_opacity_threshold=0.005, 122 | reset_opacities_every=3000, # Doesn't consider iter 0 123 | ), 124 | ), 125 | viz=dict( 126 | render_mode='color', # ['color', 'depth' or 'centers'] 127 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 128 | show_sil=False, # Show Silhouette instead of RGB 129 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 130 | viz_w=600, viz_h=340, 131 | viz_near=0.01, viz_far=100.0, 132 | view_scale=2, 133 | viz_fps=5, # FPS for Online Recon Viz 134 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 135 | ), 136 | ) -------------------------------------------------------------------------------- /configs/replica/splatam_s.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | scenes = ["room0", "room1", "room2", 5 | "office0", "office1", "office2", 6 | "office_", "office4"] 7 | 8 | primary_device="cuda:0" 9 | seed = 0 10 | scene_name = scenes[0] 11 | 12 | map_every = 1 13 | keyframe_every = 5 14 | mapping_window_size = 32 15 | tracking_iters = 10 16 | mapping_iters = 15 17 | 18 | group_name = "Replica" 19 | run_name = f"{scene_name}_{seed}" 20 | 21 | config = dict( 22 | workdir=f"./experiments/{group_name}", 23 | run_name=run_name, 24 | seed=seed, 25 | primary_device=primary_device, 26 | map_every=map_every, # Mapping every nth frame 27 | keyframe_every=keyframe_every, # Keyframe every nth frame 28 | mapping_window_size=mapping_window_size, # Mapping window size 29 | report_global_progress_every=500, # Report Global Progress every nth frame 30 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 31 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 32 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 33 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 34 | report_iter_progress=False, 35 | load_checkpoint=False, 36 | checkpoint_time_idx=0, 37 | save_checkpoints=False, # Save Checkpoints 38 | checkpoint_interval=100, # Checkpoint Interval 39 | use_wandb=True, 40 | wandb=dict( 41 | entity="theairlab", 42 | project="SplaTAM", 43 | group=group_name, 44 | name=run_name, 45 | save_qual=False, 46 | eval_save_qual=True, 47 | ), 48 | data=dict( 49 | basedir="./data/Replica", 50 | gradslam_data_cfg="./configs/data/replica.yaml", 51 | sequence=scene_name, 52 | desired_image_height=680, 53 | desired_image_width=1200, 54 | tracking_image_height=680, 55 | tracking_image_width=1200, 56 | densification_image_height=340, 57 | densification_image_width=600, 58 | start=0, 59 | end=-1, 60 | stride=1, 61 | num_frames=-1, 62 | ), 63 | tracking=dict( 64 | use_gt_poses=False, # Use GT Poses for Tracking 65 | forward_prop=True, # Forward Propagate Poses 66 | num_iters=tracking_iters, 67 | use_sil_for_loss=True, 68 | sil_thres=0.99, 69 | use_l1=True, 70 | ignore_outlier_depth_loss=False, 71 | loss_weights=dict( 72 | im=0.5, 73 | depth=1.0, 74 | ), 75 | lrs=dict( 76 | means3D=0.0, 77 | rgb_colors=0.0, 78 | unnorm_rotations=0.0, 79 | logit_opacities=0.0, 80 | log_scales=0.0, 81 | cam_unnorm_rots=0.0004, 82 | cam_trans=0.002, 83 | ), 84 | ), 85 | mapping=dict( 86 | num_iters=mapping_iters, 87 | add_new_gaussians=True, 88 | sil_thres=0.5, # For Addition of new Gaussians 89 | use_l1=True, 90 | use_sil_for_loss=False, 91 | ignore_outlier_depth_loss=False, 92 | loss_weights=dict( 93 | im=0.5, 94 | depth=1.0, 95 | ), 96 | lrs=dict( 97 | means3D=0.0001, 98 | rgb_colors=0.0025, 99 | unnorm_rotations=0.001, 100 | logit_opacities=0.05, 101 | log_scales=0.001, 102 | cam_unnorm_rots=0.0000, 103 | cam_trans=0.0000, 104 | ), 105 | prune_gaussians=True, # Prune Gaussians during Mapping 106 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 107 | start_after=0, 108 | remove_big_after=0, 109 | stop_after=20, 110 | prune_every=20, 111 | removal_opacity_threshold=0.005, 112 | final_removal_opacity_threshold=0.005, 113 | reset_opacities=False, 114 | reset_opacities_every=500, # Doesn't consider iter 0 115 | ), 116 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 117 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 118 | start_after=500, 119 | remove_big_after=3000, 120 | stop_after=5000, 121 | densify_every=100, 122 | grad_thresh=0.0002, 123 | num_to_split_into=2, 124 | removal_opacity_threshold=0.005, 125 | final_removal_opacity_threshold=0.005, 126 | reset_opacities_every=3000, # Doesn't consider iter 0 127 | ), 128 | ), 129 | viz=dict( 130 | render_mode='color', # ['color', 'depth' or 'centers'] 131 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 132 | show_sil=False, # Show Silhouette instead of RGB 133 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 134 | viz_w=600, viz_h=340, 135 | viz_near=0.01, viz_far=100.0, 136 | view_scale=2, 137 | viz_fps=5, # FPS for Online Recon Viz 138 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 139 | ), 140 | ) -------------------------------------------------------------------------------- /configs/replica_v2/eval_novel_view.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | scenes = ["room_0", "room_1", "room_2", 5 | "office_0", "office_1", "office_2", 6 | "office_3", "office_4"] 7 | 8 | primary_device="cuda:0" 9 | seed = 0 10 | scene_name = scenes[0] 11 | 12 | # # SLAM 13 | # use_train_split = True 14 | 15 | # Novel View Synthesis 16 | use_train_split = False 17 | 18 | map_every = 1 19 | keyframe_every = 5 20 | mapping_window_size = 24 21 | tracking_iters = 40 22 | mapping_iters = 60 23 | 24 | group_name = "Replica_V2" 25 | run_name = f"{scene_name}_{seed}" 26 | 27 | config = dict( 28 | workdir=f"./experiments/{group_name}", 29 | run_name=run_name, 30 | seed=seed, 31 | primary_device=primary_device, 32 | map_every=map_every, # Mapping every nth frame 33 | keyframe_every=keyframe_every, # Keyframe every nth frame 34 | mapping_window_size=mapping_window_size, # Mapping window size 35 | report_global_progress_every=500, # Report Global Progress every nth frame 36 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 37 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 38 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 39 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 40 | report_iter_progress=False, 41 | load_checkpoint=False, 42 | checkpoint_time_idx=0, 43 | save_checkpoints=False, # Save Checkpoints 44 | checkpoint_interval=100, # Checkpoint Interval 45 | use_wandb=True, 46 | wandb=dict( 47 | entity="theairlab", 48 | project="SplaTAM", 49 | group=group_name, 50 | name=run_name, 51 | save_qual=False, 52 | eval_save_qual=True, 53 | ), 54 | data=dict( 55 | basedir="./data/Replica_V2", 56 | gradslam_data_cfg="./configs/data/replica_v2.yaml", 57 | sequence=scene_name, 58 | use_train_split=use_train_split, 59 | desired_image_height=680, 60 | desired_image_width=1200, 61 | start=0, 62 | end=-1, 63 | stride=1, 64 | num_frames=-1, 65 | ), 66 | tracking=dict( 67 | use_gt_poses=False, # Use GT Poses for Tracking 68 | forward_prop=True, # Forward Propagate Poses 69 | num_iters=tracking_iters, 70 | use_sil_for_loss=True, 71 | sil_thres=0.99, 72 | use_l1=True, 73 | ignore_outlier_depth_loss=False, 74 | use_uncertainty_for_loss_mask=False, 75 | use_uncertainty_for_loss=False, 76 | use_chamfer=False, 77 | loss_weights=dict( 78 | im=0.5, 79 | depth=1.0, 80 | ), 81 | lrs=dict( 82 | means3D=0.0, 83 | rgb_colors=0.0, 84 | unnorm_rotations=0.0, 85 | logit_opacities=0.0, 86 | log_scales=0.0, 87 | cam_unnorm_rots=0.0004, 88 | cam_trans=0.002, 89 | ), 90 | ), 91 | mapping=dict( 92 | num_iters=mapping_iters, 93 | add_new_gaussians=True, 94 | sil_thres=0.5, # For Addition of new Gaussians 95 | use_l1=True, 96 | ignore_outlier_depth_loss=False, 97 | use_sil_for_loss=False, 98 | use_uncertainty_for_loss_mask=False, 99 | use_uncertainty_for_loss=False, 100 | use_chamfer=False, 101 | loss_weights=dict( 102 | im=0.5, 103 | depth=1.0, 104 | ), 105 | lrs=dict( 106 | means3D=0.0001, 107 | rgb_colors=0.0025, 108 | unnorm_rotations=0.001, 109 | logit_opacities=0.05, 110 | log_scales=0.001, 111 | cam_unnorm_rots=0.0000, 112 | cam_trans=0.0000, 113 | ), 114 | prune_gaussians=True, # Prune Gaussians during Mapping 115 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 116 | start_after=0, 117 | remove_big_after=0, 118 | stop_after=20, 119 | prune_every=20, 120 | removal_opacity_threshold=0.005, 121 | final_removal_opacity_threshold=0.005, 122 | reset_opacities=False, 123 | reset_opacities_every=500, # Doesn't consider iter 0 124 | ), 125 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 126 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 127 | start_after=500, 128 | remove_big_after=3000, 129 | stop_after=5000, 130 | densify_every=100, 131 | grad_thresh=0.0002, 132 | num_to_split_into=2, 133 | removal_opacity_threshold=0.005, 134 | final_removal_opacity_threshold=0.005, 135 | reset_opacities_every=3000, # Doesn't consider iter 0 136 | ), 137 | ), 138 | viz=dict( 139 | render_mode='color', # ['color', 'depth' or 'centers'] 140 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 141 | show_sil=False, # Show Silhouette instead of RGB 142 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 143 | viz_w=600, viz_h=340, 144 | viz_near=0.01, viz_far=100.0, 145 | view_scale=2, 146 | viz_fps=5, # FPS for Online Recon Viz 147 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 148 | ), 149 | ) -------------------------------------------------------------------------------- /configs/replica_v2/splatam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | scenes = ["room_0", "room_1", "room_2", 5 | "office_0", "office_1", "office_2", 6 | "office_3", "office_4"] 7 | 8 | primary_device="cuda:0" 9 | seed = 0 10 | scene_name = scenes[0] 11 | 12 | # SplaTAM 13 | use_train_split = True 14 | 15 | # # Novel View Synthesis 16 | # use_train_split = False 17 | 18 | map_every = 1 19 | keyframe_every = 5 20 | mapping_window_size = 24 21 | tracking_iters = 40 22 | mapping_iters = 60 23 | 24 | group_name = "Replica_V2" 25 | run_name = f"{scene_name}_{seed}" 26 | 27 | config = dict( 28 | workdir=f"./experiments/{group_name}", 29 | run_name=run_name, 30 | seed=seed, 31 | primary_device=primary_device, 32 | map_every=map_every, # Mapping every nth frame 33 | keyframe_every=keyframe_every, # Keyframe every nth frame 34 | mapping_window_size=mapping_window_size, # Mapping window size 35 | report_global_progress_every=500, # Report Global Progress every nth frame 36 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 37 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 38 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 39 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 40 | report_iter_progress=False, 41 | load_checkpoint=False, 42 | checkpoint_time_idx=0, 43 | save_checkpoints=False, # Save Checkpoints 44 | checkpoint_interval=100, # Checkpoint Interval 45 | use_wandb=True, 46 | wandb=dict( 47 | entity="theairlab", 48 | project="SplaTAM", 49 | group=group_name, 50 | name=run_name, 51 | save_qual=False, 52 | eval_save_qual=True, 53 | ), 54 | data=dict( 55 | basedir="./data/Replica_V2", 56 | gradslam_data_cfg="./configs/data/replica_v2.yaml", 57 | sequence=scene_name, 58 | use_train_split=use_train_split, 59 | desired_image_height=680, 60 | desired_image_width=1200, 61 | start=0, 62 | end=-1, 63 | stride=1, 64 | num_frames=-1, 65 | ), 66 | tracking=dict( 67 | use_gt_poses=False, # Use GT Poses for Tracking 68 | forward_prop=True, # Forward Propagate Poses 69 | num_iters=tracking_iters, 70 | use_sil_for_loss=True, 71 | sil_thres=0.99, 72 | use_l1=True, 73 | ignore_outlier_depth_loss=False, 74 | use_uncertainty_for_loss_mask=False, 75 | use_uncertainty_for_loss=False, 76 | use_chamfer=False, 77 | loss_weights=dict( 78 | im=0.5, 79 | depth=1.0, 80 | ), 81 | lrs=dict( 82 | means3D=0.0, 83 | rgb_colors=0.0, 84 | unnorm_rotations=0.0, 85 | logit_opacities=0.0, 86 | log_scales=0.0, 87 | cam_unnorm_rots=0.0004, 88 | cam_trans=0.002, 89 | ), 90 | ), 91 | mapping=dict( 92 | num_iters=mapping_iters, 93 | add_new_gaussians=True, 94 | sil_thres=0.5, # For Addition of new Gaussians 95 | use_l1=True, 96 | ignore_outlier_depth_loss=False, 97 | use_sil_for_loss=False, 98 | use_uncertainty_for_loss_mask=False, 99 | use_uncertainty_for_loss=False, 100 | use_chamfer=False, 101 | loss_weights=dict( 102 | im=0.5, 103 | depth=1.0, 104 | ), 105 | lrs=dict( 106 | means3D=0.0001, 107 | rgb_colors=0.0025, 108 | unnorm_rotations=0.001, 109 | logit_opacities=0.05, 110 | log_scales=0.001, 111 | cam_unnorm_rots=0.0000, 112 | cam_trans=0.0000, 113 | ), 114 | prune_gaussians=True, # Prune Gaussians during Mapping 115 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 116 | start_after=0, 117 | remove_big_after=0, 118 | stop_after=20, 119 | prune_every=20, 120 | removal_opacity_threshold=0.005, 121 | final_removal_opacity_threshold=0.005, 122 | reset_opacities=False, 123 | reset_opacities_every=500, # Doesn't consider iter 0 124 | ), 125 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 126 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 127 | start_after=500, 128 | remove_big_after=3000, 129 | stop_after=5000, 130 | densify_every=100, 131 | grad_thresh=0.0002, 132 | num_to_split_into=2, 133 | removal_opacity_threshold=0.005, 134 | final_removal_opacity_threshold=0.005, 135 | reset_opacities_every=3000, # Doesn't consider iter 0 136 | ), 137 | ), 138 | viz=dict( 139 | render_mode='color', # ['color', 'depth' or 'centers'] 140 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 141 | show_sil=False, # Show Silhouette instead of RGB 142 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 143 | viz_w=600, viz_h=340, 144 | viz_near=0.01, viz_far=100.0, 145 | view_scale=2, 146 | viz_fps=5, # FPS for Online Recon Viz 147 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 148 | ), 149 | ) -------------------------------------------------------------------------------- /configs/scannet/scannet.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for scene_num in 0 1 2 3 4 5 4 | do 5 | for seed in 0 1 2 6 | do 7 | SCENE_NUM=${scene_num} 8 | export SCENE_NUM 9 | SEED=${seed} 10 | export SEED 11 | echo "Running scene number ${SCENE_NUM} with seed ${SEED}" 12 | python3 -u scripts/splatam.py configs/scannet/scannet_eval.py 13 | done 14 | done -------------------------------------------------------------------------------- /configs/scannet/scannet_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["scene0000_00", "scene0059_00", "scene0106_00", 7 | "scene0169_00", "scene0181_00", "scene0207_00"] 8 | 9 | seed = int(os.environ["SEED"]) 10 | scene_name = scenes[int(os.environ["SCENE_NUM"])] 11 | 12 | map_every = 1 13 | keyframe_every = 5 14 | mapping_window_size = 10 15 | tracking_iters = 100 16 | mapping_iters = 30 17 | scene_radius_depth_ratio = 3 18 | 19 | group_name = "ScanNet" 20 | run_name = f"{scene_name}_seed{seed}" 21 | 22 | config = dict( 23 | workdir=f"./experiments/{group_name}", 24 | run_name=run_name, 25 | seed=seed, 26 | primary_device=primary_device, 27 | map_every=map_every, # Mapping every nth frame 28 | keyframe_every=keyframe_every, # Keyframe every nth frame 29 | mapping_window_size=mapping_window_size, # Mapping window size 30 | report_global_progress_every=500, # Report Global Progress every nth frame 31 | eval_every=500, # Evaluate every nth frame (at end of SLAM) 32 | scene_radius_depth_ratio=scene_radius_depth_ratio, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 33 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 34 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 35 | report_iter_progress=False, 36 | load_checkpoint=False, 37 | checkpoint_time_idx=0, 38 | save_checkpoints=False, # Save Checkpoints 39 | checkpoint_interval=100, # Checkpoint Interval 40 | use_wandb=True, 41 | wandb=dict( 42 | entity="theairlab", 43 | project="SplaTAM", 44 | group=group_name, 45 | name=run_name, 46 | save_qual=False, 47 | eval_save_qual=True, 48 | ), 49 | data=dict( 50 | basedir="./data/scannet", 51 | gradslam_data_cfg="./configs/data/scannet.yaml", 52 | sequence=scene_name, 53 | desired_image_height=480, 54 | desired_image_width=640, 55 | start=0, 56 | end=-1, 57 | stride=1, 58 | num_frames=-1, 59 | ), 60 | tracking=dict( 61 | use_gt_poses=False, # Use GT Poses for Tracking 62 | forward_prop=True, # Forward Propagate Poses 63 | num_iters=tracking_iters, 64 | use_sil_for_loss=True, 65 | sil_thres=0.99, 66 | use_l1=True, 67 | ignore_outlier_depth_loss=False, 68 | use_uncertainty_for_loss_mask=False, 69 | use_uncertainty_for_loss=False, 70 | use_chamfer=False, 71 | loss_weights=dict( 72 | im=0.5, 73 | depth=1.0, 74 | ), 75 | lrs=dict( 76 | means3D=0.0, 77 | rgb_colors=0.0, 78 | unnorm_rotations=0.0, 79 | logit_opacities=0.0, 80 | log_scales=0.0, 81 | cam_unnorm_rots=0.0005, 82 | cam_trans=0.0005, 83 | ), 84 | ), 85 | mapping=dict( 86 | num_iters=mapping_iters, 87 | add_new_gaussians=True, 88 | sil_thres=0.5, # For Addition of new Gaussians 89 | use_l1=True, 90 | use_sil_for_loss=False, 91 | ignore_outlier_depth_loss=False, 92 | use_uncertainty_for_loss_mask=False, 93 | use_uncertainty_for_loss=False, 94 | use_chamfer=False, 95 | loss_weights=dict( 96 | im=0.5, 97 | depth=1.0, 98 | ), 99 | lrs=dict( 100 | means3D=0.0001, 101 | rgb_colors=0.0025, 102 | unnorm_rotations=0.001, 103 | logit_opacities=0.05, 104 | log_scales=0.001, 105 | cam_unnorm_rots=0.0000, 106 | cam_trans=0.0000, 107 | ), 108 | prune_gaussians=True, # Prune Gaussians during Mapping 109 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 110 | start_after=0, 111 | remove_big_after=0, 112 | stop_after=20, 113 | prune_every=20, 114 | removal_opacity_threshold=0.005, 115 | final_removal_opacity_threshold=0.005, 116 | reset_opacities=False, 117 | reset_opacities_every=500, # Doesn't consider iter 0 118 | ), 119 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 120 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 121 | start_after=500, 122 | remove_big_after=3000, 123 | stop_after=5000, 124 | densify_every=100, 125 | grad_thresh=0.0002, 126 | num_to_split_into=2, 127 | removal_opacity_threshold=0.005, 128 | final_removal_opacity_threshold=0.005, 129 | reset_opacities_every=3000, # Doesn't consider iter 0 130 | ), 131 | ), 132 | viz=dict( 133 | render_mode='color', # ['color', 'depth' or 'centers'] 134 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 135 | show_sil=False, # Show Silhouette instead of RGB 136 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 137 | viz_w=600, viz_h=340, 138 | viz_near=0.01, viz_far=100.0, 139 | view_scale=2, 140 | viz_fps=5, # FPS for Online Recon Viz 141 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 142 | ), 143 | ) -------------------------------------------------------------------------------- /configs/scannet/splatam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["scene0000_00", "scene0059_00", "scene0106_00", 7 | "scene0169_00", "scene0181_00", "scene0207_00"] 8 | 9 | seed = int(6) 10 | scene_name = scenes[int(6)] 11 | 12 | map_every = 1 13 | keyframe_every = 5 14 | mapping_window_size = 10 15 | tracking_iters = 100 16 | mapping_iters = 30 17 | scene_radius_depth_ratio = 3 18 | 19 | group_name = "ScanNet" 20 | run_name = f"{scene_name}_seed{seed}" 21 | 22 | config = dict( 23 | workdir=f"./experiments/{group_name}", 24 | run_name=run_name, 25 | seed=seed, 26 | primary_device=primary_device, 27 | map_every=map_every, # Mapping every nth frame 28 | keyframe_every=keyframe_every, # Keyframe every nth frame 29 | mapping_window_size=mapping_window_size, # Mapping window size 30 | report_global_progress_every=500, # Report Global Progress every nth frame 31 | eval_every=500, # Evaluate every nth frame (at end of SLAM) 32 | scene_radius_depth_ratio=scene_radius_depth_ratio, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 33 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 34 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 35 | report_iter_progress=False, 36 | load_checkpoint=False, 37 | checkpoint_time_idx=0, 38 | save_checkpoints=False, # Save Checkpoints 39 | checkpoint_interval=100, # Checkpoint Interval 40 | use_wandb=True, 41 | wandb=dict( 42 | entity="theairlab", 43 | project="SplaTAM", 44 | group=group_name, 45 | name=run_name, 46 | save_qual=False, 47 | eval_save_qual=True, 48 | ), 49 | data=dict( 50 | basedir="./data/scannet", 51 | gradslam_data_cfg="./configs/data/scannet.yaml", 52 | sequence=scene_name, 53 | desired_image_height=480, 54 | desired_image_width=640, 55 | start=0, 56 | end=-1, 57 | stride=1, 58 | num_frames=-1, 59 | ), 60 | tracking=dict( 61 | use_gt_poses=False, # Use GT Poses for Tracking 62 | forward_prop=True, # Forward Propagate Poses 63 | num_iters=tracking_iters, 64 | use_sil_for_loss=True, 65 | sil_thres=0.99, 66 | use_l1=True, 67 | ignore_outlier_depth_loss=False, 68 | use_uncertainty_for_loss_mask=False, 69 | use_uncertainty_for_loss=False, 70 | use_chamfer=False, 71 | loss_weights=dict( 72 | im=0.5, 73 | depth=1.0, 74 | ), 75 | lrs=dict( 76 | means3D=0.0, 77 | rgb_colors=0.0, 78 | unnorm_rotations=0.0, 79 | logit_opacities=0.0, 80 | log_scales=0.0, 81 | cam_unnorm_rots=0.0005, 82 | cam_trans=0.0005, 83 | ), 84 | ), 85 | mapping=dict( 86 | num_iters=mapping_iters, 87 | add_new_gaussians=True, 88 | sil_thres=0.5, # For Addition of new Gaussians 89 | use_l1=True, 90 | use_sil_for_loss=False, 91 | ignore_outlier_depth_loss=False, 92 | use_uncertainty_for_loss_mask=False, 93 | use_uncertainty_for_loss=False, 94 | use_chamfer=False, 95 | loss_weights=dict( 96 | im=0.5, 97 | depth=1.0, 98 | ), 99 | lrs=dict( 100 | means3D=0.0001, 101 | rgb_colors=0.0025, 102 | unnorm_rotations=0.001, 103 | logit_opacities=0.05, 104 | log_scales=0.001, 105 | cam_unnorm_rots=0.0000, 106 | cam_trans=0.0000, 107 | ), 108 | prune_gaussians=True, # Prune Gaussians during Mapping 109 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 110 | start_after=0, 111 | remove_big_after=0, 112 | stop_after=20, 113 | prune_every=20, 114 | removal_opacity_threshold=0.005, 115 | final_removal_opacity_threshold=0.005, 116 | reset_opacities=False, 117 | reset_opacities_every=500, # Doesn't consider iter 0 118 | ), 119 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 120 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 121 | start_after=500, 122 | remove_big_after=3000, 123 | stop_after=5000, 124 | densify_every=100, 125 | grad_thresh=0.0002, 126 | num_to_split_into=2, 127 | removal_opacity_threshold=0.005, 128 | final_removal_opacity_threshold=0.005, 129 | reset_opacities_every=3000, # Doesn't consider iter 0 130 | ), 131 | ), 132 | viz=dict( 133 | render_mode='color', # ['color', 'depth' or 'centers'] 134 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 135 | show_sil=False, # Show Silhouette instead of RGB 136 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 137 | viz_w=600, viz_h=340, 138 | viz_near=0.01, viz_far=100.0, 139 | view_scale=2, 140 | viz_fps=5, # FPS for Online Recon Viz 141 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 142 | ), 143 | ) -------------------------------------------------------------------------------- /configs/scannetpp/eval_novel_view.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCENE=$1 4 | export SCENE 5 | 6 | echo "Evaluating scene number ${SCENE} with seed 0" 7 | python3 -u scripts/eval_novel_view.py configs/scannetpp/eval_novel_view.py -------------------------------------------------------------------------------- /configs/scannetpp/eval_novel_view.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["8b5caf3398", "b20a261fdf"] 7 | 8 | seed = 0 9 | 10 | # Export SCENE env variable before running 11 | os.environ["SCENE"] = "0" 12 | 13 | # # Train Split Eval 14 | # use_train_split = True 15 | 16 | # Novel View Synthesis Eval 17 | use_train_split = False 18 | 19 | if use_train_split: 20 | scene_num_frames = [-1, 360] 21 | else: 22 | scene_num_frames = [-1, -1] 23 | 24 | scene_name = scenes[int(os.environ["SCENE"])] 25 | num_frames = scene_num_frames[int(os.environ["SCENE"])] 26 | 27 | map_every = 1 28 | keyframe_every = 5 29 | mapping_window_size = 24 30 | tracking_iters = 200 31 | mapping_iters = 60 32 | 33 | group_name = "ScanNet++" 34 | run_name = f"{scene_name}_{seed}" 35 | 36 | config = dict( 37 | scene_path=p_join(f"./experiments/{group_name}", run_name, 'params.npz'), 38 | workdir=f"./experiments/{group_name}", 39 | run_name=run_name, 40 | seed=seed, 41 | primary_device=primary_device, 42 | map_every=map_every, # Mapping every nth frame 43 | keyframe_every=keyframe_every, # Keyframe every nth frame 44 | mapping_window_size=mapping_window_size, # Mapping window size 45 | report_global_progress_every=5, # Report Global Progress every nth frame 46 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 47 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 48 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 49 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 50 | report_iter_progress=False, 51 | load_checkpoint=False, 52 | checkpoint_time_idx=0, 53 | save_checkpoints=False, # Save Checkpoints 54 | checkpoint_interval=5, # Checkpoint Interval 55 | use_wandb=False, 56 | wandb=dict( 57 | entity="theairlab", 58 | project="SplaTAM", 59 | group=group_name, 60 | name=run_name, 61 | save_qual=False, 62 | eval_save_qual=True, 63 | ), 64 | data=dict( 65 | dataset_name="scannetpp", 66 | basedir="./data/ScanNet++/data", 67 | sequence=scene_name, 68 | ignore_bad=False, 69 | use_train_split=use_train_split, 70 | desired_image_height=584, 71 | desired_image_width=876, 72 | start=0, 73 | end=-1, 74 | stride=1, 75 | num_frames=num_frames, 76 | ), 77 | tracking=dict( 78 | use_gt_poses=False, # Use GT Poses for Tracking 79 | forward_prop=True, # Forward Propagate Poses 80 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 81 | num_iters=tracking_iters, 82 | use_sil_for_loss=True, 83 | sil_thres=0.99, 84 | use_l1=True, 85 | use_depth_loss_thres=True, 86 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 87 | ignore_outlier_depth_loss=False, 88 | use_uncertainty_for_loss_mask=False, 89 | use_uncertainty_for_loss=False, 90 | use_chamfer=False, 91 | loss_weights=dict( 92 | im=0.5, 93 | depth=1.0, 94 | ), 95 | lrs=dict( 96 | means3D=0.0, 97 | rgb_colors=0.0, 98 | unnorm_rotations=0.0, 99 | logit_opacities=0.0, 100 | log_scales=0.0, 101 | cam_unnorm_rots=0.001, 102 | cam_trans=0.004, 103 | ), 104 | ), 105 | mapping=dict( 106 | num_iters=mapping_iters, 107 | add_new_gaussians=True, 108 | sil_thres=0.5, # For Addition of new Gaussians 109 | use_l1=True, 110 | ignore_outlier_depth_loss=False, 111 | use_sil_for_loss=False, 112 | use_uncertainty_for_loss_mask=False, 113 | use_uncertainty_for_loss=False, 114 | use_chamfer=False, 115 | loss_weights=dict( 116 | im=0.5, 117 | depth=1.0, 118 | ), 119 | lrs=dict( 120 | means3D=0.0001, 121 | rgb_colors=0.0025, 122 | unnorm_rotations=0.001, 123 | logit_opacities=0.05, 124 | log_scales=0.001, 125 | cam_unnorm_rots=0.0000, 126 | cam_trans=0.0000, 127 | ), 128 | prune_gaussians=True, # Prune Gaussians during Mapping 129 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 130 | start_after=0, 131 | remove_big_after=0, 132 | stop_after=20, 133 | prune_every=20, 134 | removal_opacity_threshold=0.005, 135 | final_removal_opacity_threshold=0.005, 136 | reset_opacities=False, 137 | reset_opacities_every=500, # Doesn't consider iter 0 138 | ), 139 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 140 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 141 | start_after=500, 142 | remove_big_after=3000, 143 | stop_after=5000, 144 | densify_every=100, 145 | grad_thresh=0.0002, 146 | num_to_split_into=2, 147 | removal_opacity_threshold=0.005, 148 | final_removal_opacity_threshold=0.005, 149 | reset_opacities_every=3000, # Doesn't consider iter 0 150 | ), 151 | ), 152 | viz=dict( 153 | render_mode='color', # ['color', 'depth' or 'centers'] 154 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 155 | show_sil=False, # Show Silhouette instead of RGB 156 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 157 | viz_w=600, viz_h=340, 158 | viz_near=0.01, viz_far=100.0, 159 | view_scale=2, 160 | viz_fps=5, # FPS for Online Recon Viz 161 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 162 | ), 163 | ) -------------------------------------------------------------------------------- /configs/scannetpp/gaussian_splatting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["8b5caf3398", "b20a261fdf"] 7 | 8 | seed = 0 9 | 10 | # Export SCENE env variable before running 11 | os.environ["SCENE"] = "0" 12 | 13 | # Train Split Eval 14 | use_train_split = True 15 | 16 | if use_train_split: 17 | scene_num_frames = [-1, 360] 18 | else: 19 | scene_num_frames = [-1, -1] 20 | 21 | scene_name = scenes[int(os.environ["SCENE"])] 22 | num_frames = scene_num_frames[int(os.environ["SCENE"])] 23 | 24 | full_res_width = 1168 25 | full_res_height = 1752 26 | downscale_factor = 2.0 27 | densify_downscale_factor = 4.0 28 | 29 | map_every = 1 30 | keyframe_every = 5 31 | mapping_window_size = 24 32 | tracking_iters = 200 33 | mapping_iters = 60 34 | 35 | group_name = "ScanNet++_3DGS" 36 | run_name = f"{scene_name}_{seed}" 37 | 38 | config = dict( 39 | workdir=f"./experiments/{group_name}", 40 | run_name=run_name, 41 | seed=seed, 42 | primary_device=primary_device, 43 | map_every=map_every, # Mapping every nth frame 44 | keyframe_every=keyframe_every, # Keyframe every nth frame 45 | mapping_window_size=mapping_window_size, # Mapping window size 46 | report_global_progress_every=5, # Report Global Progress every nth frame 47 | eval_every=5, # Evaluate every nth frame (at end of SLAM) 48 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 49 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 50 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 51 | report_iter_progress=False, 52 | load_checkpoint=False, 53 | checkpoint_time_idx=0, 54 | save_checkpoints=False, # Save Checkpoints 55 | checkpoint_interval=5, # Checkpoint Interval 56 | use_wandb=True, 57 | wandb=dict( 58 | entity="theairlab", 59 | project="SplaTAM_CVPR_Rebuttal", 60 | group=group_name, 61 | name=run_name, 62 | save_qual=False, 63 | eval_save_qual=True, 64 | ), 65 | data=dict( 66 | dataset_name="scannetpp", 67 | basedir="/storage2/datasets/nkeetha/4d/data/ScanNetPP/data", 68 | sequence=scene_name, 69 | ignore_bad=False, 70 | use_train_split=True, 71 | desired_image_height=584, 72 | desired_image_width=876, 73 | desired_image_height_init=584, 74 | desired_image_width_init=876, 75 | start=0, 76 | end=-1, 77 | stride=1, 78 | num_frames=num_frames, 79 | eval_stride=1, 80 | eval_num_frames=-1, 81 | ), 82 | train=dict( 83 | num_iters_mapping=7000, 84 | sil_thres=0.5, # For Addition of new Gaussians & Visualization 85 | use_sil_for_loss=True, # Use Silhouette for Loss during Tracking 86 | loss_weights=dict( 87 | im=1.0, 88 | depth=0.0, 89 | ), 90 | lrs_mapping=dict( 91 | means3D=0.00032, 92 | rgb_colors=0.0025, 93 | unnorm_rotations=0.001, 94 | logit_opacities=0.05, 95 | log_scales=0.005, 96 | cam_unnorm_rots=0.0000, 97 | cam_trans=0.0000, 98 | ), 99 | lrs_mapping_means3D_final=0.0000032, 100 | lr_delay_mult=0.01, 101 | use_gaussian_splatting_densification=True, # Use Gaussian Splatting-based Densification during Mapping 102 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 103 | start_after=500, 104 | remove_big_after=3000, 105 | stop_after=15000, 106 | densify_every=100, 107 | grad_thresh=0.0002, 108 | num_to_split_into=2, 109 | removal_opacity_threshold=0.005, 110 | final_removal_opacity_threshold=0.005, 111 | reset_opacities=True, 112 | reset_opacities_every=3000, # Doesn't consider iter 0 113 | ), 114 | ), 115 | viz=dict( 116 | render_mode='color', # ['color', 'depth' or 'centers'] 117 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 118 | show_sil=False, # Show Silhouette instead of RGB 119 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 120 | viz_w=600, viz_h=340, 121 | viz_near=0.01, viz_far=100.0, 122 | view_scale=2, 123 | viz_fps=5, # FPS for Online Recon Viz 124 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 125 | ), 126 | ) -------------------------------------------------------------------------------- /configs/scannetpp/post_splatam_opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | 7 | scenes = ["8b5caf3398", "b20a261fdf"] 8 | 9 | seed = 0 10 | 11 | # Export SCENE env variable before running 12 | os.environ["SCENE"] = "0" 13 | 14 | # Train Split Eval 15 | use_train_split = True 16 | 17 | # # Novel View Synthesis Eval 18 | # use_train_split = False 19 | 20 | if use_train_split: 21 | scene_num_frames = [-1, 360] 22 | else: 23 | scene_num_frames = [-1, -1] 24 | 25 | scene_name = scenes[int(os.environ["SCENE"])] 26 | num_frames = scene_num_frames[int(os.environ["SCENE"])] 27 | 28 | map_every = 1 29 | keyframe_every = 5 30 | mapping_window_size = 24 31 | tracking_iters = 200 32 | mapping_iters = 60 33 | 34 | group_name = "ScanNet++" 35 | run_name = "Post_SplaTAM_Opt" 36 | 37 | config = dict( 38 | workdir=f"./experiments/{group_name}", 39 | run_name=run_name, 40 | seed=0, 41 | primary_device=primary_device, 42 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 43 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 44 | report_iter_progress=False, 45 | use_wandb=True, 46 | wandb=dict( 47 | entity="theairlab", 48 | project="SplaTAM", 49 | group=group_name, 50 | name=run_name, 51 | save_qual=False, 52 | eval_save_qual=True, 53 | ), 54 | data=dict( 55 | dataset_name="scannetpp", 56 | basedir="./data/ScanNet++/data", 57 | sequence=scene_name, 58 | ignore_bad=False, 59 | use_train_split=True, 60 | desired_image_height=584, 61 | desired_image_width=876, 62 | start=0, 63 | end=-1, 64 | stride=1, 65 | num_frames=num_frames, 66 | eval_stride=1, 67 | eval_num_frames=-1, 68 | param_ckpt_path='./experiments/ScanNet++/8b5caf3398_0/params.npz' 69 | ), 70 | train=dict( 71 | num_iters_mapping=30000, 72 | sil_thres=0.5, # For Addition of new Gaussians & Visualization 73 | use_sil_for_loss=True, # Use Silhouette for Loss during Tracking 74 | loss_weights=dict( 75 | im=1.0, 76 | depth=0.0, 77 | ), 78 | lrs_mapping=dict( 79 | means3D=0.00032, 80 | rgb_colors=0.0025, 81 | unnorm_rotations=0.001, 82 | logit_opacities=0.05, 83 | log_scales=0.005, 84 | cam_unnorm_rots=0.0000, 85 | cam_trans=0.0000, 86 | ), 87 | lrs_mapping_means3D_final=0.0000032, 88 | lr_delay_mult=0.01, 89 | use_gaussian_splatting_densification=True, # Use Gaussian Splatting-based Densification during Mapping 90 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 91 | start_after=500, 92 | remove_big_after=3000, 93 | stop_after=15000, 94 | densify_every=100, 95 | grad_thresh=0.0002, 96 | num_to_split_into=2, 97 | removal_opacity_threshold=0.005, 98 | final_removal_opacity_threshold=0.005, 99 | reset_opacities=True, 100 | reset_opacities_every=3000, # Doesn't consider iter 0 101 | ), 102 | ), 103 | viz=dict 104 | ( 105 | render_mode='color', # ['color', 'depth' or 'centers'] 106 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 107 | show_sil=False, # Show Silhouette instead of RGB 108 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 109 | viz_w=600, viz_h=340, 110 | viz_near=0.01, viz_far=100.0, 111 | view_scale=2, 112 | viz_fps=5, # FPS for Online Recon Viz 113 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 114 | ), 115 | ) -------------------------------------------------------------------------------- /configs/scannetpp/scannetpp.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for scene in 0 1 4 | do 5 | SCENE=${scene} 6 | export SCENE 7 | echo "Running scene number ${SCENE} with seed 0" 8 | python3 -u scripts/splatam.py configs/scannetpp/scannetpp_eval.py 9 | done -------------------------------------------------------------------------------- /configs/scannetpp/scannetpp_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["8b5caf3398", "b20a261fdf"] 7 | 8 | seed = 0 9 | 10 | # Train Split Eval 11 | use_train_split = True 12 | 13 | # # Novel View Synthesis Eval 14 | # use_train_split = False 15 | 16 | if use_train_split: 17 | scene_num_frames = [-1, 360] 18 | else: 19 | scene_num_frames = [-1, -1] 20 | 21 | scene_name = scenes[int(os.environ["SCENE"])] 22 | num_frames = scene_num_frames[int(os.environ["SCENE"])] 23 | 24 | map_every = 1 25 | keyframe_every = 5 26 | mapping_window_size = 24 27 | tracking_iters = 200 28 | mapping_iters = 60 29 | 30 | group_name = "ScanNet++" 31 | run_name = f"{scene_name}_{seed}" 32 | 33 | config = dict( 34 | scene_path=p_join(f"./experiments/{group_name}", run_name, 'params.npz'), 35 | workdir=f"./experiments/{group_name}", 36 | run_name=run_name, 37 | seed=seed, 38 | primary_device=primary_device, 39 | map_every=map_every, # Mapping every nth frame 40 | keyframe_every=keyframe_every, # Keyframe every nth frame 41 | mapping_window_size=mapping_window_size, # Mapping window size 42 | report_global_progress_every=5, # Report Global Progress every nth frame 43 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 44 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 45 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 46 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 47 | report_iter_progress=False, 48 | load_checkpoint=False, 49 | checkpoint_time_idx=0, 50 | save_checkpoints=False, # Save Checkpoints 51 | checkpoint_interval=5, # Checkpoint Interval 52 | use_wandb=False, 53 | wandb=dict( 54 | entity="theairlab", 55 | project="SplaTAM", 56 | group=group_name, 57 | name=run_name, 58 | save_qual=False, 59 | eval_save_qual=True, 60 | ), 61 | data=dict( 62 | dataset_name="scannetpp", 63 | basedir="./data/ScanNet++/data", 64 | sequence=scene_name, 65 | ignore_bad=False, 66 | use_train_split=use_train_split, 67 | desired_image_height=584, 68 | desired_image_width=876, 69 | start=0, 70 | end=-1, 71 | stride=1, 72 | num_frames=num_frames, 73 | ), 74 | tracking=dict( 75 | use_gt_poses=False, # Use GT Poses for Tracking 76 | forward_prop=True, # Forward Propagate Poses 77 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 78 | num_iters=tracking_iters, 79 | use_sil_for_loss=True, 80 | sil_thres=0.99, 81 | use_l1=True, 82 | use_depth_loss_thres=True, 83 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 84 | ignore_outlier_depth_loss=False, 85 | use_uncertainty_for_loss_mask=False, 86 | use_uncertainty_for_loss=False, 87 | use_chamfer=False, 88 | loss_weights=dict( 89 | im=0.5, 90 | depth=1.0, 91 | ), 92 | lrs=dict( 93 | means3D=0.0, 94 | rgb_colors=0.0, 95 | unnorm_rotations=0.0, 96 | logit_opacities=0.0, 97 | log_scales=0.0, 98 | cam_unnorm_rots=0.001, 99 | cam_trans=0.004, 100 | ), 101 | ), 102 | mapping=dict( 103 | num_iters=mapping_iters, 104 | add_new_gaussians=True, 105 | sil_thres=0.5, # For Addition of new Gaussians 106 | use_l1=True, 107 | ignore_outlier_depth_loss=False, 108 | use_sil_for_loss=False, 109 | use_uncertainty_for_loss_mask=False, 110 | use_uncertainty_for_loss=False, 111 | use_chamfer=False, 112 | loss_weights=dict( 113 | im=0.5, 114 | depth=1.0, 115 | ), 116 | lrs=dict( 117 | means3D=0.0001, 118 | rgb_colors=0.0025, 119 | unnorm_rotations=0.001, 120 | logit_opacities=0.05, 121 | log_scales=0.001, 122 | cam_unnorm_rots=0.0000, 123 | cam_trans=0.0000, 124 | ), 125 | prune_gaussians=True, # Prune Gaussians during Mapping 126 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 127 | start_after=0, 128 | remove_big_after=0, 129 | stop_after=20, 130 | prune_every=20, 131 | removal_opacity_threshold=0.005, 132 | final_removal_opacity_threshold=0.005, 133 | reset_opacities=False, 134 | reset_opacities_every=500, # Doesn't consider iter 0 135 | ), 136 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 137 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 138 | start_after=500, 139 | remove_big_after=3000, 140 | stop_after=5000, 141 | densify_every=100, 142 | grad_thresh=0.0002, 143 | num_to_split_into=2, 144 | removal_opacity_threshold=0.005, 145 | final_removal_opacity_threshold=0.005, 146 | reset_opacities_every=3000, # Doesn't consider iter 0 147 | ), 148 | ), 149 | viz=dict( 150 | render_mode='color', # ['color', 'depth' or 'centers'] 151 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 152 | show_sil=False, # Show Silhouette instead of RGB 153 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 154 | viz_w=600, viz_h=340, 155 | viz_near=0.01, viz_far=100.0, 156 | view_scale=2, 157 | viz_fps=5, # FPS for Online Recon Viz 158 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 159 | ), 160 | ) -------------------------------------------------------------------------------- /configs/scannetpp/splatam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["8b5caf3398", "b20a261fdf"] 7 | 8 | seed = 0 9 | 10 | # Export SCENE env variable before running 11 | os.environ["SCENE"] = "0" 12 | 13 | # Train Split Eval 14 | use_train_split = True 15 | 16 | # # Novel View Synthesis Eval 17 | # use_train_split = False 18 | 19 | if use_train_split: 20 | scene_num_frames = [-1, 360] 21 | else: 22 | scene_num_frames = [-1, -1] 23 | 24 | scene_name = scenes[int(os.environ["SCENE"])] 25 | num_frames = scene_num_frames[int(os.environ["SCENE"])] 26 | 27 | map_every = 1 28 | keyframe_every = 5 29 | mapping_window_size = 24 30 | tracking_iters = 200 31 | mapping_iters = 60 32 | 33 | group_name = "ScanNet++" 34 | run_name = f"{scene_name}_{seed}" 35 | 36 | config = dict( 37 | workdir=f"./experiments/{group_name}", 38 | run_name=run_name, 39 | seed=seed, 40 | primary_device=primary_device, 41 | map_every=map_every, # Mapping every nth frame 42 | keyframe_every=keyframe_every, # Keyframe every nth frame 43 | mapping_window_size=mapping_window_size, # Mapping window size 44 | report_global_progress_every=5, # Report Global Progress every nth frame 45 | eval_every=1, # Evaluate every nth frame (at end of SLAM) 46 | scene_radius_depth_ratio=3, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 47 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 48 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 49 | report_iter_progress=False, 50 | load_checkpoint=False, 51 | checkpoint_time_idx=0, 52 | save_checkpoints=False, # Save Checkpoints 53 | checkpoint_interval=5, # Checkpoint Interval 54 | use_wandb=True, 55 | wandb=dict( 56 | entity="theairlab", 57 | project="SplaTAM", 58 | group=group_name, 59 | name=run_name, 60 | save_qual=False, 61 | eval_save_qual=True, 62 | ), 63 | data=dict( 64 | dataset_name="scannetpp", 65 | basedir="./data/ScanNet++/data", 66 | sequence=scene_name, 67 | ignore_bad=False, 68 | use_train_split=use_train_split, 69 | desired_image_height=584, 70 | desired_image_width=876, 71 | start=0, 72 | end=-1, 73 | stride=1, 74 | num_frames=num_frames, 75 | ), 76 | tracking=dict( 77 | use_gt_poses=False, # Use GT Poses for Tracking 78 | forward_prop=True, # Forward Propagate Poses 79 | visualize_tracking_loss=False, # Visualize Tracking Diff Images 80 | num_iters=tracking_iters, 81 | use_sil_for_loss=True, 82 | sil_thres=0.99, 83 | use_l1=True, 84 | use_depth_loss_thres=True, 85 | depth_loss_thres=20000, # Num of Tracking Iters becomes twice if this value is not met 86 | ignore_outlier_depth_loss=False, 87 | use_uncertainty_for_loss_mask=False, 88 | use_uncertainty_for_loss=False, 89 | use_chamfer=False, 90 | loss_weights=dict( 91 | im=0.5, 92 | depth=1.0, 93 | ), 94 | lrs=dict( 95 | means3D=0.0, 96 | rgb_colors=0.0, 97 | unnorm_rotations=0.0, 98 | logit_opacities=0.0, 99 | log_scales=0.0, 100 | cam_unnorm_rots=0.001, 101 | cam_trans=0.004, 102 | ), 103 | ), 104 | mapping=dict( 105 | num_iters=mapping_iters, 106 | add_new_gaussians=True, 107 | sil_thres=0.5, # For Addition of new Gaussians 108 | use_l1=True, 109 | ignore_outlier_depth_loss=False, 110 | use_sil_for_loss=False, 111 | use_uncertainty_for_loss_mask=False, 112 | use_uncertainty_for_loss=False, 113 | use_chamfer=False, 114 | loss_weights=dict( 115 | im=0.5, 116 | depth=1.0, 117 | ), 118 | lrs=dict( 119 | means3D=0.0001, 120 | rgb_colors=0.0025, 121 | unnorm_rotations=0.001, 122 | logit_opacities=0.05, 123 | log_scales=0.001, 124 | cam_unnorm_rots=0.0000, 125 | cam_trans=0.0000, 126 | ), 127 | prune_gaussians=True, # Prune Gaussians during Mapping 128 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 129 | start_after=0, 130 | remove_big_after=0, 131 | stop_after=20, 132 | prune_every=20, 133 | removal_opacity_threshold=0.005, 134 | final_removal_opacity_threshold=0.005, 135 | reset_opacities=False, 136 | reset_opacities_every=500, # Doesn't consider iter 0 137 | ), 138 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 139 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 140 | start_after=500, 141 | remove_big_after=3000, 142 | stop_after=5000, 143 | densify_every=100, 144 | grad_thresh=0.0002, 145 | num_to_split_into=2, 146 | removal_opacity_threshold=0.005, 147 | final_removal_opacity_threshold=0.005, 148 | reset_opacities_every=3000, # Doesn't consider iter 0 149 | ), 150 | ), 151 | viz=dict( 152 | render_mode='color', # ['color', 'depth' or 'centers'] 153 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 154 | show_sil=False, # Show Silhouette instead of RGB 155 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 156 | viz_w=600, viz_h=340, 157 | viz_near=0.01, viz_far=100.0, 158 | view_scale=2, 159 | viz_fps=5, # FPS for Online Recon Viz 160 | enter_interactive_post_online=True, # Enter Interactive Mode after Online Recon Viz 161 | ), 162 | ) -------------------------------------------------------------------------------- /configs/tum/splatam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["freiburg1_desk", "freiburg1_desk2", "freiburg1_room", "freiburg2_xyz", "freiburg3_long_office_household"] 7 | 8 | seed = int(0) 9 | scene_name = scenes[int(0)] 10 | 11 | map_every = 1 12 | keyframe_every = 5 13 | mapping_window_size = 20 14 | tracking_iters = 200 15 | mapping_iters = 30 16 | scene_radius_depth_ratio = 2 17 | 18 | group_name = "TUM" 19 | run_name = f"{scene_name}_seed{seed}" 20 | 21 | config = dict( 22 | workdir=f"./experiments/{group_name}", 23 | run_name=run_name, 24 | seed=seed, 25 | primary_device=primary_device, 26 | map_every=map_every, # Mapping every nth frame 27 | keyframe_every=keyframe_every, # Keyframe every nth frame 28 | mapping_window_size=mapping_window_size, # Mapping window size 29 | report_global_progress_every=500, # Report Global Progress every nth frame 30 | eval_every=500, # Evaluate every nth frame (at end of SLAM) 31 | scene_radius_depth_ratio=scene_radius_depth_ratio, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 32 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 33 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 34 | report_iter_progress=False, 35 | load_checkpoint=False, 36 | checkpoint_time_idx=0, 37 | save_checkpoints=False, # Save Checkpoints 38 | checkpoint_interval=100, # Checkpoint Interval 39 | use_wandb=True, 40 | wandb=dict( 41 | entity="theairlab", 42 | project="SplaTAM", 43 | group=group_name, 44 | name=run_name, 45 | save_qual=False, 46 | eval_save_qual=True, 47 | ), 48 | data=dict( 49 | basedir="./data/TUM_RGBD", 50 | gradslam_data_cfg=f"./configs/data/TUM/{scene_name}.yaml", 51 | sequence=f"rgbd_dataset_{scene_name}", 52 | desired_image_height=480, 53 | desired_image_width=640, 54 | start=0, 55 | end=-1, 56 | stride=1, 57 | num_frames=-1, 58 | ), 59 | tracking=dict( 60 | use_gt_poses=False, # Use GT Poses for Tracking 61 | forward_prop=True, # Forward Propagate Poses 62 | num_iters=tracking_iters, 63 | use_sil_for_loss=True, 64 | sil_thres=0.99, 65 | use_l1=True, 66 | ignore_outlier_depth_loss=False, 67 | use_uncertainty_for_loss_mask=False, 68 | use_uncertainty_for_loss=False, 69 | use_chamfer=False, 70 | loss_weights=dict( 71 | im=0.5, 72 | depth=1.0, 73 | ), 74 | lrs=dict( 75 | means3D=0.0, 76 | rgb_colors=0.0, 77 | unnorm_rotations=0.0, 78 | logit_opacities=0.0, 79 | log_scales=0.0, 80 | cam_unnorm_rots=0.002, 81 | cam_trans=0.002, 82 | ), 83 | ), 84 | mapping=dict( 85 | num_iters=mapping_iters, 86 | add_new_gaussians=True, 87 | sil_thres=0.5, # For Addition of new Gaussians 88 | use_l1=True, 89 | use_sil_for_loss=False, 90 | ignore_outlier_depth_loss=False, 91 | use_uncertainty_for_loss_mask=False, 92 | use_uncertainty_for_loss=False, 93 | use_chamfer=False, 94 | loss_weights=dict( 95 | im=0.5, 96 | depth=1.0, 97 | ), 98 | lrs=dict( 99 | means3D=0.0001, 100 | rgb_colors=0.0025, 101 | unnorm_rotations=0.001, 102 | logit_opacities=0.05, 103 | log_scales=0.001, 104 | cam_unnorm_rots=0.0000, 105 | cam_trans=0.0000, 106 | ), 107 | prune_gaussians=True, # Prune Gaussians during Mapping 108 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 109 | start_after=0, 110 | remove_big_after=0, 111 | stop_after=20, 112 | prune_every=20, 113 | removal_opacity_threshold=0.005, 114 | final_removal_opacity_threshold=0.005, 115 | reset_opacities=False, 116 | reset_opacities_every=500, # Doesn't consider iter 0 117 | ), 118 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 119 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 120 | start_after=500, 121 | remove_big_after=3000, 122 | stop_after=5000, 123 | densify_every=100, 124 | grad_thresh=0.0002, 125 | num_to_split_into=2, 126 | removal_opacity_threshold=0.005, 127 | final_removal_opacity_threshold=0.005, 128 | reset_opacities_every=3000, # Doesn't consider iter 0 129 | ), 130 | ), 131 | viz=dict( 132 | render_mode='color', # ['color', 'depth' or 'centers'] 133 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 134 | show_sil=False, # Show Silhouette instead of RGB 135 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 136 | viz_w=600, viz_h=340, 137 | viz_near=0.01, viz_far=100.0, 138 | view_scale=2, 139 | viz_fps=5, # FPS for Online Recon Viz 140 | enter_interactive_post_online=False, # Enter Interactive Mode after Online Recon Viz 141 | ), 142 | ) -------------------------------------------------------------------------------- /configs/tum/tum.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for seed in 0 1 2 4 | do 5 | SEED=${seed} 6 | export SEED 7 | for scene in 0 1 2 3 4 8 | do 9 | SCENE_NUM=${scene} 10 | export SCENE_NUM 11 | echo "Running scene number ${SCENE_NUM} with seed ${SEED}" 12 | python3 -u scripts/splatam.py configs/tum/tum_eval.py 13 | done 14 | done -------------------------------------------------------------------------------- /configs/tum/tum_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as p_join 3 | 4 | primary_device = "cuda:0" 5 | 6 | scenes = ["freiburg1_desk", "freiburg1_desk2", "freiburg1_room", "freiburg2_xyz", "freiburg3_long_office_household"] 7 | 8 | seed = int(os.environ["SEED"]) 9 | scene_name = scenes[int(os.environ["SCENE_NUM"])] 10 | 11 | map_every = 1 12 | keyframe_every = 5 13 | mapping_window_size = 20 14 | tracking_iters = 200 15 | mapping_iters = 30 16 | scene_radius_depth_ratio = 2 17 | 18 | group_name = "TUM" 19 | run_name = f"{scene_name}_seed{seed}" 20 | 21 | config = dict( 22 | workdir=f"./experiments/{group_name}", 23 | run_name=run_name, 24 | seed=seed, 25 | primary_device=primary_device, 26 | map_every=map_every, # Mapping every nth frame 27 | keyframe_every=keyframe_every, # Keyframe every nth frame 28 | mapping_window_size=mapping_window_size, # Mapping window size 29 | report_global_progress_every=500, # Report Global Progress every nth frame 30 | eval_every=500, # Evaluate every nth frame (at end of SLAM) 31 | scene_radius_depth_ratio=scene_radius_depth_ratio, # Max First Frame Depth to Scene Radius Ratio (For Pruning/Densification) 32 | mean_sq_dist_method="projective", # ["projective", "knn"] (Type of Mean Squared Distance Calculation for Scale of Gaussians) 33 | gaussian_distribution="isotropic", # ["isotropic", "anisotropic"] (Isotropic -> Spherical Covariance, Anisotropic -> Ellipsoidal Covariance) 34 | report_iter_progress=False, 35 | load_checkpoint=False, 36 | checkpoint_time_idx=0, 37 | save_checkpoints=False, # Save Checkpoints 38 | checkpoint_interval=100, # Checkpoint Interval 39 | use_wandb=True, 40 | wandb=dict( 41 | entity="theairlab", 42 | project="SplaTAM", 43 | group=group_name, 44 | name=run_name, 45 | save_qual=False, 46 | eval_save_qual=True, 47 | ), 48 | data=dict( 49 | basedir="./data/TUM_RGBD", 50 | gradslam_data_cfg=f"./configs/data/TUM/{scene_name}.yaml", 51 | sequence=f"rgbd_dataset_{scene_name}", 52 | desired_image_height=480, 53 | desired_image_width=640, 54 | start=0, 55 | end=-1, 56 | stride=1, 57 | num_frames=-1, 58 | ), 59 | tracking=dict( 60 | use_gt_poses=False, # Use GT Poses for Tracking 61 | forward_prop=True, # Forward Propagate Poses 62 | num_iters=tracking_iters, 63 | use_sil_for_loss=True, 64 | sil_thres=0.99, 65 | use_l1=True, 66 | ignore_outlier_depth_loss=False, 67 | use_uncertainty_for_loss_mask=False, 68 | use_uncertainty_for_loss=False, 69 | use_chamfer=False, 70 | loss_weights=dict( 71 | im=0.5, 72 | depth=1.0, 73 | ), 74 | lrs=dict( 75 | means3D=0.0, 76 | rgb_colors=0.0, 77 | unnorm_rotations=0.0, 78 | logit_opacities=0.0, 79 | log_scales=0.0, 80 | cam_unnorm_rots=0.002, 81 | cam_trans=0.002, 82 | ), 83 | ), 84 | mapping=dict( 85 | num_iters=mapping_iters, 86 | add_new_gaussians=True, 87 | sil_thres=0.5, # For Addition of new Gaussians 88 | use_l1=True, 89 | use_sil_for_loss=False, 90 | use_uncertainty_for_loss_mask=False, 91 | use_uncertainty_for_loss=False, 92 | use_chamfer=False, 93 | loss_weights=dict( 94 | im=0.5, 95 | depth=1.0, 96 | ), 97 | lrs=dict( 98 | means3D=0.0001, 99 | rgb_colors=0.0025, 100 | unnorm_rotations=0.001, 101 | logit_opacities=0.05, 102 | log_scales=0.001, 103 | cam_unnorm_rots=0.0000, 104 | cam_trans=0.0000, 105 | ), 106 | prune_gaussians=True, # Prune Gaussians during Mapping 107 | pruning_dict=dict( # Needs to be updated based on the number of mapping iterations 108 | start_after=0, 109 | remove_big_after=0, 110 | stop_after=20, 111 | prune_every=20, 112 | removal_opacity_threshold=0.005, 113 | final_removal_opacity_threshold=0.005, 114 | reset_opacities=False, 115 | reset_opacities_every=500, # Doesn't consider iter 0 116 | ), 117 | use_gaussian_splatting_densification=False, # Use Gaussian Splatting-based Densification during Mapping 118 | densify_dict=dict( # Needs to be updated based on the number of mapping iterations 119 | start_after=500, 120 | remove_big_after=3000, 121 | stop_after=5000, 122 | densify_every=100, 123 | grad_thresh=0.0002, 124 | num_to_split_into=2, 125 | removal_opacity_threshold=0.005, 126 | final_removal_opacity_threshold=0.005, 127 | reset_opacities_every=3000, # Doesn't consider iter 0 128 | ), 129 | ), 130 | viz=dict( 131 | render_mode='color', # ['color', 'depth' or 'centers'] 132 | offset_first_viz_cam=True, # Offsets the view camera back by 0.5 units along the view direction (For Final Recon Viz) 133 | show_sil=False, # Show Silhouette instead of RGB 134 | visualize_cams=True, # Visualize Camera Frustums and Trajectory 135 | viz_w=600, viz_h=340, 136 | viz_near=0.01, viz_far=100.0, 137 | view_scale=2, 138 | viz_fps=5, # FPS for Online Recon Viz 139 | enter_interactive_post_online=False, # Enter Interactive Mode after Online Recon Viz 140 | ), 141 | ) -------------------------------------------------------------------------------- /datasets/_init_.py: -------------------------------------------------------------------------------- 1 | from .gradslam_datasets import * -------------------------------------------------------------------------------- /datasets/gradslam_datasets/README.md: -------------------------------------------------------------------------------- 1 | # GradSLAM Datasets 2 | 3 | This folder contains the dataloaders used for ConceptFusion (GradSLAM). 4 | 5 | Source Code: https://github.com/gradslam/gradslam/pull/58 -------------------------------------------------------------------------------- /datasets/gradslam_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .azure import AzureKinectDataset 2 | from .basedataset import GradSLAMDataset 3 | from .dataconfig import load_dataset_config 4 | from .datautils import * 5 | from .icl import ICLDataset 6 | from .replica import ReplicaDataset, ReplicaV2Dataset 7 | from .scannet import ScannetDataset 8 | from .ai2thor import Ai2thorDataset 9 | from .realsense import RealsenseDataset 10 | from .record3d import Record3DDataset 11 | from .tum import TUMDataset 12 | from .scannetpp import ScannetPPDataset 13 | from .nerfcapture import NeRFCaptureDataset -------------------------------------------------------------------------------- /datasets/gradslam_datasets/ai2thor.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import cv2 7 | import imageio.v2 as imageio 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from natsort import natsorted 12 | 13 | from .basedataset import GradSLAMDataset 14 | 15 | 16 | class Ai2thorDataset(GradSLAMDataset): 17 | def __init__( 18 | self, 19 | config_dict, 20 | basedir, 21 | sequence, 22 | stride: Optional[int] = None, 23 | start: Optional[int] = 0, 24 | end: Optional[int] = -1, 25 | desired_height: Optional[int] = 968, 26 | desired_width: Optional[int] = 1296, 27 | load_embeddings: Optional[bool] = False, 28 | embedding_dir: Optional[str] = "embeddings", 29 | embedding_dim: Optional[int] = 512, 30 | **kwargs, 31 | ): 32 | self.input_folder = os.path.join(basedir, sequence) 33 | super().__init__( 34 | config_dict, 35 | stride=stride, 36 | start=start, 37 | end=end, 38 | desired_height=desired_height, 39 | desired_width=desired_width, 40 | load_embeddings=load_embeddings, 41 | embedding_dir=embedding_dir, 42 | embedding_dim=embedding_dim, 43 | **kwargs, 44 | ) 45 | 46 | def get_filepaths(self): 47 | color_paths = natsorted(glob.glob(f"{self.input_folder}/color/*.png")) 48 | depth_paths = natsorted(glob.glob(f"{self.input_folder}/depth/*.png")) 49 | embedding_paths = None 50 | if self.load_embeddings: 51 | if self.embedding_dir == "embed_semseg": 52 | # embed_semseg is stored as uint16 pngs 53 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.png")) 54 | else: 55 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 56 | return color_paths, depth_paths, embedding_paths 57 | 58 | def load_poses(self): 59 | poses = [] 60 | posefiles = natsorted(glob.glob(f"{self.input_folder}/pose/*.txt")) 61 | for posefile in posefiles: 62 | _pose = torch.from_numpy(np.loadtxt(posefile)) 63 | poses.append(_pose) 64 | return poses 65 | 66 | def read_embedding_from_file(self, embedding_file_path): 67 | if self.embedding_dir == "embed_semseg": 68 | embedding = imageio.imread(embedding_file_path) # (H, W) 69 | embedding = cv2.resize( 70 | embedding, (self.desired_width, self.desired_height), interpolation=cv2.INTER_NEAREST 71 | ) 72 | embedding = torch.from_numpy(embedding).long() # (H, W) 73 | embedding = F.one_hot(embedding, num_classes=self.embedding_dim) # (H, W, C) 74 | embedding = embedding.half() # (H, W, C) 75 | embedding = embedding.permute(2, 0, 1) # (C, H, W) 76 | embedding = embedding.unsqueeze(0) # (1, C, H, W) 77 | else: 78 | embedding = torch.load(embedding_file_path, map_location="cpu") 79 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) -------------------------------------------------------------------------------- /datasets/gradslam_datasets/azure.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from natsort import natsorted 9 | 10 | from .basedataset import GradSLAMDataset 11 | 12 | 13 | class AzureKinectDataset(GradSLAMDataset): 14 | def __init__( 15 | self, 16 | config_dict, 17 | basedir, 18 | sequence, 19 | stride: Optional[int] = None, 20 | start: Optional[int] = 0, 21 | end: Optional[int] = -1, 22 | desired_height: Optional[int] = 480, 23 | desired_width: Optional[int] = 640, 24 | load_embeddings: Optional[bool] = False, 25 | embedding_dir: Optional[str] = "embeddings", 26 | embedding_dim: Optional[int] = 512, 27 | **kwargs, 28 | ): 29 | self.input_folder = os.path.join(basedir, sequence) 30 | self.pose_path = None 31 | 32 | # # check if a file named 'poses_global_dvo.txt' exists in the basedir / sequence folder 33 | # if os.path.isfile(os.path.join(basedir, sequence, "poses_global_dvo.txt")): 34 | # self.pose_path = os.path.join(basedir, sequence, "poses_global_dvo.txt") 35 | 36 | if "odomfile" in kwargs.keys(): 37 | self.pose_path = os.path.join(self.input_folder, kwargs["odomfile"]) 38 | super().__init__( 39 | config_dict, 40 | stride=stride, 41 | start=start, 42 | end=end, 43 | desired_height=desired_height, 44 | desired_width=desired_width, 45 | load_embeddings=load_embeddings, 46 | embedding_dir=embedding_dir, 47 | embedding_dim=embedding_dim, 48 | **kwargs, 49 | ) 50 | 51 | def get_filepaths(self): 52 | color_paths = natsorted(glob.glob(f"{self.input_folder}/color/*.jpg")) 53 | depth_paths = natsorted(glob.glob(f"{self.input_folder}/depth/*.png")) 54 | embedding_paths = None 55 | if self.load_embeddings: 56 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 57 | return color_paths, depth_paths, embedding_paths 58 | 59 | def load_poses(self): 60 | if self.pose_path is None: 61 | print("WARNING: Dataset does not contain poses. Returning identity transform.") 62 | return [torch.eye(4).float() for _ in range(self.num_imgs)] 63 | else: 64 | # Determine whether the posefile ends in ".log" 65 | # a .log file has the following format for each frame 66 | # frame_idx frame_idx+1 67 | # row 1 of 4x4 transform 68 | # row 2 of 4x4 transform 69 | # row 3 of 4x4 transform 70 | # row 4 of 4x4 transform 71 | # [repeat for all frames] 72 | # 73 | # on the other hand, the "poses_o3d.txt" or "poses_dvo.txt" files have the format 74 | # 16 entries of 4x4 transform 75 | # [repeat for all frames] 76 | if self.pose_path.endswith(".log"): 77 | # print("Loading poses from .log format") 78 | poses = [] 79 | lines = None 80 | with open(self.pose_path, "r") as f: 81 | lines = f.readlines() 82 | if len(lines) % 5 != 0: 83 | raise ValueError( 84 | "Incorrect file format for .log odom file " "Number of non-empty lines must be a multiple of 5" 85 | ) 86 | num_lines = len(lines) // 5 87 | for i in range(0, num_lines): 88 | _curpose = [] 89 | _curpose.append(list(map(float, lines[5 * i + 1].split()))) 90 | _curpose.append(list(map(float, lines[5 * i + 2].split()))) 91 | _curpose.append(list(map(float, lines[5 * i + 3].split()))) 92 | _curpose.append(list(map(float, lines[5 * i + 4].split()))) 93 | _curpose = np.array(_curpose).reshape(4, 4) 94 | poses.append(torch.from_numpy(_curpose)) 95 | else: 96 | poses = [] 97 | lines = None 98 | with open(self.pose_path, "r") as f: 99 | lines = f.readlines() 100 | for line in lines: 101 | if len(line.split()) == 0: 102 | continue 103 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 104 | poses.append(torch.from_numpy(c2w)) 105 | return poses 106 | 107 | def read_embedding_from_file(self, embedding_file_path): 108 | embedding = torch.load(embedding_file_path) 109 | return embedding # .permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 110 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/dataconfig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | 4 | 5 | def load_dataset_config(path, default_path=None): 6 | """ 7 | Loads config file. 8 | 9 | Args: 10 | path (str): path to config file. 11 | default_path (str, optional): whether to use default path. Defaults to None. 12 | 13 | Returns: 14 | cfg (dict): config dict. 15 | 16 | """ 17 | # load configuration from file itself 18 | with open(path, "r") as f: 19 | cfg_special = yaml.full_load(f) 20 | 21 | # check if we should inherit from a config 22 | inherit_from = cfg_special.get("inherit_from") 23 | 24 | # if yes, load this config first as default 25 | # if no, use the default_path 26 | if inherit_from is not None: 27 | cfg = load_dataset_config(inherit_from, default_path) 28 | elif default_path is not None: 29 | with open(default_path, "r") as f: 30 | cfg = yaml.full_load(f) 31 | else: 32 | cfg = dict() 33 | 34 | # include main configuration 35 | update_recursive(cfg, cfg_special) 36 | 37 | return cfg 38 | 39 | 40 | def update_recursive(dict1, dict2): 41 | """ 42 | Update two config dictionaries recursively. 43 | 44 | Args: 45 | dict1 (dict): first dictionary to be updated. 46 | dict2 (dict): second dictionary which entries should be used. 47 | """ 48 | for k, v in dict2.items(): 49 | if k not in dict1: 50 | dict1[k] = dict() 51 | if isinstance(v, dict): 52 | update_recursive(dict1[k], v) 53 | else: 54 | dict1[k] = v 55 | 56 | 57 | def common_dataset_to_batch(dataset): 58 | colors, depths, poses = [], [], [] 59 | intrinsics, embeddings = None, None 60 | for idx in range(len(dataset)): 61 | _color, _depth, intrinsics, _pose, _embedding = dataset[idx] 62 | colors.append(_color) 63 | depths.append(_depth) 64 | poses.append(_pose) 65 | if _embedding is not None: 66 | if embeddings is None: 67 | embeddings = [_embedding] 68 | else: 69 | embeddings.append(_embedding) 70 | colors = torch.stack(colors) 71 | depths = torch.stack(depths) 72 | poses = torch.stack(poses) 73 | if embeddings is not None: 74 | embeddings = torch.stack(embeddings, dim=1) 75 | # # (1, NUM_IMG, DIM_EMBED, H, W) -> (1, NUM_IMG, H, W, DIM_EMBED) 76 | # embeddings = embeddings.permute(0, 1, 3, 4, 2) 77 | colors = colors.unsqueeze(0) 78 | depths = depths.unsqueeze(0) 79 | intrinsics = intrinsics.unsqueeze(0).unsqueeze(0) 80 | poses = poses.unsqueeze(0) 81 | colors = colors.float() 82 | depths = depths.float() 83 | intrinsics = intrinsics.float() 84 | poses = poses.float() 85 | if embeddings is not None: 86 | embeddings = embeddings.float() 87 | return colors, depths, intrinsics, poses, embeddings 88 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/datautils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | from collections import OrderedDict 4 | from typing import List, Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | __all__ = [ 10 | "normalize_image", 11 | "channels_first", 12 | "scale_intrinsics", 13 | "pointquaternion_to_homogeneous", 14 | "poses_to_transforms", 15 | "create_label_image", 16 | ] 17 | 18 | 19 | def normalize_image(rgb: Union[torch.Tensor, np.ndarray]): 20 | r"""Normalizes RGB image values from :math:`[0, 255]` range to :math:`[0, 1]` range. 21 | 22 | Args: 23 | rgb (torch.Tensor or numpy.ndarray): RGB image in range :math:`[0, 255]` 24 | 25 | Returns: 26 | torch.Tensor or numpy.ndarray: Normalized RGB image in range :math:`[0, 1]` 27 | 28 | Shape: 29 | - rgb: :math:`(*)` (any shape) 30 | - Output: Same shape as input :math:`(*)` 31 | """ 32 | if torch.is_tensor(rgb): 33 | return rgb.float() / 255 34 | elif isinstance(rgb, np.ndarray): 35 | return rgb.astype(float) / 255 36 | else: 37 | raise TypeError("Unsupported input rgb type: %r" % type(rgb)) 38 | 39 | 40 | def channels_first(rgb: Union[torch.Tensor, np.ndarray]): 41 | r"""Converts from channels last representation :math:`(*, H, W, C)` to channels first representation 42 | :math:`(*, C, H, W)` 43 | 44 | Args: 45 | rgb (torch.Tensor or numpy.ndarray): :math:`(*, H, W, C)` ordering `(*, height, width, channels)` 46 | 47 | Returns: 48 | torch.Tensor or numpy.ndarray: :math:`(*, C, H, W)` ordering 49 | 50 | Shape: 51 | - rgb: :math:`(*, H, W, C)` 52 | - Output: :math:`(*, C, H, W)` 53 | """ 54 | if not (isinstance(rgb, np.ndarray) or torch.is_tensor(rgb)): 55 | raise TypeError("Unsupported input rgb type {}".format(type(rgb))) 56 | 57 | if rgb.ndim < 3: 58 | raise ValueError( 59 | "Input rgb must contain atleast 3 dims, but had {} dims.".format(rgb.ndim) 60 | ) 61 | if rgb.shape[-3] < rgb.shape[-1]: 62 | msg = "Are you sure that the input is correct? Number of channels exceeds height of image: %r > %r" 63 | warnings.warn(msg % (rgb.shape[-1], rgb.shape[-3])) 64 | ordering = list(range(rgb.ndim)) 65 | ordering[-2], ordering[-1], ordering[-3] = ordering[-3], ordering[-2], ordering[-1] 66 | 67 | if isinstance(rgb, np.ndarray): 68 | return np.ascontiguousarray(rgb.transpose(*ordering)) 69 | elif torch.is_tensor(rgb): 70 | return rgb.permute(*ordering).contiguous() 71 | 72 | 73 | def scale_intrinsics( 74 | intrinsics: Union[np.ndarray, torch.Tensor], 75 | h_ratio: Union[float, int], 76 | w_ratio: Union[float, int], 77 | ): 78 | r"""Scales the intrinsics appropriately for resized frames where 79 | :math:`h_\text{ratio} = h_\text{new} / h_\text{old}` and :math:`w_\text{ratio} = w_\text{new} / w_\text{old}` 80 | 81 | Args: 82 | intrinsics (numpy.ndarray or torch.Tensor): Intrinsics matrix of original frame 83 | h_ratio (float or int): Ratio of new frame's height to old frame's height 84 | :math:`h_\text{ratio} = h_\text{new} / h_\text{old}` 85 | w_ratio (float or int): Ratio of new frame's width to old frame's width 86 | :math:`w_\text{ratio} = w_\text{new} / w_\text{old}` 87 | 88 | Returns: 89 | numpy.ndarray or torch.Tensor: Intrinsics matrix scaled approprately for new frame size 90 | 91 | Shape: 92 | - intrinsics: :math:`(*, 3, 3)` or :math:`(*, 4, 4)` 93 | - Output: Matches `intrinsics` shape, :math:`(*, 3, 3)` or :math:`(*, 4, 4)` 94 | 95 | """ 96 | if isinstance(intrinsics, np.ndarray): 97 | scaled_intrinsics = intrinsics.astype(np.float32).copy() 98 | elif torch.is_tensor(intrinsics): 99 | scaled_intrinsics = intrinsics.to(torch.float).clone() 100 | else: 101 | raise TypeError("Unsupported input intrinsics type {}".format(type(intrinsics))) 102 | if not (intrinsics.shape[-2:] == (3, 3) or intrinsics.shape[-2:] == (4, 4)): 103 | raise ValueError( 104 | "intrinsics must have shape (*, 3, 3) or (*, 4, 4), but had shape {} instead".format( 105 | intrinsics.shape 106 | ) 107 | ) 108 | if (intrinsics[..., -1, -1] != 1).any() or (intrinsics[..., 2, 2] != 1).any(): 109 | warnings.warn( 110 | "Incorrect intrinsics: intrinsics[..., -1, -1] and intrinsics[..., 2, 2] should be 1." 111 | ) 112 | 113 | scaled_intrinsics[..., 0, 0] *= w_ratio # fx 114 | scaled_intrinsics[..., 1, 1] *= h_ratio # fy 115 | scaled_intrinsics[..., 0, 2] *= w_ratio # cx 116 | scaled_intrinsics[..., 1, 2] *= h_ratio # cy 117 | return scaled_intrinsics 118 | 119 | 120 | def pointquaternion_to_homogeneous( 121 | pointquaternions: Union[np.ndarray, torch.Tensor], eps: float = 1e-12 122 | ): 123 | r"""Converts 3D point and unit quaternions :math:`(t_x, t_y, t_z, q_x, q_y, q_z, q_w)` to 124 | homogeneous transformations [R | t] where :math:`R` denotes the :math:`(3, 3)` rotation matrix and :math:`T` 125 | denotes the :math:`(3, 1)` translation matrix: 126 | 127 | .. math:: 128 | 129 | \left[\begin{array}{@{}c:c@{}} 130 | R & T \\ \hdashline 131 | \begin{array}{@{}ccc@{}} 132 | 0 & 0 & 0 133 | \end{array} & 1 134 | \end{array}\right] 135 | 136 | Args: 137 | pointquaternions (numpy.ndarray or torch.Tensor): 3D point positions and unit quaternions 138 | :math:`(tx, ty, tz, qx, qy, qz, qw)` where :math:`(tx, ty, tz)` is the 3D position and 139 | :math:`(qx, qy, qz, qw)` is the unit quaternion. 140 | eps (float): Small value, to avoid division by zero. Default: 1e-12 141 | 142 | Returns: 143 | numpy.ndarray or torch.Tensor: Homogeneous transformation matrices. 144 | 145 | Shape: 146 | - pointquaternions: :math:`(*, 7)` 147 | - Output: :math:`(*, 4, 4)` 148 | 149 | """ 150 | if not ( 151 | isinstance(pointquaternions, np.ndarray) or torch.is_tensor(pointquaternions) 152 | ): 153 | raise TypeError( 154 | '"pointquaternions" must be of type "np.ndarray" or "torch.Tensor". Got {0}'.format( 155 | type(pointquaternions) 156 | ) 157 | ) 158 | if not isinstance(eps, float): 159 | raise TypeError('"eps" must be of type "float". Got {0}.'.format(type(eps))) 160 | if pointquaternions.shape[-1] != 7: 161 | raise ValueError( 162 | '"pointquaternions" must be of shape (*, 7). Got {0}.'.format( 163 | pointquaternions.shape 164 | ) 165 | ) 166 | 167 | output_shape = (*pointquaternions.shape[:-1], 4, 4) 168 | if isinstance(pointquaternions, np.ndarray): 169 | t = pointquaternions[..., :3].astype(np.float32) 170 | q = pointquaternions[..., 3:7].astype(np.float32) 171 | transform = np.zeros(output_shape, dtype=np.float32) 172 | else: 173 | t = pointquaternions[..., :3].float() 174 | q = pointquaternions[..., 3:7].float() 175 | transform = torch.zeros( 176 | output_shape, dtype=torch.float, device=pointquaternions.device 177 | ) 178 | 179 | q_norm = (0.5 * (q ** 2).sum(-1)[..., None]) ** 0.5 180 | q /= ( 181 | torch.max(q_norm, torch.tensor(eps)) 182 | if torch.is_tensor(q_norm) 183 | else np.maximum(q_norm, eps) 184 | ) 185 | 186 | if isinstance(q, np.ndarray): 187 | q = np.matmul(q[..., None], q[..., None, :]) 188 | else: 189 | q = torch.matmul(q.unsqueeze(-1), q.unsqueeze(-2)) 190 | 191 | txx = q[..., 0, 0] 192 | tyy = q[..., 1, 1] 193 | tzz = q[..., 2, 2] 194 | txy = q[..., 0, 1] 195 | txz = q[..., 0, 2] 196 | tyz = q[..., 1, 2] 197 | twx = q[..., 0, 3] 198 | twy = q[..., 1, 3] 199 | twz = q[..., 2, 3] 200 | transform[..., 0, 0] = 1.0 201 | transform[..., 1, 1] = 1.0 202 | transform[..., 2, 2] = 1.0 203 | transform[..., 3, 3] = 1.0 204 | transform[..., 0, 0] -= tyy + tzz 205 | transform[..., 0, 1] = txy - twz 206 | transform[..., 0, 2] = txz + twy 207 | transform[..., 1, 0] = txy + twz 208 | transform[..., 1, 1] -= txx + tzz 209 | transform[..., 1, 2] = tyz - twx 210 | transform[..., 2, 0] = txz - twy 211 | transform[..., 2, 1] = tyz + twx 212 | transform[..., 2, 2] -= txx + tyy 213 | transform[..., :3, 3] = t 214 | 215 | return transform 216 | 217 | 218 | def poses_to_transforms(poses: Union[np.ndarray, List[np.ndarray]]): 219 | r"""Converts poses to transformations w.r.t. the first frame in the sequence having identity pose 220 | 221 | Args: 222 | poses (numpy.ndarray or list of numpy.ndarray): Sequence of poses in `numpy.ndarray` format. 223 | 224 | Returns: 225 | numpy.ndarray or list of numpy.ndarray: Sequence of frame to frame transformations where initial 226 | frame is transformed to have identity pose. 227 | 228 | Shape: 229 | - poses: Could be `numpy.ndarray` of shape :math:`(N, 4, 4)`, or list of `numpy.ndarray`s of shape 230 | :math:`(4, 4)` 231 | - Output: Of same shape as input `poses` 232 | """ 233 | transformations = copy.deepcopy(poses) 234 | for i in range(len(poses)): 235 | if i == 0: 236 | transformations[i] = np.eye(4) 237 | else: 238 | transformations[i] = np.linalg.inv(poses[i - 1]).dot(poses[i]) 239 | return transformations 240 | 241 | 242 | def create_label_image(prediction: np.ndarray, color_palette: OrderedDict): 243 | r"""Creates a label image, given a network prediction (each pixel contains class index) and a color palette. 244 | 245 | Args: 246 | prediction (numpy.ndarray): Predicted image where each pixel contains an integer, 247 | corresponding to its class label. 248 | color_palette (OrderedDict): Contains RGB colors (`uint8`) for each class. 249 | 250 | Returns: 251 | numpy.ndarray: Label image with the given color palette 252 | 253 | Shape: 254 | - prediction: :math:`(H, W)` 255 | - Output: :math:`(H, W)` 256 | """ 257 | 258 | label_image = np.zeros( 259 | (prediction.shape[0], prediction.shape[1], 3), dtype=np.uint8 260 | ) 261 | for idx, color in enumerate(color_palette): 262 | label_image[prediction == idx] = color 263 | return label_image 264 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/icl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from natsort import natsorted 9 | 10 | from .basedataset import GradSLAMDataset 11 | 12 | 13 | class ICLDataset(GradSLAMDataset): 14 | def __init__( 15 | self, 16 | config_dict: Dict, 17 | basedir: Union[Path, str], 18 | sequence: Union[Path, str], 19 | stride: Optional[int] = 1, 20 | start: Optional[int] = 0, 21 | end: Optional[int] = -1, 22 | desired_height: Optional[int] = 480, 23 | desired_width: Optional[int] = 640, 24 | load_embeddings: Optional[bool] = False, 25 | embedding_dir: Optional[Union[Path, str]] = "embeddings", 26 | embedding_dim: Optional[int] = 512, 27 | embedding_file_extension: Optional[str] = "pt", 28 | **kwargs, 29 | ): 30 | self.input_folder = os.path.join(basedir, sequence) 31 | # Attempt to find pose file (*.gt.sim) 32 | self.pose_path = glob.glob(os.path.join(self.input_folder, "*.gt.sim")) 33 | if self.pose_path == 0: 34 | raise ValueError("Need pose file ending in extension `*.gt.sim`") 35 | self.pose_path = self.pose_path[0] 36 | self.embedding_file_extension = embedding_file_extension 37 | super().__init__( 38 | config_dict, 39 | stride=stride, 40 | start=start, 41 | end=end, 42 | desired_height=desired_height, 43 | desired_width=desired_width, 44 | load_embeddings=load_embeddings, 45 | embedding_dir=embedding_dir, 46 | embedding_dim=embedding_dim, 47 | **kwargs, 48 | ) 49 | 50 | def get_filepaths(self): 51 | color_paths = natsorted(glob.glob(f"{self.input_folder}/rgb/*.png")) 52 | depth_paths = natsorted(glob.glob(f"{self.input_folder}/depth/*.png")) 53 | embedding_paths = None 54 | if self.load_embeddings: 55 | embedding_paths = natsorted( 56 | glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.{self.embedding_file_extension}") 57 | ) 58 | return color_paths, depth_paths, embedding_paths 59 | 60 | def load_poses(self): 61 | poses = [] 62 | 63 | lines = [] 64 | with open(self.pose_path, "r") as f: 65 | lines = f.readlines() 66 | 67 | _posearr = [] 68 | for line in lines: 69 | line = line.strip().split() 70 | if len(line) == 0: 71 | continue 72 | _npvec = np.asarray([float(line[0]), float(line[1]), float(line[2]), float(line[3])]) 73 | _posearr.append(_npvec) 74 | _posearr = np.stack(_posearr) 75 | 76 | for pose_line_idx in range(0, _posearr.shape[0], 3): 77 | _curpose = np.zeros((4, 4)) 78 | _curpose[3, 3] = 3 79 | _curpose[0] = _posearr[pose_line_idx] 80 | _curpose[1] = _posearr[pose_line_idx + 1] 81 | _curpose[2] = _posearr[pose_line_idx + 2] 82 | poses.append(torch.from_numpy(_curpose).float()) 83 | 84 | return poses 85 | 86 | def read_embedding_from_file(self, embedding_file_path): 87 | embedding = torch.load(embedding_file_path) 88 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 89 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/nerfcapture.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | from pathlib import Path 5 | from typing import Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from natsort import natsorted 10 | 11 | from .basedataset import GradSLAMDataset 12 | 13 | 14 | def create_filepath_index_mapping(frames): 15 | return {frame["file_path"]: index for index, frame in enumerate(frames)} 16 | 17 | 18 | class NeRFCaptureDataset(GradSLAMDataset): 19 | def __init__( 20 | self, 21 | basedir, 22 | sequence, 23 | stride: Optional[int] = None, 24 | start: Optional[int] = 0, 25 | end: Optional[int] = -1, 26 | desired_height: Optional[int] = 1440, 27 | desired_width: Optional[int] = 1920, 28 | load_embeddings: Optional[bool] = False, 29 | embedding_dir: Optional[str] = "embeddings", 30 | embedding_dim: Optional[int] = 512, 31 | **kwargs, 32 | ): 33 | self.input_folder = os.path.join(basedir, sequence) 34 | config_dict = {} 35 | config_dict["dataset_name"] = "nerfcapture" 36 | self.pose_path = None 37 | 38 | # Load NeRFStudio format camera & poses data 39 | self.cams_metadata = self.load_cams_metadata() 40 | self.frames_metadata = self.cams_metadata["frames"] 41 | self.filepath_index_mapping = create_filepath_index_mapping(self.frames_metadata) 42 | 43 | # Load RGB & Depth filepaths 44 | self.image_names = natsorted(os.listdir(f"{self.input_folder}/rgb")) 45 | self.image_names = [f'rgb/{image_name}' for image_name in self.image_names] 46 | 47 | # Init Intrinsics 48 | config_dict["camera_params"] = {} 49 | config_dict["camera_params"]["png_depth_scale"] = 6553.5 # Depth is in mm 50 | config_dict["camera_params"]["image_height"] = self.cams_metadata["h"] 51 | config_dict["camera_params"]["image_width"] = self.cams_metadata["w"] 52 | config_dict["camera_params"]["fx"] = self.cams_metadata["fl_x"] 53 | config_dict["camera_params"]["fy"] = self.cams_metadata["fl_y"] 54 | config_dict["camera_params"]["cx"] = self.cams_metadata["cx"] 55 | config_dict["camera_params"]["cy"] = self.cams_metadata["cy"] 56 | 57 | super().__init__( 58 | config_dict, 59 | stride=stride, 60 | start=start, 61 | end=end, 62 | desired_height=desired_height, 63 | desired_width=desired_width, 64 | load_embeddings=load_embeddings, 65 | embedding_dir=embedding_dir, 66 | embedding_dim=embedding_dim, 67 | **kwargs, 68 | ) 69 | 70 | def load_cams_metadata(self): 71 | cams_metadata_path = f"{self.input_folder}/transforms.json" 72 | cams_metadata = json.load(open(cams_metadata_path, "r")) 73 | return cams_metadata 74 | 75 | def get_filepaths(self): 76 | base_path = f"{self.input_folder}" 77 | color_paths = [] 78 | depth_paths = [] 79 | self.tmp_poses = [] 80 | P = torch.tensor( 81 | [ 82 | [1, 0, 0, 0], 83 | [0, -1, 0, 0], 84 | [0, 0, -1, 0], 85 | [0, 0, 0, 1] 86 | ] 87 | ).float() 88 | for image_name in self.image_names: 89 | # Search for image name in frames_metadata 90 | frame_metadata = self.frames_metadata[self.filepath_index_mapping.get(image_name)] 91 | # Get path of image and depth 92 | color_path = f"{base_path}/{image_name}" 93 | depth_path = f"{base_path}/{image_name.replace('rgb', 'depth')}" 94 | color_paths.append(color_path) 95 | depth_paths.append(depth_path) 96 | # Get pose of image in GradSLAM format 97 | c2w = torch.from_numpy(np.array(frame_metadata["transform_matrix"])).float() 98 | _pose = P @ c2w @ P.T 99 | self.tmp_poses.append(_pose) 100 | embedding_paths = None 101 | if self.load_embeddings: 102 | embedding_paths = natsorted(glob.glob(f"{base_path}/{self.embedding_dir}/*.pt")) 103 | return color_paths, depth_paths, embedding_paths 104 | 105 | def load_poses(self): 106 | return self.tmp_poses 107 | 108 | def read_embedding_from_file(self, embedding_file_path): 109 | print(embedding_file_path) 110 | embedding = torch.load(embedding_file_path, map_location="cpu") 111 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 112 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/realsense.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from natsort import natsorted 9 | 10 | from .basedataset import GradSLAMDataset 11 | 12 | 13 | class RealsenseDataset(GradSLAMDataset): 14 | """ 15 | Dataset class to process depth images captured by realsense camera on the tabletop manipulator 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config_dict, 21 | basedir, 22 | sequence, 23 | stride: Optional[int] = None, 24 | start: Optional[int] = 0, 25 | end: Optional[int] = -1, 26 | desired_height: Optional[int] = 480, 27 | desired_width: Optional[int] = 640, 28 | load_embeddings: Optional[bool] = False, 29 | embedding_dir: Optional[str] = "embeddings", 30 | embedding_dim: Optional[int] = 512, 31 | **kwargs, 32 | ): 33 | self.input_folder = os.path.join(basedir, sequence) 34 | # only poses/images/depth corresponding to the realsense_camera_order are read/used 35 | self.pose_path = os.path.join(self.input_folder, "poses") 36 | super().__init__( 37 | config_dict, 38 | stride=stride, 39 | start=start, 40 | end=end, 41 | desired_height=desired_height, 42 | desired_width=desired_width, 43 | load_embeddings=load_embeddings, 44 | embedding_dir=embedding_dir, 45 | embedding_dim=embedding_dim, 46 | **kwargs, 47 | ) 48 | 49 | def get_filepaths(self): 50 | color_paths = natsorted(glob.glob(os.path.join(self.input_folder, "rgb", "*.jpg"))) 51 | depth_paths = natsorted(glob.glob(os.path.join(self.input_folder, "depth", "*.png"))) 52 | embedding_paths = None 53 | if self.load_embeddings: 54 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 55 | return color_paths, depth_paths, embedding_paths 56 | 57 | def load_poses(self): 58 | posefiles = natsorted(glob.glob(os.path.join(self.pose_path, "*.npy"))) 59 | poses = [] 60 | P = torch.tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).float() 61 | for posefile in posefiles: 62 | c2w = torch.from_numpy(np.load(posefile)).float() 63 | _R = c2w[:3, :3] 64 | _t = c2w[:3, 3] 65 | _pose = P @ c2w @ P.T 66 | poses.append(_pose) 67 | return poses 68 | 69 | def read_embedding_from_file(self, embedding_file_path): 70 | embedding = torch.load(embedding_file_path) 71 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) -------------------------------------------------------------------------------- /datasets/gradslam_datasets/record3d.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from natsort import natsorted 9 | 10 | from .basedataset import GradSLAMDataset 11 | 12 | 13 | class Record3DDataset(GradSLAMDataset): 14 | """ 15 | Dataset class to read in saved files from the structure created by our 16 | `save_record3d_stream.py` script 17 | """ 18 | 19 | def __init__( 20 | self, 21 | config_dict, 22 | basedir, 23 | sequence, 24 | stride: Optional[int] = None, 25 | start: Optional[int] = 0, 26 | end: Optional[int] = -1, 27 | desired_height: Optional[int] = 480, 28 | desired_width: Optional[int] = 640, 29 | load_embeddings: Optional[bool] = False, 30 | embedding_dir: Optional[str] = "embeddings", 31 | embedding_dim: Optional[int] = 512, 32 | **kwargs, 33 | ): 34 | self.input_folder = os.path.join(basedir, sequence) 35 | self.pose_path = os.path.join(self.input_folder, "poses") 36 | super().__init__( 37 | config_dict, 38 | stride=stride, 39 | start=start, 40 | end=end, 41 | desired_height=desired_height, 42 | desired_width=desired_width, 43 | load_embeddings=load_embeddings, 44 | embedding_dir=embedding_dir, 45 | embedding_dim=embedding_dim, 46 | **kwargs, 47 | ) 48 | 49 | def get_filepaths(self): 50 | color_paths = natsorted(glob.glob(os.path.join(self.input_folder, "rgb", "*.png"))) 51 | depth_paths = natsorted(glob.glob(os.path.join(self.input_folder, "depth", "*.png"))) 52 | embedding_paths = None 53 | if self.load_embeddings: 54 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 55 | return color_paths, depth_paths, embedding_paths 56 | 57 | def load_poses(self): 58 | posefiles = natsorted(glob.glob(os.path.join(self.pose_path, "*.npy"))) 59 | poses = [] 60 | P = torch.tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]).float() 61 | for posefile in posefiles: 62 | c2w = torch.from_numpy(np.load(posefile)).float() 63 | _R = c2w[:3, :3] 64 | _t = c2w[:3, 3] 65 | _pose = P @ c2w @ P.T 66 | poses.append(_pose) 67 | return poses 68 | 69 | def read_embedding_from_file(self, embedding_file_path): 70 | embedding = torch.load(embedding_file_path) 71 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) -------------------------------------------------------------------------------- /datasets/gradslam_datasets/replica.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from natsort import natsorted 9 | 10 | from .basedataset import GradSLAMDataset 11 | 12 | 13 | class ReplicaDataset(GradSLAMDataset): 14 | def __init__( 15 | self, 16 | config_dict, 17 | basedir, 18 | sequence, 19 | stride: Optional[int] = None, 20 | start: Optional[int] = 0, 21 | end: Optional[int] = -1, 22 | desired_height: Optional[int] = 480, 23 | desired_width: Optional[int] = 640, 24 | load_embeddings: Optional[bool] = False, 25 | embedding_dir: Optional[str] = "embeddings", 26 | embedding_dim: Optional[int] = 512, 27 | **kwargs, 28 | ): 29 | self.input_folder = os.path.join(basedir, sequence) 30 | self.pose_path = os.path.join(self.input_folder, "traj.txt") 31 | super().__init__( 32 | config_dict, 33 | stride=stride, 34 | start=start, 35 | end=end, 36 | desired_height=desired_height, 37 | desired_width=desired_width, 38 | load_embeddings=load_embeddings, 39 | embedding_dir=embedding_dir, 40 | embedding_dim=embedding_dim, 41 | **kwargs, 42 | ) 43 | 44 | def get_filepaths(self): 45 | color_paths = natsorted(glob.glob(f"{self.input_folder}/results/frame*.jpg")) 46 | depth_paths = natsorted(glob.glob(f"{self.input_folder}/results/depth*.png")) 47 | embedding_paths = None 48 | if self.load_embeddings: 49 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 50 | return color_paths, depth_paths, embedding_paths 51 | 52 | def load_poses(self): 53 | poses = [] 54 | with open(self.pose_path, "r") as f: 55 | lines = f.readlines() 56 | for i in range(self.num_imgs): 57 | line = lines[i] 58 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 59 | # c2w[:3, 1] *= -1 60 | # c2w[:3, 2] *= -1 61 | c2w = torch.from_numpy(c2w).float() 62 | poses.append(c2w) 63 | return poses 64 | 65 | def read_embedding_from_file(self, embedding_file_path): 66 | embedding = torch.load(embedding_file_path) 67 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 68 | 69 | class ReplicaV2Dataset(GradSLAMDataset): 70 | def __init__( 71 | self, 72 | config_dict, 73 | basedir, 74 | sequence, 75 | use_train_split: Optional[bool] = True, 76 | stride: Optional[int] = None, 77 | start: Optional[int] = 0, 78 | end: Optional[int] = -1, 79 | desired_height: Optional[int] = 480, 80 | desired_width: Optional[int] = 640, 81 | load_embeddings: Optional[bool] = False, 82 | embedding_dir: Optional[str] = "embeddings", 83 | embedding_dim: Optional[int] = 512, 84 | **kwargs, 85 | ): 86 | self.use_train_split = use_train_split 87 | if self.use_train_split: 88 | self.input_folder = os.path.join(basedir, sequence, "imap/00") 89 | self.pose_path = os.path.join(self.input_folder, "traj_w_c.txt") 90 | else: 91 | self.train_input_folder = os.path.join(basedir, sequence, "imap/00") 92 | self.train_pose_path = os.path.join(self.train_input_folder, "traj_w_c.txt") 93 | self.input_folder = os.path.join(basedir, sequence, "imap/01") 94 | self.pose_path = os.path.join(self.input_folder, "traj_w_c.txt") 95 | super().__init__( 96 | config_dict, 97 | stride=stride, 98 | start=start, 99 | end=end, 100 | desired_height=desired_height, 101 | desired_width=desired_width, 102 | load_embeddings=load_embeddings, 103 | embedding_dir=embedding_dir, 104 | embedding_dim=embedding_dim, 105 | **kwargs, 106 | ) 107 | 108 | def get_filepaths(self): 109 | if self.use_train_split: 110 | color_paths = natsorted(glob.glob(f"{self.input_folder}/rgb/rgb_*.png")) 111 | depth_paths = natsorted(glob.glob(f"{self.input_folder}/depth/depth_*.png")) 112 | else: 113 | first_train_color_path = f"{self.train_input_folder}/rgb/rgb_0.png" 114 | first_train_depth_path = f"{self.train_input_folder}/depth/depth_0.png" 115 | color_paths = [first_train_color_path] + natsorted(glob.glob(f"{self.input_folder}/rgb/rgb_*.png")) 116 | depth_paths = [first_train_depth_path] + natsorted(glob.glob(f"{self.input_folder}/depth/depth_*.png")) 117 | embedding_paths = None 118 | if self.load_embeddings: 119 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 120 | return color_paths, depth_paths, embedding_paths 121 | 122 | def load_poses(self): 123 | poses = [] 124 | if not self.use_train_split: 125 | with open(self.train_pose_path, "r") as f: 126 | train_lines = f.readlines() 127 | first_train_frame_line = train_lines[0] 128 | first_train_frame_c2w = np.array(list(map(float, first_train_frame_line.split()))).reshape(4, 4) 129 | first_train_frame_c2w = torch.from_numpy(first_train_frame_c2w).float() 130 | poses.append(first_train_frame_c2w) 131 | with open(self.pose_path, "r") as f: 132 | lines = f.readlines() 133 | if self.use_train_split: 134 | num_poses = self.num_imgs 135 | else: 136 | num_poses = self.num_imgs - 1 137 | for i in range(num_poses): 138 | line = lines[i] 139 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 140 | # c2w[:3, 1] *= -1 141 | # c2w[:3, 2] *= -1 142 | c2w = torch.from_numpy(c2w).float() 143 | poses.append(c2w) 144 | return poses 145 | 146 | def read_embedding_from_file(self, embedding_file_path): 147 | embedding = torch.load(embedding_file_path) 148 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 149 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/scannet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from natsort import natsorted 9 | 10 | from .basedataset import GradSLAMDataset 11 | 12 | 13 | class ScannetDataset(GradSLAMDataset): 14 | def __init__( 15 | self, 16 | config_dict, 17 | basedir, 18 | sequence, 19 | stride: Optional[int] = None, 20 | start: Optional[int] = 0, 21 | end: Optional[int] = -1, 22 | desired_height: Optional[int] = 968, 23 | desired_width: Optional[int] = 1296, 24 | load_embeddings: Optional[bool] = False, 25 | embedding_dir: Optional[str] = "embeddings", 26 | embedding_dim: Optional[int] = 512, 27 | **kwargs, 28 | ): 29 | self.input_folder = os.path.join(basedir, sequence) 30 | self.pose_path = None 31 | super().__init__( 32 | config_dict, 33 | stride=stride, 34 | start=start, 35 | end=end, 36 | desired_height=desired_height, 37 | desired_width=desired_width, 38 | load_embeddings=load_embeddings, 39 | embedding_dir=embedding_dir, 40 | embedding_dim=embedding_dim, 41 | **kwargs, 42 | ) 43 | 44 | def get_filepaths(self): 45 | color_paths = natsorted(glob.glob(f"{self.input_folder}/color/*.jpg")) 46 | depth_paths = natsorted(glob.glob(f"{self.input_folder}/depth/*.png")) 47 | embedding_paths = None 48 | if self.load_embeddings: 49 | embedding_paths = natsorted(glob.glob(f"{self.input_folder}/{self.embedding_dir}/*.pt")) 50 | return color_paths, depth_paths, embedding_paths 51 | 52 | def load_poses(self): 53 | poses = [] 54 | posefiles = natsorted(glob.glob(f"{self.input_folder}/pose/*.txt")) 55 | for posefile in posefiles: 56 | _pose = torch.from_numpy(np.loadtxt(posefile)) 57 | poses.append(_pose) 58 | return poses 59 | 60 | def read_embedding_from_file(self, embedding_file_path): 61 | print(embedding_file_path) 62 | embedding = torch.load(embedding_file_path, map_location="cpu") 63 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 64 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/scannetpp.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | from pathlib import Path 5 | from typing import Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from natsort import natsorted 10 | 11 | from .basedataset import GradSLAMDataset 12 | 13 | 14 | def create_filepath_index_mapping(frames): 15 | return {frame["file_path"]: index for index, frame in enumerate(frames)} 16 | 17 | 18 | class ScannetPPDataset(GradSLAMDataset): 19 | def __init__( 20 | self, 21 | basedir, 22 | sequence, 23 | ignore_bad: Optional[bool] = False, 24 | use_train_split: Optional[bool] = True, 25 | stride: Optional[int] = None, 26 | start: Optional[int] = 0, 27 | end: Optional[int] = -1, 28 | desired_height: Optional[int] = 1168, 29 | desired_width: Optional[int] = 1752, 30 | load_embeddings: Optional[bool] = False, 31 | embedding_dir: Optional[str] = "embeddings", 32 | embedding_dim: Optional[int] = 512, 33 | **kwargs, 34 | ): 35 | self.input_folder = os.path.join(basedir, sequence) 36 | config_dict = {} 37 | config_dict["dataset_name"] = "scannetpp" 38 | self.pose_path = None 39 | self.ignore_bad = ignore_bad 40 | self.use_train_split = use_train_split 41 | 42 | # Load Train & Test Split 43 | self.train_test_split = json.load(open(f"{self.input_folder}/dslr/train_test_lists.json", "r")) 44 | if self.use_train_split: 45 | self.image_names = self.train_test_split["train"] 46 | else: 47 | self.image_names = self.train_test_split["test"] 48 | self.train_image_names = self.train_test_split["train"] 49 | 50 | # Load NeRFStudio format camera & poses data 51 | self.cams_metadata = self.load_cams_metadata() 52 | if self.use_train_split: 53 | self.frames_metadata = self.cams_metadata["frames"] 54 | self.filepath_index_mapping = create_filepath_index_mapping(self.frames_metadata) 55 | else: 56 | self.frames_metadata = self.cams_metadata["test_frames"] 57 | self.train_frames_metadata = self.cams_metadata["frames"] 58 | self.filepath_index_mapping = create_filepath_index_mapping(self.frames_metadata) 59 | self.train_filepath_index_mapping = create_filepath_index_mapping(self.train_frames_metadata) 60 | 61 | # Init Intrinsics 62 | config_dict["camera_params"] = {} 63 | config_dict["camera_params"]["png_depth_scale"] = 1000.0 # Depth is in mm 64 | config_dict["camera_params"]["image_height"] = self.cams_metadata["h"] 65 | config_dict["camera_params"]["image_width"] = self.cams_metadata["w"] 66 | config_dict["camera_params"]["fx"] = self.cams_metadata["fl_x"] 67 | config_dict["camera_params"]["fy"] = self.cams_metadata["fl_y"] 68 | config_dict["camera_params"]["cx"] = self.cams_metadata["cx"] 69 | config_dict["camera_params"]["cy"] = self.cams_metadata["cy"] 70 | 71 | super().__init__( 72 | config_dict, 73 | stride=stride, 74 | start=start, 75 | end=end, 76 | desired_height=desired_height, 77 | desired_width=desired_width, 78 | load_embeddings=load_embeddings, 79 | embedding_dir=embedding_dir, 80 | embedding_dim=embedding_dim, 81 | **kwargs, 82 | ) 83 | 84 | def load_cams_metadata(self): 85 | cams_metadata_path = f"{self.input_folder}/dslr/nerfstudio/transforms_undistorted.json" 86 | cams_metadata = json.load(open(cams_metadata_path, "r")) 87 | return cams_metadata 88 | 89 | def get_filepaths(self): 90 | base_path = f"{self.input_folder}/dslr" 91 | color_paths = [] 92 | depth_paths = [] 93 | self.tmp_poses = [] 94 | P = torch.tensor( 95 | [ 96 | [1, 0, 0, 0], 97 | [0, -1, 0, 0], 98 | [0, 0, -1, 0], 99 | [0, 0, 0, 1] 100 | ] 101 | ).float() 102 | if not self.use_train_split: 103 | self.first_train_image_name = self.train_image_names[0] 104 | self.first_train_image_index = self.train_filepath_index_mapping.get(self.first_train_image_name) 105 | self.first_train_frame_metadata = self.train_frames_metadata[self.first_train_image_index] 106 | # Get path of undistorted image and depth 107 | color_path = f"{base_path}/undistorted_images/{self.first_train_image_name}" 108 | depth_path = f"{base_path}/undistorted_depths/{self.first_train_image_name.replace('.JPG', '.png')}" 109 | color_paths.append(color_path) 110 | depth_paths.append(depth_path) 111 | # Get pose of first train frame in GradSLAM format 112 | c2w = torch.from_numpy(np.array(self.first_train_frame_metadata["transform_matrix"])).float() 113 | _pose = P @ c2w @ P.T 114 | self.tmp_poses.append(_pose) 115 | for image_name in self.image_names: 116 | # Search for image name in frames_metadata 117 | frame_metadata = self.frames_metadata[self.filepath_index_mapping.get(image_name)] 118 | # Check if frame is blurry and if it needs to be ignored 119 | if self.ignore_bad and frame_metadata['is_bad']: 120 | continue 121 | # Get path of undistorted image and depth 122 | color_path = f"{base_path}/undistorted_images/{image_name}" 123 | depth_path = f"{base_path}/undistorted_depths/{image_name.replace('.JPG', '.png')}" 124 | color_paths.append(color_path) 125 | depth_paths.append(depth_path) 126 | # Get pose of undistorted image in GradSLAM format 127 | c2w = torch.from_numpy(np.array(frame_metadata["transform_matrix"])).float() 128 | _pose = P @ c2w @ P.T 129 | self.tmp_poses.append(_pose) 130 | embedding_paths = None 131 | if self.load_embeddings: 132 | embedding_paths = natsorted(glob.glob(f"{base_path}/{self.embedding_dir}/*.pt")) 133 | return color_paths, depth_paths, embedding_paths 134 | 135 | def load_poses(self): 136 | return self.tmp_poses 137 | 138 | def read_embedding_from_file(self, embedding_file_path): 139 | print(embedding_file_path) 140 | embedding = torch.load(embedding_file_path, map_location="cpu") 141 | return embedding.permute(0, 2, 3, 1) # (1, H, W, embedding_dim) 142 | -------------------------------------------------------------------------------- /datasets/gradslam_datasets/tum.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils import data 9 | from natsort import natsorted 10 | 11 | from .basedataset import GradSLAMDataset 12 | 13 | class TUMDataset(GradSLAMDataset): 14 | def __init__( 15 | self, 16 | config_dict, 17 | basedir, 18 | sequence, 19 | stride: Optional[int] = None, 20 | start: Optional[int] = 0, 21 | end: Optional[int] = -1, 22 | desired_height: Optional[int] = 480, 23 | desired_width: Optional[int] = 640, 24 | load_embeddings: Optional[bool] = False, 25 | embedding_dir: Optional[str] = "embeddings", 26 | embedding_dim: Optional[int] = 512, 27 | **kwargs, 28 | ): 29 | self.input_folder = os.path.join(basedir, sequence) 30 | self.pose_path = None 31 | super().__init__( 32 | config_dict, 33 | stride=stride, 34 | start=start, 35 | end=end, 36 | desired_height=desired_height, 37 | desired_width=desired_width, 38 | load_embeddings=load_embeddings, 39 | embedding_dir=embedding_dir, 40 | embedding_dim=embedding_dim, 41 | **kwargs, 42 | ) 43 | 44 | def parse_list(self, filepath, skiprows=0): 45 | """ read list data """ 46 | data = np.loadtxt(filepath, delimiter=' ', 47 | dtype=np.unicode_, skiprows=skiprows) 48 | return data 49 | 50 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08): 51 | """ pair images, depths, and poses """ 52 | associations = [] 53 | for i, t in enumerate(tstamp_image): 54 | if tstamp_pose is None: 55 | j = np.argmin(np.abs(tstamp_depth - t)) 56 | if (np.abs(tstamp_depth[j] - t) < max_dt): 57 | associations.append((i, j)) 58 | 59 | else: 60 | j = np.argmin(np.abs(tstamp_depth - t)) 61 | k = np.argmin(np.abs(tstamp_pose - t)) 62 | 63 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \ 64 | (np.abs(tstamp_pose[k] - t) < max_dt): 65 | associations.append((i, j, k)) 66 | 67 | return associations 68 | 69 | def pose_matrix_from_quaternion(self, pvec): 70 | """ convert 4x4 pose matrix to (t, q) """ 71 | from scipy.spatial.transform import Rotation 72 | 73 | pose = np.eye(4) 74 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix() 75 | pose[:3, 3] = pvec[:3] 76 | return pose 77 | 78 | def get_filepaths(self): 79 | 80 | frame_rate = 32 81 | """ read video data in tum-rgbd format """ 82 | if os.path.isfile(os.path.join(self.input_folder, 'groundtruth.txt')): 83 | pose_list = os.path.join(self.input_folder, 'groundtruth.txt') 84 | elif os.path.isfile(os.path.join(self.input_folder, 'pose.txt')): 85 | pose_list = os.path.join(self.input_folder, 'pose.txt') 86 | 87 | image_list = os.path.join(self.input_folder, 'rgb.txt') 88 | depth_list = os.path.join(self.input_folder, 'depth.txt') 89 | 90 | image_data = self.parse_list(image_list) 91 | depth_data = self.parse_list(depth_list) 92 | pose_data = self.parse_list(pose_list, skiprows=1) 93 | pose_vecs = pose_data[:, 1:].astype(np.float64) 94 | 95 | tstamp_image = image_data[:, 0].astype(np.float64) 96 | tstamp_depth = depth_data[:, 0].astype(np.float64) 97 | tstamp_pose = pose_data[:, 0].astype(np.float64) 98 | associations = self.associate_frames( 99 | tstamp_image, tstamp_depth, tstamp_pose) 100 | 101 | indicies = [0] 102 | for i in range(1, len(associations)): 103 | t0 = tstamp_image[associations[indicies[-1]][0]] 104 | t1 = tstamp_image[associations[i][0]] 105 | if t1 - t0 > 1.0 / frame_rate: 106 | indicies += [i] 107 | 108 | color_paths, depth_paths = [], [] 109 | for ix in indicies: 110 | (i, j, k) = associations[ix] 111 | color_paths += [os.path.join(self.input_folder, image_data[i, 1])] 112 | depth_paths += [os.path.join(self.input_folder, depth_data[j, 1])] 113 | 114 | embedding_paths = None 115 | 116 | return color_paths, depth_paths, embedding_paths 117 | 118 | def load_poses(self): 119 | 120 | frame_rate = 32 121 | """ read video data in tum-rgbd format """ 122 | if os.path.isfile(os.path.join(self.input_folder, 'groundtruth.txt')): 123 | pose_list = os.path.join(self.input_folder, 'groundtruth.txt') 124 | elif os.path.isfile(os.path.join(self.input_folder, 'pose.txt')): 125 | pose_list = os.path.join(self.input_folder, 'pose.txt') 126 | 127 | image_list = os.path.join(self.input_folder, 'rgb.txt') 128 | depth_list = os.path.join(self.input_folder, 'depth.txt') 129 | 130 | image_data = self.parse_list(image_list) 131 | depth_data = self.parse_list(depth_list) 132 | pose_data = self.parse_list(pose_list, skiprows=1) 133 | pose_vecs = pose_data[:, 1:].astype(np.float64) 134 | 135 | tstamp_image = image_data[:, 0].astype(np.float64) 136 | tstamp_depth = depth_data[:, 0].astype(np.float64) 137 | tstamp_pose = pose_data[:, 0].astype(np.float64) 138 | associations = self.associate_frames( 139 | tstamp_image, tstamp_depth, tstamp_pose) 140 | 141 | indicies = [0] 142 | for i in range(1, len(associations)): 143 | t0 = tstamp_image[associations[indicies[-1]][0]] 144 | t1 = tstamp_image[associations[i][0]] 145 | if t1 - t0 > 1.0 / frame_rate: 146 | indicies += [i] 147 | 148 | color_paths, poses, depth_paths, intrinsics = [], [], [], [] 149 | inv_pose = None 150 | for ix in indicies: 151 | (i, j, k) = associations[ix] 152 | color_paths += [os.path.join(self.input_folder, image_data[i, 1])] 153 | depth_paths += [os.path.join(self.input_folder, depth_data[j, 1])] 154 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k]) 155 | c2w = torch.from_numpy(c2w).float() 156 | poses += [c2w] 157 | 158 | return poses 159 | 160 | def read_embedding_from_file(self, embedding_file_path): 161 | embedding = torch.load(embedding_file_path, map_location="cpu") 162 | return embedding.permute(0, 2, 3, 1) 163 | 164 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: splatam 2 | channels: 3 | - nvidia/label/cuda-11.6.0 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - cuda-toolkit 9 | - cudatoolkit=11.6 10 | - python=3.10 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm=4.65.0 15 | - Pillow 16 | - faiss-gpu 17 | - opencv 18 | - imageio 19 | - matplotlib 20 | - kornia 21 | - natsort 22 | - pyyaml 23 | - wandb 24 | - gxx_linux-64=10 25 | - pip 26 | - pip: 27 | - lpips 28 | - open3d==0.16.0 29 | - torchmetrics 30 | - cyclonedds 31 | - git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git@cb65e4b86bc3bd8ed42174b72a62e8d3a3a71110 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.65.0 2 | Pillow 3 | opencv-python 4 | imageio 5 | matplotlib 6 | kornia 7 | natsort 8 | pyyaml 9 | wandb 10 | lpips 11 | open3d==0.16.0 12 | torchmetrics 13 | cyclonedds 14 | pytorch-msssim 15 | plyfile==0.8.1 16 | git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git@cb65e4b86bc3bd8ed42174b72a62e8d3a3a71110 -------------------------------------------------------------------------------- /scripts/_init_.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/scripts/_init_.py -------------------------------------------------------------------------------- /scripts/eval_novel_view.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import sys 5 | import shutil 6 | from importlib.machinery import SourceFileLoader 7 | 8 | _BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | 10 | sys.path.insert(0, _BASE_DIR) 11 | 12 | print("System Paths:") 13 | for p in sys.path: 14 | print(p) 15 | 16 | import matplotlib.pyplot as plt 17 | import cv2 18 | import numpy as np 19 | import torch 20 | from tqdm import tqdm 21 | import wandb 22 | 23 | from datasets.gradslam_datasets import (load_dataset_config, ICLDataset, ReplicaDataset, ReplicaV2Dataset, AzureKinectDataset, 24 | ScannetDataset, Ai2thorDataset, Record3DDataset, RealsenseDataset, TUMDataset, 25 | ScannetPPDataset, NeRFCaptureDataset) 26 | from utils.common_utils import seed_everything 27 | from utils.eval_helpers import eval, eval_nvs 28 | 29 | 30 | def get_dataset(config_dict, basedir, sequence, **kwargs): 31 | if config_dict["dataset_name"].lower() in ["icl"]: 32 | return ICLDataset(config_dict, basedir, sequence, **kwargs) 33 | elif config_dict["dataset_name"].lower() in ["replica"]: 34 | return ReplicaDataset(config_dict, basedir, sequence, **kwargs) 35 | elif config_dict["dataset_name"].lower() in ["replicav2"]: 36 | return ReplicaV2Dataset(config_dict, basedir, sequence, **kwargs) 37 | elif config_dict["dataset_name"].lower() in ["azure", "azurekinect"]: 38 | return AzureKinectDataset(config_dict, basedir, sequence, **kwargs) 39 | elif config_dict["dataset_name"].lower() in ["scannet"]: 40 | return ScannetDataset(config_dict, basedir, sequence, **kwargs) 41 | elif config_dict["dataset_name"].lower() in ["ai2thor"]: 42 | return Ai2thorDataset(config_dict, basedir, sequence, **kwargs) 43 | elif config_dict["dataset_name"].lower() in ["record3d"]: 44 | return Record3DDataset(config_dict, basedir, sequence, **kwargs) 45 | elif config_dict["dataset_name"].lower() in ["realsense"]: 46 | return RealsenseDataset(config_dict, basedir, sequence, **kwargs) 47 | elif config_dict["dataset_name"].lower() in ["tum"]: 48 | return TUMDataset(config_dict, basedir, sequence, **kwargs) 49 | elif config_dict["dataset_name"].lower() in ["scannetpp"]: 50 | return ScannetPPDataset(basedir, sequence, **kwargs) 51 | elif config_dict["dataset_name"].lower() in ["nerfcapture"]: 52 | return NeRFCaptureDataset(basedir, sequence, **kwargs) 53 | else: 54 | raise ValueError(f"Unknown dataset name {config_dict['dataset_name']}") 55 | 56 | 57 | def load_scene_data(scene_path): 58 | params = dict(np.load(scene_path, allow_pickle=True)) 59 | params = {k: torch.tensor(params[k]).cuda().float().requires_grad_(True) for k in params.keys()} 60 | return params 61 | 62 | 63 | if __name__=="__main__": 64 | parser = argparse.ArgumentParser() 65 | 66 | parser.add_argument("experiment", type=str, help="Path to experiment file") 67 | 68 | args = parser.parse_args() 69 | 70 | experiment = SourceFileLoader( 71 | os.path.basename(args.experiment), args.experiment 72 | ).load_module() 73 | 74 | config = experiment.config 75 | 76 | # Set Experiment Seed 77 | seed_everything(seed=experiment.config['seed']) 78 | device = torch.device(config["primary_device"]) 79 | 80 | # Create Results Directory and Copy Config 81 | results_dir = os.path.join( 82 | experiment.config["workdir"], experiment.config["run_name"] 83 | ) 84 | if not experiment.config['load_checkpoint']: 85 | os.makedirs(results_dir, exist_ok=True) 86 | shutil.copy(args.experiment, os.path.join(results_dir, "config.py")) 87 | 88 | # Load Dataset 89 | print("Loading Dataset ...") 90 | dataset_config = config["data"] 91 | if "gradslam_data_cfg" not in dataset_config: 92 | gradslam_data_cfg = {} 93 | gradslam_data_cfg["dataset_name"] = dataset_config["dataset_name"] 94 | else: 95 | gradslam_data_cfg = load_dataset_config(dataset_config["gradslam_data_cfg"]) 96 | if "ignore_bad" not in dataset_config: 97 | dataset_config["ignore_bad"] = False 98 | if "use_train_split" not in dataset_config: 99 | dataset_config["use_train_split"] = True 100 | # Poses are relative to the first training frame 101 | dataset = get_dataset( 102 | config_dict=gradslam_data_cfg, 103 | basedir=dataset_config["basedir"], 104 | sequence=os.path.basename(dataset_config["sequence"]), 105 | start=dataset_config["start"], 106 | end=dataset_config["end"], 107 | stride=dataset_config["stride"], 108 | desired_height=dataset_config["desired_image_height"], 109 | desired_width=dataset_config["desired_image_width"], 110 | device=device, 111 | relative_pose=True, 112 | ignore_bad=dataset_config["ignore_bad"], 113 | use_train_split=dataset_config["use_train_split"], 114 | ) 115 | num_frames = dataset_config["num_frames"] 116 | if num_frames == -1: 117 | num_frames = len(dataset) 118 | 119 | scene_path = config['scene_path'] 120 | params = load_scene_data(scene_path) 121 | 122 | if dataset_config['use_train_split']: 123 | eval_dir = os.path.join(results_dir, "eval_train") 124 | wandb_name = config['wandb']['name'] + "_Train_Split" 125 | else: 126 | eval_dir = os.path.join(results_dir, "eval_nvs") 127 | wandb_name = config['wandb']['name'] + "_NVS_Split" 128 | 129 | # Init WandB 130 | if config['use_wandb']: 131 | wandb_time_step = 0 132 | wandb_tracking_step = 0 133 | wandb_mapping_step = 0 134 | wandb_run = wandb.init(project=config['wandb']['project'], 135 | entity=config['wandb']['entity'], 136 | group=config['wandb']['group'], 137 | name=wandb_name, 138 | config=config) 139 | 140 | # Evaluate Final Parameters 141 | with torch.no_grad(): 142 | if config['use_wandb']: 143 | if dataset_config['use_train_split']: 144 | eval(dataset, params, num_frames, eval_dir, sil_thres=config['mapping']['sil_thres'], 145 | wandb_run=wandb_run, wandb_save_qual=config['wandb']['eval_save_qual'], 146 | mapping_iters=config['mapping']['num_iters'], add_new_gaussians=config['mapping']['add_new_gaussians'], 147 | eval_every=config['eval_every'], save_frames=True) 148 | else: 149 | eval_nvs(dataset, params, num_frames, eval_dir, sil_thres=config['mapping']['sil_thres'], 150 | wandb_run=wandb_run, wandb_save_qual=config['wandb']['eval_save_qual'], 151 | mapping_iters=config['mapping']['num_iters'], add_new_gaussians=config['mapping']['add_new_gaussians'], 152 | eval_every=config['eval_every'], save_frames=True) 153 | else: 154 | if dataset_config['use_train_split']: 155 | eval(dataset, params, num_frames, eval_dir, sil_thres=config['mapping']['sil_thres'], 156 | mapping_iters=config['mapping']['num_iters'], add_new_gaussians=config['mapping']['add_new_gaussians'], 157 | eval_every=config['eval_every'], save_frames=True) 158 | else: 159 | eval_nvs(dataset, params, num_frames, eval_dir, sil_thres=config['mapping']['sil_thres'], 160 | mapping_iters=config['mapping']['num_iters'], add_new_gaussians=config['mapping']['add_new_gaussians'], 161 | eval_every=config['eval_every'], save_frames=True) 162 | 163 | # Close WandB 164 | if config['use_wandb']: 165 | wandb_run.finish() 166 | -------------------------------------------------------------------------------- /scripts/export_ply.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from importlib.machinery import SourceFileLoader 4 | 5 | import numpy as np 6 | from plyfile import PlyData, PlyElement 7 | 8 | # Spherical harmonic constant 9 | C0 = 0.28209479177387814 10 | 11 | 12 | def rgb_to_spherical_harmonic(rgb): 13 | return (rgb-0.5) / C0 14 | 15 | 16 | def spherical_harmonic_to_rgb(sh): 17 | return sh*C0 + 0.5 18 | 19 | 20 | def save_ply(path, means, scales, rotations, rgbs, opacities, normals=None): 21 | if normals is None: 22 | normals = np.zeros_like(means) 23 | 24 | colors = rgb_to_spherical_harmonic(rgbs) 25 | 26 | if scales.shape[1] == 1: 27 | scales = np.tile(scales, (1, 3)) 28 | 29 | attrs = ['x', 'y', 'z', 30 | 'nx', 'ny', 'nz', 31 | 'f_dc_0', 'f_dc_1', 'f_dc_2', 32 | 'opacity', 33 | 'scale_0', 'scale_1', 'scale_2', 34 | 'rot_0', 'rot_1', 'rot_2', 'rot_3',] 35 | 36 | dtype_full = [(attribute, 'f4') for attribute in attrs] 37 | elements = np.empty(means.shape[0], dtype=dtype_full) 38 | 39 | attributes = np.concatenate((means, normals, colors, opacities, scales, rotations), axis=1) 40 | elements[:] = list(map(tuple, attributes)) 41 | el = PlyElement.describe(elements, 'vertex') 42 | PlyData([el]).write(path) 43 | 44 | print(f"Saved PLY format Splat to {path}") 45 | 46 | 47 | def parse_args(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("config", type=str, help="Path to config file.") 50 | return parser.parse_args() 51 | 52 | 53 | if __name__ == "__main__": 54 | args = parse_args() 55 | 56 | # Load SplaTAM config 57 | experiment = SourceFileLoader(os.path.basename(args.config), args.config).load_module() 58 | config = experiment.config 59 | work_path = config['workdir'] 60 | run_name = config['run_name'] 61 | params_path = os.path.join(work_path, run_name, "params.npz") 62 | 63 | params = dict(np.load(params_path, allow_pickle=True)) 64 | means = params['means3D'] 65 | scales = params['log_scales'] 66 | rotations = params['unnorm_rotations'] 67 | rgbs = params['rgb_colors'] 68 | opacities = params['logit_opacities'] 69 | 70 | ply_path = os.path.join(work_path, run_name, "splat.ply") 71 | 72 | save_ply(ply_path, means, scales, rotations, rgbs, opacities) -------------------------------------------------------------------------------- /scripts/nerfcapture2dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script to capture a dataset from the NeRFCapture iOS App. Code is adapted from instant-ngp/scripts/nerfcapture2nerf.py. 3 | https://github.com/NVlabs/instant-ngp/blob/master/scripts/nerfcapture2nerf.py 4 | ''' 5 | #!/usr/bin/env python3 6 | 7 | import argparse 8 | import os 9 | import shutil 10 | import sys 11 | from pathlib import Path 12 | import json 13 | from importlib.machinery import SourceFileLoader 14 | 15 | _BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | 17 | sys.path.insert(0, _BASE_DIR) 18 | 19 | import cv2 20 | import numpy as np 21 | 22 | import cyclonedds.idl as idl 23 | import cyclonedds.idl.annotations as annotate 24 | import cyclonedds.idl.types as types 25 | from dataclasses import dataclass 26 | from cyclonedds.domain import DomainParticipant, Domain 27 | from cyclonedds.core import Qos, Policy 28 | from cyclonedds.sub import DataReader 29 | from cyclonedds.topic import Topic 30 | from cyclonedds.util import duration 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--config", default="./configs/iphone/nerfcapture.py", type=str, help="Path to config file.") 36 | return parser.parse_args() 37 | 38 | 39 | # DDS 40 | # ================================================================================================== 41 | @dataclass 42 | @annotate.final 43 | @annotate.autoid("sequential") 44 | class SplatCaptureFrame(idl.IdlStruct, typename="SplatCaptureData.SplatCaptureFrame"): 45 | id: types.uint32 46 | annotate.key("id") 47 | timestamp: types.float64 48 | fl_x: types.float32 49 | fl_y: types.float32 50 | cx: types.float32 51 | cy: types.float32 52 | transform_matrix: types.array[types.float32, 16] 53 | width: types.uint32 54 | height: types.uint32 55 | image: types.sequence[types.uint8] 56 | has_depth: bool 57 | depth_width: types.uint32 58 | depth_height: types.uint32 59 | depth_scale: types.float32 60 | depth_image: types.sequence[types.uint8] 61 | 62 | 63 | dds_config = """ \ 64 | \ 65 | \ 66 | \ 67 | 10MB \ 68 | \ 69 | \ 70 | config \ 71 | stdout \ 72 | \ 73 | \ 74 | \ 75 | """ 76 | # ================================================================================================== 77 | 78 | 79 | def dataset_capture_loop(reader: DataReader, save_path: Path, overwrite: bool, n_frames: int, depth_scale: float): 80 | if save_path.exists(): 81 | if overwrite: 82 | # Prompt user to confirm deletion 83 | if (input(f"warning! folder '{save_path}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 84 | sys.exit(1) 85 | shutil.rmtree(save_path) 86 | else: 87 | print(f"save_path {save_path} already exists") 88 | sys.exit(1) 89 | 90 | print("Waiting for frames...") 91 | # Make directory 92 | images_dir = save_path.joinpath("rgb") 93 | 94 | manifest = { 95 | "fl_x": 0.0, 96 | "fl_y": 0.0, 97 | "cx": 0.0, 98 | "cy": 0.0, 99 | "w": 0.0, 100 | "h": 0.0, 101 | "frames": [] 102 | } 103 | 104 | total_frames = 0 # Total frames received 105 | 106 | # Start DDS Loop 107 | while True: 108 | sample = reader.read_next() # Get frame from NeRFCapture 109 | if sample: 110 | print(f"{total_frames + 1}/{n_frames} frames received") 111 | 112 | if total_frames == 0: 113 | save_path.mkdir(parents=True) 114 | images_dir.mkdir() 115 | manifest["w"] = sample.width 116 | manifest["h"] = sample.height 117 | manifest["cx"] = sample.cx 118 | manifest["cy"] = sample.cy 119 | manifest["fl_x"] = sample.fl_x 120 | manifest["fl_y"] = sample.fl_y 121 | manifest["integer_depth_scale"] = float(depth_scale)/65535.0 122 | if sample.has_depth: 123 | depth_dir = save_path.joinpath("depth") 124 | depth_dir.mkdir() 125 | 126 | # RGB 127 | image = np.asarray(sample.image, dtype=np.uint8).reshape((sample.height, sample.width, 3)) 128 | cv2.imwrite(str(images_dir.joinpath(f"{total_frames}.png")), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 129 | 130 | # Depth if avaiable 131 | depth = None 132 | if sample.has_depth: 133 | depth = np.asarray(sample.depth_image, dtype=np.uint8).view( 134 | dtype=np.float32).reshape((sample.depth_height, sample.depth_width)) 135 | depth = (depth*65535/float(depth_scale)).astype(np.uint16) 136 | depth = cv2.resize(depth, dsize=( 137 | sample.width, sample.height), interpolation=cv2.INTER_NEAREST) 138 | cv2.imwrite(str(depth_dir.joinpath(f"{total_frames}.png")), depth) 139 | 140 | # Transform 141 | X_WV = np.asarray(sample.transform_matrix, 142 | dtype=np.float32).reshape((4, 4)).T 143 | 144 | frame = { 145 | "transform_matrix": X_WV.tolist(), 146 | "file_path": f"rgb/{total_frames}.png", 147 | "fl_x": sample.fl_x, 148 | "fl_y": sample.fl_y, 149 | "cx": sample.cx, 150 | "cy": sample.cy, 151 | "w": sample.width, 152 | "h": sample.height 153 | } 154 | 155 | if depth is not None: 156 | frame["depth_path"] = f"depth/{total_frames}.png" 157 | 158 | manifest["frames"].append(frame) 159 | 160 | # Update index 161 | if total_frames == n_frames - 1: 162 | print("Saving manifest...") 163 | # Write manifest as json 164 | manifest_json = json.dumps(manifest, indent=4) 165 | with open(save_path.joinpath("transforms.json"), "w") as f: 166 | f.write(manifest_json) 167 | print("Done") 168 | sys.exit(0) 169 | total_frames += 1 170 | 171 | 172 | if __name__ == "__main__": 173 | args = parse_args() 174 | 175 | # Load config 176 | experiment = SourceFileLoader( 177 | os.path.basename(args.config), args.config 178 | ).load_module() 179 | 180 | # Setup DDS 181 | domain = Domain(domain_id=0, config=dds_config) 182 | participant = DomainParticipant() 183 | qos = Qos(Policy.Reliability.Reliable( 184 | max_blocking_time=duration(seconds=1))) 185 | topic = Topic(participant, "Frames", SplatCaptureFrame, qos=qos) 186 | reader = DataReader(participant, topic) 187 | 188 | config = experiment.config 189 | dataset_capture_loop(reader, Path(config['workdir']), config['overwrite'], config['num_frames'], config['depth_scale']) 190 | -------------------------------------------------------------------------------- /utils/_init_.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spla-tam/SplaTAM/da6bbcd24c248dc884ac7f49d62e91b841b26ccc/utils/_init_.py -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | 8 | def seed_everything(seed=42): 9 | """ 10 | Set the `seed` value for torch and numpy seeds. Also turns on 11 | deterministic execution for cudnn. 12 | 13 | Parameters: 14 | - seed: A hashable seed value 15 | """ 16 | random.seed(seed) 17 | os.environ["PYTHONHASHSEED"] = str(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | print(f"Seed set to: {seed} (type: {type(seed)})") 23 | 24 | 25 | def params2cpu(params): 26 | res = {} 27 | for k, v in params.items(): 28 | if isinstance(v, torch.Tensor): 29 | res[k] = v.detach().cpu().contiguous().numpy() 30 | else: 31 | res[k] = v 32 | return res 33 | 34 | 35 | def save_params(output_params, output_dir): 36 | # Convert to CPU Numpy Arrays 37 | to_save = params2cpu(output_params) 38 | # Save the Parameters containing the Gaussian Trajectories 39 | os.makedirs(output_dir, exist_ok=True) 40 | print(f"Saving parameters to: {output_dir}") 41 | save_path = os.path.join(output_dir, "params.npz") 42 | np.savez(save_path, **to_save) 43 | 44 | 45 | def save_params_ckpt(output_params, output_dir, time_idx): 46 | # Convert to CPU Numpy Arrays 47 | to_save = params2cpu(output_params) 48 | # Save the Parameters containing the Gaussian Trajectories 49 | os.makedirs(output_dir, exist_ok=True) 50 | print(f"Saving parameters to: {output_dir}") 51 | save_path = os.path.join(output_dir, "params"+str(time_idx)+".npz") 52 | np.savez(save_path, **to_save) 53 | 54 | 55 | def save_seq_params(all_params, output_dir): 56 | params_to_save = {} 57 | for frame_idx, params in enumerate(all_params): 58 | params_to_save[f"frame_{frame_idx}"] = params2cpu(params) 59 | # Save the Parameters containing the Sequence of Gaussians 60 | os.makedirs(output_dir, exist_ok=True) 61 | print(f"Saving parameters to: {output_dir}") 62 | save_path = os.path.join(output_dir, "params.npz") 63 | np.savez(save_path, **params_to_save) 64 | 65 | 66 | def save_seq_params_ckpt(all_params, output_dir,time_idx): 67 | params_to_save = {} 68 | for frame_idx, params in enumerate(all_params): 69 | params_to_save[f"frame_{frame_idx}"] = params2cpu(params) 70 | # Save the Parameters containing the Sequence of Gaussians 71 | os.makedirs(output_dir, exist_ok=True) 72 | print(f"Saving parameters to: {output_dir}") 73 | save_path = os.path.join(output_dir, "params"+str(time_idx)+".npz") 74 | np.savez(save_path, **params_to_save) -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/keyframe_selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for Keyframe Selection based on re-projection of points from 3 | the current frame to the keyframes. 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | def get_pointcloud(depth, intrinsics, w2c, sampled_indices): 11 | CX = intrinsics[0][2] 12 | CY = intrinsics[1][2] 13 | FX = intrinsics[0][0] 14 | FY = intrinsics[1][1] 15 | 16 | # Compute indices of sampled pixels 17 | xx = (sampled_indices[:, 1] - CX)/FX 18 | yy = (sampled_indices[:, 0] - CY)/FY 19 | depth_z = depth[0, sampled_indices[:, 0], sampled_indices[:, 1]] 20 | 21 | # Initialize point cloud 22 | pts_cam = torch.stack((xx * depth_z, yy * depth_z, depth_z), dim=-1) 23 | pts4 = torch.cat([pts_cam, torch.ones_like(pts_cam[:, :1])], dim=1) 24 | c2w = torch.inverse(w2c) 25 | pts = (c2w @ pts4.T).T[:, :3] 26 | 27 | # Remove points at camera origin 28 | A = torch.abs(torch.round(pts, decimals=4)) 29 | B = torch.zeros((1, 3)).cuda().float() 30 | _, idx, counts = torch.cat([A, B], dim=0).unique( 31 | dim=0, return_inverse=True, return_counts=True) 32 | mask = torch.isin(idx, torch.where(counts.gt(1))[0]) 33 | invalid_pt_idx = mask[:len(A)] 34 | valid_pt_idx = ~invalid_pt_idx 35 | pts = pts[valid_pt_idx] 36 | 37 | return pts 38 | 39 | 40 | def keyframe_selection_overlap(gt_depth, w2c, intrinsics, keyframe_list, k, pixels=1600): 41 | """ 42 | Select overlapping keyframes to the current camera observation. 43 | 44 | Args: 45 | gt_depth (tensor): ground truth depth image of the current frame. 46 | w2c (tensor): world to camera matrix (4 x 4). 47 | keyframe_list (list): a list containing info for each keyframe. 48 | k (int): number of overlapping keyframes to select. 49 | pixels (int, optional): number of pixels to sparsely sample 50 | from the image of the current camera. Defaults to 1600. 51 | Returns: 52 | selected_keyframe_list (list): list of selected keyframe id. 53 | """ 54 | # Radomly Sample Pixel Indices from valid depth pixels 55 | width, height = gt_depth.shape[2], gt_depth.shape[1] 56 | valid_depth_indices = torch.where(gt_depth[0] > 0) 57 | valid_depth_indices = torch.stack(valid_depth_indices, dim=1) 58 | indices = torch.randint(valid_depth_indices.shape[0], (pixels,)) 59 | sampled_indices = valid_depth_indices[indices] 60 | 61 | # Back Project the selected pixels to 3D Pointcloud 62 | pts = get_pointcloud(gt_depth, intrinsics, w2c, sampled_indices) 63 | 64 | list_keyframe = [] 65 | for keyframeid, keyframe in enumerate(keyframe_list): 66 | # Get the estimated world2cam of the keyframe 67 | est_w2c = keyframe['est_w2c'] 68 | # Transform the 3D pointcloud to the keyframe's camera space 69 | pts4 = torch.cat([pts, torch.ones_like(pts[:, :1])], dim=1) 70 | transformed_pts = (est_w2c @ pts4.T).T[:, :3] 71 | # Project the 3D pointcloud to the keyframe's image space 72 | points_2d = torch.matmul(intrinsics, transformed_pts.transpose(0, 1)) 73 | points_2d = points_2d.transpose(0, 1) 74 | points_z = points_2d[:, 2:] + 1e-5 75 | points_2d = points_2d / points_z 76 | projected_pts = points_2d[:, :2] 77 | # Filter out the points that are outside the image 78 | edge = 20 79 | mask = (projected_pts[:, 0] < width-edge)*(projected_pts[:, 0] > edge) * \ 80 | (projected_pts[:, 1] < height-edge)*(projected_pts[:, 1] > edge) 81 | mask = mask & (points_z[:, 0] > 0) 82 | # Compute the percentage of points that are inside the image 83 | percent_inside = mask.sum()/projected_pts.shape[0] 84 | list_keyframe.append( 85 | {'id': keyframeid, 'percent_inside': percent_inside}) 86 | 87 | # Sort the keyframes based on the percentage of points that are inside the image 88 | list_keyframe = sorted( 89 | list_keyframe, key=lambda i: i['percent_inside'], reverse=True) 90 | # Select the keyframes with percentage of points inside the image > 0 91 | selected_keyframe_list = [keyframe_dict['id'] 92 | for keyframe_dict in list_keyframe if keyframe_dict['percent_inside'] > 0.0] 93 | selected_keyframe_list = list(np.random.permutation( 94 | np.array(selected_keyframe_list))[:k]) 95 | 96 | return selected_keyframe_list -------------------------------------------------------------------------------- /utils/neighbor_search.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import faiss.contrib.torch_utils 3 | import torch 4 | 5 | 6 | def torch_3d_knn(pts, num_knn, method="l2"): 7 | # Initialize FAISS index 8 | if method == "l2": 9 | index = faiss.IndexFlatL2(pts.shape[1]) 10 | elif method == "cosine": 11 | index = faiss.IndexFlatIP(pts.shape[1]) 12 | else: 13 | raise NotImplementedError(f"Method: {method}") 14 | 15 | # Convert FAISS index to GPU 16 | if pts.get_device() != -1: 17 | res = faiss.StandardGpuResources() 18 | index = faiss.index_cpu_to_gpu(res, 0, index) 19 | 20 | # Add points to index and compute distances 21 | index.add(pts) 22 | distances, indices = index.search(pts, num_knn) 23 | return distances, indices 24 | 25 | 26 | def calculate_neighbors(params, variables, time_idx, num_knn=20): 27 | if time_idx is None: 28 | pts = params['means3D'].detach() 29 | else: 30 | pts = params['means3D'][:, :, time_idx].detach() 31 | neighbor_dist, neighbor_indices = torch_3d_knn(pts.contiguous(), num_knn) 32 | neighbor_weight = torch.exp(-2000 * torch.square(neighbor_dist)) 33 | variables["neighbor_indices"] = neighbor_indices.long().contiguous() 34 | variables["neighbor_weight"] = neighbor_weight.float().contiguous() 35 | variables["neighbor_dist"] = neighbor_dist.float().contiguous() 36 | return variables -------------------------------------------------------------------------------- /utils/recon_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera 3 | 4 | def setup_camera(w, h, k, w2c, near=0.01, far=100): 5 | fx, fy, cx, cy = k[0][0], k[1][1], k[0][2], k[1][2] 6 | w2c = torch.tensor(w2c).cuda().float() 7 | cam_center = torch.inverse(w2c)[:3, 3] 8 | w2c = w2c.unsqueeze(0).transpose(1, 2) 9 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0], 10 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0], 11 | [0.0, 0.0, far / (far - near), -(far * near) / (far - near)], 12 | [0.0, 0.0, 1.0, 0.0]]).cuda().float().unsqueeze(0).transpose(1, 2) 13 | full_proj = w2c.bmm(opengl_proj) 14 | cam = Camera( 15 | image_height=h, 16 | image_width=w, 17 | tanfovx=w / (2 * fx), 18 | tanfovy=h / (2 * fy), 19 | bg=torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda"), 20 | scale_modifier=1.0, 21 | viewmatrix=w2c, 22 | projmatrix=full_proj, 23 | sh_degree=0, 24 | campos=cam_center, 25 | prefiltered=False 26 | ) 27 | return cam 28 | -------------------------------------------------------------------------------- /venv_requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.65.0 2 | Pillow 3 | imageio 4 | matplotlib 5 | kornia 6 | natsort 7 | pyyaml 8 | wandb 9 | lpips 10 | torchmetrics 11 | cyclonedds 12 | pytorch-msssim 13 | plyfile==0.8.1 14 | git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git@cb65e4b86bc3bd8ed42174b72a62e8d3a3a71110 --------------------------------------------------------------------------------