├── .gitignore ├── LICENSE ├── README.md ├── assets └── pipeline.png ├── desplat ├── __init__.py ├── config.py ├── datamanager.py ├── dataparsers │ ├── onthego_dataparser.py │ ├── phototourism_dataparser.py │ └── robustnerf_dataparser.py ├── desplat_model.py ├── field.py └── pipeline.py ├── pyproject.toml └── scripts ├── calculate_memory.py ├── convert.py ├── download_dataset.py └── test_time_optimize.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | /examples 166 | /gsplat -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Arno Solin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # DeSplat: Decomposed Gaussian Splatting for Distractor-Free Rendering 4 | 5 |

6 | Yihao Wang · 7 | Marcus Klasson · 8 | Matias Turkulainen · 9 | Shuzhe Wang · 10 | Juho Kannala · 11 | Arno Solin 12 |

13 | 14 |

CVPR 2025

15 | 16 |

17 | Paper | 18 | Project Page 19 |

20 | 21 |
22 | 23 | --- 24 | 25 | This is the original code of DeSplat on [NerfStudio](http://www.nerf.studio/) codebase. 26 | 27 |
28 | Pipeline 29 |

Overall pipeline of our method.

30 |
31 | 32 | ## Installation 33 | Setup conda environment: 34 | ``` 35 | conda create --name desplat -y python=3.8 36 | conda activate desplat 37 | pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 38 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit 39 | ``` 40 | 41 | Install DeSplat: 42 | ``` 43 | git clone https://github.com/AaltoML/desplat.git 44 | cd desplat 45 | pip install -e . 46 | ns-install-cli 47 | ``` 48 | 49 | ## Download datasets 50 | 51 | We provide a convenient script to downloading compatible datasets. 52 | 53 | To download the datasets, run the `download_dataset.py` script: 54 | ``` 55 | python scripts/download_dataset.py --dataset ["robustnerf" or "on-the-go"] 56 | ``` 57 | 58 | Alternatively you can also run the following commands: 59 | - **RobustNeRF Dataset:** 60 | ``` 61 | wget https://storage.googleapis.com/jax3d-public/projects/robustnerf/robustnerf.tar.gz 62 | tar -xvf robustnerf.tar.gz 63 | ``` 64 | RobustNeRF Dataset includes pre-generated COLMAP points, we do not need to convert the data. 65 | 66 | - **On-the-go Dataset:** 67 | To download the undistorted and down-sampled on-the-go dataset, you can either access it on Hugging Face: [link](https://huggingface.co/datasets/jkulhanek/nerfonthego-undistorted/tree/main) or download and preprocess it using the following bash script. 68 | 69 | ``` 70 | bash scripts/download_on-the-go_processing.sh 71 | ``` 72 | This Bash script automatically downloads the data, processes it using COLMAP, and downsamples the images. For downsampling, you may need to install ImageMagick. 73 | ``` 74 | # install ImageMagick if there is not on your computer: 75 | conda install -c conda-forge imagemagick 76 | # install COLMAP if there is not on your computer: 77 | conda install conda-forge::colmap 78 | ``` 79 | > **Note**: The initial On-the-go dataset does not include COLMAP points, so preprocessing is required. For detailed preprocessing steps, please refer to the instructions below. 80 | 81 |
82 | Custom data 83 | We support COLMAP based datasets. Ensure your dataset is organized in the following structure: 84 | ``` 85 | 86 | |---images 87 | | |--- 88 | | |--- 89 | | |---... 90 | |---sparse 91 | |---0 92 | |---cameras.bin 93 | |---images.bin 94 | |---points3D.bin 95 | ``` 96 | For datasets like the On-the-go Dataset and custom datasets without point cloud information, you need to preprocess them using COLMAP. 97 | 98 | To prepare the images for the COLMAP processor, organize your dataset folder as follows: 99 | ``` 100 | 101 | |---input 102 | |--- 103 | |--- 104 | |---... 105 | ``` 106 | Then, run the following command: 107 | ``` 108 | # install COLMAP if there is not on your computer: 109 | conda install conda-forge::colmap 110 | python scripts/convert.py -s [--resize] # If not resizing, ImageMagick is not needed 111 | 112 | # an example for on-the-go dataset could be: 113 | python scripts/convert.py -s ../data/on-the-go/patio 114 | ``` 115 | 116 |
117 | 118 | ## Training 119 | 120 | For RobustNeRF data, train using: 121 | ``` 122 | ns-train desplat robustnerf-data --data [path_to_robustnerf_data] 123 | ``` 124 | 125 | For RobustNeRF data, train using: 126 | ``` 127 | ns-train desplat onthego-data --data [path_to_onthego_data] 128 | ``` 129 | 130 | For Photo Tourism data, train using: 131 | ``` 132 | ns-train desplat --steps_per_save 200000 --max_num_iterations 200000 --pipeline.model.stop_split_at 100000 \ 133 | --pipeline.model.enable_appearance True --pipeline.model.app_per_gauss True phototourism-data --data [path_to_onthego_data] 134 | ``` 135 | 136 | You can adjust the configuration by switching options such as `--pipeline.model.use_adc`. For more details, please refer to `desplat_model.py` for a closer look. 137 | 138 | 139 | ## Known Issues 140 | 141 | Due to differences in the optimization method, running the code directly within the Nerfstudio framework may result in the following issue: 142 |
143 | Traceback error 144 | 145 | ``` 146 | Traceback (most recent call last): 147 | File "/******/ns-train", line 8, in 148 | sys.exit(entrypoint()) 149 | File "/******/site-packages/nerfstudio/scripts/train.py", line 262, in entrypoint 150 | main( 151 | File "/******/site-packages/nerfstudio/scripts/train.py", line 247, in main 152 | launch( 153 | File "/******/site-packages/nerfstudio/scripts/train.py", line 189, in launch 154 | main_func(local_rank=0, world_size=world_size, config=config) 155 | File "/******/site-packages/nerfstudio/scripts/train.py", line 100, in train_loop 156 | trainer.train() 157 | File "/******/site-packages/nerfstudio/engine/trainer.py", line 301, in train 158 | self.save_checkpoint(step) 159 | File "/******/site-packages/nerfstudio/utils/decorators.py", line 82, in wrapper 160 | ret = func(*args, **kwargs) 161 | File "/******/site-packages/nerfstudio/engine/trainer.py", line 467, in save_checkpoint 162 | "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()}, 163 | File "/******/site-packages/nerfstudio/engine/trainer.py", line 467, in 164 | "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()}, 165 | File "/******/site-packages/torch/_compile.py", line 31, in inner 166 | return disable_fn(*args, **kwargs) 167 | File "/******/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn 168 | return fn(*args, **kwargs) 169 | File "/******/site-packages/torch/optim/optimizer.py", line 705, in state_dict 170 | packed_state = { 171 | File "/******/site-packages/torch/optim/optimizer.py", line 706, in 172 | (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v 173 | KeyError: ****** 174 | ``` 175 |
176 | 177 | To temporarily resolve this issue, my solution is to comment out the following two lines in the Nerfstudio library file `/nerfstudio/engine/trainer.py` (in Nerfstudio version 1.1.3, these are lines 467 and 468): 178 | 179 | ``` 180 | "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()}, 181 | "schedulers": {k: v.state_dict() for (k, v) in self.optimizers.schedulers.items()}, 182 | ``` 183 | 184 | 185 | ## Rendering 186 | 187 | To render a video, we need to use `ns-render`. The simple command is: 188 | ``` 189 | ns-render camera-path/interpolate/spiral/dataset --load-config .../config.yml 190 | ``` 191 | To render the images of the whole dataset for visualization, use the following command: 192 | ``` 193 | ns-render dataset --split train+test --load-config --output-path 194 | ``` 195 | 196 | ## Test-Time Optimization (TTO) 197 | 198 | For datasets with complex weather and lighting variations, using per-image embeddings and test-time optimization is essential. During training, you can run the following command: 199 | ``` 200 | ns-train desplat --pipeline.model.enable_appearance True phototourism-data --data [data-path] 201 | ``` 202 | After completing the training, you will get a `.ckpt` file and a configuration path. To perform test-time optimization, run: 203 | ``` 204 | python scripts/test_time_optimize.py --load-config [config-path] 205 | ``` 206 | Adding the `--save_gif` flag saves a `.gif` file for a quick visual review. 207 | Adding the `--save_all_imgs` flag saves all rendered test images. 208 | Adding the `--use_saved_embedding` flag loads the saved appearance embeddings. 209 | 210 | ## Citation 211 | 212 | ``` 213 | @InProceedings{wang2024desplat, 214 | title={{DeSplat}: Decomposed {G}aussian Splatting for Distractor-Free Rendering}, 215 | author={Wang, Yihao and Klasson, Marcus and Turkulainen, Matias and Wang, Shuzhe and Kannala, Juho and Solin, Arno}, 216 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 217 | year={2025} 218 | } 219 | ``` 220 | 221 | ## License 222 | This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details. 223 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoML/desplat/966b1f4ba68dc4e1fc900fc416e3f4244ee88902/assets/pipeline.png -------------------------------------------------------------------------------- /desplat/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataparsers.onthego_dataparser import OnthegoDataParserSpecification 2 | from .dataparsers.phototourism_dataparser import ( 3 | PhotoTourismDataParserSpecification, 4 | ) 5 | from .dataparsers.robustnerf_dataparser import RobustNerfDataParserSpecification 6 | 7 | __all__ = [ 8 | "__version__", 9 | OnthegoDataParserSpecification, 10 | PhotoTourismDataParserSpecification, 11 | RobustNerfDataParserSpecification, 12 | ] 13 | -------------------------------------------------------------------------------- /desplat/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | from desplat.datamanager import DeSplatDataManagerConfig 6 | from desplat.dataparsers.robustnerf_dataparser import RobustNerfDataParserConfig 7 | from desplat.desplat_model import DeSplatModelConfig 8 | from desplat.pipeline import DeSplatPipelineConfig 9 | from nerfstudio.configs.base_config import ViewerConfig 10 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 11 | from nerfstudio.engine.schedulers import ( 12 | ExponentialDecaySchedulerConfig, 13 | ) 14 | from nerfstudio.engine.trainer import TrainerConfig 15 | from nerfstudio.plugins.types import MethodSpecification 16 | 17 | desplat_method = MethodSpecification( 18 | config=TrainerConfig( 19 | method_name="desplat", 20 | steps_per_eval_image=100, 21 | steps_per_eval_batch=0, 22 | steps_per_save=30000, 23 | steps_per_eval_all_images=1000, 24 | max_num_iterations=30000, 25 | mixed_precision=False, 26 | pipeline=DeSplatPipelineConfig( 27 | datamanager=DeSplatDataManagerConfig( # desplat 28 | dataparser=RobustNerfDataParserConfig( 29 | load_3D_points=True, colmap_path=Path("sparse/0") 30 | ), # , downscale_factor=2 31 | cache_images_type="uint8", 32 | ), 33 | model=DeSplatModelConfig(), 34 | ), 35 | optimizers={ 36 | "means": { 37 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 38 | "scheduler": ExponentialDecaySchedulerConfig( 39 | lr_final=1.6e-6, 40 | max_steps=30000, 41 | ), 42 | }, 43 | "features_dc": { 44 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 45 | "scheduler": None, 46 | }, 47 | "features_rest": { 48 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 49 | "scheduler": None, 50 | }, 51 | "opacities": { 52 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 53 | "scheduler": None, 54 | }, 55 | "scales": { 56 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 57 | "scheduler": None, 58 | }, 59 | "quats": { 60 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), 61 | "scheduler": None, 62 | }, 63 | "embeddings": { 64 | "optimizer": AdamOptimizerConfig(lr=0.02, eps=1e-15), 65 | "scheduler": None, 66 | }, 67 | "means_dyn": { 68 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 69 | "scheduler": ExponentialDecaySchedulerConfig( 70 | lr_final=1.6e-6, 71 | max_steps=30000, 72 | ), 73 | }, 74 | "rgbs_dyn": { 75 | "optimizer": AdamOptimizerConfig(lr=0.025, eps=1e-15), 76 | "scheduler": None, 77 | }, 78 | "opacities_dyn": { 79 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 80 | "scheduler": None, 81 | }, 82 | "scales_dyn": { 83 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 84 | "scheduler": None, 85 | }, 86 | "quats_dyn": { 87 | "optimizer": AdamOptimizerConfig(lr=0.01, eps=1e-15), 88 | "scheduler": None, 89 | }, 90 | # back to original components 91 | "camera_opt": { 92 | "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15), 93 | "scheduler": ExponentialDecaySchedulerConfig( 94 | lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 95 | ), 96 | }, 97 | "appearance_mlp": { 98 | "optimizer": AdamOptimizerConfig(lr=0.0005, eps=1e-15), 99 | "scheduler": None, 100 | }, 101 | "appearance_embeddings": { 102 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 103 | "scheduler": None, 104 | }, 105 | "field_background_encoder": { 106 | "optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15), 107 | "scheduler": ExponentialDecaySchedulerConfig( 108 | lr_final=1e-4, max_steps=100000 109 | ), 110 | }, 111 | "field_background_base": { 112 | "optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15), 113 | "scheduler": ExponentialDecaySchedulerConfig( 114 | lr_final=2e-4, max_steps=100000 115 | ), 116 | }, 117 | "field_background_rest": { 118 | "optimizer": AdamOptimizerConfig(lr=2e-3 / 20, eps=1e-15), 119 | "scheduler": ExponentialDecaySchedulerConfig( 120 | lr_final=2e-4 / 20, max_steps=100000 121 | ), 122 | }, 123 | }, 124 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 125 | vis="viewer", 126 | ), 127 | description="Desplat", 128 | ) 129 | -------------------------------------------------------------------------------- /desplat/datamanager.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | from dataclasses import dataclass, field 4 | from typing import Dict, List, Literal, Tuple, Type, Union 5 | 6 | import torch 7 | import torch._dynamo 8 | 9 | from nerfstudio.cameras.cameras import Cameras 10 | from nerfstudio.data.datamanagers.full_images_datamanager import ( 11 | FullImageDatamanager, 12 | FullImageDatamanagerConfig, 13 | ) 14 | 15 | torch._dynamo.config.suppress_errors = True 16 | 17 | 18 | @dataclass 19 | class DeSplatDataManagerConfig(FullImageDatamanagerConfig): 20 | _target: Type = field(default_factory=lambda: DeSplatDataManager) 21 | 22 | 23 | class DeSplatDataManager(FullImageDatamanager): 24 | config: DeSplatDataManagerConfig 25 | 26 | def __init__( 27 | self, 28 | config: DeSplatDataManagerConfig, 29 | device: Union[torch.device, str] = "cpu", 30 | test_mode: Literal["test", "val", "inference"] = "val", 31 | world_size: int = 1, 32 | local_rank: int = 0, 33 | **kwargs, # pylint: disable=unused-argument 34 | ): 35 | self.config = config 36 | super().__init__( 37 | config=config, 38 | device=device, 39 | test_mode=test_mode, 40 | world_size=world_size, 41 | local_rank=local_rank, 42 | **kwargs, 43 | ) 44 | metadata = self.train_dataparser_outputs.metadata 45 | if test_mode == "test": 46 | self.train_unseen_cameras = [i for i in range(len(self.train_dataset))] 47 | 48 | @property 49 | def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: 50 | """ 51 | Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples 52 | """ 53 | image_indices = [i for i in range(len(self.eval_dataset))] 54 | data = [d.copy() for d in self.cached_eval] 55 | _cameras = deepcopy(self.eval_dataset.cameras).to(self.device) 56 | cameras = [] 57 | for i in image_indices: 58 | data[i]["image"] = data[i]["image"].to(self.device) 59 | if ( 60 | self.dataparser_config.eval_train 61 | or self.dataparser_config.test_time_optimize 62 | ): 63 | if _cameras.metadata is None: 64 | _cameras.metadata = {} 65 | _cameras.metadata["cam_idx"] = i 66 | cameras.append(_cameras[i : i + 1]) 67 | 68 | assert ( 69 | len(self.eval_dataset.cameras.shape) == 1 70 | ), "Assumes single batch dimension" 71 | return list(zip(cameras, data)) 72 | 73 | def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: 74 | """Returns the next evaluation batch 75 | 76 | Returns a Camera instead of raybundle 77 | 78 | TODO: Make sure this logic is consistent with the vanilladatamanager""" 79 | image_idx = self.eval_unseen_cameras.pop( 80 | random.randint(0, len(self.eval_unseen_cameras) - 1) 81 | ) 82 | # Make sure to re-populate the unseen cameras list if we have exhausted it 83 | if len(self.eval_unseen_cameras) == 0: 84 | self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] 85 | data = self.cached_eval[image_idx] 86 | data = data.copy() 87 | data["image"] = data["image"].to(self.device) 88 | assert ( 89 | len(self.eval_dataset.cameras.shape) == 1 90 | ), "Assumes single batch dimension" 91 | camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device) 92 | # keep metadata for debugging 93 | if self.dataparser_config.eval_train: 94 | if camera.metadata is None: 95 | camera.metadata = {} 96 | camera.metadata["cam_idx"] = image_idx 97 | return camera, data 98 | -------------------------------------------------------------------------------- /desplat/dataparsers/onthego_dataparser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/""" 16 | 17 | from __future__ import annotations 18 | 19 | import json 20 | import os 21 | from dataclasses import dataclass, field 22 | from pathlib import Path 23 | from typing import Type 24 | 25 | import numpy as np 26 | 27 | from nerfstudio.data.dataparsers.colmap_dataparser import ( 28 | ColmapDataParser, 29 | ColmapDataParserConfig, 30 | ) 31 | 32 | # TODO(1480) use pycolmap instead of colmap_parsing_utils 33 | # import pycolmap 34 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification 35 | 36 | 37 | @dataclass 38 | class OnthegoDataParserConfig(ColmapDataParserConfig): 39 | """On-the-go dataset parser config""" 40 | 41 | _target: Type = field(default_factory=lambda: OnthegoDataParser) 42 | """target class to instantiate""" 43 | eval_train: bool = False 44 | """evaluate test set or train set, for debug""" 45 | colmap_path: Path = Path("sparse/0") 46 | """path to colmap sparse folder""" 47 | test_time_optimize: bool = False 48 | """Whether to use test-time optimization for the dataset""" 49 | 50 | 51 | @dataclass 52 | class OnthegoDataParser(ColmapDataParser): 53 | """Phototourism dataset. This is based on https://github.com/kwea123/nerf_pl/blob/nerfw/datasets/phototourism.py 54 | and uses colmap's utils file to read the poses. 55 | """ 56 | 57 | config: OnthegoDataParserConfig 58 | 59 | def __init__(self, config: ColmapDataParserConfig): 60 | super().__init__(config=config) 61 | self.config = config 62 | self.data: Path = config.data 63 | self.config.downscale_factor = ( 64 | 4 if self.data.name == "patio" or self.data.name == "arcdetriomphe" else 8 65 | ) 66 | 67 | def _get_image_indices(self, image_filenames, split): 68 | # Load the split file to get the train/eval split 69 | with open(os.path.join(self.config.data, "split.json"), "r") as file: 70 | split_json = json.load(file) 71 | 72 | # Select the split according to the split file. 73 | all_indices = np.arange(len(image_filenames)) 74 | 75 | i_eval = all_indices[split_json["extra"]] 76 | 77 | i_train = all_indices[split_json["clutter"]] 78 | 79 | if self.config.eval_train: 80 | i_eval = i_train 81 | if split == "train": 82 | indices = i_train 83 | elif split in ["val", "test"]: 84 | indices = i_eval 85 | else: 86 | raise ValueError(f"Unknown dataparser split {split}") 87 | 88 | return indices 89 | 90 | 91 | OnthegoDataParserSpecification = DataParserSpecification( 92 | config=OnthegoDataParserConfig(), 93 | description="On-the-go dataparser", 94 | ) 95 | -------------------------------------------------------------------------------- /desplat/dataparsers/phototourism_dataparser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/""" 16 | 17 | from __future__ import annotations 18 | 19 | # TODO(1480) use pycolmap instead of colmap_parsing_utils 20 | # import pycolmap 21 | import csv 22 | import os 23 | from dataclasses import dataclass, field 24 | from pathlib import Path 25 | from typing import Type 26 | 27 | import numpy as np 28 | 29 | # from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs 30 | from nerfstudio.data.dataparsers.colmap_dataparser import ( 31 | ColmapDataParser, 32 | ColmapDataParserConfig, 33 | ) 34 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification 35 | 36 | 37 | @dataclass 38 | class PhotoTourismDataParserConfig(ColmapDataParserConfig): 39 | """Phototourism dataset parser config""" 40 | 41 | _target: Type = field(default_factory=lambda: PhotoTourismDataParser) 42 | """target class to instantiate""" 43 | eval_train: bool = False 44 | """evaluate test set or train set, for debug""" 45 | colmap_path: Path = Path("sparse") 46 | 47 | test_time_optimize: bool = True 48 | 49 | 50 | @dataclass 51 | class PhotoTourismDataParser(ColmapDataParser): 52 | """Phototourism dataset. This is based on https://github.com/kwea123/nerf_pl/blob/nerfw/datasets/phototourism.py 53 | and uses colmap's utils file to read the poses. 54 | """ 55 | 56 | config: PhotoTourismDataParserConfig 57 | 58 | def __init__(self, config: PhotoTourismDataParserConfig): 59 | super().__init__(config=config) 60 | self.data: Path = config.data 61 | 62 | def _get_image_indices(self, image_filenames, split): 63 | # Load the split file to get the train/eval split 64 | if "brandenburg_gate" in str(self.data): 65 | tsv_path = self.config.data / "brandenburg.tsv" 66 | elif "sacre_coeur" in str(self.data): 67 | tsv_path = self.config.data / "sacre.tsv" 68 | elif "trevi_fountain" in str(self.data): 69 | tsv_path = self.config.data / "trevi.tsv" 70 | else: 71 | raise ValueError(f"Unknown dataset {self.data.name}") 72 | 73 | basenames = [ 74 | os.path.basename(image_filename) for image_filename in image_filenames 75 | ] 76 | 77 | train_names, test_names = set(), set() 78 | 79 | with open(tsv_path, newline="") as tsv_file: 80 | reader = csv.reader(tsv_file, delimiter="\t") 81 | next(reader) 82 | for row in reader: 83 | if row[2] == "train": 84 | train_names.add(row[0]) 85 | elif row[2] == "test": 86 | test_names.add(row[0]) 87 | 88 | if self.config.eval_train: 89 | split = "train" 90 | 91 | indices = [ 92 | idx 93 | for idx, basename in enumerate(basenames) 94 | if (basename in train_names and split == "train") 95 | or (basename in test_names and split in ["val", "test"]) 96 | ] 97 | 98 | if not indices and split not in ["train", "test", "val"]: 99 | raise ValueError(f"Unknown dataparser split {split}") 100 | 101 | return np.array(indices) 102 | 103 | 104 | PhotoTourismDataParserSpecification = DataParserSpecification( 105 | config=PhotoTourismDataParserConfig(), description="Photo tourism dataparser" 106 | ) 107 | -------------------------------------------------------------------------------- /desplat/dataparsers/robustnerf_dataparser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/""" 16 | 17 | from __future__ import annotations 18 | 19 | import os 20 | from dataclasses import dataclass, field 21 | from pathlib import Path 22 | from typing import List, Literal, Optional, Type 23 | 24 | import numpy as np 25 | import torch 26 | 27 | from nerfstudio.data.dataparsers.colmap_dataparser import ( 28 | ColmapDataParser, 29 | ColmapDataParserConfig, 30 | ) 31 | 32 | # TODO(1480) use pycolmap instead of colmap_parsing_utils 33 | # import pycolmap 34 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification 35 | 36 | 37 | @dataclass 38 | class RobustNerfDataParserConfig(ColmapDataParserConfig): 39 | """On-the-go dataset parser config""" 40 | 41 | _target: Type = field(default_factory=lambda: RobustNerfDataParser) 42 | """target class to instantiate""" 43 | eval_train: bool = False 44 | """evaluate test set or train set, for debug""" 45 | train_split_mode: Optional[Literal["ratio", "number", "filename"]] = None 46 | """How to split the training images. If None, all cluttered images are used.""" 47 | train_split_clean_clutter_ratio: float = 1.0 48 | """The percentage of the training images that are cluttered. 0.0 -> only clean images, 1.0 -> only cluttered images""" 49 | train_split_clean_clutter_number: int = 0 50 | """The number of clean images to use for training. If 0, all clean images are used.""" 51 | idx_clutter: List[int] = field(default_factory=lambda: [76]) 52 | """The indices of the cluttered images to use for training""" 53 | colmap_path: Path = Path("sparse/0") 54 | """path to colmap sparse folder""" 55 | downscale_factor: int = 8 56 | """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" 57 | test_time_optimize: bool = False 58 | """Whether to use test-time optimization for the dataset""" 59 | 60 | 61 | @dataclass 62 | class RobustNerfDataParser(ColmapDataParser): 63 | """RobustNerf dataset. This is based on https://github.com/kwea123/nerf_pl/blob/nerfw/datasets/phototourism.py 64 | and uses colmap's utils file to read the poses. 65 | """ 66 | 67 | config: RobustNerfDataParserConfig 68 | 69 | def __init__(self, config: ColmapDataParserConfig): 70 | super().__init__(config=config) 71 | self.config = config 72 | self.data = config.data 73 | 74 | def _get_image_indices(self, image_filenames, split): 75 | i_train, i_eval = self.get_train_eval_split_filename( 76 | image_filenames, self.config.train_split_clean_clutter_ratio 77 | ) 78 | 79 | if split == "train": 80 | indices = i_train 81 | elif split in ["val", "test"]: 82 | indices = i_eval 83 | else: 84 | raise ValueError(f"Unknown dataparser split {split}") 85 | 86 | return indices 87 | 88 | def get_train_eval_split_filename( 89 | self, image_filenames: List, clean_clutter_ratio: float 90 | ) -> Tuple[np.ndarray, np.ndarray]: 91 | """ 92 | Get the train/eval split based on the filename of the images. 93 | 94 | Args: 95 | image_filenames: list of image filenames 96 | """ 97 | 98 | num_images = len(image_filenames) 99 | basenames = [ 100 | os.path.basename(image_filename) for image_filename in image_filenames 101 | ] 102 | 103 | i_all = np.arange(num_images) 104 | i_clean = [] 105 | i_clutter = [] 106 | 107 | i_train_clean = [] 108 | i_train_clutter = [] 109 | 110 | i_train = [] 111 | i_eval = [] 112 | 113 | for idx, basename in zip(i_all, basenames): 114 | if "clean" in basename: 115 | i_clean.append(idx) 116 | 117 | elif "clutter" in basename: 118 | i_clutter.append(idx) 119 | 120 | elif "extra" in basename: 121 | i_eval.append(idx) # extra is always used as eval 122 | else: 123 | raise ValueError( 124 | "image frame should contain clutter/extra in its name " 125 | ) 126 | 127 | if self.config.train_split_mode is None: 128 | i_train = i_clutter 129 | if self.config.eval_train: 130 | i_eval = i_train 131 | return np.array(i_train), np.array(i_eval) 132 | 133 | if len(i_clean) > len(i_clutter): 134 | i_clean = i_clean[: len(i_clutter)] 135 | elif len(i_clean) < len(i_clutter): 136 | i_clutter = i_clutter[: len(i_clean)] 137 | num_images_train = min(len(i_clean), len(i_clutter)) 138 | 139 | print("Number of clean images: ", num_images_train) 140 | 141 | if self.config.train_split_mode == "number": 142 | i_perm = torch.randperm( 143 | num_images_train, generator=torch.Generator().manual_seed(2023) 144 | ).tolist() 145 | num_images_cluttered = self.config.train_split_clean_clutter_number 146 | i_train = [] 147 | i_train_clutter = [] 148 | # loop over permuted indices to select one image from each clean/clutter pair 149 | 150 | for k, idx in enumerate(i_perm): 151 | if k < num_images_cluttered: 152 | i_train_clutter.append(i_clutter[idx]) 153 | else: 154 | i_train.append(i_clean[idx]) 155 | i_train.expand(i_train_clutter) 156 | elif self.config.train_split_mode == "ratio": 157 | i_train_clutter = [] 158 | if clean_clutter_ratio == 0.0: 159 | # only clean images 160 | i_train = i_clean 161 | elif clean_clutter_ratio == 1.0: 162 | # only cluttered images 163 | i_train = i_clutter 164 | elif clean_clutter_ratio > 0.0 and clean_clutter_ratio < 1.0: 165 | # pick either clutter/clean image once 166 | i_perm = torch.randperm( 167 | num_images_train, generator=torch.Generator().manual_seed(2023) 168 | ).tolist() 169 | num_images_cluttered = int( 170 | num_images_train * clean_clutter_ratio 171 | ) # rounds down 172 | i_train = [] 173 | # loop over permuted indices to select one image from each clean/clutter pair 174 | for k, idx in enumerate(i_perm): 175 | if k < num_images_cluttered: 176 | i_train_clutter.append(i_clutter[idx]) 177 | else: 178 | i_train.append(i_clean[idx]) 179 | i_train.extend(i_train_clutter) 180 | print( 181 | "basenames of cluttered images: ", 182 | [basenames[i] for i in i_train_clutter], 183 | ) 184 | else: 185 | raise ValueError( 186 | "arg train_split_clean_clutter_ratio must be between 0.0 and 1.0 " 187 | ) 188 | 189 | elif self.config.train_split_mode == "filename": 190 | # manually select images 191 | i_train = i_clean 192 | i_train.extend(i_train_clutter) 193 | 194 | # Remove elements in i_train_clean from i_train 195 | for item in i_train_clean: 196 | if item in i_train: 197 | i_train.remove(item) 198 | print( 199 | "basenames of cluttered images: ", 200 | [basenames[i] for i in i_train_clutter], 201 | ) 202 | else: 203 | raise ValueError("Unknown train_split_mode") 204 | print("i_train", i_train) 205 | print("i_eval", i_eval) 206 | if self.config.eval_train: 207 | i_eval = i_train 208 | return np.array(i_train), np.array(i_eval) 209 | 210 | 211 | RobustNerfDataParserSpecification = DataParserSpecification( 212 | config=RobustNerfDataParserConfig(), 213 | description="RobustNeRF dataparser", 214 | ) 215 | -------------------------------------------------------------------------------- /desplat/desplat_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Dict, List, Optional, Tuple, Type, Union 5 | 6 | import torch 7 | import torch._dynamo 8 | from pytorch_msssim import SSIM 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torchmetrics.image import PeakSignalNoiseRatio 12 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 13 | 14 | from desplat.field import ( 15 | BGField, 16 | EmbeddingModel, 17 | _get_fourier_features, 18 | ) 19 | from gsplat.cuda._wrapper import spherical_harmonics 20 | from gsplat.rendering import rasterization 21 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer 22 | from nerfstudio.cameras.camera_utils import normalize 23 | from nerfstudio.cameras.cameras import Cameras 24 | from nerfstudio.data.scene_box import OrientedBox 25 | from nerfstudio.engine.optimizers import Optimizers 26 | from nerfstudio.models.splatfacto import ( 27 | RGB2SH, 28 | SplatfactoModel, 29 | SplatfactoModelConfig, 30 | get_viewmat, 31 | num_sh_bases, 32 | random_quat_tensor, 33 | ) 34 | from nerfstudio.utils.colors import get_color 35 | from nerfstudio.utils.rich_utils import CONSOLE 36 | 37 | torch._dynamo.config.suppress_errors = True 38 | torch.set_float32_matmul_precision("high") 39 | 40 | 41 | def quat_to_rotmat(quat): 42 | assert quat.shape[-1] == 4, quat.shape 43 | w, x, y, z = torch.unbind(quat, dim=-1) 44 | mat = torch.stack( 45 | [ 46 | 1 - 2 * (y**2 + z**2), 47 | 2 * (x * y - w * z), 48 | 2 * (x * z + w * y), 49 | 2 * (x * y + w * z), 50 | 1 - 2 * (x**2 + z**2), 51 | 2 * (y * z - w * x), 52 | 2 * (x * z - w * y), 53 | 2 * (y * z + w * x), 54 | 1 - 2 * (x**2 + y**2), 55 | ], 56 | dim=-1, 57 | ) 58 | return mat.reshape(quat.shape[:-1] + (3, 3)) 59 | 60 | 61 | @dataclass 62 | class DeSplatModelConfig(SplatfactoModelConfig): 63 | _target: Type = field(default_factory=lambda: DeSplatModel) 64 | 65 | ### Settings for DeSplat 66 | num_dynamic_points: int = 1000 67 | """Initial number of dynamic points""" 68 | refine_every_dyn: int = 10 69 | """period of steps where gaussians are culled and densified""" 70 | use_adc: bool = True 71 | """Whether to use ADC for dynamic points management""" 72 | enable_reset_alpha: bool = False 73 | """Whether to reset alpha for dynamic points""" 74 | continue_split_dyn: bool = False 75 | """Whether to continue splitting for dynamic points after step for splitting stops""" 76 | distance: float = 0.02 77 | """Distance of dynamic points from camera""" 78 | 79 | ### Settings for Regularization 80 | alpha_bg_loss_lambda: float = 0.01 81 | """Lambda for alpha background loss""" 82 | alpha_2d_loss_lambda: float = 0.01 83 | """Lambda for alpha loss of 2D dynamic Gaussians""" 84 | enable_bg_model: bool = False 85 | """Whether to enable background model""" 86 | bg_sh_degree: int = 4 87 | """Degree of SH bases for background model""" 88 | bg_num_layers: int = 3 89 | """Number of layers in the background model""" 90 | bg_layer_width: int = 128 91 | """Width of each layer in the background model""" 92 | 93 | ### Settings for Appearance Optimization 94 | enable_appearance: bool = False 95 | """Enable or disable appearance optimization""" 96 | app_per_gauss: bool = False 97 | """Whether to optimize appearance according to the per-Gaussian embedding""" 98 | appearance_embedding_dim: int = 32 99 | """Dimension of the appearance embedding""" 100 | appearance_n_fourier_freqs: int = 4 101 | """Number of Fourier frequencies for per-Gaussian embedding initialization""" 102 | appearance_init_fourier: bool = True 103 | """Whether to initialize the per-Gaussian embedding with Fourier frequencies""" 104 | 105 | class DeSplatModel(SplatfactoModel): 106 | config: DeSplatModelConfig 107 | 108 | def populate_modules(self): 109 | cameras = self.kwargs["cameras"] 110 | self.dataparser_config = self.kwargs["dataparser"] 111 | 112 | if self.seed_points is not None and not self.config.random_init: 113 | means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color) 114 | else: 115 | means = torch.nn.Parameter( 116 | (torch.rand((self.config.num_random, 3)) - 0.5) 117 | * self.config.random_scale 118 | ) 119 | 120 | assert cameras is not None 121 | 122 | self.xys_grad_norm = None 123 | self.max_2Dsize = None 124 | distances, _ = self.k_nearest_sklearn(means.data, 3) 125 | distances = torch.from_numpy(distances) 126 | # find the average of the three nearest neighbors for each point and use that as the scale 127 | avg_dist = distances.mean(dim=-1, keepdim=True) 128 | scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3))) 129 | num_points = means.shape[0] 130 | 131 | quats = torch.nn.Parameter(random_quat_tensor(num_points)) 132 | dim_sh = num_sh_bases(self.config.sh_degree) 133 | 134 | if ( 135 | self.seed_points is not None 136 | and not self.config.random_init 137 | # We can have colors without points. 138 | and self.seed_points[1].shape[0] > 0 139 | ): 140 | shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda() 141 | if self.config.sh_degree > 0: 142 | shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255) 143 | shs[:, 1:, 3:] = 0.0 144 | else: 145 | CONSOLE.log("use color only optimization with sigmoid activation") 146 | shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10) 147 | features_dc = torch.nn.Parameter(shs[:, 0, :]) 148 | features_rest = torch.nn.Parameter(shs[:, 1:, :]) 149 | else: 150 | features_dc = torch.nn.Parameter(torch.rand(num_points, 3)) 151 | features_rest = torch.nn.Parameter(torch.zeros((num_points, dim_sh - 1, 3))) 152 | 153 | opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(num_points, 1))) 154 | 155 | # appearance embedding for each Gaussian 156 | embeddings = _get_fourier_features( 157 | means, num_features=self.config.appearance_n_fourier_freqs 158 | ) 159 | embeddings.add_(torch.randn_like(embeddings) * 0.0001) 160 | if not self.config.appearance_init_fourier: 161 | embeddings.normal_(0, 0.01) 162 | 163 | self.gauss_params = torch.nn.ParameterDict( 164 | { 165 | "means": means, 166 | "scales": scales, 167 | "quats": quats, 168 | "features_dc": features_dc, 169 | "features_rest": features_rest, 170 | "opacities": opacities, 171 | } 172 | ) 173 | 174 | if self.config.app_per_gauss: 175 | # appearance embedding for each Gaussian 176 | embeddings = _get_fourier_features(means, num_features=self.config.appearance_n_fourier_freqs) 177 | embeddings.add_(torch.randn_like(embeddings) * 0.0001) 178 | if not self.config.appearance_init_fourier: 179 | embeddings.normal_(0, 0.01) 180 | self.gauss_params["embeddings"] = embeddings 181 | 182 | self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup( 183 | num_cameras=self.num_train_data, device="cpu" 184 | ) 185 | 186 | self.camera_idx = 0 187 | self.camera = None 188 | 189 | self.psnr = PeakSignalNoiseRatio(data_range=1.0) 190 | self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) 191 | self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) 192 | self.step = 0 193 | 194 | self.crop_box: Optional[OrientedBox] = None 195 | if self.config.background_color == "random": 196 | self.background_color = torch.tensor( 197 | [0.1490, 0.1647, 0.2157] 198 | ) # This color is the same as the default background color in Viser. This would only affect the background color when rendering. 199 | else: 200 | self.background_color = get_color(self.config.background_color) 201 | 202 | # only the storage method is dict we can use ADC 203 | self.populate_modules_dyn_dict() 204 | self.num_points_dyn = self.num_points_dyn_dict 205 | 206 | self.loss_threshold = 1.0 207 | self.max_loss = 0.0 208 | self.min_loss = 1e10 209 | 210 | # Add the appearance embedding 211 | if self.config.enable_appearance: 212 | self.appearance_embeddings = torch.nn.Embedding( 213 | self.num_train_data, self.config.appearance_embedding_dim 214 | ) 215 | self.appearance_embeddings.weight.data.normal_(0, 0.01) 216 | self.appearance_mlp = EmbeddingModel(self.config) 217 | else: 218 | self.appearance_embeddings = None 219 | self.appearance_mlp = None 220 | 221 | if self.config.enable_bg_model: 222 | self.bg_model = BGField( 223 | appearance_embedding_dim=self.config.appearance_embedding_dim, 224 | implementation="torch", 225 | sh_levels=self.config.bg_sh_degree, 226 | num_layers=self.config.bg_num_layers, 227 | layer_width=self.config.bg_layer_width, 228 | ) 229 | else: 230 | self.bg_model = None 231 | 232 | def populate_modules_dyn_dict(self): 233 | cameras = self.kwargs["cameras"] 234 | self.gauss_params_dyn_dict = nn.ModuleDict() 235 | num_points_dyn = self.config.num_dynamic_points 236 | 237 | self.optimizers_dyn = {} 238 | 239 | for i in range(self.num_train_data): 240 | camera = cameras[i] 241 | 242 | optimized_camera_to_world = camera.camera_to_worlds 243 | 244 | camera_rotation = optimized_camera_to_world[:3, :3] 245 | camera_position = optimized_camera_to_world[:3, 3] 246 | distance_to_cam = self.config.distance 247 | cube = torch.rand(num_points_dyn, 3) 248 | cube[:, 0] = (cube[:, 0] - 0.5) * 0.02 249 | cube[:, 1] = (cube[:, 1] - 0.5) * 0.02 250 | cube[:, 2] = distance_to_cam 251 | 252 | means_dyn = torch.nn.Parameter( 253 | camera_position.repeat(num_points_dyn, 1) - cube @ camera_rotation.T 254 | ) 255 | 256 | distances_dyn, _ = self.k_nearest_sklearn(means_dyn.data, 3) 257 | distances_dyn = torch.from_numpy(distances_dyn) 258 | avg_dist_dyn = distances_dyn.mean(dim=-1, keepdim=True) 259 | scales_dyn = torch.nn.Parameter(torch.log(avg_dist_dyn.repeat(1, 3))) 260 | rgbs_dyn = torch.nn.Parameter(torch.rand(num_points_dyn, 3)) 261 | quats_dyn = nn.Parameter(random_quat_tensor(num_points_dyn)) 262 | opacities_dyn = nn.Parameter( 263 | torch.logit(0.1 * torch.ones(num_points_dyn, 1)) 264 | ) 265 | 266 | self.gauss_params_dyn_dict[str(i)] = nn.ParameterDict( 267 | { 268 | "means_dyn": means_dyn, 269 | "scales_dyn": scales_dyn, 270 | "rgbs_dyn": rgbs_dyn, 271 | "quats_dyn": quats_dyn, 272 | "opacities_dyn": opacities_dyn, 273 | } 274 | ) 275 | 276 | self.xys_dyn = {} 277 | self.radii_dyn = {} 278 | self.xys_grad_norm_dyn = {str(i): None for i in range(self.num_train_data)} 279 | self.max_2Dsize_dyn = {str(i): None for i in range(self.num_train_data)} 280 | self.vis_counts_dyn = {str(i): None for i in range(self.num_train_data)} 281 | 282 | @property 283 | def embeddings(self): 284 | if self.config.app_per_gauss: 285 | return self.gauss_params["embeddings"] 286 | 287 | def load_state_dict(self, dict, **kwargs): # type: ignore 288 | # resize the parameters to match the new number of points 289 | self.step = 30000 290 | if "means" in dict: 291 | # For backwards compatibility, we remap the names of parameters from 292 | # means->gauss_params.means since old checkpoints have that format 293 | for p in [ 294 | "means", 295 | "scales", 296 | "quats", 297 | "features_dc", 298 | "features_rest", 299 | "opacities", 300 | ]: 301 | dict[f"gauss_params.{p}"] = dict[p] 302 | if self.config.app_per_gauss: 303 | dict[f"gauss_params.embeddings"] = dict[p] 304 | newp = dict["gauss_params.means"].shape[0] 305 | 306 | for name, param in self.gauss_params.items(): 307 | old_shape = param.shape 308 | new_shape = (newp,) + old_shape[1:] 309 | self.gauss_params[name] = torch.nn.Parameter( 310 | torch.zeros(new_shape, device=self.device) 311 | ) 312 | 313 | for i in range(self.num_train_data): 314 | if "means_dyn" in dict: 315 | for p in [ 316 | "means_dyn", 317 | "scales_dyn", 318 | "quats_dyn", 319 | "rgbs_dyn", 320 | "opacities_dyn", 321 | ]: # "features_dc_dyn", "features_rest_dyn", 322 | dict[f"gauss_params_dyn_dict.{str(i)}.{p}"] = dict[p] 323 | newp_dyn = dict[f"gauss_params_dyn_dict.{str(i)}.means_dyn"].shape[0] 324 | for name, param in self.gauss_params_dyn_dict[str(i)].items(): 325 | old_shape = param.shape 326 | new_shape = (newp_dyn,) + old_shape[1:] 327 | self.gauss_params_dyn_dict[str(i)][name] = torch.nn.Parameter( 328 | torch.zeros(new_shape, device=self.device) 329 | ) 330 | 331 | super().load_state_dict(dict, **kwargs) 332 | 333 | def num_points_dyn_dict(self, i): 334 | return self.gauss_params_dyn_dict[str(i)]["means_dyn"].shape[0] 335 | 336 | def refinement_after(self, optimizers: Optimizers, step): 337 | assert step == self.step 338 | if self.step <= self.config.warmup_length: 339 | return 340 | with torch.no_grad(): 341 | # Offset all the opacity reset logic by refine_every so that we don't 342 | # save checkpoints right when the opacity is reset (saves every 2k) 343 | # then cull 344 | # only split/cull if we've seen every image since opacity reset 345 | reset_interval = self.config.reset_alpha_every * self.config.refine_every 346 | do_densification = ( 347 | self.step < self.config.stop_split_at 348 | and self.step % reset_interval 349 | > self.num_train_data + self.config.refine_every 350 | ) 351 | 352 | if do_densification: 353 | # then we densify 354 | # for static points 355 | assert ( 356 | self.xys_grad_norm is not None 357 | and self.vis_counts is not None 358 | and self.max_2Dsize is not None 359 | ) 360 | avg_grad_norm = ( 361 | (self.xys_grad_norm / self.vis_counts) 362 | * 0.5 363 | * max(self.last_size[0], self.last_size[1]) 364 | ) 365 | high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze() 366 | splits = ( 367 | self.scales.exp().max(dim=-1).values 368 | > self.config.densify_size_thresh 369 | ).squeeze() 370 | splits &= high_grads 371 | if self.step < self.config.stop_screen_size_at: 372 | splits |= ( 373 | self.max_2Dsize > self.config.split_screen_size 374 | ).squeeze() 375 | nsamps = self.config.n_split_samples 376 | split_params = self.split_gaussians(splits, nsamps) 377 | 378 | dups = ( 379 | self.scales.exp().max(dim=-1).values 380 | <= self.config.densify_size_thresh 381 | ).squeeze() 382 | dups &= high_grads 383 | dup_params = self.dup_gaussians(dups) 384 | 385 | for name, param in self.gauss_params.items(): 386 | self.gauss_params[name] = torch.nn.Parameter( 387 | torch.cat( 388 | [param.detach(), split_params[name], dup_params[name]], 389 | dim=0, 390 | ) 391 | ) 392 | 393 | # append zeros to the max_2Dsize tensor 394 | self.max_2Dsize = torch.cat( 395 | [ 396 | self.max_2Dsize, 397 | torch.zeros_like(split_params["scales"][:, 0]), 398 | torch.zeros_like(dup_params["scales"][:, 0]), 399 | ], 400 | dim=0, 401 | ) 402 | 403 | split_idcs = torch.where(splits)[0] 404 | self.dup_in_all_optim(optimizers, split_idcs, nsamps) 405 | 406 | dup_idcs = torch.where(dups)[0] 407 | self.dup_in_all_optim(optimizers, dup_idcs, 1) 408 | 409 | # After a guassian is split into two new gaussians, the original one should also be pruned. 410 | splits_mask = torch.cat( 411 | ( 412 | splits, 413 | torch.zeros( 414 | nsamps * splits.sum() + dups.sum(), 415 | device=self.device, 416 | dtype=torch.bool, 417 | ), 418 | ) 419 | ) 420 | 421 | deleted_mask = self.cull_gaussians(splits_mask) 422 | 423 | elif ( 424 | self.step >= self.config.stop_split_at 425 | and self.config.continue_cull_post_densification 426 | ): 427 | deleted_mask = self.cull_gaussians() 428 | 429 | else: 430 | # if we donot allow culling post refinement, no more gaussians will be pruned. 431 | deleted_mask = None 432 | 433 | if deleted_mask is not None: 434 | self.remove_from_all_optim(optimizers, deleted_mask) 435 | if ( 436 | self.step < self.config.stop_split_at 437 | and self.step % reset_interval == self.config.refine_every 438 | ): 439 | # Reset value is set to be twice of the cull_alpha_thresh 440 | reset_value = self.config.cull_alpha_thresh * 2.0 441 | self.opacities.data = torch.clamp( 442 | self.opacities.data, 443 | max=torch.logit( 444 | torch.tensor(reset_value, device=self.device) 445 | ).item(), 446 | ) 447 | 448 | # reset the exp of optimizer 449 | optim = optimizers.optimizers["opacities"] 450 | param = optim.param_groups[0]["params"][0] 451 | param_state = optim.state[param] 452 | param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"]) 453 | param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"]) 454 | 455 | self.xys_grad_norm = None 456 | self.vis_counts = None 457 | self.max_2Dsize = None 458 | 459 | if self.config.use_adc: 460 | do_densification_dyn = ( 461 | self.step < self.config.stop_split_at 462 | or self.config.continue_split_dyn 463 | ) and self.step % ( 464 | self.num_train_data * self.config.refine_every_dyn 465 | ) < self.config.refine_every 466 | # for dynamic points 467 | if do_densification_dyn: 468 | for camera_idx in range(self.num_train_data): 469 | # if number of Gaussians == 0, skip the densification 470 | if self.num_points_dyn(camera_idx) == 0: 471 | self.xys_grad_norm_dyn[str(camera_idx)] = None 472 | self.vis_counts_dyn[str(camera_idx)] = None 473 | self.max_2Dsize_dyn[str(camera_idx)] = None 474 | continue 475 | 476 | # we densify the 2D image when every image has been seen for refine_every_dyn times 477 | # For the dynamic points, for every image, we densify the points 478 | xys_grad_norm_dyn = self.xys_grad_norm_dyn[str(camera_idx)] 479 | vis_counts_dyn = self.vis_counts_dyn[str(camera_idx)] 480 | max_2Dsize_dyn = self.max_2Dsize_dyn[str(camera_idx)] 481 | 482 | dyn_gaussians = self.gauss_params_dyn_dict[str(camera_idx)] 483 | 484 | assert ( 485 | xys_grad_norm_dyn is not None 486 | and vis_counts_dyn is not None 487 | and max_2Dsize_dyn is not None 488 | ) 489 | avg_grad_norm_dyn = ( 490 | (xys_grad_norm_dyn / vis_counts_dyn) 491 | * 0.5 492 | * max(self.last_size[0], self.last_size[1]) 493 | ) 494 | high_grads_dyn = ( 495 | avg_grad_norm_dyn > self.config.densify_grad_thresh 496 | ).squeeze() 497 | splits_dyn = ( 498 | dyn_gaussians["scales_dyn"].exp().max(dim=-1).values 499 | > self.config.densify_size_thresh 500 | ).squeeze() 501 | splits_dyn &= high_grads_dyn 502 | if self.step < self.config.stop_screen_size_at: 503 | splits_dyn |= ( 504 | max_2Dsize_dyn > self.config.split_screen_size 505 | ).squeeze() 506 | nsamps = self.config.n_split_samples 507 | 508 | split_params_dyn = self.split_gaussians_dyn( 509 | camera_idx, splits_dyn, nsamps 510 | ) 511 | 512 | dups_dyn = ( 513 | dyn_gaussians["scales_dyn"].exp().max(dim=-1).values 514 | <= self.config.densify_size_thresh 515 | ).squeeze() 516 | dups_dyn &= high_grads_dyn 517 | dup_params_dyn = self.dup_gaussians_dyn(camera_idx, dups_dyn) 518 | 519 | for name, param in self.gauss_params_dyn_dict[ 520 | str(camera_idx) 521 | ].items(): 522 | self.gauss_params_dyn_dict[str(camera_idx)][name] = ( 523 | torch.nn.Parameter( 524 | torch.cat( 525 | [ 526 | param.detach(), 527 | split_params_dyn[name], 528 | dup_params_dyn[name], 529 | ], 530 | dim=0, 531 | ) 532 | ) 533 | ) 534 | 535 | # append zeros to the max_2Dsize tensor 536 | self.max_2Dsize_dyn[str(camera_idx)] = torch.cat( 537 | [ 538 | self.max_2Dsize_dyn[str(camera_idx)], 539 | torch.zeros_like(split_params_dyn["scales_dyn"][:, 0]), 540 | torch.zeros_like(dup_params_dyn["scales_dyn"][:, 0]), 541 | ], 542 | dim=0, 543 | ) 544 | 545 | split_idcs = torch.where(splits_dyn)[0] 546 | # log this info to logfile 547 | 548 | self.dup_in_all_optim_dyn( 549 | camera_idx, optimizers, split_idcs, nsamps 550 | ) 551 | dup_idcs = torch.where(dups_dyn)[0] 552 | self.dup_in_all_optim_dyn(camera_idx, optimizers, dup_idcs, 1) 553 | 554 | # After a guassian is split into two new gaussians, the original one should also be pruned. 555 | splits_mask_dyn = torch.cat( 556 | ( 557 | splits_dyn, 558 | torch.zeros( 559 | nsamps * splits_dyn.sum() + dups_dyn.sum(), 560 | device=self.device, 561 | dtype=torch.bool, 562 | ), 563 | ) 564 | ) 565 | 566 | vis_cull_mask = (self.radii_dyn[str(camera_idx)] < 0.01).squeeze() # cull invisible distractor Gaussians 567 | splits_mask_dyn = splits_mask_dyn | vis_cull_mask 568 | 569 | deleted_mask_dyn = self.cull_gaussians_dyn( 570 | camera_idx, splits_mask_dyn 571 | ) 572 | 573 | if deleted_mask_dyn is not None: 574 | self.remove_from_all_optim_dyn( 575 | camera_idx, optimizers, deleted_mask_dyn 576 | ) 577 | 578 | self.xys_grad_norm_dyn[str(camera_idx)] = None 579 | self.vis_counts_dyn[str(camera_idx)] = None 580 | self.max_2Dsize_dyn[str(camera_idx)] = None 581 | 582 | elif ( 583 | self.step >= self.config.stop_split_at 584 | and self.config.continue_cull_post_densification 585 | and self.step % ( 586 | self.num_train_data * self.config.refine_every_dyn 587 | ) < self.config.refine_every 588 | ): 589 | for camera_idx in range(self.num_train_data): 590 | if self.num_points_dyn(camera_idx) == 0: 591 | self.xys_grad_norm_dyn[str(camera_idx)] = None 592 | self.vis_counts_dyn[str(camera_idx)] = None 593 | self.max_2Dsize_dyn[str(camera_idx)] = None 594 | continue 595 | 596 | deleted_mask_dyn = self.cull_gaussians_dyn(camera_idx) 597 | if deleted_mask_dyn is not None: 598 | self.remove_from_all_optim_dyn( 599 | camera_idx, optimizers, deleted_mask_dyn 600 | ) 601 | self.xys_grad_norm_dyn[str(camera_idx)] = None 602 | self.vis_counts_dyn[str(camera_idx)] = None 603 | self.max_2Dsize_dyn[str(camera_idx)] = None 604 | 605 | 606 | reset_dyn = ( 607 | self.step < self.config.stop_split_at 608 | and self.step % reset_interval == self.config.refine_every 609 | and self.config.enable_reset_alpha 610 | ) 611 | # reset 612 | if reset_dyn: 613 | # Reset value for dynamic points 614 | for camera_idx in range(self.num_train_data): 615 | reset_value = self.config.cull_alpha_thresh * 2.0 616 | dyn_gaussians = self.gauss_params_dyn_dict[str(camera_idx)] 617 | 618 | dyn_gaussians["opacities_dyn"].data = torch.clamp( 619 | dyn_gaussians["opacities_dyn"], 620 | max=torch.logit( 621 | torch.tensor(reset_value, device=self.device) 622 | ).item(), 623 | ) 624 | 625 | # reset the exp of optimizer 626 | optim = optimizers.optimizers["opacities_dyn"] 627 | 628 | param = optim.param_groups[0]["params"][camera_idx] 629 | param_state = optim.state[param] 630 | if "exp_avg" in param_state: 631 | param_state["exp_avg"] = torch.zeros_like( 632 | param_state["exp_avg"] 633 | ) 634 | param_state["exp_avg_sq"] = torch.zeros_like( 635 | param_state["exp_avg_sq"] 636 | ) 637 | 638 | def cull_gaussians_dyn(self, i, extra_cull_mask: Optional[torch.Tensor] = None): 639 | """ 640 | This function deletes gaussians with under a certain opacity threshold 641 | extra_cull_mask: a mask indicates extra gaussians to cull besides existing culling criterion 642 | """ 643 | n_bef = self.num_points_dyn(i) 644 | # cull transparent ones 645 | culls = ( 646 | torch.sigmoid(self.gauss_params_dyn_dict[str(i)]["opacities_dyn"]) 647 | < self.config.cull_alpha_thresh 648 | ).squeeze() # self.config.cull_alpha_thresh 649 | # if the point is invisible for all the camera, cull it 650 | 651 | below_alpha_count = torch.sum(culls).item() 652 | toobigs_count = 0 653 | if extra_cull_mask is not None: 654 | culls = culls | extra_cull_mask 655 | 656 | if self.step > self.config.refine_every * self.config.reset_alpha_every: 657 | # cull huge ones 658 | toobigs = ( 659 | torch.exp(self.gauss_params_dyn_dict[str(i)]["scales_dyn"]) 660 | .max(dim=-1) 661 | .values 662 | > self.config.cull_scale_thresh 663 | ).squeeze() 664 | if self.step < self.config.stop_screen_size_at: 665 | # cull big screen space 666 | if self.max_2Dsize_dyn[str(i)] is not None: 667 | toobigs = ( 668 | toobigs 669 | | ( 670 | self.max_2Dsize_dyn[str(i)] > self.config.cull_screen_size 671 | ).squeeze() 672 | ) 673 | culls = culls | toobigs 674 | toobigs_count = torch.sum(toobigs).item() 675 | for name, param in self.gauss_params_dyn_dict[str(i)].items(): 676 | self.gauss_params_dyn_dict[str(i)][name] = torch.nn.Parameter(param[~culls]) 677 | 678 | CONSOLE.log( 679 | f"Dynamic Culled {n_bef - self.num_points_dyn(i)} gaussians " 680 | f"({below_alpha_count} below alpha thresh, {toobigs_count} too bigs, {self.num_points_dyn(i)} remaining)" 681 | ) 682 | 683 | return culls 684 | 685 | def split_gaussians_dyn(self, i, split_mask, samps): 686 | """ 687 | This function splits gaussians that are too large 688 | """ 689 | n_splits = split_mask.sum().item() 690 | CONSOLE.log( 691 | f"Dynamic Splitting {split_mask.sum().item()/self.num_points_dyn(i)} gaussians: {n_splits}/{self.num_points_dyn(i)}" 692 | ) 693 | centered_samples = torch.randn( 694 | (samps * n_splits, 3), device=self.device 695 | ) # Nx3 of axis-aligned scales 696 | scaled_samples = ( 697 | torch.exp( 698 | self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask].repeat( 699 | samps, 1 700 | ) 701 | ) 702 | * centered_samples 703 | ) # how these scales are rotated 704 | quats = self.gauss_params_dyn_dict[str(i)]["quats_dyn"][ 705 | split_mask 706 | ] / self.gauss_params_dyn_dict[str(i)]["quats_dyn"][split_mask].norm( 707 | dim=-1, keepdim=True 708 | ) # normalize them first 709 | rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated 710 | rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze() 711 | new_means = rotated_samples + self.gauss_params_dyn_dict[str(i)]["means_dyn"][ 712 | split_mask 713 | ].repeat(samps, 1) 714 | # step 2, sample new colors 715 | new_colors = self.gauss_params_dyn_dict[str(i)]["rgbs_dyn"][split_mask].repeat( 716 | samps, 1 717 | ) 718 | # new_features_dc = self.gauss_params_dyn_dict[str(i)]["features_dc_dyn"][split_mask].repeat(samps, 1) 719 | # new_features_rest = self.gauss_params_dyn_dict[str(i)]["features_rest_dyn"][split_mask].repeat(samps, 1, 1) 720 | # step 3, sample new opacities 721 | new_opacities = self.gauss_params_dyn_dict[str(i)]["opacities_dyn"][ 722 | split_mask 723 | ].repeat(samps, 1) 724 | # step 4, sample new scales 725 | size_fac = 1.6 726 | new_scales = torch.log( 727 | torch.exp(self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask]) 728 | / size_fac 729 | ).repeat(samps, 1) 730 | self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask] = torch.log( 731 | torch.exp(self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask]) 732 | / size_fac 733 | ) 734 | # step 5, sample new quats 735 | new_quats = self.gauss_params_dyn_dict[str(i)]["quats_dyn"][split_mask].repeat( 736 | samps, 1 737 | ) 738 | 739 | out = { 740 | "means_dyn": new_means, 741 | "rgbs_dyn": new_colors, 742 | "opacities_dyn": new_opacities, 743 | "scales_dyn": new_scales, 744 | "quats_dyn": new_quats, 745 | } 746 | for name, param in self.gauss_params_dyn_dict[str(i)].items(): 747 | if name not in out: 748 | out[name] = param[split_mask].repeat(samps, 1) 749 | return out 750 | 751 | def dup_gaussians_dyn(self, i, dup_mask): 752 | """ 753 | This function duplicates gaussians that are too small 754 | """ 755 | n_dups = dup_mask.sum().item() 756 | CONSOLE.log( 757 | f"Dynamic Duplication: Duplicating {dup_mask.sum().item()/self.num_points_dyn(i)} gaussians: {n_dups}/{self.num_points_dyn(i)}" 758 | ) 759 | new_dups = {} 760 | for name, param in self.gauss_params_dyn_dict[str(i)].items(): 761 | new_dups[name] = param[dup_mask] 762 | return new_dups 763 | 764 | def dup_in_all_optim_dyn(self, i, optimizers, dup_mask, n): 765 | param_groups = self.get_gaussian_param_groups_dyn_dict() 766 | 767 | for group, _ in param_groups.items(): 768 | param = param_groups[group][i] 769 | self.dup_in_optim_dyn( 770 | i, optimizers.optimizers[group], dup_mask, param, n 771 | ) 772 | 773 | self.radii_dyn[str(i)] = torch.cat( 774 | [self.radii_dyn[str(i)], self.radii_dyn[str(i)][dup_mask.squeeze()].repeat(n,)], dim=0 775 | ) 776 | 777 | def dup_in_optim_dyn(self, i, optimizer, dup_mask, new_params, n=2): 778 | """adds the parameters to the optimizer""" 779 | param = optimizer.param_groups[0]["params"][i][0] 780 | param_state = optimizer.state[param] 781 | if "exp_avg" in param_state: 782 | repeat_dims = (n,) + tuple( 783 | 1 for _ in range(param_state["exp_avg"].dim() - 1) 784 | ) 785 | param_state["exp_avg"] = torch.cat( 786 | [ 787 | param_state["exp_avg"], 788 | torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat( 789 | *repeat_dims 790 | ), 791 | ], 792 | dim=0, 793 | ) 794 | param_state["exp_avg_sq"] = torch.cat( 795 | [ 796 | param_state["exp_avg_sq"], 797 | torch.zeros_like( 798 | param_state["exp_avg_sq"][dup_mask.squeeze()] 799 | ).repeat(*repeat_dims), 800 | ], 801 | dim=0, 802 | ) 803 | del optimizer.state[param] 804 | optimizer.state[new_params] = param_state 805 | optimizer.param_groups[0]["params"][i] = new_params 806 | del param 807 | 808 | def remove_from_all_optim_dyn(self, i, optimizers, deleted_mask): 809 | param_groups = self.get_gaussian_param_groups_dyn_dict() 810 | for group, _ in param_groups.items(): 811 | param = param_groups[group][i] 812 | self.remove_from_optim_dyn( 813 | i, optimizers.optimizers[group], deleted_mask, param 814 | ) # 815 | torch.cuda.empty_cache() 816 | 817 | def remove_from_optim_dyn(self, i, optimizer, deleted_mask, new_params): 818 | """removes the deleted_mask from the optimizer provided""" 819 | param = optimizer.param_groups[0]["params"][i][0] 820 | param_state = optimizer.state[param] 821 | del optimizer.state[param] 822 | 823 | # Modify the state directly without deleting and reassigning. 824 | if "exp_avg" in param_state: 825 | param_state["exp_avg"] = param_state["exp_avg"][~deleted_mask] 826 | param_state["exp_avg_sq"] = param_state["exp_avg_sq"][~deleted_mask] 827 | 828 | # Update the parameter in the optimizer's param group. 829 | del optimizer.param_groups[0]["params"][i] 830 | optimizer.param_groups[0]["params"].insert(i, new_params) 831 | optimizer.state[new_params] = param_state 832 | 833 | def after_train(self, step: int): 834 | assert step == self.step 835 | # to save some training time, we no longer need to update those stats post refinement 836 | if self.step >= self.config.stop_split_at: 837 | if self.config.continue_split_dyn: 838 | with torch.no_grad(): 839 | visible_mask_dyn = ( 840 | self.radii_dyn[str(self.camera_idx)] > 0 841 | ).flatten() 842 | grads_dyn = ( 843 | self.xys_dyn[str(self.camera_idx)] 844 | .absgrad[0][visible_mask_dyn] 845 | .norm(dim=-1) 846 | ) # type: ignore 847 | if self.xys_grad_norm_dyn[str(self.camera_idx)] is None: 848 | self.xys_grad_norm_dyn[str(self.camera_idx)] = torch.zeros( 849 | self.num_points_dyn(self.camera_idx), 850 | device=self.device, 851 | dtype=torch.float32, 852 | ) # + self.num_points_dyn 853 | self.vis_counts_dyn[str(self.camera_idx)] = torch.ones( 854 | self.num_points_dyn(self.camera_idx), 855 | device=self.device, 856 | dtype=torch.float32, 857 | ) # + self.num_points_dyn 858 | 859 | assert self.vis_counts_dyn[str(self.camera_idx)] is not None 860 | self.vis_counts_dyn[str(self.camera_idx)][visible_mask_dyn] += 1 861 | self.xys_grad_norm_dyn[str(self.camera_idx)][visible_mask_dyn] += ( 862 | grads_dyn 863 | ) 864 | # update the max screen size, as a ratio of number of pixels 865 | if self.max_2Dsize_dyn[str(self.camera_idx)] is None: 866 | self.max_2Dsize_dyn[str(self.camera_idx)] = torch.zeros_like( 867 | self.radii_dyn[str(self.camera_idx)], dtype=torch.float32 868 | ) 869 | newradii_dyn = self.radii_dyn[str(self.camera_idx)].detach()[ 870 | visible_mask_dyn 871 | ] 872 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn] = ( 873 | torch.maximum( 874 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn], 875 | newradii_dyn 876 | / float(max(self.last_size[0], self.last_size[1])), 877 | ) 878 | ) 879 | return 880 | 881 | with torch.no_grad(): 882 | # keep track of a moving average of grad norms 883 | visible_mask = (self.radii > 0).flatten() 884 | grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore 885 | if self.xys_grad_norm is None: 886 | self.xys_grad_norm = torch.zeros( 887 | self.num_points, device=self.device, dtype=torch.float32 888 | ) 889 | self.vis_counts = torch.ones( 890 | self.num_points, device=self.device, dtype=torch.float32 891 | ) 892 | assert self.vis_counts is not None 893 | self.vis_counts[visible_mask] += 1 894 | self.xys_grad_norm[visible_mask] += grads 895 | # update the max screen size, as a ratio of number of pixels 896 | if self.max_2Dsize is None: 897 | self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) 898 | newradii = self.radii.detach()[visible_mask] 899 | self.max_2Dsize[visible_mask] = torch.maximum( 900 | self.max_2Dsize[visible_mask], 901 | newradii / float(max(self.last_size[0], self.last_size[1])), 902 | ) 903 | 904 | if self.config.use_adc: 905 | # for dynamic points 906 | visible_mask_dyn = (self.radii_dyn[str(self.camera_idx)] > 0).flatten() 907 | grads_dyn = ( 908 | self.xys_dyn[str(self.camera_idx)] 909 | .absgrad[0][visible_mask_dyn] 910 | .norm(dim=-1) 911 | ) # type: ignore 912 | if self.xys_grad_norm_dyn[str(self.camera_idx)] is None: 913 | self.xys_grad_norm_dyn[str(self.camera_idx)] = torch.zeros( 914 | self.num_points_dyn(self.camera_idx), 915 | device=self.device, 916 | dtype=torch.float32, 917 | ) 918 | self.vis_counts_dyn[str(self.camera_idx)] = torch.ones( 919 | self.num_points_dyn(self.camera_idx), 920 | device=self.device, 921 | dtype=torch.float32, 922 | ) 923 | 924 | assert self.vis_counts_dyn[str(self.camera_idx)] is not None 925 | self.vis_counts_dyn[str(self.camera_idx)][visible_mask_dyn] += 1 926 | self.xys_grad_norm_dyn[str(self.camera_idx)][visible_mask_dyn] += ( 927 | grads_dyn 928 | ) 929 | # update the max screen size, as a ratio of number of pixels 930 | if self.max_2Dsize_dyn[str(self.camera_idx)] is None: 931 | self.max_2Dsize_dyn[str(self.camera_idx)] = torch.zeros_like( 932 | self.radii_dyn[str(self.camera_idx)], dtype=torch.float32 933 | ) 934 | newradii_dyn = self.radii_dyn[str(self.camera_idx)].detach()[ 935 | visible_mask_dyn 936 | ] 937 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn] = ( 938 | torch.maximum( 939 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn], 940 | newradii_dyn / float(max(self.last_size[0], self.last_size[1])), 941 | ) 942 | ) 943 | 944 | def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]: 945 | # Here we explicitly use the means, scales as parameters so that the user can override this function and 946 | # specify more if they want to add more optimizable params to gaussians. 947 | 948 | keys = [ 949 | "means", 950 | "scales", 951 | "quats", 952 | "features_dc", 953 | "features_rest", 954 | "opacities", 955 | ] 956 | if "embeddings" in self.gauss_params: 957 | keys.append("embeddings") # Add dynamically if it exists 958 | 959 | return { 960 | name: [self.gauss_params[name]] 961 | for name in keys 962 | } 963 | 964 | 965 | def get_gaussian_param_groups_dyn_dict( 966 | self, 967 | ) -> Dict[str, dict[int, List[Parameter]]]: 968 | return { 969 | name: [ 970 | self.gauss_params_dyn_dict[str(i)][name] 971 | for i in range(self.num_train_data) 972 | ] 973 | for name in [ 974 | "means_dyn", 975 | "scales_dyn", 976 | "quats_dyn", 977 | "rgbs_dyn", 978 | "opacities_dyn", 979 | ] 980 | } 981 | 982 | def get_param_groups(self): 983 | """Obtain the parameter groups for the optimizers 984 | 985 | Returns: 986 | Mapping of different parameter groups 987 | """ 988 | gps = self.get_gaussian_param_groups() 989 | self.camera_optimizer.get_param_groups(param_groups=gps) 990 | gps_dyn = self.get_gaussian_param_groups_dyn_dict() 991 | gps_bg = {} 992 | if self.config.enable_bg_model: 993 | assert self.bg_model is not None 994 | gps_bg["field_background_encoder"] = list(self.bg_model.encoder.parameters()) 995 | gps_bg["field_background_base"] = list(self.bg_model.sh_base_head.parameters()) 996 | gps_bg["field_background_rest"] = list(self.bg_model.sh_rest_head.parameters()) 997 | 998 | if self.config.enable_appearance: 999 | gps["appearance_mlp"] = list(self.appearance_mlp.parameters()) 1000 | gps["appearance_embeddings"] = list(self.appearance_embeddings.parameters()) 1001 | 1002 | return {**gps, **gps_dyn, **gps_bg} 1003 | 1004 | def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: 1005 | """Takes in a Ray Bundle and returns a dictionary of outputs. 1006 | 1007 | Args: 1008 | ray_bundle: Input bundle of rays. This raybundle should have all the 1009 | needed information to compute the outputs. 1010 | 1011 | Returns: 1012 | Outputs of model. (ie. rendered colors) 1013 | """ 1014 | if not isinstance(camera, Cameras): 1015 | print("Called get_outputs with not a camera") 1016 | return {} 1017 | 1018 | if self.training: 1019 | assert camera.shape[0] == 1, "Only one camera at a time" 1020 | optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) 1021 | else: 1022 | optimized_camera_to_world = camera.camera_to_worlds 1023 | 1024 | # cropping 1025 | if self.crop_box is not None and not self.training: 1026 | crop_ids = self.crop_box.within(self.means).squeeze() 1027 | if crop_ids.sum() == 0: 1028 | return self.get_empty_outputs( 1029 | int(camera.width.item()), 1030 | int(camera.height.item()), 1031 | self.background_color, 1032 | ) 1033 | else: 1034 | crop_ids = None 1035 | 1036 | if crop_ids is not None: 1037 | opacities_crop = self.opacities[crop_ids] 1038 | means_crop = self.means[crop_ids] 1039 | features_dc_crop = self.features_dc[crop_ids] 1040 | features_rest_crop = self.features_rest[crop_ids] 1041 | scales_crop = self.scales[crop_ids] 1042 | quats_crop = self.quats[crop_ids] 1043 | if self.config.app_per_gauss: 1044 | embeddings_crop = self.embeddings[crop_ids] 1045 | else: 1046 | opacities_crop = self.opacities 1047 | means_crop = self.means 1048 | features_dc_crop = self.features_dc 1049 | features_rest_crop = self.features_rest 1050 | scales_crop = self.scales 1051 | quats_crop = self.quats 1052 | if self.config.app_per_gauss: 1053 | embeddings_crop = self.embeddings 1054 | 1055 | colors_crop = torch.cat( 1056 | (features_dc_crop[:, None, :], features_rest_crop), dim=1 1057 | ) 1058 | 1059 | BLOCK_WIDTH = ( 1060 | 16 # this controls the tile size of rasterization, 16 is a good default 1061 | ) 1062 | camera_scale_fac = self._get_downscale_factor() 1063 | camera.rescale_output_resolution(1 / camera_scale_fac) 1064 | viewmat = get_viewmat(optimized_camera_to_world) 1065 | K = camera.get_intrinsics_matrices().cuda() 1066 | W, H = int(camera.width.item()), int(camera.height.item()) 1067 | self.last_size = (H, W) 1068 | # camera.rescale_output_resolution(camera_scale_fac) # type: ignore 1069 | # apply the compensation of screen space blurring to gaussians 1070 | if self.config.rasterize_mode not in ["antialiased", "classic"]: 1071 | raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode) 1072 | if self.config.enable_appearance: 1073 | if camera.metadata is not None and "cam_idx" in camera.metadata: 1074 | # if self.training or self.dataparser_config.eval_train: 1075 | self.camera_idx = camera.metadata["cam_idx"] 1076 | appearance_embed = self.appearance_embeddings( 1077 | torch.tensor(self.camera_idx, device=self.device) 1078 | ) 1079 | else: 1080 | # appearance_embed is zero 1081 | appearance_embed = torch.zeros( 1082 | self.config.appearance_embedding_dim, device=self.device 1083 | ) 1084 | assert self.appearance_mlp is not None 1085 | # assert self.embeddings is not None 1086 | 1087 | features = torch.cat( 1088 | (features_dc_crop.unsqueeze(1), features_rest_crop), dim=1 1089 | ) 1090 | # offset, mul = self.appearance_mlp(appearance_embed.repeat(self.num_points, 1), features_dc_crop) 1091 | offset, mul = self.appearance_mlp( 1092 | self.embeddings if self.config.app_per_gauss else None, 1093 | appearance_embed.repeat(self.num_points, 1), 1094 | features_dc_crop, 1095 | ) 1096 | colors_toned = colors_crop * mul.unsqueeze(1) + offset.unsqueeze(1) 1097 | shdim = (self.config.sh_degree + 1) ** 2 1098 | colors_toned = colors_toned.view(-1, shdim, 3).contiguous().clamp_max(1.0) 1099 | # colors_toned = eval_sh(self.active_sh_degree, colors_toned, dir_pp_normalized) 1100 | colors_toned = torch.clamp_min(colors_toned + 0.5, 0.0) 1101 | colors_crop = colors_toned 1102 | 1103 | if self.config.sh_degree > 0: 1104 | sh_degree_to_use = min( 1105 | self.step // self.config.sh_degree_interval, self.config.sh_degree 1106 | ) 1107 | bg_sh_degree_to_use = min( 1108 | self.step // (self.config.sh_degree_interval // 2), 1109 | self.config.bg_sh_degree, 1110 | ) 1111 | else: 1112 | colors_crop = torch.sigmoid(colors_crop).squeeze(1) # [N, 1, 3] -> [N, 3] 1113 | sh_degree_to_use = None 1114 | 1115 | render_3d, alpha_3d, info_3d = rasterization( 1116 | means=means_crop, 1117 | quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), 1118 | scales=torch.exp(scales_crop), 1119 | opacities=torch.sigmoid(opacities_crop).squeeze(-1), 1120 | colors=colors_crop, 1121 | viewmats=viewmat, # [1, 4, 4] 1122 | Ks=K, # [1, 3, 3] 1123 | width=W, 1124 | height=H, 1125 | tile_size=BLOCK_WIDTH, 1126 | packed=False, 1127 | near_plane=0.01, 1128 | far_plane=1e10, 1129 | render_mode="RGB+ED", # render_mode, 1130 | sh_degree=sh_degree_to_use, 1131 | sparse_grad=False, 1132 | absgrad=True, 1133 | rasterize_mode=self.config.rasterize_mode, 1134 | # set some threshold to disregrad small gaussians for faster rendering. 1135 | # radius_clip=3.0, 1136 | ) 1137 | if self.training and info_3d["means2d"].requires_grad: 1138 | info_3d["means2d"].retain_grad() 1139 | 1140 | self.xys = info_3d["means2d"] # [1, N, 2] 1141 | self.radii = info_3d["radii"][0] # [N] 1142 | 1143 | alpha_3d = alpha_3d[:, ...] 1144 | 1145 | # Only need for one time 1146 | ### BACKGROUND MODEL 1147 | if self.config.enable_bg_model: 1148 | directions = normalize( 1149 | camera.generate_rays(camera_indices=0, keep_shape=False).directions 1150 | ) 1151 | 1152 | bg_sh_coeffs = self.bg_model.get_sh_coeffs( 1153 | appearance_embedding=appearance_embed, 1154 | ) 1155 | 1156 | background = spherical_harmonics( 1157 | degrees_to_use=bg_sh_degree_to_use, 1158 | coeffs=bg_sh_coeffs.repeat(directions.shape[0], 1, 1), 1159 | dirs=directions, 1160 | ) 1161 | background = background.view(H, W, 3) 1162 | else: 1163 | background = self._get_background_color().view(1, 1, 3) 1164 | 1165 | rgb_static = render_3d[:, ..., :3] + (1 - alpha_3d) * background 1166 | 1167 | rgb_static = torch.clamp(rgb_static, 0.0, 1.0) 1168 | 1169 | rgb = rgb_static 1170 | 1171 | if background.shape[0] == 3 and not self.training: 1172 | background = background.expand(H, W, 3) 1173 | 1174 | returns = {} 1175 | returns["rgb"] = rgb.squeeze(0) 1176 | returns["depth"] = render_3d[..., 3].squeeze(0).unsqueeze(-1) 1177 | returns["background"] = background 1178 | returns["accumulation"] = alpha_3d.squeeze(0) 1179 | 1180 | img_dyn, alpha_2d, depth_2d = None, None, None # for debug 1181 | if camera.metadata is not None and "cam_idx" in camera.metadata: 1182 | self.camera_idx = camera.metadata["cam_idx"] 1183 | 1184 | dyn_gaussians = self.gauss_params_dyn_dict[str(self.camera_idx)] 1185 | opacities_dyn = dyn_gaussians["opacities_dyn"] 1186 | means_dyn = dyn_gaussians["means_dyn"] 1187 | rgbs_dyn = dyn_gaussians["rgbs_dyn"] 1188 | scales_dyn = dyn_gaussians["scales_dyn"] 1189 | quats_dyn = dyn_gaussians["quats_dyn"] 1190 | 1191 | colors_dyn = rgbs_dyn.clamp_min(0.0).clamp_max(1.0) 1192 | 1193 | render_2d, alpha_2d, info_2d = rasterization( 1194 | means=means_dyn, 1195 | quats=quats_dyn / quats_dyn.norm(dim=-1, keepdim=True), 1196 | scales=torch.exp(scales_dyn), 1197 | opacities=torch.sigmoid(opacities_dyn).squeeze(-1), 1198 | colors=colors_dyn, # [N, 3] 1199 | viewmats=viewmat, # [C, 4, 4] 1200 | Ks=K, # [C, 3, 3] 1201 | width=W, 1202 | height=H, 1203 | tile_size=BLOCK_WIDTH, 1204 | render_mode="RGB+ED", 1205 | sh_degree=None, # sh_degree_to_use, 1206 | packed=False, 1207 | sparse_grad=False, 1208 | absgrad=True, 1209 | rasterize_mode=self.config.rasterize_mode, 1210 | ) 1211 | if info_2d["means2d"].requires_grad: 1212 | info_2d["means2d"].retain_grad() 1213 | if self.config.use_adc: 1214 | self.xys_dyn[str(self.camera_idx)] = info_2d["means2d"] # [1, N, 2] 1215 | self.radii_dyn[str(self.camera_idx)] = info_2d["radii"][0] # [N] 1216 | depth_2d = render_2d[..., 3].unsqueeze(-1) # [H, W, 1] 1217 | render_2d = render_2d[..., :3] # [H, W, 3] 1218 | img_dyn = render_2d.squeeze(0) 1219 | img_dyn = torch.clamp(img_dyn, 0.0, 1.0) 1220 | 1221 | rgb = ( 1222 | render_2d 1223 | + (1 - alpha_2d) * render_3d[:, ..., :3] 1224 | + (1 - alpha_3d) * (1 - alpha_2d) * background 1225 | ) 1226 | rgb = torch.clamp(rgb, 0.0, 1.0) 1227 | camera.rescale_output_resolution(camera_scale_fac) # type: ignore 1228 | alpha_2d = alpha_2d.squeeze(0) 1229 | depth_2d = depth_2d.squeeze(0) 1230 | 1231 | returns["img_dyn"] = img_dyn 1232 | returns["alpha_2d"] = alpha_2d 1233 | returns["depth_2d"] = depth_2d 1234 | returns["rgb"] = rgb.squeeze(0) 1235 | returns["rgb_static"] = rgb_static.squeeze(0) 1236 | else: 1237 | self.camera_idx = None 1238 | 1239 | return returns 1240 | 1241 | def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: 1242 | """Compute and returns metrics. 1243 | 1244 | Args: 1245 | outputs: the output to compute loss dict to 1246 | batch: ground truth batch corresponding to outputs 1247 | """ 1248 | gt_rgb = self.composite_with_background( 1249 | self.get_gt_img(batch["image"]), outputs["background"] 1250 | ) 1251 | metrics_dict = {} 1252 | predicted_rgb = outputs["rgb"] 1253 | metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) 1254 | 1255 | metrics_dict["gaussian_count_static"] = self.num_points 1256 | # metrics_dict["gaussian_count_transient"] = self.num_points_dyn 1257 | 1258 | self.camera_optimizer.get_metrics_dict(metrics_dict) 1259 | return metrics_dict 1260 | 1261 | def get_loss_dict( 1262 | self, outputs, batch, metrics_dict=None 1263 | ) -> Dict[str, torch.Tensor]: 1264 | """Computes and returns the losses dict. 1265 | 1266 | Args: 1267 | outputs: the output to compute loss dict to 1268 | batch: ground truth batch corresponding to outputs 1269 | metrics_dict: dictionary of metrics, some of which we can use for loss 1270 | """ 1271 | gt_img = self.composite_with_background( 1272 | self.get_gt_img(batch["image"]), outputs["background"] 1273 | ) 1274 | pred_img = outputs["rgb"] 1275 | 1276 | Ll1_img = torch.abs(gt_img - pred_img) 1277 | 1278 | Ll1 = Ll1_img.mean() 1279 | simloss = 1 - self.ssim( 1280 | gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...] 1281 | ) 1282 | main_loss = ( 1283 | 1 - self.config.ssim_lambda 1284 | ) * Ll1 + self.config.ssim_lambda * simloss 1285 | 1286 | if self.config.use_scale_regularization and self.step % 10 == 0: 1287 | scale_exp = torch.exp(self.scales) 1288 | scale_reg = ( 1289 | torch.maximum( 1290 | scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), 1291 | torch.tensor(self.config.max_gauss_ratio), 1292 | ) 1293 | - self.config.max_gauss_ratio 1294 | ) 1295 | scale_reg = 0.1 * scale_reg.mean() 1296 | else: 1297 | scale_reg = torch.tensor(0.0).to(self.device) 1298 | 1299 | loss_dict = { 1300 | "main_loss": main_loss, 1301 | "scale_reg": scale_reg, 1302 | } 1303 | 1304 | # add loss to prevent the hole of static 1305 | if self.config.alpha_bg_loss_lambda > 0: 1306 | if self.config.enable_bg_model: 1307 | alpha_loss = torch.tensor(0.0).to(self.device) 1308 | background = outputs["background"] 1309 | alpha = outputs["alpha_2d"].detach() + (1 - outputs["alpha_2d"].detach()) * outputs["accumulation"] 1310 | # alpha = outputs["accumulation"] 1311 | # for those pixel are well represented by bg and has low alpha, we encourage the gaussian to be transparent 1312 | bg_mask = torch.abs(gt_img - background).mean(dim=-1, keepdim=True) < 0.003 1313 | # use a box filter to avoid penalty high frequency parts 1314 | f = 3 1315 | window = (torch.ones((f, f)).view(1, 1, f, f) / (f * f)).cuda() 1316 | bg_mask = ( 1317 | torch.nn.functional.conv2d( 1318 | bg_mask.float().unsqueeze(0).permute(0, 3, 1, 2), 1319 | window, 1320 | stride=1, 1321 | padding="same", 1322 | ) 1323 | .permute(0, 2, 3, 1) 1324 | .squeeze(0) 1325 | ) 1326 | alpha_mask = bg_mask > 0.6 1327 | # prevent NaN 1328 | if alpha_mask.sum() != 0: 1329 | alpha_loss = alpha[alpha_mask].mean() * self.config.alpha_bg_loss_lambda # default: 0.15 1330 | else: 1331 | alpha_loss = torch.tensor(0.0).to(self.device) 1332 | loss_dict["alpha_bg_loss"] = alpha_loss 1333 | else: 1334 | alpha_bg = outputs["accumulation"] 1335 | loss_dict["alpha_bg_loss"] = self.config.alpha_bg_loss_lambda * (1 - alpha_bg).mean() 1336 | 1337 | if self.config.alpha_2d_loss_lambda > 0 and "alpha_2d" in outputs: 1338 | alpha_2d = outputs["alpha_2d"] 1339 | loss_dict["alpha_2d_loss"] = self.config.alpha_2d_loss_lambda * alpha_2d.mean() 1340 | 1341 | return loss_dict 1342 | 1343 | def get_gt_img(self, image: torch.Tensor): 1344 | """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose 1345 | 1346 | Args: 1347 | image: tensor.Tensor in type uint8 or float32 1348 | """ 1349 | if image.dtype == torch.uint8: 1350 | image = image.float() / 255.0 1351 | gt_img = self._downscale_if_required(image) 1352 | return gt_img.to(self.device) 1353 | 1354 | def get_image_metrics_and_images( 1355 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] 1356 | ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: 1357 | gt_rgb = self.composite_with_background( 1358 | self.get_gt_img(batch["image"]), outputs["background"] 1359 | ) 1360 | predicted_rgb = outputs["rgb"] 1361 | 1362 | assert gt_rgb.shape == predicted_rgb.shape 1363 | combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) 1364 | 1365 | # Switch images from [H, W, C] to [1, C, H, W] for metrics computations 1366 | gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] 1367 | predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] 1368 | 1369 | psnr = self.psnr(gt_rgb, predicted_rgb) 1370 | ssim = self.ssim(gt_rgb, predicted_rgb) 1371 | lpips = self.lpips(gt_rgb, predicted_rgb) 1372 | 1373 | # all of these metrics will be logged as scalars 1374 | metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore 1375 | metrics_dict["lpips"] = float(lpips) 1376 | 1377 | images_dict = { 1378 | "img": combined_rgb, 1379 | } 1380 | 1381 | return metrics_dict, images_dict 1382 | -------------------------------------------------------------------------------- /desplat/field.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | from operator import mul 4 | from typing import Literal 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor, nn 9 | 10 | from gsplat.cuda._wrapper import spherical_harmonics 11 | from nerfstudio.cameras.rays import RayBundle 12 | from nerfstudio.field_components import MLP 13 | from nerfstudio.fields.base_field import Field 14 | 15 | 16 | def _get_fourier_features(xyz: Tensor, num_features=3): 17 | xyz = xyz - xyz.mean(dim=0, keepdim=True) 18 | xyz = xyz / torch.quantile(xyz.abs(), 0.97, dim=0) * 0.5 + 0.5 19 | freqs = torch.repeat_interleave( 20 | 2 21 | ** torch.linspace( 22 | 0, num_features - 1, num_features, dtype=xyz.dtype, device=xyz.device 23 | ), 24 | 2, 25 | ) 26 | offsets = torch.tensor( 27 | [0, 0.5 * math.pi] * num_features, dtype=xyz.dtype, device=xyz.device 28 | ) 29 | feat = xyz[..., None] * freqs[None, None] * 2 * math.pi + offsets[None, None] 30 | feat = torch.sin(feat).view(-1, reduce(mul, feat.shape[1:])) 31 | return feat 32 | 33 | 34 | class EmbeddingModel(nn.Module): 35 | def __init__(self, config): 36 | super().__init__() 37 | self.config = config 38 | # sh_coeffs = 4**2 39 | self.feat_in = 3 40 | input_dim = config.appearance_embedding_dim 41 | if config.app_per_gauss: 42 | input_dim += 6 * self.config.appearance_n_fourier_freqs + self.feat_in 43 | 44 | self.mlp = nn.Sequential( 45 | nn.Linear(input_dim, 128), 46 | nn.ReLU(), 47 | nn.Linear(128, 128), 48 | nn.ReLU(), 49 | nn.Linear(128, self.feat_in * 2), 50 | ) 51 | 52 | self.bg_head = nn.Linear(self.feat_in * 2, 3) 53 | 54 | def forward(self, gembedding, aembedding, features_dc, viewdir=None): 55 | del viewdir # Viewdirs interface is kept to be compatible with prev. version 56 | if self.config.app_per_gauss and gembedding is not None: 57 | inp = torch.cat((features_dc, gembedding, aembedding), dim=-1) 58 | else: 59 | inp = aembedding 60 | offset, mul = torch.split( 61 | self.mlp(inp) * 0.01, [self.feat_in, self.feat_in], dim=-1 62 | ) 63 | return offset, mul 64 | 65 | def get_bg_color(self, aembedding): 66 | return self.bg_head(self.mlp(aembedding)) 67 | 68 | 69 | class BackgroundModel(nn.Module): 70 | def __init__(self, config): 71 | super().__init__() 72 | self.config = config 73 | # sh_coeffs = 4**2 74 | self.feat_in = 3 75 | input_dim = config.appearance_embedding_dim 76 | if config.app_per_gauss: 77 | input_dim += 6 * self.config.appearance_n_fourier_freqs + self.feat_in 78 | 79 | self.mlp = nn.Sequential( 80 | nn.Linear(input_dim, 128), 81 | nn.ReLU(), 82 | nn.Linear(128, 128), 83 | nn.ReLU(), 84 | nn.Linear(128, self.feat_in * 2), 85 | ) 86 | 87 | def forward(self, gembedding, aembedding, features_dc, viewdir=None): 88 | del viewdir # Viewdirs interface is kept to be compatible with prev. version 89 | 90 | if self.config.app_per_gauss_bg: 91 | inp = torch.cat((features_dc, gembedding, aembedding), dim=-1) 92 | else: 93 | inp = aembedding 94 | offset, mul = torch.split( 95 | self.mlp(inp) * 0.01, [self.feat_in, self.feat_in], dim=-1 96 | ) 97 | return offset, mul 98 | 99 | class BGField(Field): 100 | def __init__( 101 | self, 102 | appearance_embedding_dim: int, 103 | implementation: Literal["tcnn", "torch"] = "torch", 104 | sh_levels: int = 4, 105 | layer_width: int = 128, 106 | num_layers: int = 3, 107 | ): 108 | super().__init__() 109 | self.sh_dim = (sh_levels + 1) ** 2 110 | 111 | self.encoder = MLP( 112 | in_dim=appearance_embedding_dim, 113 | num_layers=num_layers - 1, 114 | layer_width=layer_width, 115 | out_dim=layer_width, 116 | activation=nn.ReLU(), 117 | out_activation=nn.ReLU(), 118 | implementation=implementation, 119 | ) 120 | self.sh_base_head = nn.Linear(layer_width, 3) 121 | self.sh_rest_head = nn.Linear(layer_width, (self.sh_dim - 1) * 3) 122 | # zero initialization 123 | self.sh_rest_head.weight.data.zero_() 124 | self.sh_rest_head.bias.data.zero_() 125 | 126 | def get_background_rgb( 127 | self, ray_bundle: RayBundle, appearance_embedding=None, num_sh=4 128 | ) -> Tensor: 129 | """Predicts background colors at infinity.""" 130 | cur_sh_dim = (num_sh + 1) ** 2 131 | directions = ray_bundle.directions.view(-1, 3) 132 | x = self.encoder(appearance_embedding).float() 133 | sh_base = self.sh_base_head(x) # [batch, 3] 134 | sh_rest = self.sh_rest_head(x)[ 135 | ..., : (cur_sh_dim - 1) * 3 136 | ] # [batch, 3 * (num_sh - 1)] 137 | sh_coeffs = ( 138 | torch.cat([sh_base, sh_rest], dim=-1) 139 | .view(-1, cur_sh_dim, 3) 140 | .repeat(directions.shape[0], 1, 1) 141 | ) 142 | colors = spherical_harmonics( 143 | degrees_to_use=num_sh, dirs=directions, coeffs=sh_coeffs 144 | ) 145 | 146 | return colors 147 | 148 | def get_sh_coeffs(self, appearance_embedding=None) -> Tensor: 149 | x = self.encoder(appearance_embedding) 150 | base_color = self.sh_base_head(x) 151 | sh_rest = self.sh_rest_head(x) 152 | sh_coeffs = torch.cat([base_color, sh_rest], dim=-1).view(-1, self.sh_dim, 3) 153 | return sh_coeffs -------------------------------------------------------------------------------- /desplat/pipeline.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from dataclasses import dataclass, field 3 | from pathlib import Path 4 | from time import time 5 | from typing import Dict, List, Literal, Optional, Type 6 | 7 | import torch 8 | import torch._dynamo 9 | import torch.distributed as dist 10 | import torchvision.utils as vutils 11 | from rich.progress import ( 12 | BarColumn, 13 | MofNCompleteColumn, 14 | Progress, 15 | TextColumn, 16 | TimeElapsedColumn, 17 | ) 18 | from torch.cuda.amp.grad_scaler import GradScaler 19 | from torch.nn import Parameter 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | 22 | from desplat.datamanager import ( 23 | DeSplatDataManager, 24 | DeSplatDataManagerConfig, 25 | ) 26 | from desplat.desplat_model import DeSplatModel, DeSplatModelConfig 27 | from nerfstudio.data.datamanagers.base_datamanager import ( 28 | DataManagerConfig, 29 | ) 30 | from nerfstudio.models.base_model import ModelConfig 31 | from nerfstudio.pipelines.base_pipeline import ( 32 | VanillaPipeline, 33 | VanillaPipelineConfig, 34 | ) 35 | from nerfstudio.utils import profiler 36 | 37 | torch._dynamo.config.suppress_errors = True 38 | 39 | 40 | @dataclass 41 | class DeSplatPipelineConfig(VanillaPipelineConfig): 42 | _target: Type = field(default_factory=lambda: DeSplatPipeline) 43 | """target class to instantiate""" 44 | 45 | datamanager: DataManagerConfig = field( 46 | default_factory=lambda: DeSplatDataManagerConfig() 47 | ) 48 | """specifies the datamanager config""" 49 | model: ModelConfig = field(default_factory=lambda: DeSplatModelConfig()) 50 | """specifies the model config""" 51 | # finetune: bool = True 52 | # """Whether to mask the left half and evaluate the right half of the images""" 53 | test_time_optimize: bool = False 54 | 55 | 56 | class DeSplatPipeline(VanillaPipeline): 57 | def __init__( 58 | self, 59 | config: DeSplatPipelineConfig, 60 | device: str, 61 | test_mode: Literal["test", "val", "inference"] = "val", 62 | world_size: int = 1, 63 | local_rank: int = 0, 64 | grad_scaler: Optional[GradScaler] = None, 65 | ): 66 | super(VanillaPipeline, self).__init__() 67 | self.config = config 68 | self.test_mode = test_mode 69 | self.datamanager: DeSplatDataManager = config.datamanager.setup( 70 | device=device, 71 | test_mode=test_mode, 72 | world_size=world_size, 73 | local_rank=local_rank, 74 | ) 75 | self.datamanager.to(device) 76 | 77 | seed_pts = None 78 | if ( 79 | hasattr(self.datamanager, "train_dataparser_outputs") 80 | and "points3D_xyz" in self.datamanager.train_dataparser_outputs.metadata # type: ignore 81 | ): 82 | pts = self.datamanager.train_dataparser_outputs.metadata["points3D_xyz"] # type: ignore 83 | pts_rgb = self.datamanager.train_dataparser_outputs.metadata["points3D_rgb"] # type: ignore 84 | 85 | seed_pts = (pts, pts_rgb) 86 | 87 | # if not config.model.finetune: 88 | assert self.datamanager.train_dataset is not None, "Missing input dataset" 89 | 90 | self._model = config.model.setup( 91 | scene_box=self.datamanager.train_dataset.scene_box, 92 | num_train_data=len(self.datamanager.train_dataset), 93 | metadata=self.datamanager.train_dataset.metadata, 94 | device=device, 95 | grad_scaler=grad_scaler, 96 | seed_points=seed_pts, 97 | cameras=self.datamanager.train_dataset.cameras, 98 | dataparser=self.datamanager.dataparser_config, 99 | ) 100 | 101 | self.model.to(device) 102 | 103 | self.world_size = world_size 104 | if world_size > 1: 105 | self._model = typing.cast( 106 | DeSplatModel, 107 | DDP(self._model, device_ids=[local_rank], find_unused_parameters=True), 108 | ) 109 | dist.barrier(device_ids=[local_rank]) 110 | 111 | def get_param_groups(self) -> Dict[str, List[Parameter]]: 112 | """Get the param groups for the pipeline. 113 | 114 | Returns: 115 | A list of dictionaries containing the pipeline's param groups. 116 | """ 117 | datamanager_params = self.datamanager.get_param_groups() 118 | model_params = self.model.get_param_groups() 119 | # TODO(ethan): assert that key names don't overlap 120 | return {**datamanager_params, **model_params} # , **model_params_dyn 121 | 122 | @profiler.time_function 123 | def get_average_image_metrics( 124 | self, 125 | data_loader, 126 | image_prefix: str, 127 | step: Optional[int] = None, 128 | output_path: Optional[Path] = None, 129 | get_std: bool = False, 130 | ): 131 | self.eval() 132 | metrics_dict_list = [] 133 | num_images = len(data_loader) 134 | if output_path is not None: 135 | output_path.mkdir(exist_ok=True, parents=True) 136 | with Progress( 137 | TextColumn("[progress.description]{task.description}"), 138 | BarColumn(), 139 | TimeElapsedColumn(), 140 | MofNCompleteColumn(), 141 | transient=True, 142 | ) as progress: 143 | task = progress.add_task( 144 | "[green]Evaluating all images...", total=num_images 145 | ) 146 | idx = 0 147 | for camera, batch in data_loader: 148 | # time this the following line 149 | inner_start = time() 150 | outputs = self.model.get_outputs_for_camera(camera=camera) 151 | height, width = camera.height, camera.width 152 | num_rays = height * width 153 | metrics_dict, image_dict = self.model.get_image_metrics_and_images( 154 | outputs, batch 155 | ) 156 | if output_path is not None: 157 | for key in image_dict.keys(): 158 | image = image_dict[key] # [H, W, C] order 159 | vutils.save_image( 160 | image.permute(2, 0, 1).cpu(), 161 | output_path / f"{image_prefix}_{key}_{idx:04d}.png", 162 | ) 163 | 164 | assert "num_rays_per_sec" not in metrics_dict 165 | metrics_dict["num_rays_per_sec"] = ( 166 | num_rays / (time() - inner_start) 167 | ).item() 168 | fps_str = "fps" 169 | assert fps_str not in metrics_dict 170 | metrics_dict[fps_str] = ( 171 | metrics_dict["num_rays_per_sec"] / (height * width) 172 | ).item() 173 | metrics_dict_list.append(metrics_dict) 174 | progress.advance(task) 175 | idx = idx + 1 176 | 177 | metrics_dict = {} 178 | for key in metrics_dict_list[0].keys(): 179 | if get_std: 180 | key_std, key_mean = torch.std_mean( 181 | torch.tensor( 182 | [metrics_dict[key] for metrics_dict in metrics_dict_list] 183 | ) 184 | ) 185 | metrics_dict[key] = float(key_mean) 186 | metrics_dict[f"{key}_std"] = float(key_std) 187 | else: 188 | metrics_dict[key] = float( 189 | torch.mean( 190 | torch.tensor( 191 | [metrics_dict[key] for metrics_dict in metrics_dict_list] 192 | ) 193 | ) 194 | ) 195 | 196 | if self.test_mode == "inference": 197 | print("Now we are in the test mode.") 198 | metrics_dict["full_results"] = metrics_dict_list 199 | del image_dict["depth"] 200 | 201 | self.train() 202 | return metrics_dict 203 | 204 | @profiler.time_function 205 | def get_average_eval_half_image_metrics( 206 | self, 207 | step: Optional[int] = None, 208 | output_path: Optional[Path] = None, 209 | get_std: bool = False, 210 | ): 211 | """Get the average metrics for evaluation images.""" 212 | # assert hasattr( 213 | # self.datamanager, "fixed_indices_eval_dataloader" 214 | # ), "datamanager must have 'fixed_indices_eval_dataloader' attribute" 215 | image_prefix = "eval" 216 | return self.get_average_half_image_metrics( 217 | self.datamanager.fixed_indices_eval_dataloader, 218 | image_prefix, 219 | step, 220 | output_path, 221 | get_std, 222 | ) 223 | 224 | @profiler.time_function 225 | def get_average_half_image_metrics( 226 | self, 227 | data_loader, 228 | image_prefix: str, 229 | step: Optional[int] = None, 230 | output_path: Optional[Path] = None, 231 | get_std: bool = False, 232 | ): 233 | """Iterate over all the images in the dataset and get the average. 234 | 235 | Args: 236 | data_loader: the data loader to iterate over 237 | image_prefix: prefix to use for the saved image filenames 238 | step: current training step 239 | output_path: optional path to save rendered images to 240 | get_std: Set True if you want to return std with the mean metric. 241 | 242 | Returns: 243 | metrics_dict: dictionary of metrics 244 | """ 245 | self.eval() 246 | metrics_dict_list = [] 247 | num_images = len(data_loader) 248 | if output_path is not None: 249 | output_path.mkdir(exist_ok=True, parents=True) 250 | with Progress( 251 | TextColumn("[progress.description]{task.description}"), 252 | BarColumn(), 253 | TimeElapsedColumn(), 254 | MofNCompleteColumn(), 255 | transient=True, 256 | ) as progress: 257 | task = progress.add_task( 258 | "[green]Evaluating all images...", total=num_images 259 | ) 260 | idx = 0 261 | for camera, batch in data_loader: 262 | # time this the following line 263 | inner_start = time() 264 | outputs = self.model.get_outputs_for_camera(camera=camera) 265 | height, width = camera.height, camera.width 266 | 267 | half_width = width // 2 268 | for key, img in outputs.items(): 269 | if img.dim() == 3: # (H, W, C) 270 | masked_img = img.clone() 271 | masked_img = img[:, half_width:, :].clone() 272 | outputs[key] = masked_img 273 | right_half = {} 274 | right_half["image"] = batch["image"][:, half_width:, :] 275 | 276 | num_rays = height * width 277 | metrics_dict, image_dict = self.model.get_image_metrics_and_images( 278 | outputs, right_half 279 | ) 280 | if output_path is not None: 281 | for key in image_dict.keys(): 282 | image = image_dict[key] # [H, W, C] order 283 | vutils.save_image( 284 | image.permute(2, 0, 1).cpu(), 285 | output_path / f"{image_prefix}_{key}_{idx:04d}.png", 286 | ) 287 | 288 | metrics_dict_list.append(metrics_dict) 289 | progress.advance(task) 290 | idx = idx + 1 291 | 292 | metrics_dict = {} 293 | for key in metrics_dict_list[0].keys(): 294 | if get_std: 295 | key_std, key_mean = torch.std_mean( 296 | torch.tensor( 297 | [metrics_dict[key] for metrics_dict in metrics_dict_list] 298 | ) 299 | ) 300 | metrics_dict[key] = float(key_mean) 301 | metrics_dict[f"{key}_std"] = float(key_std) 302 | else: 303 | metrics_dict[key] = float( 304 | torch.mean( 305 | torch.tensor( 306 | [metrics_dict[key] for metrics_dict in metrics_dict_list] 307 | ) 308 | ) 309 | ) 310 | 311 | self.train() 312 | return metrics_dict 313 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "desplat" 3 | description = "DeSplat: Decomposed Gaussian Splatting for Distractor-Free Rendering" 4 | version = "0.0.1" 5 | 6 | dependencies = [ 7 | "nerfstudio >= 1.1.3", 8 | "gsplat == 1.0.0", 9 | "numpy == 1.24.4", 10 | "ruff", 11 | "pyyaml", 12 | "tyro==0.8.12"] 13 | 14 | [tool.setuptools.packages.find] 15 | include = ["desplat"] 16 | 17 | [project.entry-points.'nerfstudio.method_configs'] 18 | test = 'desplat.config:desplat_method' 19 | 20 | [project.entry-points.'nerfstudio.dataparser_configs'] 21 | onthego-data = 'desplat.dataparsers.onthego_dataparser:OnthegoDataParserSpecification' 22 | robustnerf-data = 'desplat.dataparsers.robustnerf_dataparser:RobustNerfDataParserSpecification' 23 | phototourism-data = 'desplat.dataparsers.phototourism_dataparser:PhotoTourismDataParserSpecification' 24 | -------------------------------------------------------------------------------- /scripts/calculate_memory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | from typing import Callable, Literal, Optional, Tuple 5 | 6 | import torch 7 | import yaml 8 | 9 | from nerfstudio.configs.method_configs import all_methods 10 | from nerfstudio.engine.trainer import TrainerConfig 11 | from nerfstudio.pipelines.base_pipeline import Pipeline 12 | from nerfstudio.utils.rich_utils import CONSOLE 13 | 14 | 15 | def load_checkpoint_for_eval( 16 | config: TrainerConfig, pipeline: Pipeline 17 | ) -> Tuple[Path, int]: 18 | """Load a checkpointed pipeline for evaluation. 19 | 20 | Args: 21 | config (TrainerConfig): Configuration for loading the pipeline. 22 | pipeline (Pipeline): The pipeline instance to load weights into. 23 | 24 | Returns: 25 | Tuple containing the path to the loaded checkpoint and the step at which it was saved. 26 | """ 27 | assert config.load_dir is not None, "Checkpoint directory must be specified." 28 | if config.load_step is None: 29 | CONSOLE.print("Loading latest checkpoint from specified directory.") 30 | if not os.path.exists(config.load_dir): 31 | CONSOLE.rule("Error", style="red") 32 | CONSOLE.print( 33 | f"Checkpoint directory not found at {config.load_dir}", justify="center" 34 | ) 35 | CONSOLE.print( 36 | "Ensure checkpoints were generated during training.", justify="center" 37 | ) 38 | sys.exit(1) 39 | load_step = max( 40 | int(f.split("-")[1].split(".")[0]) 41 | for f in os.listdir(config.load_dir) 42 | if "step-" in f 43 | ) 44 | else: 45 | load_step = config.load_step 46 | 47 | load_path = config.load_dir / f"step-{load_step:09d}.ckpt" 48 | assert load_path.exists(), f"Checkpoint {load_path} does not exist." 49 | 50 | # Load checkpoint 51 | loaded_state = torch.load(load_path, map_location="cpu") 52 | pipeline.load_pipeline(loaded_state["pipeline"], loaded_state["step"]) 53 | CONSOLE.print(f":white_check_mark: Successfully loaded checkpoint from {load_path}") 54 | 55 | return load_path, load_step 56 | 57 | 58 | def setup_evaluation( 59 | config_path: Path, 60 | eval_num_rays_per_chunk: Optional[int] = None, 61 | test_mode: Literal["test", "val", "inference"] = "test", 62 | update_config_callback: Optional[Callable[[TrainerConfig], TrainerConfig]] = None, 63 | ) -> Tuple[TrainerConfig, Pipeline, Path, int, float]: 64 | """Set up pipeline loading for evaluation, with an option to calculate model size. 65 | 66 | Args: 67 | config_path: Path to the configuration YAML file. 68 | eval_num_rays_per_chunk: Rays per forward pass (optional). 69 | test_mode: Data loading mode ('test', 'val', or 'inference'). 70 | update_config_callback: Optional function to modify config before loading the pipeline. 71 | 72 | Returns: 73 | Config, loaded pipeline, checkpoint path, step, and model size in MB. 74 | """ 75 | # Load and validate configuration 76 | config = yaml.load(config_path.read_text(), Loader=yaml.Loader) 77 | assert isinstance(config, TrainerConfig) 78 | 79 | config.pipeline.datamanager._target = all_methods[ 80 | config.method_name 81 | ].pipeline.datamanager._target 82 | if eval_num_rays_per_chunk: 83 | config.pipeline.model.eval_num_rays_per_chunk = eval_num_rays_per_chunk 84 | 85 | if update_config_callback: 86 | config = update_config_callback(config) 87 | 88 | # Define the checkpoint directory 89 | config.load_dir = config.get_checkpoint_dir() 90 | 91 | # Initialize the pipeline 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | pipeline = config.pipeline.setup(device=device, test_mode=test_mode) 94 | pipeline.eval() 95 | 96 | # Load the checkpoint 97 | checkpoint_path, step = load_checkpoint_for_eval(config, pipeline) 98 | 99 | # Calculate the size of the loaded model 100 | model_size_mb = calculate_model_size(pipeline) 101 | CONSOLE.print(f"Model size: {model_size_mb:.2f} MB") 102 | 103 | return config, pipeline, checkpoint_path, step, model_size_mb 104 | 105 | 106 | def calculate_model_size(model: torch.nn.Module) -> float: 107 | """Calculate the size of a PyTorch model in MB. 108 | 109 | Args: 110 | model: The PyTorch model for which to calculate the memory usage. 111 | 112 | Returns: 113 | Model size in megabytes (MB). 114 | """ 115 | dynamic_param_size = 0 116 | static_param_size = 0 117 | 118 | for name, p in model.named_parameters(): 119 | # Determine if the parameter is dynamic or static 120 | if "gauss_params_dyn_dict" in name: 121 | dynamic_param_size += p.nelement() * p.element_size() 122 | else: 123 | static_param_size += p.nelement() * p.element_size() 124 | 125 | buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers()) 126 | 127 | # Print size of each parameter category 128 | print( 129 | "Total dynamic parameters size: {:.3f} MB".format(dynamic_param_size / 1024**2) 130 | ) 131 | print("Total static parameters size: {:.3f} MB".format(static_param_size / 1024**2)) 132 | print( 133 | "Total model size: {:.3f} MB".format( 134 | (dynamic_param_size + static_param_size) / 1024**2 135 | ) 136 | ) 137 | 138 | total_size_mb = ( 139 | dynamic_param_size + static_param_size + buffer_size 140 | ) / 1024**2 # Convert to MB 141 | return total_size_mb 142 | 143 | 144 | if __name__ == "__main__": 145 | import tyro 146 | 147 | tyro.cli(setup_evaluation) 148 | -------------------------------------------------------------------------------- /scripts/convert.py: -------------------------------------------------------------------------------- 1 | # This code is copied from the 3D Gaussian Splatter repository, available at https://github.com/graphdeco-inria/gaussian-splatting. 2 | # 3 | # Copyright (C) 2023, Inria 4 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 5 | # All rights reserved. 6 | # 7 | # This software is free for non-commercial, research and evaluation use 8 | # under the terms of the LICENSE.md file. 9 | # 10 | # For inquiries contact george.drettakis@inria.fr 11 | # 12 | 13 | import os 14 | import logging 15 | from argparse import ArgumentParser 16 | import shutil 17 | import math 18 | from PIL import Image 19 | 20 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 21 | parser = ArgumentParser("Colmap converter") 22 | parser.add_argument("--no_gpu", action="store_true") 23 | parser.add_argument("--skip_matching", action="store_true") 24 | parser.add_argument("--source_path", "-s", required=True, type=str) 25 | parser.add_argument("--camera", default="OPENCV", type=str) 26 | parser.add_argument("--colmap_executable", default="", type=str) 27 | parser.add_argument("--resize", action="store_true") 28 | parser.add_argument("--magick_executable", default="", type=str) 29 | args = parser.parse_args() 30 | colmap_command = ( 31 | '"{}"'.format(args.colmap_executable) 32 | if len(args.colmap_executable) > 0 33 | else "colmap" 34 | ) 35 | magick_command = ( 36 | '"{}"'.format(args.magick_executable) 37 | if len(args.magick_executable) > 0 38 | else "magick" 39 | ) 40 | use_gpu = 1 if not args.no_gpu else 0 41 | 42 | # rename images folder to input 43 | images_path = os.path.join(args.source_path, "images") 44 | input_path = os.path.join(args.source_path, "input") 45 | 46 | if os.path.exists(images_path): 47 | os.rename(images_path, input_path) 48 | print(f"'{images_path}' has been renamed to '{input_path}'") 49 | else: 50 | print(f"The folder '{images_path}' does not exist.") 51 | 52 | if not args.skip_matching: 53 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 54 | ## Feature extraction 55 | feat_extracton_cmd = ( 56 | colmap_command + " feature_extractor " 57 | "--database_path " 58 | + args.source_path 59 | + "/distorted/database.db \ 60 | --image_path " 61 | + args.source_path 62 | + "/input \ 63 | --ImageReader.single_camera 1 \ 64 | --ImageReader.camera_model " 65 | + args.camera 66 | + " \ 67 | --SiftExtraction.use_gpu " 68 | + str(use_gpu) 69 | ) 70 | exit_code = os.system(feat_extracton_cmd) 71 | if exit_code != 0: 72 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 73 | exit(exit_code) 74 | 75 | ## Feature matching 76 | feat_matching_cmd = ( 77 | colmap_command 78 | + " exhaustive_matcher \ 79 | --database_path " 80 | + args.source_path 81 | + "/distorted/database.db \ 82 | --SiftMatching.use_gpu " 83 | + str(use_gpu) 84 | ) 85 | exit_code = os.system(feat_matching_cmd) 86 | if exit_code != 0: 87 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 88 | exit(exit_code) 89 | 90 | ### Bundle adjustment 91 | # The default Mapper tolerance is unnecessarily large, 92 | # decreasing it speeds up bundle adjustment steps. 93 | mapper_cmd = ( 94 | colmap_command 95 | + " mapper \ 96 | --database_path " 97 | + args.source_path 98 | + "/distorted/database.db \ 99 | --image_path " 100 | + args.source_path 101 | + "/input \ 102 | --output_path " 103 | + args.source_path 104 | + "/distorted/sparse \ 105 | --Mapper.ba_global_function_tolerance=0.000001" 106 | ) 107 | exit_code = os.system(mapper_cmd) 108 | if exit_code != 0: 109 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 110 | exit(exit_code) 111 | 112 | ### Image undistortion 113 | ## We need to undistort our images into ideal pinhole intrinsics. 114 | img_undist_cmd = ( 115 | colmap_command 116 | + " image_undistorter \ 117 | --image_path " 118 | + args.source_path 119 | + "/input \ 120 | --input_path " 121 | + args.source_path 122 | + "/distorted/sparse/0 \ 123 | --output_path " 124 | + args.source_path 125 | + "\ 126 | --output_type COLMAP" 127 | ) 128 | exit_code = os.system(img_undist_cmd) 129 | if exit_code != 0: 130 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 131 | exit(exit_code) 132 | 133 | files = os.listdir(args.source_path + "/sparse") 134 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 135 | # Copy each file from the source directory to the destination directory 136 | for file in files: 137 | if file == "0": 138 | continue 139 | source_file = os.path.join(args.source_path, "sparse", file) 140 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 141 | shutil.move(source_file, destination_file) 142 | 143 | if args.resize: 144 | print("Copying and resizing...") 145 | 146 | # Resize images. 147 | # for patio scene, we resize by 25% 148 | if args.source_path.endswith("patio") or args.source_path.endswith("arcdetriomphe"): 149 | resize_factor = 4 150 | else: 151 | resize_factor = 8 152 | os.makedirs(args.source_path + "/images_" + str(resize_factor), exist_ok=True) 153 | # Get the list of files in the source directory 154 | files = os.listdir(args.source_path + "/images") 155 | # Copy each file from the source directory to the destination directory 156 | for file in files: 157 | source_file = os.path.join(args.source_path, "images", file) 158 | 159 | with Image.open(source_file) as img: 160 | width, height = img.size 161 | target_width = math.floor(width / resize_factor) 162 | target_height = math.floor(height / resize_factor) 163 | 164 | destination_file = os.path.join( 165 | args.source_path, f"images_{resize_factor}", file 166 | ) 167 | shutil.copy2(source_file, destination_file) 168 | exit_code = os.system( 169 | f"{magick_command} mogrify -resize {target_width}x{target_height}! {destination_file}" 170 | ) 171 | if exit_code != 0: 172 | logging.error(f"resize failed with code {exit_code}. Exiting.") 173 | exit(exit_code) 174 | 175 | print("Done.") 176 | -------------------------------------------------------------------------------- /scripts/download_dataset.py: -------------------------------------------------------------------------------- 1 | """Script to download benchmark dataset(s)""" 2 | 3 | import os 4 | import subprocess 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Literal 8 | 9 | import tyro 10 | 11 | # dataset names 12 | dataset_names = Literal[ 13 | "robustnerf", 14 | "on-the-go", 15 | ] 16 | 17 | # dataset urls 18 | urls = { 19 | "robustnerf": "https://storage.googleapis.com/jax3d-public/projects/robustnerf/robustnerf.tar.gz", 20 | "on-the-go": "https://cvg-data.inf.ethz.ch/on-the-go.zip", 21 | } 22 | 23 | # rename maps 24 | dataset_rename_map = { 25 | "robustnerf": "robustnerf", 26 | "on-the-go": "on-the-go", 27 | } 28 | 29 | 30 | @dataclass 31 | class DownloadData: 32 | dataset: dataset_names = "robustnerf" 33 | save_dir: Path = Path(os.getcwd() + "/data") 34 | 35 | def main(self): 36 | self.save_dir.mkdir(parents=True, exist_ok=True) 37 | self.dataset_download(self.dataset) 38 | 39 | def dataset_download(self, dataset: dataset_names): 40 | (self.save_dir / dataset_rename_map[dataset]).mkdir(parents=True, exist_ok=True) 41 | 42 | file_name = Path(urls[dataset]).name 43 | 44 | # download 45 | download_command = [ 46 | "curl", 47 | "-o", 48 | str(self.save_dir / dataset_rename_map[dataset] / file_name), 49 | urls[dataset], 50 | ] 51 | try: 52 | subprocess.run(download_command, check=True) 53 | print("File file downloaded succesfully.") 54 | except subprocess.CalledProcessError as e: 55 | print(f"Error downloading file: {e}") 56 | 57 | # if .zip 58 | if Path(urls[dataset]).suffix == ".zip": 59 | if os.name == "nt": # Windows doesn't have 'unzip' but 'tar' works 60 | extract_command = [ 61 | "tar", 62 | "-xvf", 63 | self.save_dir / dataset_rename_map[dataset] / file_name, 64 | "-C", 65 | self.save_dir / dataset_rename_map[dataset], 66 | ] 67 | else: 68 | extract_command = [ 69 | "unzip", 70 | self.save_dir / dataset_rename_map[dataset] / file_name, 71 | "-d", 72 | self.save_dir / dataset_rename_map[dataset], 73 | ] 74 | # if .tar 75 | else: 76 | extract_command = [ 77 | "tar", 78 | "-xvf", 79 | self.save_dir / dataset_rename_map[dataset] / file_name, 80 | "-C", 81 | self.save_dir / dataset_rename_map[dataset], 82 | ] 83 | 84 | # extract 85 | try: 86 | subprocess.run(extract_command, check=True) 87 | os.remove(self.save_dir / dataset_rename_map[dataset] / file_name) 88 | print("Extraction complete.") 89 | except subprocess.CalledProcessError as e: 90 | print(f"Extraction failed: {e}") 91 | 92 | 93 | if __name__ == "__main__": 94 | tyro.cli(DownloadData).main() 95 | -------------------------------------------------------------------------------- /scripts/test_time_optimize.py: -------------------------------------------------------------------------------- 1 | """test time camera appearance optimization 2 | 3 | Usage: 4 | 5 | python scripts/test_time_optimize.py --load-config [path_to_script] 6 | """ 7 | 8 | import functools 9 | import json 10 | import os 11 | from dataclasses import dataclass 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | import torch 16 | import tyro 17 | from PIL import Image 18 | from torchvision.utils import save_image 19 | 20 | from nerfstudio.cameras.cameras import Cameras 21 | from nerfstudio.data.datasets.base_dataset import InputDataset 22 | from nerfstudio.models.splatfacto import SplatfactoModel 23 | from nerfstudio.utils.eval_utils import eval_setup 24 | 25 | 26 | @dataclass 27 | class AppearanceModelConfigs: 28 | # TODO: only works when there are no per-gauss features atm 29 | app_per_gauss: bool = False 30 | appearance_embedding_dim: int = 32 31 | appearance_n_fourier_freqs: int = 4 32 | appearance_init_fourier: bool = True 33 | 34 | 35 | @dataclass 36 | class TestTimeOpt: 37 | load_config: Path = Path("") 38 | """Path to the config YAML file.""" 39 | train_iters: int = 128 40 | 41 | """train iters""" 42 | save_gif: bool = False 43 | """save a training gif""" 44 | lr_app_emb: float = 0.01 45 | """learning rate for appearance embedding""" 46 | metrics_output_path: Path = Path("./test_time_metrics/") 47 | """Output path of test time opt eval metrics""" 48 | use_saved_embedding: bool = False 49 | """Use saved embedding module""" 50 | save_all_imgs: bool = False 51 | """Save all images""" 52 | 53 | def main(self): 54 | if "brandenburg_gate" in str(self.load_config) or "unnamed" in str(self.load_config): 55 | scene = "brandenburg_gate" 56 | elif "sacre_coeur" in str(self.load_config): 57 | scene = "sacre_coeur" 58 | elif "trevi_fountain" in str(self.load_config): 59 | scene = "trevi_fountain" 60 | else: 61 | raise ValueError(f"Unknown dataset") 62 | 63 | if not self.metrics_output_path.exists(): 64 | self.metrics_output_path.mkdir(parents=True) 65 | 66 | config, pipeline, _, _ = eval_setup(self.load_config) 67 | pipeline.test_time_optimize = True 68 | pipeline.train() 69 | pipeline.cuda() 70 | assert isinstance(pipeline.model, SplatfactoModel) 71 | 72 | model: SplatfactoModel = pipeline.model 73 | train_dataset: InputDataset = pipeline.datamanager.train_dataset 74 | eval_dataset: InputDataset = pipeline.datamanager.eval_dataset 75 | cameras: Cameras = pipeline.datamanager.eval_dataset.cameras # type: ignore 76 | # init app model 77 | app_config = AppearanceModelConfigs() 78 | 79 | if self.use_saved_embedding: 80 | model.appearance_embeddings = torch.load("embedding_"+ scene + ".pth") 81 | 82 | else: 83 | # define app model optimizers 84 | model.appearance_embeddings = torch.nn.Embedding( 85 | len(eval_dataset), app_config.appearance_embedding_dim 86 | ).cuda() 87 | model.appearance_embeddings.weight.data.normal_(0, 0.01) 88 | 89 | optimizer = torch.optim.Adam( 90 | model.appearance_embeddings.parameters(), 91 | lr=self.lr_app_emb, 92 | ) 93 | 94 | # Force model to have appearance 95 | model.config.enable_appearance = True 96 | 97 | # train eval dataset 98 | gif_frames = [] 99 | 100 | # before test time metrics: 101 | before_test_time_metrics = pipeline.get_average_eval_half_image_metrics( 102 | step=0, output_path=None 103 | ) 104 | print("Metrics before test-time: ", before_test_time_metrics) 105 | 106 | for epoch in range(self.train_iters): 107 | for image_idx, data in enumerate( 108 | pipeline.datamanager.cached_eval # Undistorted images 109 | ): # type: ignore 110 | # process batch gt data 111 | # process pred outputs 112 | camera = cameras[image_idx : image_idx + 1] 113 | camera.metadata = {} 114 | camera.metadata["cam_idx"] = image_idx 115 | camera = camera.to("cuda") 116 | 117 | height, width = camera.height, camera.width 118 | 119 | outputs = model.get_outputs(camera=camera) # type: dict 120 | outputs_left_half = {} 121 | # mask the right half of the image 122 | for key, img in outputs.items(): 123 | half_width = width // 2 124 | if key == 'background': 125 | outputs_left_half[key] = img 126 | if img.dim() == 3: # (H, W, C) 127 | masked_img = img[:, :half_width, :].clone() 128 | 129 | # masked_img[:, half_width:, :] = 0 130 | outputs_left_half[key] = masked_img 131 | left_half = {} 132 | left_half["image"] = data["image"][:, :half_width, :] 133 | 134 | loss_dict = model.get_loss_dict(outputs=outputs_left_half, batch=left_half) 135 | if image_idx == 1: 136 | # choose the right side of the image 137 | rgb = outputs["rgb"][:, width // 2:, :] 138 | gt_img = data["image"][:, width // 2:, :] 139 | save_image(rgb.permute(2, 0, 1), "rgb.jpg") 140 | print("Epoch: ", epoch, "loss of img_0:", loss_dict["main_loss"]) 141 | if self.save_gif and epoch % 1 == 0: 142 | gif_frames.append( 143 | (rgb.detach().cpu().numpy() * 255).astype(np.uint8) 144 | ) 145 | 146 | loss = functools.reduce(torch.add, loss_dict.values()) 147 | loss.backward() 148 | 149 | optimizer.step() 150 | optimizer.zero_grad(set_to_none=True) 151 | 152 | # save pth 153 | torch.save(model.appearance_embeddings, "embedding_"+ scene + ".pth") 154 | 155 | # Get eval metrics after 156 | after_test_time_metrics = pipeline.get_average_eval_half_image_metrics( 157 | step=0, output_path=None 158 | ) 159 | 160 | print("Metrics after test-time: ", after_test_time_metrics) 161 | 162 | output_dir = f"{self.metrics_output_path}/{scene}" 163 | 164 | os.makedirs(output_dir, exist_ok=True) 165 | 166 | metrics_path = f"{output_dir}/metrics.json" 167 | with open(metrics_path, "w") as f: 168 | json.dump(after_test_time_metrics, f) 169 | 170 | if self.save_gif: 171 | gif_frames = [Image.fromarray(frame) for frame in gif_frames] 172 | out_dir = os.path.join(os.getcwd(), f"renders/{scene}") 173 | os.makedirs(out_dir, exist_ok=True) 174 | print(f"saving depth gif to {out_dir}/training.gif") 175 | gif_frames[0].save( 176 | f"{out_dir}/training.gif", 177 | save_all=True, 178 | append_images=gif_frames[1:], 179 | optimize=False, 180 | duration=5, 181 | loop=0, 182 | ) 183 | 184 | if self.save_all_imgs: 185 | for image_idx, data in enumerate( 186 | pipeline.datamanager.cached_eval # Undistorted images 187 | ): # type: ignore 188 | # process batch gt data 189 | # process pred outputs 190 | camera = cameras[image_idx : image_idx + 1] 191 | camera.metadata = {} 192 | camera.metadata["cam_idx"] = image_idx 193 | camera = camera.to("cuda") 194 | 195 | height, width = camera.height, camera.width 196 | 197 | outputs = model.get_outputs(camera=camera) # type: dict 198 | 199 | # Define the output directory 200 | out_dir = os.path.join(os.getcwd(), f"renders/{scene}") 201 | 202 | # Create the directory if it doesn't exist 203 | os.makedirs(out_dir, exist_ok=True) 204 | 205 | # Define the full path for the image file 206 | image_path = os.path.join(out_dir, f"render_{image_idx}.jpg") 207 | 208 | # out_dir = os.path.join(os.getcwd(), f"renders/{scene}/render_{image_idx}.jpg") 209 | save_image(outputs["rgb"].permute(2, 0, 1), image_path) 210 | 211 | if __name__ == "__main__": 212 | tyro.cli(TestTimeOpt).main() 213 | --------------------------------------------------------------------------------