├── .gitignore
├── LICENSE
├── README.md
├── assets
└── pipeline.png
├── desplat
├── __init__.py
├── config.py
├── datamanager.py
├── dataparsers
│ ├── onthego_dataparser.py
│ ├── phototourism_dataparser.py
│ └── robustnerf_dataparser.py
├── desplat_model.py
├── field.py
└── pipeline.py
├── pyproject.toml
└── scripts
├── calculate_memory.py
├── convert.py
├── download_dataset.py
└── test_time_optimize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111 | .pdm.toml
112 | .pdm-python
113 | .pdm-build/
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | /examples
166 | /gsplat
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Arno Solin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
22 |
23 | ---
24 |
25 | This is the original code of DeSplat on [NerfStudio](http://www.nerf.studio/) codebase.
26 |
27 |
28 |

29 |
Overall pipeline of our method.
30 |
31 |
32 | ## Installation
33 | Setup conda environment:
34 | ```
35 | conda create --name desplat -y python=3.8
36 | conda activate desplat
37 | pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
38 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
39 | ```
40 |
41 | Install DeSplat:
42 | ```
43 | git clone https://github.com/AaltoML/desplat.git
44 | cd desplat
45 | pip install -e .
46 | ns-install-cli
47 | ```
48 |
49 | ## Download datasets
50 |
51 | We provide a convenient script to downloading compatible datasets.
52 |
53 | To download the datasets, run the `download_dataset.py` script:
54 | ```
55 | python scripts/download_dataset.py --dataset ["robustnerf" or "on-the-go"]
56 | ```
57 |
58 | Alternatively you can also run the following commands:
59 | - **RobustNeRF Dataset:**
60 | ```
61 | wget https://storage.googleapis.com/jax3d-public/projects/robustnerf/robustnerf.tar.gz
62 | tar -xvf robustnerf.tar.gz
63 | ```
64 | RobustNeRF Dataset includes pre-generated COLMAP points, we do not need to convert the data.
65 |
66 | - **On-the-go Dataset:**
67 | To download the undistorted and down-sampled on-the-go dataset, you can either access it on Hugging Face: [link](https://huggingface.co/datasets/jkulhanek/nerfonthego-undistorted/tree/main) or download and preprocess it using the following bash script.
68 |
69 | ```
70 | bash scripts/download_on-the-go_processing.sh
71 | ```
72 | This Bash script automatically downloads the data, processes it using COLMAP, and downsamples the images. For downsampling, you may need to install ImageMagick.
73 | ```
74 | # install ImageMagick if there is not on your computer:
75 | conda install -c conda-forge imagemagick
76 | # install COLMAP if there is not on your computer:
77 | conda install conda-forge::colmap
78 | ```
79 | > **Note**: The initial On-the-go dataset does not include COLMAP points, so preprocessing is required. For detailed preprocessing steps, please refer to the instructions below.
80 |
81 |
82 | Custom data
83 | We support COLMAP based datasets. Ensure your dataset is organized in the following structure:
84 | ```
85 |
86 | |---images
87 | | |---
88 | | |---
89 | | |---...
90 | |---sparse
91 | |---0
92 | |---cameras.bin
93 | |---images.bin
94 | |---points3D.bin
95 | ```
96 | For datasets like the On-the-go Dataset and custom datasets without point cloud information, you need to preprocess them using COLMAP.
97 |
98 | To prepare the images for the COLMAP processor, organize your dataset folder as follows:
99 | ```
100 |
101 | |---input
102 | |---
103 | |---
104 | |---...
105 | ```
106 | Then, run the following command:
107 | ```
108 | # install COLMAP if there is not on your computer:
109 | conda install conda-forge::colmap
110 | python scripts/convert.py -s [--resize] # If not resizing, ImageMagick is not needed
111 |
112 | # an example for on-the-go dataset could be:
113 | python scripts/convert.py -s ../data/on-the-go/patio
114 | ```
115 |
116 |
117 |
118 | ## Training
119 |
120 | For RobustNeRF data, train using:
121 | ```
122 | ns-train desplat robustnerf-data --data [path_to_robustnerf_data]
123 | ```
124 |
125 | For RobustNeRF data, train using:
126 | ```
127 | ns-train desplat onthego-data --data [path_to_onthego_data]
128 | ```
129 |
130 | For Photo Tourism data, train using:
131 | ```
132 | ns-train desplat --steps_per_save 200000 --max_num_iterations 200000 --pipeline.model.stop_split_at 100000 \
133 | --pipeline.model.enable_appearance True --pipeline.model.app_per_gauss True phototourism-data --data [path_to_onthego_data]
134 | ```
135 |
136 | You can adjust the configuration by switching options such as `--pipeline.model.use_adc`. For more details, please refer to `desplat_model.py` for a closer look.
137 |
138 |
139 | ## Known Issues
140 |
141 | Due to differences in the optimization method, running the code directly within the Nerfstudio framework may result in the following issue:
142 |
143 | Traceback error
144 |
145 | ```
146 | Traceback (most recent call last):
147 | File "/******/ns-train", line 8, in
148 | sys.exit(entrypoint())
149 | File "/******/site-packages/nerfstudio/scripts/train.py", line 262, in entrypoint
150 | main(
151 | File "/******/site-packages/nerfstudio/scripts/train.py", line 247, in main
152 | launch(
153 | File "/******/site-packages/nerfstudio/scripts/train.py", line 189, in launch
154 | main_func(local_rank=0, world_size=world_size, config=config)
155 | File "/******/site-packages/nerfstudio/scripts/train.py", line 100, in train_loop
156 | trainer.train()
157 | File "/******/site-packages/nerfstudio/engine/trainer.py", line 301, in train
158 | self.save_checkpoint(step)
159 | File "/******/site-packages/nerfstudio/utils/decorators.py", line 82, in wrapper
160 | ret = func(*args, **kwargs)
161 | File "/******/site-packages/nerfstudio/engine/trainer.py", line 467, in save_checkpoint
162 | "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
163 | File "/******/site-packages/nerfstudio/engine/trainer.py", line 467, in
164 | "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
165 | File "/******/site-packages/torch/_compile.py", line 31, in inner
166 | return disable_fn(*args, **kwargs)
167 | File "/******/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
168 | return fn(*args, **kwargs)
169 | File "/******/site-packages/torch/optim/optimizer.py", line 705, in state_dict
170 | packed_state = {
171 | File "/******/site-packages/torch/optim/optimizer.py", line 706, in
172 | (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
173 | KeyError: ******
174 | ```
175 |
176 |
177 | To temporarily resolve this issue, my solution is to comment out the following two lines in the Nerfstudio library file `/nerfstudio/engine/trainer.py` (in Nerfstudio version 1.1.3, these are lines 467 and 468):
178 |
179 | ```
180 | "optimizers": {k: v.state_dict() for (k, v) in self.optimizers.optimizers.items()},
181 | "schedulers": {k: v.state_dict() for (k, v) in self.optimizers.schedulers.items()},
182 | ```
183 |
184 |
185 | ## Rendering
186 |
187 | To render a video, we need to use `ns-render`. The simple command is:
188 | ```
189 | ns-render camera-path/interpolate/spiral/dataset --load-config .../config.yml
190 | ```
191 | To render the images of the whole dataset for visualization, use the following command:
192 | ```
193 | ns-render dataset --split train+test --load-config --output-path
194 | ```
195 |
196 | ## Test-Time Optimization (TTO)
197 |
198 | For datasets with complex weather and lighting variations, using per-image embeddings and test-time optimization is essential. During training, you can run the following command:
199 | ```
200 | ns-train desplat --pipeline.model.enable_appearance True phototourism-data --data [data-path]
201 | ```
202 | After completing the training, you will get a `.ckpt` file and a configuration path. To perform test-time optimization, run:
203 | ```
204 | python scripts/test_time_optimize.py --load-config [config-path]
205 | ```
206 | Adding the `--save_gif` flag saves a `.gif` file for a quick visual review.
207 | Adding the `--save_all_imgs` flag saves all rendered test images.
208 | Adding the `--use_saved_embedding` flag loads the saved appearance embeddings.
209 |
210 | ## Citation
211 |
212 | ```
213 | @InProceedings{wang2024desplat,
214 | title={{DeSplat}: Decomposed {G}aussian Splatting for Distractor-Free Rendering},
215 | author={Wang, Yihao and Klasson, Marcus and Turkulainen, Matias and Wang, Shuzhe and Kannala, Juho and Solin, Arno},
216 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
217 | year={2025}
218 | }
219 | ```
220 |
221 | ## License
222 | This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.
223 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AaltoML/desplat/966b1f4ba68dc4e1fc900fc416e3f4244ee88902/assets/pipeline.png
--------------------------------------------------------------------------------
/desplat/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataparsers.onthego_dataparser import OnthegoDataParserSpecification
2 | from .dataparsers.phototourism_dataparser import (
3 | PhotoTourismDataParserSpecification,
4 | )
5 | from .dataparsers.robustnerf_dataparser import RobustNerfDataParserSpecification
6 |
7 | __all__ = [
8 | "__version__",
9 | OnthegoDataParserSpecification,
10 | PhotoTourismDataParserSpecification,
11 | RobustNerfDataParserSpecification,
12 | ]
13 |
--------------------------------------------------------------------------------
/desplat/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 |
5 | from desplat.datamanager import DeSplatDataManagerConfig
6 | from desplat.dataparsers.robustnerf_dataparser import RobustNerfDataParserConfig
7 | from desplat.desplat_model import DeSplatModelConfig
8 | from desplat.pipeline import DeSplatPipelineConfig
9 | from nerfstudio.configs.base_config import ViewerConfig
10 | from nerfstudio.engine.optimizers import AdamOptimizerConfig
11 | from nerfstudio.engine.schedulers import (
12 | ExponentialDecaySchedulerConfig,
13 | )
14 | from nerfstudio.engine.trainer import TrainerConfig
15 | from nerfstudio.plugins.types import MethodSpecification
16 |
17 | desplat_method = MethodSpecification(
18 | config=TrainerConfig(
19 | method_name="desplat",
20 | steps_per_eval_image=100,
21 | steps_per_eval_batch=0,
22 | steps_per_save=30000,
23 | steps_per_eval_all_images=1000,
24 | max_num_iterations=30000,
25 | mixed_precision=False,
26 | pipeline=DeSplatPipelineConfig(
27 | datamanager=DeSplatDataManagerConfig( # desplat
28 | dataparser=RobustNerfDataParserConfig(
29 | load_3D_points=True, colmap_path=Path("sparse/0")
30 | ), # , downscale_factor=2
31 | cache_images_type="uint8",
32 | ),
33 | model=DeSplatModelConfig(),
34 | ),
35 | optimizers={
36 | "means": {
37 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
38 | "scheduler": ExponentialDecaySchedulerConfig(
39 | lr_final=1.6e-6,
40 | max_steps=30000,
41 | ),
42 | },
43 | "features_dc": {
44 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15),
45 | "scheduler": None,
46 | },
47 | "features_rest": {
48 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15),
49 | "scheduler": None,
50 | },
51 | "opacities": {
52 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
53 | "scheduler": None,
54 | },
55 | "scales": {
56 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
57 | "scheduler": None,
58 | },
59 | "quats": {
60 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15),
61 | "scheduler": None,
62 | },
63 | "embeddings": {
64 | "optimizer": AdamOptimizerConfig(lr=0.02, eps=1e-15),
65 | "scheduler": None,
66 | },
67 | "means_dyn": {
68 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
69 | "scheduler": ExponentialDecaySchedulerConfig(
70 | lr_final=1.6e-6,
71 | max_steps=30000,
72 | ),
73 | },
74 | "rgbs_dyn": {
75 | "optimizer": AdamOptimizerConfig(lr=0.025, eps=1e-15),
76 | "scheduler": None,
77 | },
78 | "opacities_dyn": {
79 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
80 | "scheduler": None,
81 | },
82 | "scales_dyn": {
83 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
84 | "scheduler": None,
85 | },
86 | "quats_dyn": {
87 | "optimizer": AdamOptimizerConfig(lr=0.01, eps=1e-15),
88 | "scheduler": None,
89 | },
90 | # back to original components
91 | "camera_opt": {
92 | "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
93 | "scheduler": ExponentialDecaySchedulerConfig(
94 | lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
95 | ),
96 | },
97 | "appearance_mlp": {
98 | "optimizer": AdamOptimizerConfig(lr=0.0005, eps=1e-15),
99 | "scheduler": None,
100 | },
101 | "appearance_embeddings": {
102 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
103 | "scheduler": None,
104 | },
105 | "field_background_encoder": {
106 | "optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
107 | "scheduler": ExponentialDecaySchedulerConfig(
108 | lr_final=1e-4, max_steps=100000
109 | ),
110 | },
111 | "field_background_base": {
112 | "optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
113 | "scheduler": ExponentialDecaySchedulerConfig(
114 | lr_final=2e-4, max_steps=100000
115 | ),
116 | },
117 | "field_background_rest": {
118 | "optimizer": AdamOptimizerConfig(lr=2e-3 / 20, eps=1e-15),
119 | "scheduler": ExponentialDecaySchedulerConfig(
120 | lr_final=2e-4 / 20, max_steps=100000
121 | ),
122 | },
123 | },
124 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
125 | vis="viewer",
126 | ),
127 | description="Desplat",
128 | )
129 |
--------------------------------------------------------------------------------
/desplat/datamanager.py:
--------------------------------------------------------------------------------
1 | import random
2 | from copy import deepcopy
3 | from dataclasses import dataclass, field
4 | from typing import Dict, List, Literal, Tuple, Type, Union
5 |
6 | import torch
7 | import torch._dynamo
8 |
9 | from nerfstudio.cameras.cameras import Cameras
10 | from nerfstudio.data.datamanagers.full_images_datamanager import (
11 | FullImageDatamanager,
12 | FullImageDatamanagerConfig,
13 | )
14 |
15 | torch._dynamo.config.suppress_errors = True
16 |
17 |
18 | @dataclass
19 | class DeSplatDataManagerConfig(FullImageDatamanagerConfig):
20 | _target: Type = field(default_factory=lambda: DeSplatDataManager)
21 |
22 |
23 | class DeSplatDataManager(FullImageDatamanager):
24 | config: DeSplatDataManagerConfig
25 |
26 | def __init__(
27 | self,
28 | config: DeSplatDataManagerConfig,
29 | device: Union[torch.device, str] = "cpu",
30 | test_mode: Literal["test", "val", "inference"] = "val",
31 | world_size: int = 1,
32 | local_rank: int = 0,
33 | **kwargs, # pylint: disable=unused-argument
34 | ):
35 | self.config = config
36 | super().__init__(
37 | config=config,
38 | device=device,
39 | test_mode=test_mode,
40 | world_size=world_size,
41 | local_rank=local_rank,
42 | **kwargs,
43 | )
44 | metadata = self.train_dataparser_outputs.metadata
45 | if test_mode == "test":
46 | self.train_unseen_cameras = [i for i in range(len(self.train_dataset))]
47 |
48 | @property
49 | def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]:
50 | """
51 | Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples
52 | """
53 | image_indices = [i for i in range(len(self.eval_dataset))]
54 | data = [d.copy() for d in self.cached_eval]
55 | _cameras = deepcopy(self.eval_dataset.cameras).to(self.device)
56 | cameras = []
57 | for i in image_indices:
58 | data[i]["image"] = data[i]["image"].to(self.device)
59 | if (
60 | self.dataparser_config.eval_train
61 | or self.dataparser_config.test_time_optimize
62 | ):
63 | if _cameras.metadata is None:
64 | _cameras.metadata = {}
65 | _cameras.metadata["cam_idx"] = i
66 | cameras.append(_cameras[i : i + 1])
67 |
68 | assert (
69 | len(self.eval_dataset.cameras.shape) == 1
70 | ), "Assumes single batch dimension"
71 | return list(zip(cameras, data))
72 |
73 | def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
74 | """Returns the next evaluation batch
75 |
76 | Returns a Camera instead of raybundle
77 |
78 | TODO: Make sure this logic is consistent with the vanilladatamanager"""
79 | image_idx = self.eval_unseen_cameras.pop(
80 | random.randint(0, len(self.eval_unseen_cameras) - 1)
81 | )
82 | # Make sure to re-populate the unseen cameras list if we have exhausted it
83 | if len(self.eval_unseen_cameras) == 0:
84 | self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))]
85 | data = self.cached_eval[image_idx]
86 | data = data.copy()
87 | data["image"] = data["image"].to(self.device)
88 | assert (
89 | len(self.eval_dataset.cameras.shape) == 1
90 | ), "Assumes single batch dimension"
91 | camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
92 | # keep metadata for debugging
93 | if self.dataparser_config.eval_train:
94 | if camera.metadata is None:
95 | camera.metadata = {}
96 | camera.metadata["cam_idx"] = image_idx
97 | return camera, data
98 |
--------------------------------------------------------------------------------
/desplat/dataparsers/onthego_dataparser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/"""
16 |
17 | from __future__ import annotations
18 |
19 | import json
20 | import os
21 | from dataclasses import dataclass, field
22 | from pathlib import Path
23 | from typing import Type
24 |
25 | import numpy as np
26 |
27 | from nerfstudio.data.dataparsers.colmap_dataparser import (
28 | ColmapDataParser,
29 | ColmapDataParserConfig,
30 | )
31 |
32 | # TODO(1480) use pycolmap instead of colmap_parsing_utils
33 | # import pycolmap
34 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification
35 |
36 |
37 | @dataclass
38 | class OnthegoDataParserConfig(ColmapDataParserConfig):
39 | """On-the-go dataset parser config"""
40 |
41 | _target: Type = field(default_factory=lambda: OnthegoDataParser)
42 | """target class to instantiate"""
43 | eval_train: bool = False
44 | """evaluate test set or train set, for debug"""
45 | colmap_path: Path = Path("sparse/0")
46 | """path to colmap sparse folder"""
47 | test_time_optimize: bool = False
48 | """Whether to use test-time optimization for the dataset"""
49 |
50 |
51 | @dataclass
52 | class OnthegoDataParser(ColmapDataParser):
53 | """Phototourism dataset. This is based on https://github.com/kwea123/nerf_pl/blob/nerfw/datasets/phototourism.py
54 | and uses colmap's utils file to read the poses.
55 | """
56 |
57 | config: OnthegoDataParserConfig
58 |
59 | def __init__(self, config: ColmapDataParserConfig):
60 | super().__init__(config=config)
61 | self.config = config
62 | self.data: Path = config.data
63 | self.config.downscale_factor = (
64 | 4 if self.data.name == "patio" or self.data.name == "arcdetriomphe" else 8
65 | )
66 |
67 | def _get_image_indices(self, image_filenames, split):
68 | # Load the split file to get the train/eval split
69 | with open(os.path.join(self.config.data, "split.json"), "r") as file:
70 | split_json = json.load(file)
71 |
72 | # Select the split according to the split file.
73 | all_indices = np.arange(len(image_filenames))
74 |
75 | i_eval = all_indices[split_json["extra"]]
76 |
77 | i_train = all_indices[split_json["clutter"]]
78 |
79 | if self.config.eval_train:
80 | i_eval = i_train
81 | if split == "train":
82 | indices = i_train
83 | elif split in ["val", "test"]:
84 | indices = i_eval
85 | else:
86 | raise ValueError(f"Unknown dataparser split {split}")
87 |
88 | return indices
89 |
90 |
91 | OnthegoDataParserSpecification = DataParserSpecification(
92 | config=OnthegoDataParserConfig(),
93 | description="On-the-go dataparser",
94 | )
95 |
--------------------------------------------------------------------------------
/desplat/dataparsers/phototourism_dataparser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/"""
16 |
17 | from __future__ import annotations
18 |
19 | # TODO(1480) use pycolmap instead of colmap_parsing_utils
20 | # import pycolmap
21 | import csv
22 | import os
23 | from dataclasses import dataclass, field
24 | from pathlib import Path
25 | from typing import Type
26 |
27 | import numpy as np
28 |
29 | # from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs
30 | from nerfstudio.data.dataparsers.colmap_dataparser import (
31 | ColmapDataParser,
32 | ColmapDataParserConfig,
33 | )
34 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification
35 |
36 |
37 | @dataclass
38 | class PhotoTourismDataParserConfig(ColmapDataParserConfig):
39 | """Phototourism dataset parser config"""
40 |
41 | _target: Type = field(default_factory=lambda: PhotoTourismDataParser)
42 | """target class to instantiate"""
43 | eval_train: bool = False
44 | """evaluate test set or train set, for debug"""
45 | colmap_path: Path = Path("sparse")
46 |
47 | test_time_optimize: bool = True
48 |
49 |
50 | @dataclass
51 | class PhotoTourismDataParser(ColmapDataParser):
52 | """Phototourism dataset. This is based on https://github.com/kwea123/nerf_pl/blob/nerfw/datasets/phototourism.py
53 | and uses colmap's utils file to read the poses.
54 | """
55 |
56 | config: PhotoTourismDataParserConfig
57 |
58 | def __init__(self, config: PhotoTourismDataParserConfig):
59 | super().__init__(config=config)
60 | self.data: Path = config.data
61 |
62 | def _get_image_indices(self, image_filenames, split):
63 | # Load the split file to get the train/eval split
64 | if "brandenburg_gate" in str(self.data):
65 | tsv_path = self.config.data / "brandenburg.tsv"
66 | elif "sacre_coeur" in str(self.data):
67 | tsv_path = self.config.data / "sacre.tsv"
68 | elif "trevi_fountain" in str(self.data):
69 | tsv_path = self.config.data / "trevi.tsv"
70 | else:
71 | raise ValueError(f"Unknown dataset {self.data.name}")
72 |
73 | basenames = [
74 | os.path.basename(image_filename) for image_filename in image_filenames
75 | ]
76 |
77 | train_names, test_names = set(), set()
78 |
79 | with open(tsv_path, newline="") as tsv_file:
80 | reader = csv.reader(tsv_file, delimiter="\t")
81 | next(reader)
82 | for row in reader:
83 | if row[2] == "train":
84 | train_names.add(row[0])
85 | elif row[2] == "test":
86 | test_names.add(row[0])
87 |
88 | if self.config.eval_train:
89 | split = "train"
90 |
91 | indices = [
92 | idx
93 | for idx, basename in enumerate(basenames)
94 | if (basename in train_names and split == "train")
95 | or (basename in test_names and split in ["val", "test"])
96 | ]
97 |
98 | if not indices and split not in ["train", "test", "val"]:
99 | raise ValueError(f"Unknown dataparser split {split}")
100 |
101 | return np.array(indices)
102 |
103 |
104 | PhotoTourismDataParserSpecification = DataParserSpecification(
105 | config=PhotoTourismDataParserConfig(), description="Photo tourism dataparser"
106 | )
107 |
--------------------------------------------------------------------------------
/desplat/dataparsers/robustnerf_dataparser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/"""
16 |
17 | from __future__ import annotations
18 |
19 | import os
20 | from dataclasses import dataclass, field
21 | from pathlib import Path
22 | from typing import List, Literal, Optional, Type
23 |
24 | import numpy as np
25 | import torch
26 |
27 | from nerfstudio.data.dataparsers.colmap_dataparser import (
28 | ColmapDataParser,
29 | ColmapDataParserConfig,
30 | )
31 |
32 | # TODO(1480) use pycolmap instead of colmap_parsing_utils
33 | # import pycolmap
34 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification
35 |
36 |
37 | @dataclass
38 | class RobustNerfDataParserConfig(ColmapDataParserConfig):
39 | """On-the-go dataset parser config"""
40 |
41 | _target: Type = field(default_factory=lambda: RobustNerfDataParser)
42 | """target class to instantiate"""
43 | eval_train: bool = False
44 | """evaluate test set or train set, for debug"""
45 | train_split_mode: Optional[Literal["ratio", "number", "filename"]] = None
46 | """How to split the training images. If None, all cluttered images are used."""
47 | train_split_clean_clutter_ratio: float = 1.0
48 | """The percentage of the training images that are cluttered. 0.0 -> only clean images, 1.0 -> only cluttered images"""
49 | train_split_clean_clutter_number: int = 0
50 | """The number of clean images to use for training. If 0, all clean images are used."""
51 | idx_clutter: List[int] = field(default_factory=lambda: [76])
52 | """The indices of the cluttered images to use for training"""
53 | colmap_path: Path = Path("sparse/0")
54 | """path to colmap sparse folder"""
55 | downscale_factor: int = 8
56 | """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px."""
57 | test_time_optimize: bool = False
58 | """Whether to use test-time optimization for the dataset"""
59 |
60 |
61 | @dataclass
62 | class RobustNerfDataParser(ColmapDataParser):
63 | """RobustNerf dataset. This is based on https://github.com/kwea123/nerf_pl/blob/nerfw/datasets/phototourism.py
64 | and uses colmap's utils file to read the poses.
65 | """
66 |
67 | config: RobustNerfDataParserConfig
68 |
69 | def __init__(self, config: ColmapDataParserConfig):
70 | super().__init__(config=config)
71 | self.config = config
72 | self.data = config.data
73 |
74 | def _get_image_indices(self, image_filenames, split):
75 | i_train, i_eval = self.get_train_eval_split_filename(
76 | image_filenames, self.config.train_split_clean_clutter_ratio
77 | )
78 |
79 | if split == "train":
80 | indices = i_train
81 | elif split in ["val", "test"]:
82 | indices = i_eval
83 | else:
84 | raise ValueError(f"Unknown dataparser split {split}")
85 |
86 | return indices
87 |
88 | def get_train_eval_split_filename(
89 | self, image_filenames: List, clean_clutter_ratio: float
90 | ) -> Tuple[np.ndarray, np.ndarray]:
91 | """
92 | Get the train/eval split based on the filename of the images.
93 |
94 | Args:
95 | image_filenames: list of image filenames
96 | """
97 |
98 | num_images = len(image_filenames)
99 | basenames = [
100 | os.path.basename(image_filename) for image_filename in image_filenames
101 | ]
102 |
103 | i_all = np.arange(num_images)
104 | i_clean = []
105 | i_clutter = []
106 |
107 | i_train_clean = []
108 | i_train_clutter = []
109 |
110 | i_train = []
111 | i_eval = []
112 |
113 | for idx, basename in zip(i_all, basenames):
114 | if "clean" in basename:
115 | i_clean.append(idx)
116 |
117 | elif "clutter" in basename:
118 | i_clutter.append(idx)
119 |
120 | elif "extra" in basename:
121 | i_eval.append(idx) # extra is always used as eval
122 | else:
123 | raise ValueError(
124 | "image frame should contain clutter/extra in its name "
125 | )
126 |
127 | if self.config.train_split_mode is None:
128 | i_train = i_clutter
129 | if self.config.eval_train:
130 | i_eval = i_train
131 | return np.array(i_train), np.array(i_eval)
132 |
133 | if len(i_clean) > len(i_clutter):
134 | i_clean = i_clean[: len(i_clutter)]
135 | elif len(i_clean) < len(i_clutter):
136 | i_clutter = i_clutter[: len(i_clean)]
137 | num_images_train = min(len(i_clean), len(i_clutter))
138 |
139 | print("Number of clean images: ", num_images_train)
140 |
141 | if self.config.train_split_mode == "number":
142 | i_perm = torch.randperm(
143 | num_images_train, generator=torch.Generator().manual_seed(2023)
144 | ).tolist()
145 | num_images_cluttered = self.config.train_split_clean_clutter_number
146 | i_train = []
147 | i_train_clutter = []
148 | # loop over permuted indices to select one image from each clean/clutter pair
149 |
150 | for k, idx in enumerate(i_perm):
151 | if k < num_images_cluttered:
152 | i_train_clutter.append(i_clutter[idx])
153 | else:
154 | i_train.append(i_clean[idx])
155 | i_train.expand(i_train_clutter)
156 | elif self.config.train_split_mode == "ratio":
157 | i_train_clutter = []
158 | if clean_clutter_ratio == 0.0:
159 | # only clean images
160 | i_train = i_clean
161 | elif clean_clutter_ratio == 1.0:
162 | # only cluttered images
163 | i_train = i_clutter
164 | elif clean_clutter_ratio > 0.0 and clean_clutter_ratio < 1.0:
165 | # pick either clutter/clean image once
166 | i_perm = torch.randperm(
167 | num_images_train, generator=torch.Generator().manual_seed(2023)
168 | ).tolist()
169 | num_images_cluttered = int(
170 | num_images_train * clean_clutter_ratio
171 | ) # rounds down
172 | i_train = []
173 | # loop over permuted indices to select one image from each clean/clutter pair
174 | for k, idx in enumerate(i_perm):
175 | if k < num_images_cluttered:
176 | i_train_clutter.append(i_clutter[idx])
177 | else:
178 | i_train.append(i_clean[idx])
179 | i_train.extend(i_train_clutter)
180 | print(
181 | "basenames of cluttered images: ",
182 | [basenames[i] for i in i_train_clutter],
183 | )
184 | else:
185 | raise ValueError(
186 | "arg train_split_clean_clutter_ratio must be between 0.0 and 1.0 "
187 | )
188 |
189 | elif self.config.train_split_mode == "filename":
190 | # manually select images
191 | i_train = i_clean
192 | i_train.extend(i_train_clutter)
193 |
194 | # Remove elements in i_train_clean from i_train
195 | for item in i_train_clean:
196 | if item in i_train:
197 | i_train.remove(item)
198 | print(
199 | "basenames of cluttered images: ",
200 | [basenames[i] for i in i_train_clutter],
201 | )
202 | else:
203 | raise ValueError("Unknown train_split_mode")
204 | print("i_train", i_train)
205 | print("i_eval", i_eval)
206 | if self.config.eval_train:
207 | i_eval = i_train
208 | return np.array(i_train), np.array(i_eval)
209 |
210 |
211 | RobustNerfDataParserSpecification = DataParserSpecification(
212 | config=RobustNerfDataParserConfig(),
213 | description="RobustNeRF dataparser",
214 | )
215 |
--------------------------------------------------------------------------------
/desplat/desplat_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Dict, List, Optional, Tuple, Type, Union
5 |
6 | import torch
7 | import torch._dynamo
8 | from pytorch_msssim import SSIM
9 | from torch import nn
10 | from torch.nn import functional as F
11 | from torchmetrics.image import PeakSignalNoiseRatio
12 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
13 |
14 | from desplat.field import (
15 | BGField,
16 | EmbeddingModel,
17 | _get_fourier_features,
18 | )
19 | from gsplat.cuda._wrapper import spherical_harmonics
20 | from gsplat.rendering import rasterization
21 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer
22 | from nerfstudio.cameras.camera_utils import normalize
23 | from nerfstudio.cameras.cameras import Cameras
24 | from nerfstudio.data.scene_box import OrientedBox
25 | from nerfstudio.engine.optimizers import Optimizers
26 | from nerfstudio.models.splatfacto import (
27 | RGB2SH,
28 | SplatfactoModel,
29 | SplatfactoModelConfig,
30 | get_viewmat,
31 | num_sh_bases,
32 | random_quat_tensor,
33 | )
34 | from nerfstudio.utils.colors import get_color
35 | from nerfstudio.utils.rich_utils import CONSOLE
36 |
37 | torch._dynamo.config.suppress_errors = True
38 | torch.set_float32_matmul_precision("high")
39 |
40 |
41 | def quat_to_rotmat(quat):
42 | assert quat.shape[-1] == 4, quat.shape
43 | w, x, y, z = torch.unbind(quat, dim=-1)
44 | mat = torch.stack(
45 | [
46 | 1 - 2 * (y**2 + z**2),
47 | 2 * (x * y - w * z),
48 | 2 * (x * z + w * y),
49 | 2 * (x * y + w * z),
50 | 1 - 2 * (x**2 + z**2),
51 | 2 * (y * z - w * x),
52 | 2 * (x * z - w * y),
53 | 2 * (y * z + w * x),
54 | 1 - 2 * (x**2 + y**2),
55 | ],
56 | dim=-1,
57 | )
58 | return mat.reshape(quat.shape[:-1] + (3, 3))
59 |
60 |
61 | @dataclass
62 | class DeSplatModelConfig(SplatfactoModelConfig):
63 | _target: Type = field(default_factory=lambda: DeSplatModel)
64 |
65 | ### Settings for DeSplat
66 | num_dynamic_points: int = 1000
67 | """Initial number of dynamic points"""
68 | refine_every_dyn: int = 10
69 | """period of steps where gaussians are culled and densified"""
70 | use_adc: bool = True
71 | """Whether to use ADC for dynamic points management"""
72 | enable_reset_alpha: bool = False
73 | """Whether to reset alpha for dynamic points"""
74 | continue_split_dyn: bool = False
75 | """Whether to continue splitting for dynamic points after step for splitting stops"""
76 | distance: float = 0.02
77 | """Distance of dynamic points from camera"""
78 |
79 | ### Settings for Regularization
80 | alpha_bg_loss_lambda: float = 0.01
81 | """Lambda for alpha background loss"""
82 | alpha_2d_loss_lambda: float = 0.01
83 | """Lambda for alpha loss of 2D dynamic Gaussians"""
84 | enable_bg_model: bool = False
85 | """Whether to enable background model"""
86 | bg_sh_degree: int = 4
87 | """Degree of SH bases for background model"""
88 | bg_num_layers: int = 3
89 | """Number of layers in the background model"""
90 | bg_layer_width: int = 128
91 | """Width of each layer in the background model"""
92 |
93 | ### Settings for Appearance Optimization
94 | enable_appearance: bool = False
95 | """Enable or disable appearance optimization"""
96 | app_per_gauss: bool = False
97 | """Whether to optimize appearance according to the per-Gaussian embedding"""
98 | appearance_embedding_dim: int = 32
99 | """Dimension of the appearance embedding"""
100 | appearance_n_fourier_freqs: int = 4
101 | """Number of Fourier frequencies for per-Gaussian embedding initialization"""
102 | appearance_init_fourier: bool = True
103 | """Whether to initialize the per-Gaussian embedding with Fourier frequencies"""
104 |
105 | class DeSplatModel(SplatfactoModel):
106 | config: DeSplatModelConfig
107 |
108 | def populate_modules(self):
109 | cameras = self.kwargs["cameras"]
110 | self.dataparser_config = self.kwargs["dataparser"]
111 |
112 | if self.seed_points is not None and not self.config.random_init:
113 | means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
114 | else:
115 | means = torch.nn.Parameter(
116 | (torch.rand((self.config.num_random, 3)) - 0.5)
117 | * self.config.random_scale
118 | )
119 |
120 | assert cameras is not None
121 |
122 | self.xys_grad_norm = None
123 | self.max_2Dsize = None
124 | distances, _ = self.k_nearest_sklearn(means.data, 3)
125 | distances = torch.from_numpy(distances)
126 | # find the average of the three nearest neighbors for each point and use that as the scale
127 | avg_dist = distances.mean(dim=-1, keepdim=True)
128 | scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
129 | num_points = means.shape[0]
130 |
131 | quats = torch.nn.Parameter(random_quat_tensor(num_points))
132 | dim_sh = num_sh_bases(self.config.sh_degree)
133 |
134 | if (
135 | self.seed_points is not None
136 | and not self.config.random_init
137 | # We can have colors without points.
138 | and self.seed_points[1].shape[0] > 0
139 | ):
140 | shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda()
141 | if self.config.sh_degree > 0:
142 | shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255)
143 | shs[:, 1:, 3:] = 0.0
144 | else:
145 | CONSOLE.log("use color only optimization with sigmoid activation")
146 | shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10)
147 | features_dc = torch.nn.Parameter(shs[:, 0, :])
148 | features_rest = torch.nn.Parameter(shs[:, 1:, :])
149 | else:
150 | features_dc = torch.nn.Parameter(torch.rand(num_points, 3))
151 | features_rest = torch.nn.Parameter(torch.zeros((num_points, dim_sh - 1, 3)))
152 |
153 | opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(num_points, 1)))
154 |
155 | # appearance embedding for each Gaussian
156 | embeddings = _get_fourier_features(
157 | means, num_features=self.config.appearance_n_fourier_freqs
158 | )
159 | embeddings.add_(torch.randn_like(embeddings) * 0.0001)
160 | if not self.config.appearance_init_fourier:
161 | embeddings.normal_(0, 0.01)
162 |
163 | self.gauss_params = torch.nn.ParameterDict(
164 | {
165 | "means": means,
166 | "scales": scales,
167 | "quats": quats,
168 | "features_dc": features_dc,
169 | "features_rest": features_rest,
170 | "opacities": opacities,
171 | }
172 | )
173 |
174 | if self.config.app_per_gauss:
175 | # appearance embedding for each Gaussian
176 | embeddings = _get_fourier_features(means, num_features=self.config.appearance_n_fourier_freqs)
177 | embeddings.add_(torch.randn_like(embeddings) * 0.0001)
178 | if not self.config.appearance_init_fourier:
179 | embeddings.normal_(0, 0.01)
180 | self.gauss_params["embeddings"] = embeddings
181 |
182 | self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
183 | num_cameras=self.num_train_data, device="cpu"
184 | )
185 |
186 | self.camera_idx = 0
187 | self.camera = None
188 |
189 | self.psnr = PeakSignalNoiseRatio(data_range=1.0)
190 | self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)
191 | self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)
192 | self.step = 0
193 |
194 | self.crop_box: Optional[OrientedBox] = None
195 | if self.config.background_color == "random":
196 | self.background_color = torch.tensor(
197 | [0.1490, 0.1647, 0.2157]
198 | ) # This color is the same as the default background color in Viser. This would only affect the background color when rendering.
199 | else:
200 | self.background_color = get_color(self.config.background_color)
201 |
202 | # only the storage method is dict we can use ADC
203 | self.populate_modules_dyn_dict()
204 | self.num_points_dyn = self.num_points_dyn_dict
205 |
206 | self.loss_threshold = 1.0
207 | self.max_loss = 0.0
208 | self.min_loss = 1e10
209 |
210 | # Add the appearance embedding
211 | if self.config.enable_appearance:
212 | self.appearance_embeddings = torch.nn.Embedding(
213 | self.num_train_data, self.config.appearance_embedding_dim
214 | )
215 | self.appearance_embeddings.weight.data.normal_(0, 0.01)
216 | self.appearance_mlp = EmbeddingModel(self.config)
217 | else:
218 | self.appearance_embeddings = None
219 | self.appearance_mlp = None
220 |
221 | if self.config.enable_bg_model:
222 | self.bg_model = BGField(
223 | appearance_embedding_dim=self.config.appearance_embedding_dim,
224 | implementation="torch",
225 | sh_levels=self.config.bg_sh_degree,
226 | num_layers=self.config.bg_num_layers,
227 | layer_width=self.config.bg_layer_width,
228 | )
229 | else:
230 | self.bg_model = None
231 |
232 | def populate_modules_dyn_dict(self):
233 | cameras = self.kwargs["cameras"]
234 | self.gauss_params_dyn_dict = nn.ModuleDict()
235 | num_points_dyn = self.config.num_dynamic_points
236 |
237 | self.optimizers_dyn = {}
238 |
239 | for i in range(self.num_train_data):
240 | camera = cameras[i]
241 |
242 | optimized_camera_to_world = camera.camera_to_worlds
243 |
244 | camera_rotation = optimized_camera_to_world[:3, :3]
245 | camera_position = optimized_camera_to_world[:3, 3]
246 | distance_to_cam = self.config.distance
247 | cube = torch.rand(num_points_dyn, 3)
248 | cube[:, 0] = (cube[:, 0] - 0.5) * 0.02
249 | cube[:, 1] = (cube[:, 1] - 0.5) * 0.02
250 | cube[:, 2] = distance_to_cam
251 |
252 | means_dyn = torch.nn.Parameter(
253 | camera_position.repeat(num_points_dyn, 1) - cube @ camera_rotation.T
254 | )
255 |
256 | distances_dyn, _ = self.k_nearest_sklearn(means_dyn.data, 3)
257 | distances_dyn = torch.from_numpy(distances_dyn)
258 | avg_dist_dyn = distances_dyn.mean(dim=-1, keepdim=True)
259 | scales_dyn = torch.nn.Parameter(torch.log(avg_dist_dyn.repeat(1, 3)))
260 | rgbs_dyn = torch.nn.Parameter(torch.rand(num_points_dyn, 3))
261 | quats_dyn = nn.Parameter(random_quat_tensor(num_points_dyn))
262 | opacities_dyn = nn.Parameter(
263 | torch.logit(0.1 * torch.ones(num_points_dyn, 1))
264 | )
265 |
266 | self.gauss_params_dyn_dict[str(i)] = nn.ParameterDict(
267 | {
268 | "means_dyn": means_dyn,
269 | "scales_dyn": scales_dyn,
270 | "rgbs_dyn": rgbs_dyn,
271 | "quats_dyn": quats_dyn,
272 | "opacities_dyn": opacities_dyn,
273 | }
274 | )
275 |
276 | self.xys_dyn = {}
277 | self.radii_dyn = {}
278 | self.xys_grad_norm_dyn = {str(i): None for i in range(self.num_train_data)}
279 | self.max_2Dsize_dyn = {str(i): None for i in range(self.num_train_data)}
280 | self.vis_counts_dyn = {str(i): None for i in range(self.num_train_data)}
281 |
282 | @property
283 | def embeddings(self):
284 | if self.config.app_per_gauss:
285 | return self.gauss_params["embeddings"]
286 |
287 | def load_state_dict(self, dict, **kwargs): # type: ignore
288 | # resize the parameters to match the new number of points
289 | self.step = 30000
290 | if "means" in dict:
291 | # For backwards compatibility, we remap the names of parameters from
292 | # means->gauss_params.means since old checkpoints have that format
293 | for p in [
294 | "means",
295 | "scales",
296 | "quats",
297 | "features_dc",
298 | "features_rest",
299 | "opacities",
300 | ]:
301 | dict[f"gauss_params.{p}"] = dict[p]
302 | if self.config.app_per_gauss:
303 | dict[f"gauss_params.embeddings"] = dict[p]
304 | newp = dict["gauss_params.means"].shape[0]
305 |
306 | for name, param in self.gauss_params.items():
307 | old_shape = param.shape
308 | new_shape = (newp,) + old_shape[1:]
309 | self.gauss_params[name] = torch.nn.Parameter(
310 | torch.zeros(new_shape, device=self.device)
311 | )
312 |
313 | for i in range(self.num_train_data):
314 | if "means_dyn" in dict:
315 | for p in [
316 | "means_dyn",
317 | "scales_dyn",
318 | "quats_dyn",
319 | "rgbs_dyn",
320 | "opacities_dyn",
321 | ]: # "features_dc_dyn", "features_rest_dyn",
322 | dict[f"gauss_params_dyn_dict.{str(i)}.{p}"] = dict[p]
323 | newp_dyn = dict[f"gauss_params_dyn_dict.{str(i)}.means_dyn"].shape[0]
324 | for name, param in self.gauss_params_dyn_dict[str(i)].items():
325 | old_shape = param.shape
326 | new_shape = (newp_dyn,) + old_shape[1:]
327 | self.gauss_params_dyn_dict[str(i)][name] = torch.nn.Parameter(
328 | torch.zeros(new_shape, device=self.device)
329 | )
330 |
331 | super().load_state_dict(dict, **kwargs)
332 |
333 | def num_points_dyn_dict(self, i):
334 | return self.gauss_params_dyn_dict[str(i)]["means_dyn"].shape[0]
335 |
336 | def refinement_after(self, optimizers: Optimizers, step):
337 | assert step == self.step
338 | if self.step <= self.config.warmup_length:
339 | return
340 | with torch.no_grad():
341 | # Offset all the opacity reset logic by refine_every so that we don't
342 | # save checkpoints right when the opacity is reset (saves every 2k)
343 | # then cull
344 | # only split/cull if we've seen every image since opacity reset
345 | reset_interval = self.config.reset_alpha_every * self.config.refine_every
346 | do_densification = (
347 | self.step < self.config.stop_split_at
348 | and self.step % reset_interval
349 | > self.num_train_data + self.config.refine_every
350 | )
351 |
352 | if do_densification:
353 | # then we densify
354 | # for static points
355 | assert (
356 | self.xys_grad_norm is not None
357 | and self.vis_counts is not None
358 | and self.max_2Dsize is not None
359 | )
360 | avg_grad_norm = (
361 | (self.xys_grad_norm / self.vis_counts)
362 | * 0.5
363 | * max(self.last_size[0], self.last_size[1])
364 | )
365 | high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze()
366 | splits = (
367 | self.scales.exp().max(dim=-1).values
368 | > self.config.densify_size_thresh
369 | ).squeeze()
370 | splits &= high_grads
371 | if self.step < self.config.stop_screen_size_at:
372 | splits |= (
373 | self.max_2Dsize > self.config.split_screen_size
374 | ).squeeze()
375 | nsamps = self.config.n_split_samples
376 | split_params = self.split_gaussians(splits, nsamps)
377 |
378 | dups = (
379 | self.scales.exp().max(dim=-1).values
380 | <= self.config.densify_size_thresh
381 | ).squeeze()
382 | dups &= high_grads
383 | dup_params = self.dup_gaussians(dups)
384 |
385 | for name, param in self.gauss_params.items():
386 | self.gauss_params[name] = torch.nn.Parameter(
387 | torch.cat(
388 | [param.detach(), split_params[name], dup_params[name]],
389 | dim=0,
390 | )
391 | )
392 |
393 | # append zeros to the max_2Dsize tensor
394 | self.max_2Dsize = torch.cat(
395 | [
396 | self.max_2Dsize,
397 | torch.zeros_like(split_params["scales"][:, 0]),
398 | torch.zeros_like(dup_params["scales"][:, 0]),
399 | ],
400 | dim=0,
401 | )
402 |
403 | split_idcs = torch.where(splits)[0]
404 | self.dup_in_all_optim(optimizers, split_idcs, nsamps)
405 |
406 | dup_idcs = torch.where(dups)[0]
407 | self.dup_in_all_optim(optimizers, dup_idcs, 1)
408 |
409 | # After a guassian is split into two new gaussians, the original one should also be pruned.
410 | splits_mask = torch.cat(
411 | (
412 | splits,
413 | torch.zeros(
414 | nsamps * splits.sum() + dups.sum(),
415 | device=self.device,
416 | dtype=torch.bool,
417 | ),
418 | )
419 | )
420 |
421 | deleted_mask = self.cull_gaussians(splits_mask)
422 |
423 | elif (
424 | self.step >= self.config.stop_split_at
425 | and self.config.continue_cull_post_densification
426 | ):
427 | deleted_mask = self.cull_gaussians()
428 |
429 | else:
430 | # if we donot allow culling post refinement, no more gaussians will be pruned.
431 | deleted_mask = None
432 |
433 | if deleted_mask is not None:
434 | self.remove_from_all_optim(optimizers, deleted_mask)
435 | if (
436 | self.step < self.config.stop_split_at
437 | and self.step % reset_interval == self.config.refine_every
438 | ):
439 | # Reset value is set to be twice of the cull_alpha_thresh
440 | reset_value = self.config.cull_alpha_thresh * 2.0
441 | self.opacities.data = torch.clamp(
442 | self.opacities.data,
443 | max=torch.logit(
444 | torch.tensor(reset_value, device=self.device)
445 | ).item(),
446 | )
447 |
448 | # reset the exp of optimizer
449 | optim = optimizers.optimizers["opacities"]
450 | param = optim.param_groups[0]["params"][0]
451 | param_state = optim.state[param]
452 | param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"])
453 | param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"])
454 |
455 | self.xys_grad_norm = None
456 | self.vis_counts = None
457 | self.max_2Dsize = None
458 |
459 | if self.config.use_adc:
460 | do_densification_dyn = (
461 | self.step < self.config.stop_split_at
462 | or self.config.continue_split_dyn
463 | ) and self.step % (
464 | self.num_train_data * self.config.refine_every_dyn
465 | ) < self.config.refine_every
466 | # for dynamic points
467 | if do_densification_dyn:
468 | for camera_idx in range(self.num_train_data):
469 | # if number of Gaussians == 0, skip the densification
470 | if self.num_points_dyn(camera_idx) == 0:
471 | self.xys_grad_norm_dyn[str(camera_idx)] = None
472 | self.vis_counts_dyn[str(camera_idx)] = None
473 | self.max_2Dsize_dyn[str(camera_idx)] = None
474 | continue
475 |
476 | # we densify the 2D image when every image has been seen for refine_every_dyn times
477 | # For the dynamic points, for every image, we densify the points
478 | xys_grad_norm_dyn = self.xys_grad_norm_dyn[str(camera_idx)]
479 | vis_counts_dyn = self.vis_counts_dyn[str(camera_idx)]
480 | max_2Dsize_dyn = self.max_2Dsize_dyn[str(camera_idx)]
481 |
482 | dyn_gaussians = self.gauss_params_dyn_dict[str(camera_idx)]
483 |
484 | assert (
485 | xys_grad_norm_dyn is not None
486 | and vis_counts_dyn is not None
487 | and max_2Dsize_dyn is not None
488 | )
489 | avg_grad_norm_dyn = (
490 | (xys_grad_norm_dyn / vis_counts_dyn)
491 | * 0.5
492 | * max(self.last_size[0], self.last_size[1])
493 | )
494 | high_grads_dyn = (
495 | avg_grad_norm_dyn > self.config.densify_grad_thresh
496 | ).squeeze()
497 | splits_dyn = (
498 | dyn_gaussians["scales_dyn"].exp().max(dim=-1).values
499 | > self.config.densify_size_thresh
500 | ).squeeze()
501 | splits_dyn &= high_grads_dyn
502 | if self.step < self.config.stop_screen_size_at:
503 | splits_dyn |= (
504 | max_2Dsize_dyn > self.config.split_screen_size
505 | ).squeeze()
506 | nsamps = self.config.n_split_samples
507 |
508 | split_params_dyn = self.split_gaussians_dyn(
509 | camera_idx, splits_dyn, nsamps
510 | )
511 |
512 | dups_dyn = (
513 | dyn_gaussians["scales_dyn"].exp().max(dim=-1).values
514 | <= self.config.densify_size_thresh
515 | ).squeeze()
516 | dups_dyn &= high_grads_dyn
517 | dup_params_dyn = self.dup_gaussians_dyn(camera_idx, dups_dyn)
518 |
519 | for name, param in self.gauss_params_dyn_dict[
520 | str(camera_idx)
521 | ].items():
522 | self.gauss_params_dyn_dict[str(camera_idx)][name] = (
523 | torch.nn.Parameter(
524 | torch.cat(
525 | [
526 | param.detach(),
527 | split_params_dyn[name],
528 | dup_params_dyn[name],
529 | ],
530 | dim=0,
531 | )
532 | )
533 | )
534 |
535 | # append zeros to the max_2Dsize tensor
536 | self.max_2Dsize_dyn[str(camera_idx)] = torch.cat(
537 | [
538 | self.max_2Dsize_dyn[str(camera_idx)],
539 | torch.zeros_like(split_params_dyn["scales_dyn"][:, 0]),
540 | torch.zeros_like(dup_params_dyn["scales_dyn"][:, 0]),
541 | ],
542 | dim=0,
543 | )
544 |
545 | split_idcs = torch.where(splits_dyn)[0]
546 | # log this info to logfile
547 |
548 | self.dup_in_all_optim_dyn(
549 | camera_idx, optimizers, split_idcs, nsamps
550 | )
551 | dup_idcs = torch.where(dups_dyn)[0]
552 | self.dup_in_all_optim_dyn(camera_idx, optimizers, dup_idcs, 1)
553 |
554 | # After a guassian is split into two new gaussians, the original one should also be pruned.
555 | splits_mask_dyn = torch.cat(
556 | (
557 | splits_dyn,
558 | torch.zeros(
559 | nsamps * splits_dyn.sum() + dups_dyn.sum(),
560 | device=self.device,
561 | dtype=torch.bool,
562 | ),
563 | )
564 | )
565 |
566 | vis_cull_mask = (self.radii_dyn[str(camera_idx)] < 0.01).squeeze() # cull invisible distractor Gaussians
567 | splits_mask_dyn = splits_mask_dyn | vis_cull_mask
568 |
569 | deleted_mask_dyn = self.cull_gaussians_dyn(
570 | camera_idx, splits_mask_dyn
571 | )
572 |
573 | if deleted_mask_dyn is not None:
574 | self.remove_from_all_optim_dyn(
575 | camera_idx, optimizers, deleted_mask_dyn
576 | )
577 |
578 | self.xys_grad_norm_dyn[str(camera_idx)] = None
579 | self.vis_counts_dyn[str(camera_idx)] = None
580 | self.max_2Dsize_dyn[str(camera_idx)] = None
581 |
582 | elif (
583 | self.step >= self.config.stop_split_at
584 | and self.config.continue_cull_post_densification
585 | and self.step % (
586 | self.num_train_data * self.config.refine_every_dyn
587 | ) < self.config.refine_every
588 | ):
589 | for camera_idx in range(self.num_train_data):
590 | if self.num_points_dyn(camera_idx) == 0:
591 | self.xys_grad_norm_dyn[str(camera_idx)] = None
592 | self.vis_counts_dyn[str(camera_idx)] = None
593 | self.max_2Dsize_dyn[str(camera_idx)] = None
594 | continue
595 |
596 | deleted_mask_dyn = self.cull_gaussians_dyn(camera_idx)
597 | if deleted_mask_dyn is not None:
598 | self.remove_from_all_optim_dyn(
599 | camera_idx, optimizers, deleted_mask_dyn
600 | )
601 | self.xys_grad_norm_dyn[str(camera_idx)] = None
602 | self.vis_counts_dyn[str(camera_idx)] = None
603 | self.max_2Dsize_dyn[str(camera_idx)] = None
604 |
605 |
606 | reset_dyn = (
607 | self.step < self.config.stop_split_at
608 | and self.step % reset_interval == self.config.refine_every
609 | and self.config.enable_reset_alpha
610 | )
611 | # reset
612 | if reset_dyn:
613 | # Reset value for dynamic points
614 | for camera_idx in range(self.num_train_data):
615 | reset_value = self.config.cull_alpha_thresh * 2.0
616 | dyn_gaussians = self.gauss_params_dyn_dict[str(camera_idx)]
617 |
618 | dyn_gaussians["opacities_dyn"].data = torch.clamp(
619 | dyn_gaussians["opacities_dyn"],
620 | max=torch.logit(
621 | torch.tensor(reset_value, device=self.device)
622 | ).item(),
623 | )
624 |
625 | # reset the exp of optimizer
626 | optim = optimizers.optimizers["opacities_dyn"]
627 |
628 | param = optim.param_groups[0]["params"][camera_idx]
629 | param_state = optim.state[param]
630 | if "exp_avg" in param_state:
631 | param_state["exp_avg"] = torch.zeros_like(
632 | param_state["exp_avg"]
633 | )
634 | param_state["exp_avg_sq"] = torch.zeros_like(
635 | param_state["exp_avg_sq"]
636 | )
637 |
638 | def cull_gaussians_dyn(self, i, extra_cull_mask: Optional[torch.Tensor] = None):
639 | """
640 | This function deletes gaussians with under a certain opacity threshold
641 | extra_cull_mask: a mask indicates extra gaussians to cull besides existing culling criterion
642 | """
643 | n_bef = self.num_points_dyn(i)
644 | # cull transparent ones
645 | culls = (
646 | torch.sigmoid(self.gauss_params_dyn_dict[str(i)]["opacities_dyn"])
647 | < self.config.cull_alpha_thresh
648 | ).squeeze() # self.config.cull_alpha_thresh
649 | # if the point is invisible for all the camera, cull it
650 |
651 | below_alpha_count = torch.sum(culls).item()
652 | toobigs_count = 0
653 | if extra_cull_mask is not None:
654 | culls = culls | extra_cull_mask
655 |
656 | if self.step > self.config.refine_every * self.config.reset_alpha_every:
657 | # cull huge ones
658 | toobigs = (
659 | torch.exp(self.gauss_params_dyn_dict[str(i)]["scales_dyn"])
660 | .max(dim=-1)
661 | .values
662 | > self.config.cull_scale_thresh
663 | ).squeeze()
664 | if self.step < self.config.stop_screen_size_at:
665 | # cull big screen space
666 | if self.max_2Dsize_dyn[str(i)] is not None:
667 | toobigs = (
668 | toobigs
669 | | (
670 | self.max_2Dsize_dyn[str(i)] > self.config.cull_screen_size
671 | ).squeeze()
672 | )
673 | culls = culls | toobigs
674 | toobigs_count = torch.sum(toobigs).item()
675 | for name, param in self.gauss_params_dyn_dict[str(i)].items():
676 | self.gauss_params_dyn_dict[str(i)][name] = torch.nn.Parameter(param[~culls])
677 |
678 | CONSOLE.log(
679 | f"Dynamic Culled {n_bef - self.num_points_dyn(i)} gaussians "
680 | f"({below_alpha_count} below alpha thresh, {toobigs_count} too bigs, {self.num_points_dyn(i)} remaining)"
681 | )
682 |
683 | return culls
684 |
685 | def split_gaussians_dyn(self, i, split_mask, samps):
686 | """
687 | This function splits gaussians that are too large
688 | """
689 | n_splits = split_mask.sum().item()
690 | CONSOLE.log(
691 | f"Dynamic Splitting {split_mask.sum().item()/self.num_points_dyn(i)} gaussians: {n_splits}/{self.num_points_dyn(i)}"
692 | )
693 | centered_samples = torch.randn(
694 | (samps * n_splits, 3), device=self.device
695 | ) # Nx3 of axis-aligned scales
696 | scaled_samples = (
697 | torch.exp(
698 | self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask].repeat(
699 | samps, 1
700 | )
701 | )
702 | * centered_samples
703 | ) # how these scales are rotated
704 | quats = self.gauss_params_dyn_dict[str(i)]["quats_dyn"][
705 | split_mask
706 | ] / self.gauss_params_dyn_dict[str(i)]["quats_dyn"][split_mask].norm(
707 | dim=-1, keepdim=True
708 | ) # normalize them first
709 | rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated
710 | rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze()
711 | new_means = rotated_samples + self.gauss_params_dyn_dict[str(i)]["means_dyn"][
712 | split_mask
713 | ].repeat(samps, 1)
714 | # step 2, sample new colors
715 | new_colors = self.gauss_params_dyn_dict[str(i)]["rgbs_dyn"][split_mask].repeat(
716 | samps, 1
717 | )
718 | # new_features_dc = self.gauss_params_dyn_dict[str(i)]["features_dc_dyn"][split_mask].repeat(samps, 1)
719 | # new_features_rest = self.gauss_params_dyn_dict[str(i)]["features_rest_dyn"][split_mask].repeat(samps, 1, 1)
720 | # step 3, sample new opacities
721 | new_opacities = self.gauss_params_dyn_dict[str(i)]["opacities_dyn"][
722 | split_mask
723 | ].repeat(samps, 1)
724 | # step 4, sample new scales
725 | size_fac = 1.6
726 | new_scales = torch.log(
727 | torch.exp(self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask])
728 | / size_fac
729 | ).repeat(samps, 1)
730 | self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask] = torch.log(
731 | torch.exp(self.gauss_params_dyn_dict[str(i)]["scales_dyn"][split_mask])
732 | / size_fac
733 | )
734 | # step 5, sample new quats
735 | new_quats = self.gauss_params_dyn_dict[str(i)]["quats_dyn"][split_mask].repeat(
736 | samps, 1
737 | )
738 |
739 | out = {
740 | "means_dyn": new_means,
741 | "rgbs_dyn": new_colors,
742 | "opacities_dyn": new_opacities,
743 | "scales_dyn": new_scales,
744 | "quats_dyn": new_quats,
745 | }
746 | for name, param in self.gauss_params_dyn_dict[str(i)].items():
747 | if name not in out:
748 | out[name] = param[split_mask].repeat(samps, 1)
749 | return out
750 |
751 | def dup_gaussians_dyn(self, i, dup_mask):
752 | """
753 | This function duplicates gaussians that are too small
754 | """
755 | n_dups = dup_mask.sum().item()
756 | CONSOLE.log(
757 | f"Dynamic Duplication: Duplicating {dup_mask.sum().item()/self.num_points_dyn(i)} gaussians: {n_dups}/{self.num_points_dyn(i)}"
758 | )
759 | new_dups = {}
760 | for name, param in self.gauss_params_dyn_dict[str(i)].items():
761 | new_dups[name] = param[dup_mask]
762 | return new_dups
763 |
764 | def dup_in_all_optim_dyn(self, i, optimizers, dup_mask, n):
765 | param_groups = self.get_gaussian_param_groups_dyn_dict()
766 |
767 | for group, _ in param_groups.items():
768 | param = param_groups[group][i]
769 | self.dup_in_optim_dyn(
770 | i, optimizers.optimizers[group], dup_mask, param, n
771 | )
772 |
773 | self.radii_dyn[str(i)] = torch.cat(
774 | [self.radii_dyn[str(i)], self.radii_dyn[str(i)][dup_mask.squeeze()].repeat(n,)], dim=0
775 | )
776 |
777 | def dup_in_optim_dyn(self, i, optimizer, dup_mask, new_params, n=2):
778 | """adds the parameters to the optimizer"""
779 | param = optimizer.param_groups[0]["params"][i][0]
780 | param_state = optimizer.state[param]
781 | if "exp_avg" in param_state:
782 | repeat_dims = (n,) + tuple(
783 | 1 for _ in range(param_state["exp_avg"].dim() - 1)
784 | )
785 | param_state["exp_avg"] = torch.cat(
786 | [
787 | param_state["exp_avg"],
788 | torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(
789 | *repeat_dims
790 | ),
791 | ],
792 | dim=0,
793 | )
794 | param_state["exp_avg_sq"] = torch.cat(
795 | [
796 | param_state["exp_avg_sq"],
797 | torch.zeros_like(
798 | param_state["exp_avg_sq"][dup_mask.squeeze()]
799 | ).repeat(*repeat_dims),
800 | ],
801 | dim=0,
802 | )
803 | del optimizer.state[param]
804 | optimizer.state[new_params] = param_state
805 | optimizer.param_groups[0]["params"][i] = new_params
806 | del param
807 |
808 | def remove_from_all_optim_dyn(self, i, optimizers, deleted_mask):
809 | param_groups = self.get_gaussian_param_groups_dyn_dict()
810 | for group, _ in param_groups.items():
811 | param = param_groups[group][i]
812 | self.remove_from_optim_dyn(
813 | i, optimizers.optimizers[group], deleted_mask, param
814 | ) #
815 | torch.cuda.empty_cache()
816 |
817 | def remove_from_optim_dyn(self, i, optimizer, deleted_mask, new_params):
818 | """removes the deleted_mask from the optimizer provided"""
819 | param = optimizer.param_groups[0]["params"][i][0]
820 | param_state = optimizer.state[param]
821 | del optimizer.state[param]
822 |
823 | # Modify the state directly without deleting and reassigning.
824 | if "exp_avg" in param_state:
825 | param_state["exp_avg"] = param_state["exp_avg"][~deleted_mask]
826 | param_state["exp_avg_sq"] = param_state["exp_avg_sq"][~deleted_mask]
827 |
828 | # Update the parameter in the optimizer's param group.
829 | del optimizer.param_groups[0]["params"][i]
830 | optimizer.param_groups[0]["params"].insert(i, new_params)
831 | optimizer.state[new_params] = param_state
832 |
833 | def after_train(self, step: int):
834 | assert step == self.step
835 | # to save some training time, we no longer need to update those stats post refinement
836 | if self.step >= self.config.stop_split_at:
837 | if self.config.continue_split_dyn:
838 | with torch.no_grad():
839 | visible_mask_dyn = (
840 | self.radii_dyn[str(self.camera_idx)] > 0
841 | ).flatten()
842 | grads_dyn = (
843 | self.xys_dyn[str(self.camera_idx)]
844 | .absgrad[0][visible_mask_dyn]
845 | .norm(dim=-1)
846 | ) # type: ignore
847 | if self.xys_grad_norm_dyn[str(self.camera_idx)] is None:
848 | self.xys_grad_norm_dyn[str(self.camera_idx)] = torch.zeros(
849 | self.num_points_dyn(self.camera_idx),
850 | device=self.device,
851 | dtype=torch.float32,
852 | ) # + self.num_points_dyn
853 | self.vis_counts_dyn[str(self.camera_idx)] = torch.ones(
854 | self.num_points_dyn(self.camera_idx),
855 | device=self.device,
856 | dtype=torch.float32,
857 | ) # + self.num_points_dyn
858 |
859 | assert self.vis_counts_dyn[str(self.camera_idx)] is not None
860 | self.vis_counts_dyn[str(self.camera_idx)][visible_mask_dyn] += 1
861 | self.xys_grad_norm_dyn[str(self.camera_idx)][visible_mask_dyn] += (
862 | grads_dyn
863 | )
864 | # update the max screen size, as a ratio of number of pixels
865 | if self.max_2Dsize_dyn[str(self.camera_idx)] is None:
866 | self.max_2Dsize_dyn[str(self.camera_idx)] = torch.zeros_like(
867 | self.radii_dyn[str(self.camera_idx)], dtype=torch.float32
868 | )
869 | newradii_dyn = self.radii_dyn[str(self.camera_idx)].detach()[
870 | visible_mask_dyn
871 | ]
872 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn] = (
873 | torch.maximum(
874 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn],
875 | newradii_dyn
876 | / float(max(self.last_size[0], self.last_size[1])),
877 | )
878 | )
879 | return
880 |
881 | with torch.no_grad():
882 | # keep track of a moving average of grad norms
883 | visible_mask = (self.radii > 0).flatten()
884 | grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore
885 | if self.xys_grad_norm is None:
886 | self.xys_grad_norm = torch.zeros(
887 | self.num_points, device=self.device, dtype=torch.float32
888 | )
889 | self.vis_counts = torch.ones(
890 | self.num_points, device=self.device, dtype=torch.float32
891 | )
892 | assert self.vis_counts is not None
893 | self.vis_counts[visible_mask] += 1
894 | self.xys_grad_norm[visible_mask] += grads
895 | # update the max screen size, as a ratio of number of pixels
896 | if self.max_2Dsize is None:
897 | self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32)
898 | newradii = self.radii.detach()[visible_mask]
899 | self.max_2Dsize[visible_mask] = torch.maximum(
900 | self.max_2Dsize[visible_mask],
901 | newradii / float(max(self.last_size[0], self.last_size[1])),
902 | )
903 |
904 | if self.config.use_adc:
905 | # for dynamic points
906 | visible_mask_dyn = (self.radii_dyn[str(self.camera_idx)] > 0).flatten()
907 | grads_dyn = (
908 | self.xys_dyn[str(self.camera_idx)]
909 | .absgrad[0][visible_mask_dyn]
910 | .norm(dim=-1)
911 | ) # type: ignore
912 | if self.xys_grad_norm_dyn[str(self.camera_idx)] is None:
913 | self.xys_grad_norm_dyn[str(self.camera_idx)] = torch.zeros(
914 | self.num_points_dyn(self.camera_idx),
915 | device=self.device,
916 | dtype=torch.float32,
917 | )
918 | self.vis_counts_dyn[str(self.camera_idx)] = torch.ones(
919 | self.num_points_dyn(self.camera_idx),
920 | device=self.device,
921 | dtype=torch.float32,
922 | )
923 |
924 | assert self.vis_counts_dyn[str(self.camera_idx)] is not None
925 | self.vis_counts_dyn[str(self.camera_idx)][visible_mask_dyn] += 1
926 | self.xys_grad_norm_dyn[str(self.camera_idx)][visible_mask_dyn] += (
927 | grads_dyn
928 | )
929 | # update the max screen size, as a ratio of number of pixels
930 | if self.max_2Dsize_dyn[str(self.camera_idx)] is None:
931 | self.max_2Dsize_dyn[str(self.camera_idx)] = torch.zeros_like(
932 | self.radii_dyn[str(self.camera_idx)], dtype=torch.float32
933 | )
934 | newradii_dyn = self.radii_dyn[str(self.camera_idx)].detach()[
935 | visible_mask_dyn
936 | ]
937 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn] = (
938 | torch.maximum(
939 | self.max_2Dsize_dyn[str(self.camera_idx)][visible_mask_dyn],
940 | newradii_dyn / float(max(self.last_size[0], self.last_size[1])),
941 | )
942 | )
943 |
944 | def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]:
945 | # Here we explicitly use the means, scales as parameters so that the user can override this function and
946 | # specify more if they want to add more optimizable params to gaussians.
947 |
948 | keys = [
949 | "means",
950 | "scales",
951 | "quats",
952 | "features_dc",
953 | "features_rest",
954 | "opacities",
955 | ]
956 | if "embeddings" in self.gauss_params:
957 | keys.append("embeddings") # Add dynamically if it exists
958 |
959 | return {
960 | name: [self.gauss_params[name]]
961 | for name in keys
962 | }
963 |
964 |
965 | def get_gaussian_param_groups_dyn_dict(
966 | self,
967 | ) -> Dict[str, dict[int, List[Parameter]]]:
968 | return {
969 | name: [
970 | self.gauss_params_dyn_dict[str(i)][name]
971 | for i in range(self.num_train_data)
972 | ]
973 | for name in [
974 | "means_dyn",
975 | "scales_dyn",
976 | "quats_dyn",
977 | "rgbs_dyn",
978 | "opacities_dyn",
979 | ]
980 | }
981 |
982 | def get_param_groups(self):
983 | """Obtain the parameter groups for the optimizers
984 |
985 | Returns:
986 | Mapping of different parameter groups
987 | """
988 | gps = self.get_gaussian_param_groups()
989 | self.camera_optimizer.get_param_groups(param_groups=gps)
990 | gps_dyn = self.get_gaussian_param_groups_dyn_dict()
991 | gps_bg = {}
992 | if self.config.enable_bg_model:
993 | assert self.bg_model is not None
994 | gps_bg["field_background_encoder"] = list(self.bg_model.encoder.parameters())
995 | gps_bg["field_background_base"] = list(self.bg_model.sh_base_head.parameters())
996 | gps_bg["field_background_rest"] = list(self.bg_model.sh_rest_head.parameters())
997 |
998 | if self.config.enable_appearance:
999 | gps["appearance_mlp"] = list(self.appearance_mlp.parameters())
1000 | gps["appearance_embeddings"] = list(self.appearance_embeddings.parameters())
1001 |
1002 | return {**gps, **gps_dyn, **gps_bg}
1003 |
1004 | def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
1005 | """Takes in a Ray Bundle and returns a dictionary of outputs.
1006 |
1007 | Args:
1008 | ray_bundle: Input bundle of rays. This raybundle should have all the
1009 | needed information to compute the outputs.
1010 |
1011 | Returns:
1012 | Outputs of model. (ie. rendered colors)
1013 | """
1014 | if not isinstance(camera, Cameras):
1015 | print("Called get_outputs with not a camera")
1016 | return {}
1017 |
1018 | if self.training:
1019 | assert camera.shape[0] == 1, "Only one camera at a time"
1020 | optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)
1021 | else:
1022 | optimized_camera_to_world = camera.camera_to_worlds
1023 |
1024 | # cropping
1025 | if self.crop_box is not None and not self.training:
1026 | crop_ids = self.crop_box.within(self.means).squeeze()
1027 | if crop_ids.sum() == 0:
1028 | return self.get_empty_outputs(
1029 | int(camera.width.item()),
1030 | int(camera.height.item()),
1031 | self.background_color,
1032 | )
1033 | else:
1034 | crop_ids = None
1035 |
1036 | if crop_ids is not None:
1037 | opacities_crop = self.opacities[crop_ids]
1038 | means_crop = self.means[crop_ids]
1039 | features_dc_crop = self.features_dc[crop_ids]
1040 | features_rest_crop = self.features_rest[crop_ids]
1041 | scales_crop = self.scales[crop_ids]
1042 | quats_crop = self.quats[crop_ids]
1043 | if self.config.app_per_gauss:
1044 | embeddings_crop = self.embeddings[crop_ids]
1045 | else:
1046 | opacities_crop = self.opacities
1047 | means_crop = self.means
1048 | features_dc_crop = self.features_dc
1049 | features_rest_crop = self.features_rest
1050 | scales_crop = self.scales
1051 | quats_crop = self.quats
1052 | if self.config.app_per_gauss:
1053 | embeddings_crop = self.embeddings
1054 |
1055 | colors_crop = torch.cat(
1056 | (features_dc_crop[:, None, :], features_rest_crop), dim=1
1057 | )
1058 |
1059 | BLOCK_WIDTH = (
1060 | 16 # this controls the tile size of rasterization, 16 is a good default
1061 | )
1062 | camera_scale_fac = self._get_downscale_factor()
1063 | camera.rescale_output_resolution(1 / camera_scale_fac)
1064 | viewmat = get_viewmat(optimized_camera_to_world)
1065 | K = camera.get_intrinsics_matrices().cuda()
1066 | W, H = int(camera.width.item()), int(camera.height.item())
1067 | self.last_size = (H, W)
1068 | # camera.rescale_output_resolution(camera_scale_fac) # type: ignore
1069 | # apply the compensation of screen space blurring to gaussians
1070 | if self.config.rasterize_mode not in ["antialiased", "classic"]:
1071 | raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode)
1072 | if self.config.enable_appearance:
1073 | if camera.metadata is not None and "cam_idx" in camera.metadata:
1074 | # if self.training or self.dataparser_config.eval_train:
1075 | self.camera_idx = camera.metadata["cam_idx"]
1076 | appearance_embed = self.appearance_embeddings(
1077 | torch.tensor(self.camera_idx, device=self.device)
1078 | )
1079 | else:
1080 | # appearance_embed is zero
1081 | appearance_embed = torch.zeros(
1082 | self.config.appearance_embedding_dim, device=self.device
1083 | )
1084 | assert self.appearance_mlp is not None
1085 | # assert self.embeddings is not None
1086 |
1087 | features = torch.cat(
1088 | (features_dc_crop.unsqueeze(1), features_rest_crop), dim=1
1089 | )
1090 | # offset, mul = self.appearance_mlp(appearance_embed.repeat(self.num_points, 1), features_dc_crop)
1091 | offset, mul = self.appearance_mlp(
1092 | self.embeddings if self.config.app_per_gauss else None,
1093 | appearance_embed.repeat(self.num_points, 1),
1094 | features_dc_crop,
1095 | )
1096 | colors_toned = colors_crop * mul.unsqueeze(1) + offset.unsqueeze(1)
1097 | shdim = (self.config.sh_degree + 1) ** 2
1098 | colors_toned = colors_toned.view(-1, shdim, 3).contiguous().clamp_max(1.0)
1099 | # colors_toned = eval_sh(self.active_sh_degree, colors_toned, dir_pp_normalized)
1100 | colors_toned = torch.clamp_min(colors_toned + 0.5, 0.0)
1101 | colors_crop = colors_toned
1102 |
1103 | if self.config.sh_degree > 0:
1104 | sh_degree_to_use = min(
1105 | self.step // self.config.sh_degree_interval, self.config.sh_degree
1106 | )
1107 | bg_sh_degree_to_use = min(
1108 | self.step // (self.config.sh_degree_interval // 2),
1109 | self.config.bg_sh_degree,
1110 | )
1111 | else:
1112 | colors_crop = torch.sigmoid(colors_crop).squeeze(1) # [N, 1, 3] -> [N, 3]
1113 | sh_degree_to_use = None
1114 |
1115 | render_3d, alpha_3d, info_3d = rasterization(
1116 | means=means_crop,
1117 | quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True),
1118 | scales=torch.exp(scales_crop),
1119 | opacities=torch.sigmoid(opacities_crop).squeeze(-1),
1120 | colors=colors_crop,
1121 | viewmats=viewmat, # [1, 4, 4]
1122 | Ks=K, # [1, 3, 3]
1123 | width=W,
1124 | height=H,
1125 | tile_size=BLOCK_WIDTH,
1126 | packed=False,
1127 | near_plane=0.01,
1128 | far_plane=1e10,
1129 | render_mode="RGB+ED", # render_mode,
1130 | sh_degree=sh_degree_to_use,
1131 | sparse_grad=False,
1132 | absgrad=True,
1133 | rasterize_mode=self.config.rasterize_mode,
1134 | # set some threshold to disregrad small gaussians for faster rendering.
1135 | # radius_clip=3.0,
1136 | )
1137 | if self.training and info_3d["means2d"].requires_grad:
1138 | info_3d["means2d"].retain_grad()
1139 |
1140 | self.xys = info_3d["means2d"] # [1, N, 2]
1141 | self.radii = info_3d["radii"][0] # [N]
1142 |
1143 | alpha_3d = alpha_3d[:, ...]
1144 |
1145 | # Only need for one time
1146 | ### BACKGROUND MODEL
1147 | if self.config.enable_bg_model:
1148 | directions = normalize(
1149 | camera.generate_rays(camera_indices=0, keep_shape=False).directions
1150 | )
1151 |
1152 | bg_sh_coeffs = self.bg_model.get_sh_coeffs(
1153 | appearance_embedding=appearance_embed,
1154 | )
1155 |
1156 | background = spherical_harmonics(
1157 | degrees_to_use=bg_sh_degree_to_use,
1158 | coeffs=bg_sh_coeffs.repeat(directions.shape[0], 1, 1),
1159 | dirs=directions,
1160 | )
1161 | background = background.view(H, W, 3)
1162 | else:
1163 | background = self._get_background_color().view(1, 1, 3)
1164 |
1165 | rgb_static = render_3d[:, ..., :3] + (1 - alpha_3d) * background
1166 |
1167 | rgb_static = torch.clamp(rgb_static, 0.0, 1.0)
1168 |
1169 | rgb = rgb_static
1170 |
1171 | if background.shape[0] == 3 and not self.training:
1172 | background = background.expand(H, W, 3)
1173 |
1174 | returns = {}
1175 | returns["rgb"] = rgb.squeeze(0)
1176 | returns["depth"] = render_3d[..., 3].squeeze(0).unsqueeze(-1)
1177 | returns["background"] = background
1178 | returns["accumulation"] = alpha_3d.squeeze(0)
1179 |
1180 | img_dyn, alpha_2d, depth_2d = None, None, None # for debug
1181 | if camera.metadata is not None and "cam_idx" in camera.metadata:
1182 | self.camera_idx = camera.metadata["cam_idx"]
1183 |
1184 | dyn_gaussians = self.gauss_params_dyn_dict[str(self.camera_idx)]
1185 | opacities_dyn = dyn_gaussians["opacities_dyn"]
1186 | means_dyn = dyn_gaussians["means_dyn"]
1187 | rgbs_dyn = dyn_gaussians["rgbs_dyn"]
1188 | scales_dyn = dyn_gaussians["scales_dyn"]
1189 | quats_dyn = dyn_gaussians["quats_dyn"]
1190 |
1191 | colors_dyn = rgbs_dyn.clamp_min(0.0).clamp_max(1.0)
1192 |
1193 | render_2d, alpha_2d, info_2d = rasterization(
1194 | means=means_dyn,
1195 | quats=quats_dyn / quats_dyn.norm(dim=-1, keepdim=True),
1196 | scales=torch.exp(scales_dyn),
1197 | opacities=torch.sigmoid(opacities_dyn).squeeze(-1),
1198 | colors=colors_dyn, # [N, 3]
1199 | viewmats=viewmat, # [C, 4, 4]
1200 | Ks=K, # [C, 3, 3]
1201 | width=W,
1202 | height=H,
1203 | tile_size=BLOCK_WIDTH,
1204 | render_mode="RGB+ED",
1205 | sh_degree=None, # sh_degree_to_use,
1206 | packed=False,
1207 | sparse_grad=False,
1208 | absgrad=True,
1209 | rasterize_mode=self.config.rasterize_mode,
1210 | )
1211 | if info_2d["means2d"].requires_grad:
1212 | info_2d["means2d"].retain_grad()
1213 | if self.config.use_adc:
1214 | self.xys_dyn[str(self.camera_idx)] = info_2d["means2d"] # [1, N, 2]
1215 | self.radii_dyn[str(self.camera_idx)] = info_2d["radii"][0] # [N]
1216 | depth_2d = render_2d[..., 3].unsqueeze(-1) # [H, W, 1]
1217 | render_2d = render_2d[..., :3] # [H, W, 3]
1218 | img_dyn = render_2d.squeeze(0)
1219 | img_dyn = torch.clamp(img_dyn, 0.0, 1.0)
1220 |
1221 | rgb = (
1222 | render_2d
1223 | + (1 - alpha_2d) * render_3d[:, ..., :3]
1224 | + (1 - alpha_3d) * (1 - alpha_2d) * background
1225 | )
1226 | rgb = torch.clamp(rgb, 0.0, 1.0)
1227 | camera.rescale_output_resolution(camera_scale_fac) # type: ignore
1228 | alpha_2d = alpha_2d.squeeze(0)
1229 | depth_2d = depth_2d.squeeze(0)
1230 |
1231 | returns["img_dyn"] = img_dyn
1232 | returns["alpha_2d"] = alpha_2d
1233 | returns["depth_2d"] = depth_2d
1234 | returns["rgb"] = rgb.squeeze(0)
1235 | returns["rgb_static"] = rgb_static.squeeze(0)
1236 | else:
1237 | self.camera_idx = None
1238 |
1239 | return returns
1240 |
1241 | def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
1242 | """Compute and returns metrics.
1243 |
1244 | Args:
1245 | outputs: the output to compute loss dict to
1246 | batch: ground truth batch corresponding to outputs
1247 | """
1248 | gt_rgb = self.composite_with_background(
1249 | self.get_gt_img(batch["image"]), outputs["background"]
1250 | )
1251 | metrics_dict = {}
1252 | predicted_rgb = outputs["rgb"]
1253 | metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)
1254 |
1255 | metrics_dict["gaussian_count_static"] = self.num_points
1256 | # metrics_dict["gaussian_count_transient"] = self.num_points_dyn
1257 |
1258 | self.camera_optimizer.get_metrics_dict(metrics_dict)
1259 | return metrics_dict
1260 |
1261 | def get_loss_dict(
1262 | self, outputs, batch, metrics_dict=None
1263 | ) -> Dict[str, torch.Tensor]:
1264 | """Computes and returns the losses dict.
1265 |
1266 | Args:
1267 | outputs: the output to compute loss dict to
1268 | batch: ground truth batch corresponding to outputs
1269 | metrics_dict: dictionary of metrics, some of which we can use for loss
1270 | """
1271 | gt_img = self.composite_with_background(
1272 | self.get_gt_img(batch["image"]), outputs["background"]
1273 | )
1274 | pred_img = outputs["rgb"]
1275 |
1276 | Ll1_img = torch.abs(gt_img - pred_img)
1277 |
1278 | Ll1 = Ll1_img.mean()
1279 | simloss = 1 - self.ssim(
1280 | gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...]
1281 | )
1282 | main_loss = (
1283 | 1 - self.config.ssim_lambda
1284 | ) * Ll1 + self.config.ssim_lambda * simloss
1285 |
1286 | if self.config.use_scale_regularization and self.step % 10 == 0:
1287 | scale_exp = torch.exp(self.scales)
1288 | scale_reg = (
1289 | torch.maximum(
1290 | scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1),
1291 | torch.tensor(self.config.max_gauss_ratio),
1292 | )
1293 | - self.config.max_gauss_ratio
1294 | )
1295 | scale_reg = 0.1 * scale_reg.mean()
1296 | else:
1297 | scale_reg = torch.tensor(0.0).to(self.device)
1298 |
1299 | loss_dict = {
1300 | "main_loss": main_loss,
1301 | "scale_reg": scale_reg,
1302 | }
1303 |
1304 | # add loss to prevent the hole of static
1305 | if self.config.alpha_bg_loss_lambda > 0:
1306 | if self.config.enable_bg_model:
1307 | alpha_loss = torch.tensor(0.0).to(self.device)
1308 | background = outputs["background"]
1309 | alpha = outputs["alpha_2d"].detach() + (1 - outputs["alpha_2d"].detach()) * outputs["accumulation"]
1310 | # alpha = outputs["accumulation"]
1311 | # for those pixel are well represented by bg and has low alpha, we encourage the gaussian to be transparent
1312 | bg_mask = torch.abs(gt_img - background).mean(dim=-1, keepdim=True) < 0.003
1313 | # use a box filter to avoid penalty high frequency parts
1314 | f = 3
1315 | window = (torch.ones((f, f)).view(1, 1, f, f) / (f * f)).cuda()
1316 | bg_mask = (
1317 | torch.nn.functional.conv2d(
1318 | bg_mask.float().unsqueeze(0).permute(0, 3, 1, 2),
1319 | window,
1320 | stride=1,
1321 | padding="same",
1322 | )
1323 | .permute(0, 2, 3, 1)
1324 | .squeeze(0)
1325 | )
1326 | alpha_mask = bg_mask > 0.6
1327 | # prevent NaN
1328 | if alpha_mask.sum() != 0:
1329 | alpha_loss = alpha[alpha_mask].mean() * self.config.alpha_bg_loss_lambda # default: 0.15
1330 | else:
1331 | alpha_loss = torch.tensor(0.0).to(self.device)
1332 | loss_dict["alpha_bg_loss"] = alpha_loss
1333 | else:
1334 | alpha_bg = outputs["accumulation"]
1335 | loss_dict["alpha_bg_loss"] = self.config.alpha_bg_loss_lambda * (1 - alpha_bg).mean()
1336 |
1337 | if self.config.alpha_2d_loss_lambda > 0 and "alpha_2d" in outputs:
1338 | alpha_2d = outputs["alpha_2d"]
1339 | loss_dict["alpha_2d_loss"] = self.config.alpha_2d_loss_lambda * alpha_2d.mean()
1340 |
1341 | return loss_dict
1342 |
1343 | def get_gt_img(self, image: torch.Tensor):
1344 | """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
1345 |
1346 | Args:
1347 | image: tensor.Tensor in type uint8 or float32
1348 | """
1349 | if image.dtype == torch.uint8:
1350 | image = image.float() / 255.0
1351 | gt_img = self._downscale_if_required(image)
1352 | return gt_img.to(self.device)
1353 |
1354 | def get_image_metrics_and_images(
1355 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]
1356 | ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]:
1357 | gt_rgb = self.composite_with_background(
1358 | self.get_gt_img(batch["image"]), outputs["background"]
1359 | )
1360 | predicted_rgb = outputs["rgb"]
1361 |
1362 | assert gt_rgb.shape == predicted_rgb.shape
1363 | combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)
1364 |
1365 | # Switch images from [H, W, C] to [1, C, H, W] for metrics computations
1366 | gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...]
1367 | predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...]
1368 |
1369 | psnr = self.psnr(gt_rgb, predicted_rgb)
1370 | ssim = self.ssim(gt_rgb, predicted_rgb)
1371 | lpips = self.lpips(gt_rgb, predicted_rgb)
1372 |
1373 | # all of these metrics will be logged as scalars
1374 | metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore
1375 | metrics_dict["lpips"] = float(lpips)
1376 |
1377 | images_dict = {
1378 | "img": combined_rgb,
1379 | }
1380 |
1381 | return metrics_dict, images_dict
1382 |
--------------------------------------------------------------------------------
/desplat/field.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import reduce
3 | from operator import mul
4 | from typing import Literal
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import Tensor, nn
9 |
10 | from gsplat.cuda._wrapper import spherical_harmonics
11 | from nerfstudio.cameras.rays import RayBundle
12 | from nerfstudio.field_components import MLP
13 | from nerfstudio.fields.base_field import Field
14 |
15 |
16 | def _get_fourier_features(xyz: Tensor, num_features=3):
17 | xyz = xyz - xyz.mean(dim=0, keepdim=True)
18 | xyz = xyz / torch.quantile(xyz.abs(), 0.97, dim=0) * 0.5 + 0.5
19 | freqs = torch.repeat_interleave(
20 | 2
21 | ** torch.linspace(
22 | 0, num_features - 1, num_features, dtype=xyz.dtype, device=xyz.device
23 | ),
24 | 2,
25 | )
26 | offsets = torch.tensor(
27 | [0, 0.5 * math.pi] * num_features, dtype=xyz.dtype, device=xyz.device
28 | )
29 | feat = xyz[..., None] * freqs[None, None] * 2 * math.pi + offsets[None, None]
30 | feat = torch.sin(feat).view(-1, reduce(mul, feat.shape[1:]))
31 | return feat
32 |
33 |
34 | class EmbeddingModel(nn.Module):
35 | def __init__(self, config):
36 | super().__init__()
37 | self.config = config
38 | # sh_coeffs = 4**2
39 | self.feat_in = 3
40 | input_dim = config.appearance_embedding_dim
41 | if config.app_per_gauss:
42 | input_dim += 6 * self.config.appearance_n_fourier_freqs + self.feat_in
43 |
44 | self.mlp = nn.Sequential(
45 | nn.Linear(input_dim, 128),
46 | nn.ReLU(),
47 | nn.Linear(128, 128),
48 | nn.ReLU(),
49 | nn.Linear(128, self.feat_in * 2),
50 | )
51 |
52 | self.bg_head = nn.Linear(self.feat_in * 2, 3)
53 |
54 | def forward(self, gembedding, aembedding, features_dc, viewdir=None):
55 | del viewdir # Viewdirs interface is kept to be compatible with prev. version
56 | if self.config.app_per_gauss and gembedding is not None:
57 | inp = torch.cat((features_dc, gembedding, aembedding), dim=-1)
58 | else:
59 | inp = aembedding
60 | offset, mul = torch.split(
61 | self.mlp(inp) * 0.01, [self.feat_in, self.feat_in], dim=-1
62 | )
63 | return offset, mul
64 |
65 | def get_bg_color(self, aembedding):
66 | return self.bg_head(self.mlp(aembedding))
67 |
68 |
69 | class BackgroundModel(nn.Module):
70 | def __init__(self, config):
71 | super().__init__()
72 | self.config = config
73 | # sh_coeffs = 4**2
74 | self.feat_in = 3
75 | input_dim = config.appearance_embedding_dim
76 | if config.app_per_gauss:
77 | input_dim += 6 * self.config.appearance_n_fourier_freqs + self.feat_in
78 |
79 | self.mlp = nn.Sequential(
80 | nn.Linear(input_dim, 128),
81 | nn.ReLU(),
82 | nn.Linear(128, 128),
83 | nn.ReLU(),
84 | nn.Linear(128, self.feat_in * 2),
85 | )
86 |
87 | def forward(self, gembedding, aembedding, features_dc, viewdir=None):
88 | del viewdir # Viewdirs interface is kept to be compatible with prev. version
89 |
90 | if self.config.app_per_gauss_bg:
91 | inp = torch.cat((features_dc, gembedding, aembedding), dim=-1)
92 | else:
93 | inp = aembedding
94 | offset, mul = torch.split(
95 | self.mlp(inp) * 0.01, [self.feat_in, self.feat_in], dim=-1
96 | )
97 | return offset, mul
98 |
99 | class BGField(Field):
100 | def __init__(
101 | self,
102 | appearance_embedding_dim: int,
103 | implementation: Literal["tcnn", "torch"] = "torch",
104 | sh_levels: int = 4,
105 | layer_width: int = 128,
106 | num_layers: int = 3,
107 | ):
108 | super().__init__()
109 | self.sh_dim = (sh_levels + 1) ** 2
110 |
111 | self.encoder = MLP(
112 | in_dim=appearance_embedding_dim,
113 | num_layers=num_layers - 1,
114 | layer_width=layer_width,
115 | out_dim=layer_width,
116 | activation=nn.ReLU(),
117 | out_activation=nn.ReLU(),
118 | implementation=implementation,
119 | )
120 | self.sh_base_head = nn.Linear(layer_width, 3)
121 | self.sh_rest_head = nn.Linear(layer_width, (self.sh_dim - 1) * 3)
122 | # zero initialization
123 | self.sh_rest_head.weight.data.zero_()
124 | self.sh_rest_head.bias.data.zero_()
125 |
126 | def get_background_rgb(
127 | self, ray_bundle: RayBundle, appearance_embedding=None, num_sh=4
128 | ) -> Tensor:
129 | """Predicts background colors at infinity."""
130 | cur_sh_dim = (num_sh + 1) ** 2
131 | directions = ray_bundle.directions.view(-1, 3)
132 | x = self.encoder(appearance_embedding).float()
133 | sh_base = self.sh_base_head(x) # [batch, 3]
134 | sh_rest = self.sh_rest_head(x)[
135 | ..., : (cur_sh_dim - 1) * 3
136 | ] # [batch, 3 * (num_sh - 1)]
137 | sh_coeffs = (
138 | torch.cat([sh_base, sh_rest], dim=-1)
139 | .view(-1, cur_sh_dim, 3)
140 | .repeat(directions.shape[0], 1, 1)
141 | )
142 | colors = spherical_harmonics(
143 | degrees_to_use=num_sh, dirs=directions, coeffs=sh_coeffs
144 | )
145 |
146 | return colors
147 |
148 | def get_sh_coeffs(self, appearance_embedding=None) -> Tensor:
149 | x = self.encoder(appearance_embedding)
150 | base_color = self.sh_base_head(x)
151 | sh_rest = self.sh_rest_head(x)
152 | sh_coeffs = torch.cat([base_color, sh_rest], dim=-1).view(-1, self.sh_dim, 3)
153 | return sh_coeffs
--------------------------------------------------------------------------------
/desplat/pipeline.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from dataclasses import dataclass, field
3 | from pathlib import Path
4 | from time import time
5 | from typing import Dict, List, Literal, Optional, Type
6 |
7 | import torch
8 | import torch._dynamo
9 | import torch.distributed as dist
10 | import torchvision.utils as vutils
11 | from rich.progress import (
12 | BarColumn,
13 | MofNCompleteColumn,
14 | Progress,
15 | TextColumn,
16 | TimeElapsedColumn,
17 | )
18 | from torch.cuda.amp.grad_scaler import GradScaler
19 | from torch.nn import Parameter
20 | from torch.nn.parallel import DistributedDataParallel as DDP
21 |
22 | from desplat.datamanager import (
23 | DeSplatDataManager,
24 | DeSplatDataManagerConfig,
25 | )
26 | from desplat.desplat_model import DeSplatModel, DeSplatModelConfig
27 | from nerfstudio.data.datamanagers.base_datamanager import (
28 | DataManagerConfig,
29 | )
30 | from nerfstudio.models.base_model import ModelConfig
31 | from nerfstudio.pipelines.base_pipeline import (
32 | VanillaPipeline,
33 | VanillaPipelineConfig,
34 | )
35 | from nerfstudio.utils import profiler
36 |
37 | torch._dynamo.config.suppress_errors = True
38 |
39 |
40 | @dataclass
41 | class DeSplatPipelineConfig(VanillaPipelineConfig):
42 | _target: Type = field(default_factory=lambda: DeSplatPipeline)
43 | """target class to instantiate"""
44 |
45 | datamanager: DataManagerConfig = field(
46 | default_factory=lambda: DeSplatDataManagerConfig()
47 | )
48 | """specifies the datamanager config"""
49 | model: ModelConfig = field(default_factory=lambda: DeSplatModelConfig())
50 | """specifies the model config"""
51 | # finetune: bool = True
52 | # """Whether to mask the left half and evaluate the right half of the images"""
53 | test_time_optimize: bool = False
54 |
55 |
56 | class DeSplatPipeline(VanillaPipeline):
57 | def __init__(
58 | self,
59 | config: DeSplatPipelineConfig,
60 | device: str,
61 | test_mode: Literal["test", "val", "inference"] = "val",
62 | world_size: int = 1,
63 | local_rank: int = 0,
64 | grad_scaler: Optional[GradScaler] = None,
65 | ):
66 | super(VanillaPipeline, self).__init__()
67 | self.config = config
68 | self.test_mode = test_mode
69 | self.datamanager: DeSplatDataManager = config.datamanager.setup(
70 | device=device,
71 | test_mode=test_mode,
72 | world_size=world_size,
73 | local_rank=local_rank,
74 | )
75 | self.datamanager.to(device)
76 |
77 | seed_pts = None
78 | if (
79 | hasattr(self.datamanager, "train_dataparser_outputs")
80 | and "points3D_xyz" in self.datamanager.train_dataparser_outputs.metadata # type: ignore
81 | ):
82 | pts = self.datamanager.train_dataparser_outputs.metadata["points3D_xyz"] # type: ignore
83 | pts_rgb = self.datamanager.train_dataparser_outputs.metadata["points3D_rgb"] # type: ignore
84 |
85 | seed_pts = (pts, pts_rgb)
86 |
87 | # if not config.model.finetune:
88 | assert self.datamanager.train_dataset is not None, "Missing input dataset"
89 |
90 | self._model = config.model.setup(
91 | scene_box=self.datamanager.train_dataset.scene_box,
92 | num_train_data=len(self.datamanager.train_dataset),
93 | metadata=self.datamanager.train_dataset.metadata,
94 | device=device,
95 | grad_scaler=grad_scaler,
96 | seed_points=seed_pts,
97 | cameras=self.datamanager.train_dataset.cameras,
98 | dataparser=self.datamanager.dataparser_config,
99 | )
100 |
101 | self.model.to(device)
102 |
103 | self.world_size = world_size
104 | if world_size > 1:
105 | self._model = typing.cast(
106 | DeSplatModel,
107 | DDP(self._model, device_ids=[local_rank], find_unused_parameters=True),
108 | )
109 | dist.barrier(device_ids=[local_rank])
110 |
111 | def get_param_groups(self) -> Dict[str, List[Parameter]]:
112 | """Get the param groups for the pipeline.
113 |
114 | Returns:
115 | A list of dictionaries containing the pipeline's param groups.
116 | """
117 | datamanager_params = self.datamanager.get_param_groups()
118 | model_params = self.model.get_param_groups()
119 | # TODO(ethan): assert that key names don't overlap
120 | return {**datamanager_params, **model_params} # , **model_params_dyn
121 |
122 | @profiler.time_function
123 | def get_average_image_metrics(
124 | self,
125 | data_loader,
126 | image_prefix: str,
127 | step: Optional[int] = None,
128 | output_path: Optional[Path] = None,
129 | get_std: bool = False,
130 | ):
131 | self.eval()
132 | metrics_dict_list = []
133 | num_images = len(data_loader)
134 | if output_path is not None:
135 | output_path.mkdir(exist_ok=True, parents=True)
136 | with Progress(
137 | TextColumn("[progress.description]{task.description}"),
138 | BarColumn(),
139 | TimeElapsedColumn(),
140 | MofNCompleteColumn(),
141 | transient=True,
142 | ) as progress:
143 | task = progress.add_task(
144 | "[green]Evaluating all images...", total=num_images
145 | )
146 | idx = 0
147 | for camera, batch in data_loader:
148 | # time this the following line
149 | inner_start = time()
150 | outputs = self.model.get_outputs_for_camera(camera=camera)
151 | height, width = camera.height, camera.width
152 | num_rays = height * width
153 | metrics_dict, image_dict = self.model.get_image_metrics_and_images(
154 | outputs, batch
155 | )
156 | if output_path is not None:
157 | for key in image_dict.keys():
158 | image = image_dict[key] # [H, W, C] order
159 | vutils.save_image(
160 | image.permute(2, 0, 1).cpu(),
161 | output_path / f"{image_prefix}_{key}_{idx:04d}.png",
162 | )
163 |
164 | assert "num_rays_per_sec" not in metrics_dict
165 | metrics_dict["num_rays_per_sec"] = (
166 | num_rays / (time() - inner_start)
167 | ).item()
168 | fps_str = "fps"
169 | assert fps_str not in metrics_dict
170 | metrics_dict[fps_str] = (
171 | metrics_dict["num_rays_per_sec"] / (height * width)
172 | ).item()
173 | metrics_dict_list.append(metrics_dict)
174 | progress.advance(task)
175 | idx = idx + 1
176 |
177 | metrics_dict = {}
178 | for key in metrics_dict_list[0].keys():
179 | if get_std:
180 | key_std, key_mean = torch.std_mean(
181 | torch.tensor(
182 | [metrics_dict[key] for metrics_dict in metrics_dict_list]
183 | )
184 | )
185 | metrics_dict[key] = float(key_mean)
186 | metrics_dict[f"{key}_std"] = float(key_std)
187 | else:
188 | metrics_dict[key] = float(
189 | torch.mean(
190 | torch.tensor(
191 | [metrics_dict[key] for metrics_dict in metrics_dict_list]
192 | )
193 | )
194 | )
195 |
196 | if self.test_mode == "inference":
197 | print("Now we are in the test mode.")
198 | metrics_dict["full_results"] = metrics_dict_list
199 | del image_dict["depth"]
200 |
201 | self.train()
202 | return metrics_dict
203 |
204 | @profiler.time_function
205 | def get_average_eval_half_image_metrics(
206 | self,
207 | step: Optional[int] = None,
208 | output_path: Optional[Path] = None,
209 | get_std: bool = False,
210 | ):
211 | """Get the average metrics for evaluation images."""
212 | # assert hasattr(
213 | # self.datamanager, "fixed_indices_eval_dataloader"
214 | # ), "datamanager must have 'fixed_indices_eval_dataloader' attribute"
215 | image_prefix = "eval"
216 | return self.get_average_half_image_metrics(
217 | self.datamanager.fixed_indices_eval_dataloader,
218 | image_prefix,
219 | step,
220 | output_path,
221 | get_std,
222 | )
223 |
224 | @profiler.time_function
225 | def get_average_half_image_metrics(
226 | self,
227 | data_loader,
228 | image_prefix: str,
229 | step: Optional[int] = None,
230 | output_path: Optional[Path] = None,
231 | get_std: bool = False,
232 | ):
233 | """Iterate over all the images in the dataset and get the average.
234 |
235 | Args:
236 | data_loader: the data loader to iterate over
237 | image_prefix: prefix to use for the saved image filenames
238 | step: current training step
239 | output_path: optional path to save rendered images to
240 | get_std: Set True if you want to return std with the mean metric.
241 |
242 | Returns:
243 | metrics_dict: dictionary of metrics
244 | """
245 | self.eval()
246 | metrics_dict_list = []
247 | num_images = len(data_loader)
248 | if output_path is not None:
249 | output_path.mkdir(exist_ok=True, parents=True)
250 | with Progress(
251 | TextColumn("[progress.description]{task.description}"),
252 | BarColumn(),
253 | TimeElapsedColumn(),
254 | MofNCompleteColumn(),
255 | transient=True,
256 | ) as progress:
257 | task = progress.add_task(
258 | "[green]Evaluating all images...", total=num_images
259 | )
260 | idx = 0
261 | for camera, batch in data_loader:
262 | # time this the following line
263 | inner_start = time()
264 | outputs = self.model.get_outputs_for_camera(camera=camera)
265 | height, width = camera.height, camera.width
266 |
267 | half_width = width // 2
268 | for key, img in outputs.items():
269 | if img.dim() == 3: # (H, W, C)
270 | masked_img = img.clone()
271 | masked_img = img[:, half_width:, :].clone()
272 | outputs[key] = masked_img
273 | right_half = {}
274 | right_half["image"] = batch["image"][:, half_width:, :]
275 |
276 | num_rays = height * width
277 | metrics_dict, image_dict = self.model.get_image_metrics_and_images(
278 | outputs, right_half
279 | )
280 | if output_path is not None:
281 | for key in image_dict.keys():
282 | image = image_dict[key] # [H, W, C] order
283 | vutils.save_image(
284 | image.permute(2, 0, 1).cpu(),
285 | output_path / f"{image_prefix}_{key}_{idx:04d}.png",
286 | )
287 |
288 | metrics_dict_list.append(metrics_dict)
289 | progress.advance(task)
290 | idx = idx + 1
291 |
292 | metrics_dict = {}
293 | for key in metrics_dict_list[0].keys():
294 | if get_std:
295 | key_std, key_mean = torch.std_mean(
296 | torch.tensor(
297 | [metrics_dict[key] for metrics_dict in metrics_dict_list]
298 | )
299 | )
300 | metrics_dict[key] = float(key_mean)
301 | metrics_dict[f"{key}_std"] = float(key_std)
302 | else:
303 | metrics_dict[key] = float(
304 | torch.mean(
305 | torch.tensor(
306 | [metrics_dict[key] for metrics_dict in metrics_dict_list]
307 | )
308 | )
309 | )
310 |
311 | self.train()
312 | return metrics_dict
313 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "desplat"
3 | description = "DeSplat: Decomposed Gaussian Splatting for Distractor-Free Rendering"
4 | version = "0.0.1"
5 |
6 | dependencies = [
7 | "nerfstudio >= 1.1.3",
8 | "gsplat == 1.0.0",
9 | "numpy == 1.24.4",
10 | "ruff",
11 | "pyyaml",
12 | "tyro==0.8.12"]
13 |
14 | [tool.setuptools.packages.find]
15 | include = ["desplat"]
16 |
17 | [project.entry-points.'nerfstudio.method_configs']
18 | test = 'desplat.config:desplat_method'
19 |
20 | [project.entry-points.'nerfstudio.dataparser_configs']
21 | onthego-data = 'desplat.dataparsers.onthego_dataparser:OnthegoDataParserSpecification'
22 | robustnerf-data = 'desplat.dataparsers.robustnerf_dataparser:RobustNerfDataParserSpecification'
23 | phototourism-data = 'desplat.dataparsers.phototourism_dataparser:PhotoTourismDataParserSpecification'
24 |
--------------------------------------------------------------------------------
/scripts/calculate_memory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from pathlib import Path
4 | from typing import Callable, Literal, Optional, Tuple
5 |
6 | import torch
7 | import yaml
8 |
9 | from nerfstudio.configs.method_configs import all_methods
10 | from nerfstudio.engine.trainer import TrainerConfig
11 | from nerfstudio.pipelines.base_pipeline import Pipeline
12 | from nerfstudio.utils.rich_utils import CONSOLE
13 |
14 |
15 | def load_checkpoint_for_eval(
16 | config: TrainerConfig, pipeline: Pipeline
17 | ) -> Tuple[Path, int]:
18 | """Load a checkpointed pipeline for evaluation.
19 |
20 | Args:
21 | config (TrainerConfig): Configuration for loading the pipeline.
22 | pipeline (Pipeline): The pipeline instance to load weights into.
23 |
24 | Returns:
25 | Tuple containing the path to the loaded checkpoint and the step at which it was saved.
26 | """
27 | assert config.load_dir is not None, "Checkpoint directory must be specified."
28 | if config.load_step is None:
29 | CONSOLE.print("Loading latest checkpoint from specified directory.")
30 | if not os.path.exists(config.load_dir):
31 | CONSOLE.rule("Error", style="red")
32 | CONSOLE.print(
33 | f"Checkpoint directory not found at {config.load_dir}", justify="center"
34 | )
35 | CONSOLE.print(
36 | "Ensure checkpoints were generated during training.", justify="center"
37 | )
38 | sys.exit(1)
39 | load_step = max(
40 | int(f.split("-")[1].split(".")[0])
41 | for f in os.listdir(config.load_dir)
42 | if "step-" in f
43 | )
44 | else:
45 | load_step = config.load_step
46 |
47 | load_path = config.load_dir / f"step-{load_step:09d}.ckpt"
48 | assert load_path.exists(), f"Checkpoint {load_path} does not exist."
49 |
50 | # Load checkpoint
51 | loaded_state = torch.load(load_path, map_location="cpu")
52 | pipeline.load_pipeline(loaded_state["pipeline"], loaded_state["step"])
53 | CONSOLE.print(f":white_check_mark: Successfully loaded checkpoint from {load_path}")
54 |
55 | return load_path, load_step
56 |
57 |
58 | def setup_evaluation(
59 | config_path: Path,
60 | eval_num_rays_per_chunk: Optional[int] = None,
61 | test_mode: Literal["test", "val", "inference"] = "test",
62 | update_config_callback: Optional[Callable[[TrainerConfig], TrainerConfig]] = None,
63 | ) -> Tuple[TrainerConfig, Pipeline, Path, int, float]:
64 | """Set up pipeline loading for evaluation, with an option to calculate model size.
65 |
66 | Args:
67 | config_path: Path to the configuration YAML file.
68 | eval_num_rays_per_chunk: Rays per forward pass (optional).
69 | test_mode: Data loading mode ('test', 'val', or 'inference').
70 | update_config_callback: Optional function to modify config before loading the pipeline.
71 |
72 | Returns:
73 | Config, loaded pipeline, checkpoint path, step, and model size in MB.
74 | """
75 | # Load and validate configuration
76 | config = yaml.load(config_path.read_text(), Loader=yaml.Loader)
77 | assert isinstance(config, TrainerConfig)
78 |
79 | config.pipeline.datamanager._target = all_methods[
80 | config.method_name
81 | ].pipeline.datamanager._target
82 | if eval_num_rays_per_chunk:
83 | config.pipeline.model.eval_num_rays_per_chunk = eval_num_rays_per_chunk
84 |
85 | if update_config_callback:
86 | config = update_config_callback(config)
87 |
88 | # Define the checkpoint directory
89 | config.load_dir = config.get_checkpoint_dir()
90 |
91 | # Initialize the pipeline
92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93 | pipeline = config.pipeline.setup(device=device, test_mode=test_mode)
94 | pipeline.eval()
95 |
96 | # Load the checkpoint
97 | checkpoint_path, step = load_checkpoint_for_eval(config, pipeline)
98 |
99 | # Calculate the size of the loaded model
100 | model_size_mb = calculate_model_size(pipeline)
101 | CONSOLE.print(f"Model size: {model_size_mb:.2f} MB")
102 |
103 | return config, pipeline, checkpoint_path, step, model_size_mb
104 |
105 |
106 | def calculate_model_size(model: torch.nn.Module) -> float:
107 | """Calculate the size of a PyTorch model in MB.
108 |
109 | Args:
110 | model: The PyTorch model for which to calculate the memory usage.
111 |
112 | Returns:
113 | Model size in megabytes (MB).
114 | """
115 | dynamic_param_size = 0
116 | static_param_size = 0
117 |
118 | for name, p in model.named_parameters():
119 | # Determine if the parameter is dynamic or static
120 | if "gauss_params_dyn_dict" in name:
121 | dynamic_param_size += p.nelement() * p.element_size()
122 | else:
123 | static_param_size += p.nelement() * p.element_size()
124 |
125 | buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
126 |
127 | # Print size of each parameter category
128 | print(
129 | "Total dynamic parameters size: {:.3f} MB".format(dynamic_param_size / 1024**2)
130 | )
131 | print("Total static parameters size: {:.3f} MB".format(static_param_size / 1024**2))
132 | print(
133 | "Total model size: {:.3f} MB".format(
134 | (dynamic_param_size + static_param_size) / 1024**2
135 | )
136 | )
137 |
138 | total_size_mb = (
139 | dynamic_param_size + static_param_size + buffer_size
140 | ) / 1024**2 # Convert to MB
141 | return total_size_mb
142 |
143 |
144 | if __name__ == "__main__":
145 | import tyro
146 |
147 | tyro.cli(setup_evaluation)
148 |
--------------------------------------------------------------------------------
/scripts/convert.py:
--------------------------------------------------------------------------------
1 | # This code is copied from the 3D Gaussian Splatter repository, available at https://github.com/graphdeco-inria/gaussian-splatting.
2 | #
3 | # Copyright (C) 2023, Inria
4 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
5 | # All rights reserved.
6 | #
7 | # This software is free for non-commercial, research and evaluation use
8 | # under the terms of the LICENSE.md file.
9 | #
10 | # For inquiries contact george.drettakis@inria.fr
11 | #
12 |
13 | import os
14 | import logging
15 | from argparse import ArgumentParser
16 | import shutil
17 | import math
18 | from PIL import Image
19 |
20 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
21 | parser = ArgumentParser("Colmap converter")
22 | parser.add_argument("--no_gpu", action="store_true")
23 | parser.add_argument("--skip_matching", action="store_true")
24 | parser.add_argument("--source_path", "-s", required=True, type=str)
25 | parser.add_argument("--camera", default="OPENCV", type=str)
26 | parser.add_argument("--colmap_executable", default="", type=str)
27 | parser.add_argument("--resize", action="store_true")
28 | parser.add_argument("--magick_executable", default="", type=str)
29 | args = parser.parse_args()
30 | colmap_command = (
31 | '"{}"'.format(args.colmap_executable)
32 | if len(args.colmap_executable) > 0
33 | else "colmap"
34 | )
35 | magick_command = (
36 | '"{}"'.format(args.magick_executable)
37 | if len(args.magick_executable) > 0
38 | else "magick"
39 | )
40 | use_gpu = 1 if not args.no_gpu else 0
41 |
42 | # rename images folder to input
43 | images_path = os.path.join(args.source_path, "images")
44 | input_path = os.path.join(args.source_path, "input")
45 |
46 | if os.path.exists(images_path):
47 | os.rename(images_path, input_path)
48 | print(f"'{images_path}' has been renamed to '{input_path}'")
49 | else:
50 | print(f"The folder '{images_path}' does not exist.")
51 |
52 | if not args.skip_matching:
53 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
54 | ## Feature extraction
55 | feat_extracton_cmd = (
56 | colmap_command + " feature_extractor "
57 | "--database_path "
58 | + args.source_path
59 | + "/distorted/database.db \
60 | --image_path "
61 | + args.source_path
62 | + "/input \
63 | --ImageReader.single_camera 1 \
64 | --ImageReader.camera_model "
65 | + args.camera
66 | + " \
67 | --SiftExtraction.use_gpu "
68 | + str(use_gpu)
69 | )
70 | exit_code = os.system(feat_extracton_cmd)
71 | if exit_code != 0:
72 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
73 | exit(exit_code)
74 |
75 | ## Feature matching
76 | feat_matching_cmd = (
77 | colmap_command
78 | + " exhaustive_matcher \
79 | --database_path "
80 | + args.source_path
81 | + "/distorted/database.db \
82 | --SiftMatching.use_gpu "
83 | + str(use_gpu)
84 | )
85 | exit_code = os.system(feat_matching_cmd)
86 | if exit_code != 0:
87 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
88 | exit(exit_code)
89 |
90 | ### Bundle adjustment
91 | # The default Mapper tolerance is unnecessarily large,
92 | # decreasing it speeds up bundle adjustment steps.
93 | mapper_cmd = (
94 | colmap_command
95 | + " mapper \
96 | --database_path "
97 | + args.source_path
98 | + "/distorted/database.db \
99 | --image_path "
100 | + args.source_path
101 | + "/input \
102 | --output_path "
103 | + args.source_path
104 | + "/distorted/sparse \
105 | --Mapper.ba_global_function_tolerance=0.000001"
106 | )
107 | exit_code = os.system(mapper_cmd)
108 | if exit_code != 0:
109 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
110 | exit(exit_code)
111 |
112 | ### Image undistortion
113 | ## We need to undistort our images into ideal pinhole intrinsics.
114 | img_undist_cmd = (
115 | colmap_command
116 | + " image_undistorter \
117 | --image_path "
118 | + args.source_path
119 | + "/input \
120 | --input_path "
121 | + args.source_path
122 | + "/distorted/sparse/0 \
123 | --output_path "
124 | + args.source_path
125 | + "\
126 | --output_type COLMAP"
127 | )
128 | exit_code = os.system(img_undist_cmd)
129 | if exit_code != 0:
130 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
131 | exit(exit_code)
132 |
133 | files = os.listdir(args.source_path + "/sparse")
134 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
135 | # Copy each file from the source directory to the destination directory
136 | for file in files:
137 | if file == "0":
138 | continue
139 | source_file = os.path.join(args.source_path, "sparse", file)
140 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
141 | shutil.move(source_file, destination_file)
142 |
143 | if args.resize:
144 | print("Copying and resizing...")
145 |
146 | # Resize images.
147 | # for patio scene, we resize by 25%
148 | if args.source_path.endswith("patio") or args.source_path.endswith("arcdetriomphe"):
149 | resize_factor = 4
150 | else:
151 | resize_factor = 8
152 | os.makedirs(args.source_path + "/images_" + str(resize_factor), exist_ok=True)
153 | # Get the list of files in the source directory
154 | files = os.listdir(args.source_path + "/images")
155 | # Copy each file from the source directory to the destination directory
156 | for file in files:
157 | source_file = os.path.join(args.source_path, "images", file)
158 |
159 | with Image.open(source_file) as img:
160 | width, height = img.size
161 | target_width = math.floor(width / resize_factor)
162 | target_height = math.floor(height / resize_factor)
163 |
164 | destination_file = os.path.join(
165 | args.source_path, f"images_{resize_factor}", file
166 | )
167 | shutil.copy2(source_file, destination_file)
168 | exit_code = os.system(
169 | f"{magick_command} mogrify -resize {target_width}x{target_height}! {destination_file}"
170 | )
171 | if exit_code != 0:
172 | logging.error(f"resize failed with code {exit_code}. Exiting.")
173 | exit(exit_code)
174 |
175 | print("Done.")
176 |
--------------------------------------------------------------------------------
/scripts/download_dataset.py:
--------------------------------------------------------------------------------
1 | """Script to download benchmark dataset(s)"""
2 |
3 | import os
4 | import subprocess
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 | from typing import Literal
8 |
9 | import tyro
10 |
11 | # dataset names
12 | dataset_names = Literal[
13 | "robustnerf",
14 | "on-the-go",
15 | ]
16 |
17 | # dataset urls
18 | urls = {
19 | "robustnerf": "https://storage.googleapis.com/jax3d-public/projects/robustnerf/robustnerf.tar.gz",
20 | "on-the-go": "https://cvg-data.inf.ethz.ch/on-the-go.zip",
21 | }
22 |
23 | # rename maps
24 | dataset_rename_map = {
25 | "robustnerf": "robustnerf",
26 | "on-the-go": "on-the-go",
27 | }
28 |
29 |
30 | @dataclass
31 | class DownloadData:
32 | dataset: dataset_names = "robustnerf"
33 | save_dir: Path = Path(os.getcwd() + "/data")
34 |
35 | def main(self):
36 | self.save_dir.mkdir(parents=True, exist_ok=True)
37 | self.dataset_download(self.dataset)
38 |
39 | def dataset_download(self, dataset: dataset_names):
40 | (self.save_dir / dataset_rename_map[dataset]).mkdir(parents=True, exist_ok=True)
41 |
42 | file_name = Path(urls[dataset]).name
43 |
44 | # download
45 | download_command = [
46 | "curl",
47 | "-o",
48 | str(self.save_dir / dataset_rename_map[dataset] / file_name),
49 | urls[dataset],
50 | ]
51 | try:
52 | subprocess.run(download_command, check=True)
53 | print("File file downloaded succesfully.")
54 | except subprocess.CalledProcessError as e:
55 | print(f"Error downloading file: {e}")
56 |
57 | # if .zip
58 | if Path(urls[dataset]).suffix == ".zip":
59 | if os.name == "nt": # Windows doesn't have 'unzip' but 'tar' works
60 | extract_command = [
61 | "tar",
62 | "-xvf",
63 | self.save_dir / dataset_rename_map[dataset] / file_name,
64 | "-C",
65 | self.save_dir / dataset_rename_map[dataset],
66 | ]
67 | else:
68 | extract_command = [
69 | "unzip",
70 | self.save_dir / dataset_rename_map[dataset] / file_name,
71 | "-d",
72 | self.save_dir / dataset_rename_map[dataset],
73 | ]
74 | # if .tar
75 | else:
76 | extract_command = [
77 | "tar",
78 | "-xvf",
79 | self.save_dir / dataset_rename_map[dataset] / file_name,
80 | "-C",
81 | self.save_dir / dataset_rename_map[dataset],
82 | ]
83 |
84 | # extract
85 | try:
86 | subprocess.run(extract_command, check=True)
87 | os.remove(self.save_dir / dataset_rename_map[dataset] / file_name)
88 | print("Extraction complete.")
89 | except subprocess.CalledProcessError as e:
90 | print(f"Extraction failed: {e}")
91 |
92 |
93 | if __name__ == "__main__":
94 | tyro.cli(DownloadData).main()
95 |
--------------------------------------------------------------------------------
/scripts/test_time_optimize.py:
--------------------------------------------------------------------------------
1 | """test time camera appearance optimization
2 |
3 | Usage:
4 |
5 | python scripts/test_time_optimize.py --load-config [path_to_script]
6 | """
7 |
8 | import functools
9 | import json
10 | import os
11 | from dataclasses import dataclass
12 | from pathlib import Path
13 |
14 | import numpy as np
15 | import torch
16 | import tyro
17 | from PIL import Image
18 | from torchvision.utils import save_image
19 |
20 | from nerfstudio.cameras.cameras import Cameras
21 | from nerfstudio.data.datasets.base_dataset import InputDataset
22 | from nerfstudio.models.splatfacto import SplatfactoModel
23 | from nerfstudio.utils.eval_utils import eval_setup
24 |
25 |
26 | @dataclass
27 | class AppearanceModelConfigs:
28 | # TODO: only works when there are no per-gauss features atm
29 | app_per_gauss: bool = False
30 | appearance_embedding_dim: int = 32
31 | appearance_n_fourier_freqs: int = 4
32 | appearance_init_fourier: bool = True
33 |
34 |
35 | @dataclass
36 | class TestTimeOpt:
37 | load_config: Path = Path("")
38 | """Path to the config YAML file."""
39 | train_iters: int = 128
40 |
41 | """train iters"""
42 | save_gif: bool = False
43 | """save a training gif"""
44 | lr_app_emb: float = 0.01
45 | """learning rate for appearance embedding"""
46 | metrics_output_path: Path = Path("./test_time_metrics/")
47 | """Output path of test time opt eval metrics"""
48 | use_saved_embedding: bool = False
49 | """Use saved embedding module"""
50 | save_all_imgs: bool = False
51 | """Save all images"""
52 |
53 | def main(self):
54 | if "brandenburg_gate" in str(self.load_config) or "unnamed" in str(self.load_config):
55 | scene = "brandenburg_gate"
56 | elif "sacre_coeur" in str(self.load_config):
57 | scene = "sacre_coeur"
58 | elif "trevi_fountain" in str(self.load_config):
59 | scene = "trevi_fountain"
60 | else:
61 | raise ValueError(f"Unknown dataset")
62 |
63 | if not self.metrics_output_path.exists():
64 | self.metrics_output_path.mkdir(parents=True)
65 |
66 | config, pipeline, _, _ = eval_setup(self.load_config)
67 | pipeline.test_time_optimize = True
68 | pipeline.train()
69 | pipeline.cuda()
70 | assert isinstance(pipeline.model, SplatfactoModel)
71 |
72 | model: SplatfactoModel = pipeline.model
73 | train_dataset: InputDataset = pipeline.datamanager.train_dataset
74 | eval_dataset: InputDataset = pipeline.datamanager.eval_dataset
75 | cameras: Cameras = pipeline.datamanager.eval_dataset.cameras # type: ignore
76 | # init app model
77 | app_config = AppearanceModelConfigs()
78 |
79 | if self.use_saved_embedding:
80 | model.appearance_embeddings = torch.load("embedding_"+ scene + ".pth")
81 |
82 | else:
83 | # define app model optimizers
84 | model.appearance_embeddings = torch.nn.Embedding(
85 | len(eval_dataset), app_config.appearance_embedding_dim
86 | ).cuda()
87 | model.appearance_embeddings.weight.data.normal_(0, 0.01)
88 |
89 | optimizer = torch.optim.Adam(
90 | model.appearance_embeddings.parameters(),
91 | lr=self.lr_app_emb,
92 | )
93 |
94 | # Force model to have appearance
95 | model.config.enable_appearance = True
96 |
97 | # train eval dataset
98 | gif_frames = []
99 |
100 | # before test time metrics:
101 | before_test_time_metrics = pipeline.get_average_eval_half_image_metrics(
102 | step=0, output_path=None
103 | )
104 | print("Metrics before test-time: ", before_test_time_metrics)
105 |
106 | for epoch in range(self.train_iters):
107 | for image_idx, data in enumerate(
108 | pipeline.datamanager.cached_eval # Undistorted images
109 | ): # type: ignore
110 | # process batch gt data
111 | # process pred outputs
112 | camera = cameras[image_idx : image_idx + 1]
113 | camera.metadata = {}
114 | camera.metadata["cam_idx"] = image_idx
115 | camera = camera.to("cuda")
116 |
117 | height, width = camera.height, camera.width
118 |
119 | outputs = model.get_outputs(camera=camera) # type: dict
120 | outputs_left_half = {}
121 | # mask the right half of the image
122 | for key, img in outputs.items():
123 | half_width = width // 2
124 | if key == 'background':
125 | outputs_left_half[key] = img
126 | if img.dim() == 3: # (H, W, C)
127 | masked_img = img[:, :half_width, :].clone()
128 |
129 | # masked_img[:, half_width:, :] = 0
130 | outputs_left_half[key] = masked_img
131 | left_half = {}
132 | left_half["image"] = data["image"][:, :half_width, :]
133 |
134 | loss_dict = model.get_loss_dict(outputs=outputs_left_half, batch=left_half)
135 | if image_idx == 1:
136 | # choose the right side of the image
137 | rgb = outputs["rgb"][:, width // 2:, :]
138 | gt_img = data["image"][:, width // 2:, :]
139 | save_image(rgb.permute(2, 0, 1), "rgb.jpg")
140 | print("Epoch: ", epoch, "loss of img_0:", loss_dict["main_loss"])
141 | if self.save_gif and epoch % 1 == 0:
142 | gif_frames.append(
143 | (rgb.detach().cpu().numpy() * 255).astype(np.uint8)
144 | )
145 |
146 | loss = functools.reduce(torch.add, loss_dict.values())
147 | loss.backward()
148 |
149 | optimizer.step()
150 | optimizer.zero_grad(set_to_none=True)
151 |
152 | # save pth
153 | torch.save(model.appearance_embeddings, "embedding_"+ scene + ".pth")
154 |
155 | # Get eval metrics after
156 | after_test_time_metrics = pipeline.get_average_eval_half_image_metrics(
157 | step=0, output_path=None
158 | )
159 |
160 | print("Metrics after test-time: ", after_test_time_metrics)
161 |
162 | output_dir = f"{self.metrics_output_path}/{scene}"
163 |
164 | os.makedirs(output_dir, exist_ok=True)
165 |
166 | metrics_path = f"{output_dir}/metrics.json"
167 | with open(metrics_path, "w") as f:
168 | json.dump(after_test_time_metrics, f)
169 |
170 | if self.save_gif:
171 | gif_frames = [Image.fromarray(frame) for frame in gif_frames]
172 | out_dir = os.path.join(os.getcwd(), f"renders/{scene}")
173 | os.makedirs(out_dir, exist_ok=True)
174 | print(f"saving depth gif to {out_dir}/training.gif")
175 | gif_frames[0].save(
176 | f"{out_dir}/training.gif",
177 | save_all=True,
178 | append_images=gif_frames[1:],
179 | optimize=False,
180 | duration=5,
181 | loop=0,
182 | )
183 |
184 | if self.save_all_imgs:
185 | for image_idx, data in enumerate(
186 | pipeline.datamanager.cached_eval # Undistorted images
187 | ): # type: ignore
188 | # process batch gt data
189 | # process pred outputs
190 | camera = cameras[image_idx : image_idx + 1]
191 | camera.metadata = {}
192 | camera.metadata["cam_idx"] = image_idx
193 | camera = camera.to("cuda")
194 |
195 | height, width = camera.height, camera.width
196 |
197 | outputs = model.get_outputs(camera=camera) # type: dict
198 |
199 | # Define the output directory
200 | out_dir = os.path.join(os.getcwd(), f"renders/{scene}")
201 |
202 | # Create the directory if it doesn't exist
203 | os.makedirs(out_dir, exist_ok=True)
204 |
205 | # Define the full path for the image file
206 | image_path = os.path.join(out_dir, f"render_{image_idx}.jpg")
207 |
208 | # out_dir = os.path.join(os.getcwd(), f"renders/{scene}/render_{image_idx}.jpg")
209 | save_image(outputs["rgb"].permute(2, 0, 1), image_path)
210 |
211 | if __name__ == "__main__":
212 | tyro.cli(TestTimeOpt).main()
213 |
--------------------------------------------------------------------------------