├── .gitignore
├── LICENSE
├── README.md
├── bad_gaussians
├── bad_camera_optimizer.py
├── bad_config_dataparser.py
├── bad_config_method.py
├── bad_gaussians.py
├── bad_losses.py
├── bad_utils.py
├── bad_viewer.py
├── deblur_nerf_dataparser.py
├── image_restoration_dataloader.py
├── image_restoration_full_image_datamanager.py
├── image_restoration_pipeline.py
├── image_restoration_trainer.py
├── spline.py
└── spline_functor.py
├── pyproject.toml
├── scripts
└── tools
│ ├── export_poses_from_ckpt.py
│ ├── export_poses_from_colmap.py
│ ├── export_poses_from_npy.py
│ ├── interpolate_traj.py
│ ├── kitti_to_tum.py
│ └── tum_to_kitti.py
└── tests
├── data
└── traj.txt
└── test_spline.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | **/__pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | !scripts/downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .envrc
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # Experiments and outputs
165 | outputs/
166 | exports/
167 | renders/
168 | # tensorboard log files
169 | events.out.*
170 |
171 | # Data
172 | data
173 | !*/data
174 |
175 | # Misc
176 | .vscode/
177 | .idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
😈BAD-Gaussians: Bundle-Adjusted Deblur Gaussian Splatting
2 |
3 |
4 |
5 |
6 | This as an official implementation of our arXiv 2024 paper
7 | [**BAD-Gaussians**: Bundle Adjusted Deblur Gaussian Splatting](https://lingzhezhao.github.io/BAD-Gaussians/), based on the [nerfstudio](https://github.com/nerfstudio-project/nerfstudio) framework.
8 |
9 | ## Demo
10 |
11 | Deblurring & novel-view synthesis results on [Deblur-NeRF](https://github.com/limacv/Deblur-NeRF/)'s real-world motion-blurred data:
12 |
13 |
14 |
15 | > Left: BAD-Gaussians deblured novel-view renderings;
16 | >
17 | > Right: Input images.
18 |
19 |
20 | ## Quickstart
21 |
22 | ### 1. Installation
23 |
24 | You may check out the original [`nerfstudio`](https://github.com/nerfstudio-project/nerfstudio) repo for prerequisites and dependencies.
25 | Currently, our codebase is tested with nerfstudio v1.0.3.
26 |
27 | TL;DR: You can install `nerfstudio` with:
28 |
29 | ```bash
30 | # (Optional) create a fresh conda env
31 | conda create --name nerfstudio -y "python<3.11"
32 | conda activate nerfstudio
33 |
34 | # install dependencies
35 | pip install --upgrade pip setuptools
36 | pip install "torch==2.1.2+cu118" "torchvision==0.16.2+cu118" --extra-index-url https://download.pytorch.org/whl/cu118
37 |
38 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
39 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
40 |
41 | # install nerfstudio!
42 | pip install nerfstudio==1.0.3
43 | ```
44 |
45 | Then you can install this repo as a Python package with:
46 |
47 | ```bash
48 | pip install git+https://github.com/WU-CVGL/BAD-Gaussians
49 | ```
50 |
51 | ### 2. Prepare the dataset
52 |
53 | #### Deblur-NeRF Synthetic Dataset (Re-rendered)
54 |
55 | As described in the previous BAD-NeRF paper, we re-rendered Deblur-NeRF's synthetic dataset with 51 interpolations per blurry image.
56 |
57 | Additionally, in the previous BAD-NeRF paper, we directly run COLMAP on blurry images only, with neither ground-truth
58 | camera intrinsics nor sharp novel-view images. We find this is quite challenging for COLMAP - it may fail to
59 | reconstruct the scene and we need to re-run COLMAP for serval times. To this end, we provided a new set of data,
60 | where we ran COLMAP with ground-truth camera intrinsics over both blurry and sharp novel-view images,
61 | named `bad-nerf-gtK-colmap-nvs`:
62 |
63 | [Download link](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EoCe3vaC9V5Fl74DjbGriwcBKj1nbB0HQFSWnVTLX7qT9A)
64 |
65 | #### Deblur-NeRF Real Dataset
66 |
67 | You can directly download the `real_camera_motion_blur` folder from [Deblur-NeRF](https://limacv.github.io/deblurnerf/).
68 |
69 | #### Your Custom Dataset
70 |
71 | 1. Use the [`ns-process-data` tool from Nerfstudio](https://docs.nerf.studio/reference/cli/ns_process_data.html)
72 | to process deblur-nerf training images.
73 |
74 | For example, if the
75 | [dataset from BAD-NeRF](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EsgdW2cRic5JqerhNbTsxtkBqy9m6cbnb2ugYZtvaib3qA?e=bjK7op)
76 | is in `llff_data`, execute:
77 |
78 | ```
79 | ns-process-data images \
80 | --data llff_data/blurtanabata/images \
81 | --output-dir data/my_data/blurtanabata
82 | ```
83 |
84 | 2. The folder `data/my_data/blurtanabata` is ready.
85 |
86 | > Note: Although nerfstudio does not model the NDC scene contraction for LLFF data,
87 | > we found that `scale_factor = 0.25` works well on LLFF datasets.
88 | > If your data is captured in a [LLFF fashion](https://github.com/Fyusion/LLFF#using-your-own-input-images-for-view-synthesis) (i.e. forward-facing),
89 | > instead of object-centric like Mip-NeRF 360,
90 | > you can pass the `scale_factor = 0.25` parameter to the nerfstudio dataparser (which is already set to default in our `DeblurNerfDataParser`),
91 | > e.g., `ns-train bad-gaussians --data data/my_data/my_seq --vis viewer+tensorboard nerfstudio-data --scale_factor 0.25`
92 |
93 | ### 3. Training
94 |
95 | 1. For `Deblur-NeRF synthetic` dataset, train with:
96 |
97 | ```bash
98 | ns-train bad-gaussians \
99 | --data data/bad-nerf-gtK-colmap-nvs/blurtanabata \
100 | --pipeline.model.camera-optimizer.mode "linear" \
101 | --vis viewer+tensorboard \
102 | deblur-nerf-data
103 | ```
104 | where
105 | - `--data data/bad-nerf-gtK-colmap-nvs/blurtanabata` is the relative path of the data sequence;
106 | - `--pipeline.model.camera-optimizer.mode "linear"` enables linear camera pose interpolation
107 | - `--vis viewer+tensorboard` enables both the viewer and the tensorboard metrics saving
108 | - `deblur-nerf-data` chooses the DeblurNerfDataparser
109 |
110 | 2. For `Deblur-NeRF real` dataset with `downscale_factor=4`, train with:
111 | ```bash
112 | ns-train bad-gaussians \
113 | --data data/real_camera_motion_blur/blurdecoration \
114 | --pipeline.model.camera-optimizer.mode "cubic" \
115 | --vis viewer+tensorboard \
116 | deblur-nerf-data \
117 | --downscale_factor 4
118 | ```
119 | where
120 | - `--pipeline.model.camera-optimizer.mode "cubic"` enables cubic B-spline;
121 | - `--downscale_factor 4` after the `deblur-nerf-data` tells the DeblurNerfDataparser to downscale the images' width and height to `1/4` of its originals.
122 |
123 | 3. For `Deblur-NeRF real` dataset with *full resolution*, train with:
124 | ```bash
125 | ns-train bad-gaussians \
126 | --data data/real_camera_motion_blur/blurdecoration \
127 | --pipeline.model.camera-optimizer.mode "cubic" \
128 | --pipeline.model.camera-optimizer.num_virtual_views 15 \
129 | --pipeline.model.num_downscales 2 \
130 | --pipeline.model.resolution_schedule 3000 \
131 | --vis viewer+tensorboard \
132 | deblur-nerf-data
133 | ```
134 | where
135 | - `--pipeline.model.camera-optimizer.mode "cubic"` enables cubic B-spline;
136 | - `--pipeline.model.camera-optimizer.num_virtual_views 15` increases the number of virtual cameras to 15;
137 | - `--pipeline.model.num_downscales 2` and `--pipeline.model.resolution_schedule 3000` enables coarse-to-fine training.
138 |
139 | 4. For custom data processed with `ns-process-data`, train with:
140 |
141 | ```bash
142 | ns-train bad-gaussians \
143 | --data data/my_data/blurtanabata \
144 | --vis viewer+tensorboard \
145 | nerfstudio-data --eval_mode "all"
146 | ```
147 |
148 | > Note: To improve reconstruction quality on your custom dataset, you may need to add
149 | some of the parameters to enable *cubic B-spline*, *more virtual cameras* and
150 | *coarse-to-fine training*, as shown in the examples above.
151 |
152 | ### 4. Render videos
153 |
154 | This command will generate a trajectory with the camera poses of the training images, keeping their original order, interplate 10 frames between adjacent images with a frame rate of 30. It will load the `config.yml` and save the video to `renders/.mp4`.
155 |
156 | ```bash
157 | ns-render interpolate \
158 | --load-config outputs/blurtanabata/bad-gaussians//config.yml \
159 | --pose-source train \
160 | --frame-rate 30 \
161 | --interpolation-steps 10 \
162 | --output-path renders/.mp4
163 | ```
164 |
165 | > Note1: You can add the `--render-nearest-camera True` option to compare with the blurry inputs, but it will slow down the rendering process significantly.
166 | >
167 | > Note2: The working directory when executing this command must be the parent of `outputs`, i.e. the same directory when training.
168 | >
169 | > Note3: You can find more information of this command in the [nerfstudio docs](https://docs.nerf.studio/reference/cli/ns_render.html#ns-render).
170 |
171 | ### 5. Export the 3D Gaussians
172 |
173 | This command will load the `config.yml` and export a `splat.ply` into the same folder:
174 |
175 | ```bash
176 | ns-export gaussian-splat \
177 | --load-config outputs/blurtanabata/bad-gaussians//config.yml \
178 | --output-dir outputs/blurtanabata/bad-gaussians/
179 | ```
180 |
181 | > Note1: We use `rasterize_mode = antialiased` by default. However, if you want to export the 3D gaussians, since the `antialiased` mode (i.e. *Mip-Splatting*) is not supported by most 3D-GS viewers, it is better to turn if off during training using: `--pipeline.model.rasterize_mode "classic"`
182 | >
183 | > Note2: The working directory when executing this command must be the parent of `outputs`, i.e. the same directory when training.
184 |
185 | Then you can visualize this file with any viewer, for example the [WebGL Viewer](https://antimatter15.com/splat/).
186 |
187 | ### 6. Debug with your IDE
188 |
189 | Open this repo with your IDE, create a configuration, and set the executing python script path to
190 | `/nerfstudio/scripts/train.py`, with the parameters above.
191 |
192 |
193 | ## Citation
194 |
195 | If you find this useful, please consider citing:
196 |
197 | ```bibtex
198 | @inproceedings{zhao2024badgaussians,
199 | author = {Zhao, Lingzhe and Wang, Peng and Liu, Peidong},
200 | title = {Bad-gaussians: Bundle adjusted deblur gaussian splatting},
201 | booktitle = {European Conference on Computer Vision (ECCV)},
202 | year = {2024}
203 | }
204 | ```
205 |
206 | ## Acknowledgments
207 |
208 | - Kudos to the [Nerfstudio](https://github.com/nerfstudio-project/) and [gsplat](https://github.com/nerfstudio-project/gsplat) contributors for their amazing works:
209 |
210 | ```bibtex
211 | @inproceedings{nerfstudio,
212 | title = {Nerfstudio: A Modular Framework for Neural Radiance Field Development},
213 | author = {
214 | Tancik, Matthew and Weber, Ethan and Ng, Evonne and Li, Ruilong and Yi, Brent
215 | and Kerr, Justin and Wang, Terrance and Kristoffersen, Alexander and Austin,
216 | Jake and Salahi, Kamyar and Ahuja, Abhik and McAllister, David and Kanazawa,
217 | Angjoo
218 | },
219 | year = 2023,
220 | booktitle = {ACM SIGGRAPH 2023 Conference Proceedings},
221 | series = {SIGGRAPH '23}
222 | }
223 |
224 | @software{Ye_gsplat,
225 | author = {Ye, Vickie and Turkulainen, Matias, and the Nerfstudio team},
226 | title = {{gsplat}},
227 | url = {https://github.com/nerfstudio-project/gsplat}
228 | }
229 |
230 | @misc{ye2023mathematical,
231 | title={Mathematical Supplement for the $\texttt{gsplat}$ Library},
232 | author={Vickie Ye and Angjoo Kanazawa},
233 | year={2023},
234 | eprint={2312.02121},
235 | archivePrefix={arXiv},
236 | primaryClass={cs.MS}
237 | }
238 | ```
239 |
240 | - Kudos to the [pypose](https://github.com/pypose/pypose) contributors for their amazing library:
241 |
242 | ```bibtex
243 | @inproceedings{wang2023pypose,
244 | title = {{PyPose}: A Library for Robot Learning with Physics-based Optimization},
245 | author = {Wang, Chen and Gao, Dasong and Xu, Kuan and Geng, Junyi and Hu, Yaoyu and Qiu, Yuheng and Li, Bowen and Yang, Fan and Moon, Brady and Pandey, Abhinav and Aryan and Xu, Jiahe and Wu, Tianhao and He, Haonan and Huang, Daning and Ren, Zhongqiang and Zhao, Shibo and Fu, Taimeng and Reddy, Pranay and Lin, Xiao and Wang, Wenshan and Shi, Jingnan and Talak, Rajat and Cao, Kun and Du, Yi and Wang, Han and Yu, Huai and Wang, Shanzhao and Chen, Siyu and Kashyap, Ananth and Bandaru, Rohan and Dantu, Karthik and Wu, Jiajun and Xie, Lihua and Carlone, Luca and Hutter, Marco and Scherer, Sebastian},
246 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
247 | year = {2023}
248 | }
249 | ```
250 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_camera_optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Pose and Intrinsics Optimizers
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import functools
8 | from copy import deepcopy
9 | from typing import List, Literal, Optional, Type, Union
10 |
11 | import pypose as pp
12 | import torch
13 | from dataclasses import dataclass, field
14 | from jaxtyping import Float, Int
15 | from pypose import LieTensor
16 | from torch import Tensor
17 | from typing_extensions import assert_never
18 |
19 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
20 | from nerfstudio.cameras.cameras import Cameras
21 | from nerfstudio.cameras.rays import RayBundle
22 |
23 | from bad_gaussians.spline_functor import (
24 | bezier_interpolation,
25 | cubic_bspline_interpolation,
26 | linear_interpolation,
27 | linear_interpolation_mid,
28 | )
29 |
30 |
31 | TrajSamplingMode = Literal["uniform", "start", "mid", "end"]
32 | """How to sample the camera trajectory"""
33 |
34 |
35 | @dataclass
36 | class BadCameraOptimizerConfig(CameraOptimizerConfig):
37 | """Configuration of BAD-Gaussians camera optimizer."""
38 |
39 | _target: Type = field(default_factory=lambda: BadCameraOptimizer)
40 | """The target class to be instantiated."""
41 |
42 | mode: Literal["off", "linear", "cubic", "bezier"] = "linear"
43 | """Pose optimization strategy to use.
44 | linear: linear interpolation on SE(3);
45 | cubic: cubic b-spline interpolation on SE(3).
46 | bezier: Bezier curve interpolation on SE(3).
47 | """
48 |
49 | bezier_degree: int = 9
50 | """Degree of the Bezier curve. Only used when mode is bezier."""
51 |
52 | trans_l2_penalty: float = 0.0
53 | """L2 penalty on translation parameters."""
54 |
55 | rot_l2_penalty: float = 0.0
56 | """L2 penalty on rotation parameters."""
57 |
58 | num_virtual_views: int = 10
59 | """The number of samples used to model the motion-blurring."""
60 |
61 | initial_noise_se3_std: float = 1e-5
62 | """Initial perturbation to pose delta on se(3). Must be non-zero to prevent NaNs."""
63 |
64 |
65 | class BadCameraOptimizer(CameraOptimizer):
66 | """Optimization for BAD-Gaussians virtual camera trajectories."""
67 |
68 | config: BadCameraOptimizerConfig
69 |
70 | def __init__(
71 | self,
72 | config: BadCameraOptimizerConfig,
73 | num_cameras: int,
74 | device: Union[torch.device, str],
75 | non_trainable_camera_indices: Optional[Int[Tensor, "num_non_trainable_cameras"]] = None,
76 | **kwargs,
77 | ) -> None:
78 | super().__init__(CameraOptimizerConfig(), num_cameras, device)
79 | self.config = config
80 | self.num_cameras = num_cameras
81 | self.device = device
82 | self.non_trainable_camera_indices = non_trainable_camera_indices
83 | self.dof = 6
84 | """Degrees of freedom of manifold, i.e. number of dimensions of the tangent space"""
85 | self.dim = 7
86 | """Dimentions of pose parameterization. Three for translation, 4-tuple for quaternion"""
87 |
88 | # Initialize learnable parameters.
89 | if self.config.mode == "off":
90 | return
91 | elif self.config.mode == "linear":
92 | self.num_control_knots = 2
93 | elif self.config.mode == "cubic":
94 | self.num_control_knots = 4
95 | elif self.config.mode == "bezier":
96 | self.num_control_knots = self.config.bezier_degree
97 | else:
98 | assert_never(self.config.mode)
99 |
100 | self.pose_adjustment = pp.Parameter(
101 | pp.randn_se3(
102 | (num_cameras, self.num_control_knots),
103 | sigma=self.config.initial_noise_se3_std,
104 | device=device,
105 | ),
106 | )
107 |
108 | def forward(
109 | self,
110 | indices: Int[Tensor, "camera_indices"],
111 | mode: TrajSamplingMode = "mid",
112 | ) -> Float[LieTensor, "camera_indices self.num_control_knots self.dof"]:
113 | """Indexing into camera adjustments.
114 |
115 | Args:
116 | indices: indices of Cameras to optimize.
117 | mode: interpolate between start and end, or return start / mid / end.
118 |
119 | Returns:
120 | Transformation matrices from optimized camera coordinates
121 | to given camera coordinates.
122 | """
123 | outputs = []
124 |
125 | # Apply learned transformation delta.
126 | if self.config.mode == "off":
127 | pass
128 | else:
129 | indices = indices.int()
130 | unique_indices, lut = torch.unique(indices, return_inverse=True)
131 | camera_opt = self.pose_adjustment[unique_indices].Exp()
132 | outputs.append(self._interpolate(camera_opt, mode)[lut])
133 |
134 | # Detach non-trainable indices by setting to identity transform
135 | if (
136 | torch.is_grad_enabled()
137 | and self.non_trainable_camera_indices is not None
138 | and len(indices) > len(self.non_trainable_camera_indices)
139 | ):
140 | if self.non_trainable_camera_indices.device != self.pose_adjustment.device:
141 | self.non_trainable_camera_indices = self.non_trainable_camera_indices.to(self.pose_adjustment.device)
142 | nt = self.non_trainable_camera_indices
143 | outputs[0][nt] = outputs[0][nt].clone().detach()
144 |
145 | # Return: identity if no transforms are needed, otherwise composite transforms together.
146 | if len(outputs) == 0:
147 | return pp.identity_SE3(*indices.shape, device=self.device)
148 | return functools.reduce(pp.mul, outputs)
149 |
150 | def _interpolate(
151 | self,
152 | camera_opt: Float[LieTensor, "*batch_size self.num_control_knots self.dof"],
153 | mode: TrajSamplingMode
154 | ) -> Float[Tensor, "*batch_size interpolations self.dof"]:
155 | if mode == "uniform":
156 | u = torch.linspace(
157 | start=0,
158 | end=1,
159 | steps=self.config.num_virtual_views,
160 | device=camera_opt.device,
161 | )
162 | if self.config.mode == "linear":
163 | return linear_interpolation(camera_opt, u)
164 | elif self.config.mode == "cubic":
165 | return cubic_bspline_interpolation(camera_opt, u)
166 | elif self.config.mode == "bezier":
167 | return bezier_interpolation(camera_opt, u)
168 | else:
169 | assert_never(self.config.mode)
170 | elif mode == "mid":
171 | if self.config.mode == "linear":
172 | return linear_interpolation_mid(camera_opt)
173 | elif self.config.mode == "cubic":
174 | return cubic_bspline_interpolation(
175 | camera_opt,
176 | torch.tensor([0.5], device=camera_opt.device)
177 | ).squeeze(1)
178 | elif self.config.mode == "bezier":
179 | return bezier_interpolation(camera_opt, torch.tensor([0.5], device=camera_opt.device)).squeeze(1)
180 | else:
181 | assert_never(self.config.mode)
182 | elif mode == "start":
183 | if self.config.mode == "linear":
184 | return camera_opt[..., 0, :]
185 | elif self.config.mode == "cubic":
186 | return cubic_bspline_interpolation(
187 | camera_opt,
188 | torch.tensor([0.0], device=camera_opt.device)
189 | ).squeeze(1)
190 | elif self.config.mode == "bezier":
191 | return bezier_interpolation(camera_opt, torch.tensor([0.0], device=camera_opt.device)).squeeze(1)
192 | else:
193 | assert_never(self.config.mode)
194 | elif mode == "end":
195 | if self.config.mode == "linear":
196 | return camera_opt[..., 1, :]
197 | elif self.config.mode == "cubic":
198 | return cubic_bspline_interpolation(
199 | camera_opt,
200 | torch.tensor([1.0], device=camera_opt.device)
201 | ).squeeze(1)
202 | elif self.config.mode == "bezier":
203 | return bezier_interpolation(camera_opt, torch.tensor([1.0], device=camera_opt.device)).squeeze(1)
204 | else:
205 | assert_never(self.config.mode)
206 | else:
207 | assert_never(mode)
208 |
209 | def apply_to_raybundle(self, *args, **kwargs):
210 | """Not implemented. Should not be called."""
211 | raise NotImplementedError("Not implemented in BAD-Gaussians. Please checkout https://github.com/WU-CVGL/Bad-RFs")
212 |
213 | def apply_to_camera(self, camera: Cameras, mode: TrajSamplingMode) -> List[Cameras]:
214 | """Apply pose correction to the camera"""
215 | # assert camera.metadata is not None, "Must provide camera metadata"
216 | # assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata"
217 | if self.config.mode == "off" or camera.metadata is None or not ("cam_idx" in camera.metadata):
218 | # print("[WARN] Cannot get cam_idx in camera.metadata")
219 | return [deepcopy(camera)]
220 |
221 | camera_idx = camera.metadata["cam_idx"]
222 | c2w = camera.camera_to_worlds # shape: (1, 4, 4)
223 | if c2w.shape[1] == 3:
224 | c2w = torch.cat([c2w, torch.tensor([0, 0, 0, 1], device=c2w.device).view(1, 1, 4)], dim=1)
225 |
226 | poses_delta = self((torch.tensor([camera_idx])), mode)
227 |
228 | if mode == "uniform":
229 | c2ws = c2w.tile((self.config.num_virtual_views, 1, 1)) # shape: (num_virtual_views, 4, 4)
230 | c2ws_adjusted = torch.bmm(c2ws, poses_delta.matrix().squeeze())
231 | cameras_list = [deepcopy(camera) for _ in range(self.config.num_virtual_views)]
232 | for i in range(self.config.num_virtual_views):
233 | cameras_list[i].camera_to_worlds = c2ws_adjusted[None, i, :, :]
234 | else:
235 | c2w_adjusted = torch.bmm(c2w, poses_delta.matrix())
236 | cameras_list = [deepcopy(camera)]
237 | cameras_list[0].camera_to_worlds = c2w_adjusted
238 |
239 | assert len(cameras_list)
240 | return cameras_list
241 |
242 | def get_metrics_dict(self, metrics_dict: dict) -> None:
243 | """Get camera optimizer metrics"""
244 | if self.config.mode != "off":
245 | metrics_dict["camera_opt_trajectory_translation"] = (
246 | self.pose_adjustment[:, 1, :3] - self.pose_adjustment[:, 0, :3]).norm()
247 | metrics_dict["camera_opt_trajectory_rotation"] = (
248 | self.pose_adjustment[:, 1, 3:] - self.pose_adjustment[:, 0, 3:]).norm()
249 | metrics_dict["camera_opt_translation"] = 0
250 | metrics_dict["camera_opt_rotation"] = 0
251 | for i in range(self.num_control_knots):
252 | metrics_dict["camera_opt_translation"] += self.pose_adjustment[:, i, :3].norm()
253 | metrics_dict["camera_opt_rotation"] += self.pose_adjustment[:, i, 3:].norm()
254 |
255 | def get_loss_dict(self, loss_dict: dict) -> None:
256 | """Add regularization"""
257 | pass
258 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_config_dataparser.py:
--------------------------------------------------------------------------------
1 | """
2 | BAD-Gaussians dataparser configs.
3 | """
4 |
5 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification
6 |
7 | from bad_gaussians.deblur_nerf_dataparser import DeblurNerfDataParserConfig
8 |
9 | DeblurNerfDataParser = DataParserSpecification(config=DeblurNerfDataParserConfig())
10 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_config_method.py:
--------------------------------------------------------------------------------
1 | """
2 | BAD-Gaussians configs.
3 | """
4 |
5 | from nerfstudio.configs.base_config import ViewerConfig
6 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
7 | from nerfstudio.engine.optimizers import AdamOptimizerConfig
8 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig
9 | from nerfstudio.plugins.types import MethodSpecification
10 |
11 | from bad_gaussians.bad_camera_optimizer import BadCameraOptimizerConfig
12 | from bad_gaussians.bad_gaussians import BadGaussiansModelConfig
13 | from bad_gaussians.image_restoration_full_image_datamanager import ImageRestorationFullImageDataManagerConfig
14 | from bad_gaussians.image_restoration_pipeline import ImageRestorationPipelineConfig
15 | from bad_gaussians.image_restoration_trainer import ImageRestorationTrainerConfig
16 |
17 |
18 | bad_gaussians = MethodSpecification(
19 | config=ImageRestorationTrainerConfig(
20 | method_name="bad-gaussians",
21 | steps_per_eval_image=500,
22 | steps_per_eval_batch=500,
23 | steps_per_save=2000,
24 | steps_per_eval_all_images=500,
25 | max_num_iterations=30001,
26 | mixed_precision=False,
27 | use_grad_scaler=False,
28 | gradient_accumulation_steps={"camera_opt": 25},
29 | pipeline=ImageRestorationPipelineConfig(
30 | eval_render_start_end=True,
31 | eval_render_estimated=True,
32 | datamanager=ImageRestorationFullImageDataManagerConfig(
33 | # cache_images="gpu", # reduce CPU usage, caused by pin_memory()?
34 | dataparser=NerfstudioDataParserConfig(
35 | load_3D_points=True,
36 | eval_mode="interval",
37 | eval_interval=8,
38 | ),
39 | ),
40 | model=BadGaussiansModelConfig(
41 | camera_optimizer=BadCameraOptimizerConfig(mode="linear", num_virtual_views=10),
42 | use_scale_regularization=True,
43 | continue_cull_post_densification=False,
44 | cull_alpha_thresh=5e-3,
45 | densify_grad_thresh=4e-4,
46 | num_downscales=0,
47 | resolution_schedule=250,
48 | tv_loss_lambda=None,
49 | ),
50 | ),
51 | optimizers={
52 | "means": {
53 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
54 | "scheduler": ExponentialDecaySchedulerConfig(
55 | lr_final=1.6e-6,
56 | max_steps=30000,
57 | ),
58 | },
59 | "features_dc": {
60 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15),
61 | "scheduler": None,
62 | },
63 | "features_rest": {
64 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15),
65 | "scheduler": None,
66 | },
67 | "opacities": {
68 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
69 | "scheduler": None,
70 | },
71 | "scales": {
72 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
73 | "scheduler": None,
74 | },
75 | "quats": {
76 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15),
77 | "scheduler": None
78 | },
79 | "camera_opt": {
80 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
81 | "scheduler": ExponentialDecaySchedulerConfig(
82 | lr_final=1e-5,
83 | max_steps=30000,
84 | ),
85 | },
86 | },
87 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
88 | vis="viewer",
89 | ),
90 | description="Implementation of BAD-Gaussians",
91 | )
92 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_gaussians.py:
--------------------------------------------------------------------------------
1 | """
2 | BAD-Gaussians model.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from dataclasses import dataclass, field
8 | from typing import Dict, List, Literal, Optional, Tuple, Type, Union
9 |
10 | import torch
11 |
12 | from gsplat.project_gaussians import project_gaussians
13 | from gsplat.rasterize import rasterize_gaussians
14 | from gsplat.sh import spherical_harmonics
15 |
16 | from nerfstudio.cameras.cameras import Cameras
17 | from nerfstudio.data.scene_box import OrientedBox
18 | from nerfstudio.models.splatfacto import SplatfactoModel, SplatfactoModelConfig
19 | from nerfstudio.model_components import renderers
20 |
21 | from bad_gaussians.bad_camera_optimizer import (
22 | BadCameraOptimizer,
23 | BadCameraOptimizerConfig,
24 | TrajSamplingMode,
25 | )
26 | from bad_gaussians.bad_losses import EdgeAwareVariationLoss
27 |
28 |
29 | @dataclass
30 | class BadGaussiansModelConfig(SplatfactoModelConfig):
31 | """BAD-Gaussians Model config"""
32 |
33 | _target: Type = field(default_factory=lambda: BadGaussiansModel)
34 | """The target class to be instantiated."""
35 |
36 | rasterize_mode: Literal["classic", "antialiased"] = "antialiased"
37 | """
38 | Classic mode of rendering will use the EWA volume splatting with a [0.3, 0.3] screen space blurring kernel. This
39 | approach is however not suitable to render tiny gaussians at higher or lower resolution than the captured, which
40 | results "aliasing-like" artifacts. The antialiased mode overcomes this limitation by calculating compensation factors
41 | and apply them to the opacities of gaussians to preserve the total integrated density of splats.
42 |
43 | However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that
44 | were implemented for classic mode can not render antialiased mode PLY properly without modifications.
45 | Refs:
46 | 1. https://github.com/nerfstudio-project/gsplat/pull/117
47 | 2. https://github.com/nerfstudio-project/nerfstudio/pull/2888
48 | 3. Yu, Zehao, et al. "Mip-Splatting: Alias-free 3D Gaussian Splatting." arXiv preprint arXiv:2311.16493 (2023).
49 | """
50 |
51 | camera_optimizer: BadCameraOptimizerConfig = field(default_factory=BadCameraOptimizerConfig)
52 | """Config of the camera optimizer to use"""
53 |
54 | cull_alpha_thresh: float = 0.005
55 | """Threshold for alpha to cull gaussians. Default: 0.1 in splatfacto, 0.005 in splatfacto-big."""
56 |
57 | densify_grad_thresh: float = 4e-4
58 | """[IMPORTANT] Threshold for gradient to densify gaussians. Default: 4e-4. Tune it smaller with complex scenes."""
59 |
60 | continue_cull_post_densification: bool = False
61 | """Whether to continue culling after densification. Default: True in splatfacto, False in splatfacto-big."""
62 |
63 | resolution_schedule: int = 250
64 | """training starts at 1/d resolution, every n steps this is doubled.
65 | Default: 250. Use 3000 with high resolution images (e.g. higher than 1920x1080).
66 | """
67 |
68 | num_downscales: int = 0
69 | """at the beginning, resolution is 1/2^d, where d is this number. Default: 0. Use 2 with high resolution images."""
70 |
71 | enable_absgrad: bool = False
72 | """Whether to enable absgrad for gaussians. (It affects param tuning of densify_grad_thresh)
73 | Default: False. Ref: (https://github.com/nerfstudio-project/nerfstudio/pull/3113)
74 | """
75 |
76 | tv_loss_lambda: Optional[float] = None
77 | """weight of total variation loss"""
78 |
79 |
80 | class BadGaussiansModel(SplatfactoModel):
81 | """BAD-Gaussians Model
82 |
83 | Args:
84 | config: configuration to instantiate model
85 | """
86 |
87 | config: BadGaussiansModelConfig
88 | camera_optimizer: BadCameraOptimizer
89 |
90 | def __init__(self, config: BadGaussiansModelConfig, **kwargs) -> None:
91 | super().__init__(config=config, **kwargs)
92 | # Scale densify_grad_thresh by the number of virtual views
93 | self.config.densify_grad_thresh /= self.config.camera_optimizer.num_virtual_views
94 | # (Experimental) Total variation loss
95 | self.tv_loss = EdgeAwareVariationLoss(in1_nc=3)
96 |
97 | def populate_modules(self) -> None:
98 | super().populate_modules()
99 | self.camera_optimizer: BadCameraOptimizer = self.config.camera_optimizer.setup(
100 | num_cameras=self.num_train_data, device="cpu"
101 | )
102 |
103 | def forward(
104 | self,
105 | camera: Cameras,
106 | mode: TrajSamplingMode = "uniform",
107 | ) -> Dict[str, Union[torch.Tensor, List]]:
108 | return self.get_outputs(camera, mode)
109 |
110 | def get_outputs(
111 | self, camera: Cameras,
112 | mode: TrajSamplingMode = "uniform",
113 | ) -> Dict[str, Union[torch.Tensor, List]]:
114 | """Takes in a Camera and returns a dictionary of outputs.
115 |
116 | Args:
117 | camera: Input camera. This camera should have all the needed information to compute the outputs.
118 |
119 | Returns:
120 | Outputs of model. (ie. rendered colors)
121 | """
122 | if not isinstance(camera, Cameras):
123 | print("Called get_outputs with not a camera")
124 | return {}
125 | assert camera.shape[0] == 1, "Only one camera at a time"
126 |
127 | is_training = self.training and torch.is_grad_enabled()
128 |
129 | # BAD-Gaussians: get virtual cameras
130 | virtual_cameras = self.camera_optimizer.apply_to_camera(camera, mode)
131 |
132 | if is_training:
133 | if self.config.background_color == "random":
134 | background = torch.rand(3, device=self.device)
135 | elif self.config.background_color == "white":
136 | background = torch.ones(3, device=self.device)
137 | elif self.config.background_color == "black":
138 | background = torch.zeros(3, device=self.device)
139 | else:
140 | background = self.background_color.to(self.device)
141 | else:
142 | # logic for setting the background of the scene
143 | if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
144 | background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
145 | else:
146 | background = self.background_color.to(self.device)
147 | if self.crop_box is not None and not is_training:
148 | crop_ids = self.crop_box.within(self.means).squeeze()
149 | if crop_ids.sum() == 0:
150 | rgb = background.repeat(int(camera.height.item()), int(camera.width.item()), 1)
151 | depth = background.new_ones(*rgb.shape[:2], 1) * 10
152 | accumulation = background.new_zeros(*rgb.shape[:2], 1)
153 | return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}
154 | else:
155 | crop_ids = None
156 |
157 | camera_downscale = self._get_downscale_factor()
158 |
159 | for cam in virtual_cameras:
160 | cam.rescale_output_resolution(1 / camera_downscale)
161 |
162 | # BAD-Gaussians: render virtual views
163 | virtual_views_rgb = []
164 | virtual_views_alpha = []
165 | for cam in virtual_cameras:
166 | # shift the camera to center of scene looking at center
167 | R = cam.camera_to_worlds[0, :3, :3] # 3 x 3
168 | T = cam.camera_to_worlds[0, :3, 3:4] # 3 x 1
169 | # flip the z axis to align with gsplat conventions
170 | R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype))
171 | R = R @ R_edit
172 | # analytic matrix inverse to get world2camera matrix
173 | R_inv = R.T
174 | T_inv = -R_inv @ T
175 | viewmat = torch.eye(4, device=R.device, dtype=R.dtype)
176 | viewmat[:3, :3] = R_inv
177 | viewmat[:3, 3:4] = T_inv
178 | # update last_size
179 | W, H = int(cam.width.item()), int(cam.height.item())
180 | self.last_size = (H, W)
181 |
182 | if crop_ids is not None:
183 | opacities_crop = self.opacities[crop_ids]
184 | means_crop = self.means[crop_ids]
185 | features_dc_crop = self.features_dc[crop_ids]
186 | features_rest_crop = self.features_rest[crop_ids]
187 | scales_crop = self.scales[crop_ids]
188 | quats_crop = self.quats[crop_ids]
189 | else:
190 | opacities_crop = self.opacities
191 | means_crop = self.means
192 | features_dc_crop = self.features_dc
193 | features_rest_crop = self.features_rest
194 | scales_crop = self.scales
195 | quats_crop = self.quats
196 |
197 | colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)
198 | BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default
199 | self.xys, depths, self.radii, conics, comp, num_tiles_hit, cov3d = project_gaussians(
200 | means_crop,
201 | torch.exp(scales_crop),
202 | 1,
203 | quats_crop / quats_crop.norm(dim=-1, keepdim=True),
204 | viewmat.squeeze()[:3, :],
205 | cam.fx.item(),
206 | cam.fy.item(),
207 | cam.cx.item(),
208 | cam.cy.item(),
209 | H,
210 | W,
211 | BLOCK_WIDTH,
212 | ) # type: ignore
213 |
214 | # rescale the camera back to original dimensions before returning
215 | cam.rescale_output_resolution(camera_downscale)
216 |
217 | if (self.radii).sum() == 0:
218 | rgb = background.repeat(H, W, 1)
219 | depth = background.new_ones(*rgb.shape[:2], 1) * 10
220 | accumulation = background.new_zeros(*rgb.shape[:2], 1)
221 |
222 | return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}
223 |
224 | # Important to allow xys grads to populate properly
225 | if is_training:
226 | self.xys.retain_grad()
227 |
228 | if self.config.sh_degree > 0:
229 | viewdirs = means_crop.detach() - cam.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
230 | viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
231 | n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
232 | rgbs = spherical_harmonics(n, viewdirs, colors_crop)
233 | rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
234 | else:
235 | rgbs = torch.sigmoid(colors_crop[:, 0, :])
236 |
237 | # rescale the camera back to original dimensions
238 | # cam.rescale_output_resolution(camera_downscale)
239 | assert (num_tiles_hit > 0).any() # type: ignore
240 |
241 | # apply the compensation of screen space blurring to gaussians
242 | if self.config.rasterize_mode == "antialiased":
243 | alphas = torch.sigmoid(opacities_crop) * comp[:, None]
244 | elif self.config.rasterize_mode == "classic":
245 | alphas = torch.sigmoid(opacities_crop)
246 | rgb, alpha = rasterize_gaussians( # type: ignore
247 | self.xys,
248 | depths,
249 | self.radii,
250 | conics,
251 | num_tiles_hit, # type: ignore
252 | rgbs,
253 | alphas,
254 | H,
255 | W,
256 | BLOCK_WIDTH,
257 | background=background,
258 | return_alpha=True,
259 | ) # type: ignore
260 | alpha = alpha[..., None]
261 | rgb = torch.clamp(rgb, max=1.0) # type: ignore
262 | virtual_views_rgb.append(rgb)
263 | virtual_views_alpha.append(alpha)
264 | depth_im = None
265 | rgb = torch.stack(virtual_views_rgb, dim=0).mean(dim=0)
266 | alpha = torch.stack(virtual_views_alpha, dim=0).mean(dim=0)
267 |
268 | # eval
269 | if not is_training:
270 | depth_im = rasterize_gaussians( # type: ignore
271 | self.xys,
272 | depths,
273 | self.radii,
274 | conics,
275 | num_tiles_hit, # type: ignore
276 | depths[:, None].repeat(1, 3),
277 | torch.sigmoid(opacities_crop),
278 | H,
279 | W,
280 | BLOCK_WIDTH,
281 | background=torch.zeros(3, device=self.device),
282 | )[..., 0:1] # type: ignore
283 | depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max())
284 |
285 | return {"rgb": rgb, "depth": depth_im, "accumulation": alpha, "background": background} # type: ignore
286 |
287 | def after_train(self, step: int):
288 | assert step == self.step
289 | # to save some training time, we no longer need to update those stats post refinement
290 | if self.step >= self.config.stop_split_at:
291 | return
292 | with torch.no_grad():
293 | # keep track of a moving average of grad norms
294 | visible_mask = (self.radii > 0).flatten()
295 | # BAD-Gaussians: use absgrad if enabled
296 | if self.config.enable_absgrad:
297 | assert self.xys.absgrad is not None # type: ignore
298 | grads = self.xys.absgrad.detach().norm(dim=-1) # type: ignore
299 | else:
300 | assert self.xys.grad is not None
301 | grads = self.xys.grad.detach().norm(dim=-1)
302 | # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}")
303 | if self.xys_grad_norm is None:
304 | self.xys_grad_norm = grads
305 | self.vis_counts = torch.ones_like(self.xys_grad_norm)
306 | else:
307 | assert self.vis_counts is not None
308 | self.vis_counts[visible_mask] = self.vis_counts[visible_mask] + 1
309 | self.xys_grad_norm[visible_mask] = grads[visible_mask] + self.xys_grad_norm[visible_mask]
310 | # update the max screen size, as a ratio of number of pixels
311 | if self.max_2Dsize is None:
312 | self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32)
313 | newradii = self.radii.detach()[visible_mask]
314 | self.max_2Dsize[visible_mask] = torch.maximum(
315 | self.max_2Dsize[visible_mask],
316 | newradii / float(max(self.last_size[0], self.last_size[1])),
317 | )
318 |
319 | @torch.no_grad()
320 | def get_outputs_for_camera(
321 | self,
322 | camera: Cameras,
323 | obb_box: Optional[OrientedBox] = None,
324 | mode: TrajSamplingMode = "mid",
325 | ) -> Dict[str, torch.Tensor]:
326 | """Takes in a camera, generates the raybundle, and computes the output of the model.
327 | Overridden for a camera-based gaussian model.
328 | """
329 | assert camera is not None, "must provide camera to gaussian model"
330 | self.set_crop(obb_box)
331 | # BAD-Gaussians: camera.to(device) will drop metadata
332 | metadata = camera.metadata
333 | camera = camera.to(self.device)
334 | camera.metadata = metadata
335 | outs = self.get_outputs(camera, mode=mode)
336 | return outs # type: ignore
337 |
338 | def get_loss_dict(self, outputs, batch, metrics_dict=None):
339 | loss_dict = super().get_loss_dict(outputs, batch, metrics_dict)
340 | # Add total variation loss
341 | rgb = outputs["rgb"].permute(2, 0, 1).unsqueeze(0) # H, W, 3 to 1, 3, H, W
342 | if self.config.tv_loss_lambda is not None:
343 | loss_dict["tv_loss"] = self.tv_loss(rgb) * self.config.tv_loss_lambda
344 | # Add loss from camera optimizer
345 | self.camera_optimizer.get_loss_dict(loss_dict)
346 | return loss_dict
347 |
348 | def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
349 | metrics_dict = super().get_metrics_dict(outputs, batch)
350 | # Add metrics from camera optimizer
351 | self.camera_optimizer.get_metrics_dict(metrics_dict)
352 | return metrics_dict
353 |
354 | def get_param_groups(self) -> Dict[str, List[torch.nn.Parameter]]:
355 | param_groups = super().get_param_groups()
356 | self.camera_optimizer.get_param_groups(param_groups=param_groups)
357 | return param_groups
358 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class GridGradientCentralDiff:
6 | def __init__(self, nc, padding=True, diagonal=False):
7 | self.conv_x = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False)
8 | self.conv_y = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False)
9 | self.conv_xy = None
10 | if diagonal:
11 | self.conv_xy = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False)
12 |
13 | self.padding = None
14 | if padding:
15 | self.padding = nn.ReplicationPad2d([0, 1, 0, 1])
16 |
17 | fx = torch.zeros(nc, nc, 2, 2).float().cuda()
18 | fy = torch.zeros(nc, nc, 2, 2).float().cuda()
19 | if diagonal:
20 | fxy = torch.zeros(nc, nc, 2, 2).float().cuda()
21 |
22 | fx_ = torch.tensor([[1, -1], [0, 0]]).cuda()
23 | fy_ = torch.tensor([[1, 0], [-1, 0]]).cuda()
24 | if diagonal:
25 | fxy_ = torch.tensor([[1, 0], [0, -1]]).cuda()
26 |
27 | for i in range(nc):
28 | fx[i, i, :, :] = fx_
29 | fy[i, i, :, :] = fy_
30 | if diagonal:
31 | fxy[i, i, :, :] = fxy_
32 |
33 | self.conv_x.weight = nn.Parameter(fx)
34 | self.conv_y.weight = nn.Parameter(fy)
35 | if diagonal:
36 | self.conv_xy.weight = nn.Parameter(fxy)
37 |
38 | def __call__(self, grid_2d):
39 | _image = grid_2d
40 | if self.padding is not None:
41 | _image = self.padding(_image)
42 | dx = self.conv_x(_image)
43 | dy = self.conv_y(_image)
44 |
45 | if self.conv_xy is not None:
46 | dxy = self.conv_xy(_image)
47 | return dx, dy, dxy
48 | return dx, dy
49 |
50 |
51 | class EdgeAwareVariationLoss(nn.Module):
52 | def __init__(self, in1_nc, grad_fn=GridGradientCentralDiff):
53 | super(EdgeAwareVariationLoss, self).__init__()
54 | self.in1_grad_fn = grad_fn(in1_nc)
55 | # self.in2_grad_fn = grad_fn(in2_nc)
56 |
57 | def forward(self, in1, mean=False):
58 | in1_dx, in1_dy = self.in1_grad_fn(in1)
59 | # in2_dx, in2_dy = self.in2_grad_fn(in2)
60 |
61 | abs_in1_dx, abs_in1_dy = in1_dx.abs().sum(dim=1, keepdim=True), in1_dy.abs().sum(dim=1, keepdim=True)
62 | # abs_in2_dx, abs_in2_dy = in2_dx.abs().sum(dim=1,keepdim=True), in2_dy.abs().sum(dim=1,keepdim=True)
63 |
64 | weight_dx, weight_dy = torch.exp(-abs_in1_dx), torch.exp(-abs_in1_dy)
65 |
66 | variation = weight_dx * abs_in1_dx + weight_dy * abs_in1_dy
67 |
68 | if mean != False:
69 | return variation.mean()
70 | return variation.sum()
71 |
72 |
73 | class GrayEdgeAwareVariationLoss(nn.Module):
74 | def __init__(self, in1_nc, in2_nc, grad_fn=GridGradientCentralDiff):
75 | super(GrayEdgeAwareVariationLoss, self).__init__()
76 | self.in1_grad_fn = grad_fn(in1_nc) # Gray
77 | self.in2_grad_fn = grad_fn(in2_nc) # Sharp
78 |
79 | def forward(self, in1, in2, mean=False):
80 | in1_dx, in1_dy = self.in1_grad_fn(in1)
81 | in2_dx, in2_dy = self.in2_grad_fn(in2)
82 |
83 | abs_in1_dx, abs_in1_dy = in1_dx.abs().sum(dim=1, keepdim=True), in1_dy.abs().sum(dim=1, keepdim=True)
84 | abs_in2_dx, abs_in2_dy = in2_dx.abs().sum(dim=1, keepdim=True), in2_dy.abs().sum(dim=1, keepdim=True)
85 |
86 | weight_dx, weight_dy = torch.exp(-abs_in2_dx), torch.exp(-abs_in2_dy)
87 |
88 | variation = weight_dx * abs_in1_dx + weight_dy * abs_in1_dy
89 |
90 | if mean != False:
91 | return variation.mean()
92 | return variation.sum()
93 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | BAD-Gaussians utils.
3 | """
4 | from __future__ import annotations
5 |
6 | from pathlib import Path
7 | from typing import Tuple
8 |
9 | import pypose as pp
10 | import torch
11 | from jaxtyping import Float
12 | from pypose import LieTensor
13 | from torch import Tensor
14 |
15 |
16 | class TrajectoryIO:
17 | @staticmethod
18 | def load_tum_trajectory(
19 | filename: Path
20 | ) -> Tuple[
21 | Float[Tensor, "num_poses"],
22 | Float[LieTensor, "num_poses 7"]
23 | ]:
24 | """Load TUM trajectory from file"""
25 | with open(filename, 'r', encoding="UTF-8") as f:
26 | lines = f.read().splitlines()
27 | if lines[0].startswith('#'):
28 | lines.pop(0)
29 | lines = [line.split() for line in lines]
30 | lines = [[float(val) for val in line] for line in lines]
31 | timestamps = []
32 | poses = []
33 | for line in lines:
34 | timestamps.append(torch.tensor(line[0]))
35 | poses.append(torch.tensor(line[1:]))
36 | timestamps = torch.stack(timestamps)
37 | poses = pp.SE3(torch.stack(poses))
38 | return timestamps, poses
39 |
40 | @staticmethod
41 | def load_kitti_trajectory(filename: Path) -> Float[LieTensor, "num_poses 7"]:
42 | """Load KITTI trajectory from file"""
43 | with open(filename, 'r', encoding="UTF-8") as f:
44 | lines = f.read().splitlines()
45 | lines = [line.split() for line in lines]
46 | lines = [[float(val) for val in line] for line in lines]
47 | poses = []
48 | for line in lines:
49 | poses.append(torch.tensor(line).reshape(4, 4))
50 | poses = pp.mat2SE3(torch.stack(poses).cuda())
51 | return poses
52 |
53 | @staticmethod
54 | def write_tum_trajectory(
55 | filename: Path,
56 | timestamps: Float[Tensor, "num_poses"],
57 | poses: Float[LieTensor, "num_poses 7"] | Float[Tensor, "num_poses 7"]
58 | ):
59 | """Write TUM trajectory to file"""
60 | with open(filename, 'w', encoding="UTF-8") as f:
61 | if pp.is_lietensor(poses):
62 | poses = poses.tensor()
63 | for timestamp, pose in zip(timestamps, poses):
64 | f.write(f'{timestamp.item()} {pose[0]} {pose[1]} {pose[2]} {pose[3]} {pose[4]} {pose[5]} {pose[6]}\n')
65 |
66 | @staticmethod
67 | def write_kitti_trajectory(
68 | filename: Path,
69 | poses: Float[LieTensor, "num_poses 7"] | Float[Tensor, "num_poses 7"]
70 | ):
71 | """Write KITTI trajectory to file"""
72 | with open(filename, 'w', encoding="UTF-8") as f:
73 | poses = pp.SE3(poses)
74 | poses = poses.matrix() # 4x4 matrix
75 | for pose in poses:
76 | f.write(f"{' '.join([str(p.item()) for p in pose.flatten()])}\n")
77 |
--------------------------------------------------------------------------------
/bad_gaussians/bad_viewer.py:
--------------------------------------------------------------------------------
1 | '''Viewer of BAD-Gaussians'''
2 | import numpy as np
3 | import torch
4 | import viser.transforms as vtf
5 |
6 | from nerfstudio.viewer.viewer import Viewer, VISER_NERFSTUDIO_SCALE_RATIO
7 |
8 | from bad_gaussians.bad_camera_optimizer import BadCameraOptimizer
9 |
10 |
11 | class BadViewer(Viewer):
12 | # BAD-Gaussians: Overriding original update_camera_poses because BadNerfCameraOptimizer returns LieTensor
13 | def update_camera_poses(self):
14 | # TODO this fn accounts for like ~5% of total train time
15 | # Update the train camera locations based on optimization
16 | assert self.camera_handles is not None
17 | if hasattr(self.pipeline.datamanager, "train_camera_optimizer"):
18 | camera_optimizer = self.pipeline.datamanager.train_camera_optimizer
19 | elif hasattr(self.pipeline.model, "camera_optimizer"):
20 | camera_optimizer = self.pipeline.model.camera_optimizer
21 | else:
22 | return
23 | idxs = list(self.camera_handles.keys())
24 | with torch.no_grad():
25 | assert isinstance(camera_optimizer, BadCameraOptimizer)
26 | c2ws_delta = camera_optimizer(torch.tensor(idxs, device=camera_optimizer.device))
27 | for i, key in enumerate(idxs):
28 | # both are numpy arrays
29 | c2w_orig = self.original_c2w[key]
30 | c2w_delta = c2ws_delta[i, ...]
31 | c2w = c2w_orig @ c2w_delta.matrix().cpu().numpy()
32 | R = vtf.SO3.from_matrix(c2w[:3, :3]) # type: ignore
33 | R = R @ vtf.SO3.from_x_radians(np.pi)
34 | self.camera_handles[key].position = c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO
35 | self.camera_handles[key].wxyz = R.wxyz
36 |
--------------------------------------------------------------------------------
/bad_gaussians/deblur_nerf_dataparser.py:
--------------------------------------------------------------------------------
1 | """
2 | Data parser for Deblur-NeRF COLMAP datasets.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import glob
8 | import os
9 | from dataclasses import dataclass, field
10 | from pathlib import Path
11 | from typing import List, Literal, Optional, Type
12 |
13 | import cv2
14 | import torch
15 | from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParser, ColmapDataParserConfig
16 |
17 |
18 | def _find_files(directory: Path, exts: List[str]) -> List[Path]:
19 | """Find all files in a directory that have a certain file extension.
20 |
21 | Args:
22 | directory : The directory to search for files.
23 | exts : A list of file extensions to search for. Each file extension should be in the form '*.ext'.
24 |
25 | Returns:
26 | A list of file paths for all the files that were found. The list is sorted alphabetically.
27 | """
28 | assert directory.exists()
29 | if os.path.isdir(directory):
30 | # types should be ['*.png', '*.jpg', '*.JPG', '*.PNG']
31 | files_grabbed = []
32 | for ext in exts:
33 | files_grabbed.extend(glob.glob(os.path.join(directory, ext)))
34 | if len(files_grabbed) > 0:
35 | files_grabbed = sorted(list(set(files_grabbed)))
36 | files_grabbed = [Path(f) for f in files_grabbed]
37 | return files_grabbed
38 | return []
39 |
40 |
41 | @dataclass
42 | class DeblurNerfDataParserConfig(ColmapDataParserConfig):
43 | """Deblur-NeRF dataset config"""
44 |
45 | _target: Type = field(default_factory=lambda: DeblurNerfDataParser)
46 | """target class to instantiate"""
47 | eval_mode: Literal["fraction", "filename", "interval", "all"] = "interval"
48 | """
49 | The method to use for splitting the dataset into train and eval.
50 | Fraction splits based on a percentage for train and the remaining for eval.
51 | Filename splits based on filenames containing train/eval.
52 | Interval uses every nth frame for eval (used by most academic papers, e.g. MipNerf360, GSplat).
53 | All uses all the images for any split.
54 | """
55 | eval_interval: int = 8
56 | """The interval between frames to use for eval. Only used when eval_mode is eval-interval."""
57 | images_path: Path = Path("images")
58 | """Path to images directory relative to the data path."""
59 | downscale_factor: Optional[int] = 1
60 | """The downscale factor for the images. Default: 1."""
61 | poses_bounds_path: Path = Path("poses_bounds.npy")
62 | """Path to the poses bounds file relative to the data path."""
63 | colmap_path: Path = Path("sparse/0")
64 | """Path to the colmap reconstruction directory relative to the data path."""
65 | drop_distortion: bool = False
66 | """Whether to drop the camera distortion parameters. Default: False."""
67 | scale_factor: float = 0.25
68 | """[IMPORTANT] How much to scale the camera origins by.
69 | Default: 0.25 suggested for LLFF datasets with COLMAP.
70 | """
71 |
72 |
73 | @dataclass
74 | class DeblurNerfDataParser(ColmapDataParser):
75 | """Deblur-NeRF COLMAP dataset parser"""
76 |
77 | config: DeblurNerfDataParserConfig
78 | _downscale_factor: Optional[int] = None
79 |
80 | def _get_all_images_and_cameras(self, recon_dir: Path):
81 | out = super()._get_all_images_and_cameras(recon_dir)
82 | out["frames"] = sorted(out["frames"], key=lambda x: x["file_path"])
83 | return out
84 |
85 | def _check_outputs(self, outputs):
86 | """
87 | Check if the colmap outputs are estimated on downscaled data. If so, correct the camera parameters.
88 | """
89 | # load the first image to get the image size
90 | image = cv2.imread(str(outputs.image_filenames[0]))
91 | # get the image size
92 | h, w = image.shape[:2]
93 | # check if the cx and cy are in the correct range
94 | cx = outputs.cameras.cx[0]
95 | cy = outputs.cameras.cy[0]
96 | ideal_cx = torch.tensor(w / 2)
97 | ideal_cy = torch.tensor(h / 2)
98 | if not torch.allclose(cx, ideal_cx, rtol=0.3):
99 | x_scale = cx / ideal_cx
100 | print(f"[WARN] cx is not at the center of the image, correcting... cx scale: {x_scale}")
101 | if x_scale < 1:
102 | outputs.cameras.fx *= round(1 / x_scale.item())
103 | outputs.cameras.cx *= round(1 / x_scale.item())
104 | outputs.cameras.width *= round(1 / x_scale.item())
105 | else:
106 | outputs.cameras.fx /= round(x_scale.item())
107 | outputs.cameras.cx /= round(x_scale.item())
108 | outputs.cameras.width //= round(x_scale.item())
109 |
110 | if not torch.allclose(cy, ideal_cy, rtol=0.3):
111 | y_scale = cy / ideal_cy
112 | print(f"[WARN] cy is not at the center of the image, correcting... cy scale: {y_scale}")
113 | if y_scale < 1:
114 | outputs.cameras.fy *= round(1 / y_scale.item())
115 | outputs.cameras.cy *= round(1 / y_scale.item())
116 | outputs.cameras.height *= round(1 / y_scale.item())
117 | else:
118 | outputs.cameras.fy /= round(y_scale.item())
119 | outputs.cameras.cy /= round(y_scale.item())
120 | outputs.cameras.height //= round(y_scale.item())
121 |
122 | return outputs
123 |
124 | def _check_suffixes(self, filenames):
125 | """
126 | Check if the file path exists. if not, check if the file path with the correct suffix exists.
127 | """
128 | for i, filename in enumerate(filenames):
129 | if not filename.exists():
130 | flag_found = False
131 | exts = [".png", ".PNG", ".jpg", ".JPG"]
132 | for ext in exts:
133 | new_filename = filename.with_suffix(ext)
134 | if new_filename.exists():
135 | filenames[i] = new_filename
136 | flag_found = True
137 | break
138 | if not flag_found:
139 | print(f"[WARN] {filename} not found in the images directory.")
140 |
141 | return filenames
142 |
143 | def _generate_dataparser_outputs(self, split="train"):
144 | assert self.config.data.exists(), f"Data directory {self.config.data} does not exist."
145 |
146 | if self.config.eval_mode == "interval":
147 | # find the file named `hold=n` , n is the eval_interval to be recognized
148 | hold_file = [f for f in os.listdir(self.config.data) if f.startswith('hold=')]
149 | if len(hold_file) == 0:
150 | print(f"[INFO] defaulting hold={self.config.eval_interval}")
151 | else:
152 | self.config.eval_interval = int(hold_file[0].split('=')[-1])
153 | if self.config.eval_interval < 1:
154 | self.config.eval_mode = "all"
155 |
156 | gt_folder_path = self.config.data / "images_test"
157 | if gt_folder_path.exists():
158 | outputs = super()._generate_dataparser_outputs("train")
159 | if split != "train":
160 | gt_image_filenames = _find_files(gt_folder_path, exts=["*.png", "*.jpg", "*.JPG", "*.PNG"])
161 | num_gt_images = len(gt_image_filenames)
162 | print(f"[INFO] Found {num_gt_images} ground truth sharp images.")
163 | # number of GT sharp testing images should be equal to the number of degraded training images
164 | assert num_gt_images == len(outputs.image_filenames)
165 | outputs.image_filenames = gt_image_filenames
166 | else:
167 | print("[INFO] No ground truth sharp images found.")
168 | outputs = super()._generate_dataparser_outputs(split)
169 |
170 | if self.config.drop_distortion:
171 | for camera in outputs.cameras:
172 | camera.distortion_params = None
173 | outputs.image_filenames = self._check_suffixes(outputs.image_filenames)
174 | outputs = self._check_outputs(outputs)
175 |
176 | return outputs
177 |
--------------------------------------------------------------------------------
/bad_gaussians/image_restoration_dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Image Restoration Dataloaders.
3 | """
4 |
5 | from typing import Dict, Optional, Tuple, Union
6 |
7 | import torch
8 |
9 | from nerfstudio.cameras.cameras import Cameras
10 | from nerfstudio.data.datasets.base_dataset import InputDataset
11 | from nerfstudio.data.utils.dataloaders import RandIndicesEvalDataloader, FixedIndicesEvalDataloader
12 | from nerfstudio.utils.misc import get_dict_to_torch
13 |
14 |
15 | class ImageRestorationRandIndicesEvalDataloader(RandIndicesEvalDataloader):
16 | """eval_dataloader that returns random images.
17 |
18 | Args:
19 | input_dataset: Ground-truth images for evaluation
20 | degraded_dataset: Corresponding training images with degradation.
21 | device: Device to load data to.
22 | """
23 |
24 | def __init__(
25 | self,
26 | input_dataset: InputDataset,
27 | degraded_dataset: InputDataset,
28 | device: Union[torch.device, str] = "cpu",
29 | **kwargs,
30 | ):
31 | super().__init__(input_dataset, device, **kwargs)
32 | self.degraded_dataset = degraded_dataset
33 |
34 | def get_camera(self, image_idx: int = 0) -> Tuple[Cameras, Dict]:
35 | """Returns the data for a specific image index.
36 |
37 | Args:
38 | image_idx: Camera image index
39 | """
40 | camera = self.cameras[image_idx : image_idx + 1]
41 | batch = self.input_dataset[image_idx]
42 | batch = get_dict_to_torch(batch, device=self.device, exclude=["image"])
43 | batch["degraded"] = self.degraded_dataset[image_idx]["image"]
44 | assert isinstance(batch, dict)
45 | if camera.metadata is None:
46 | camera.metadata = {}
47 | camera.metadata["cam_idx"] = image_idx
48 | return camera, batch
49 |
50 |
51 | class ImageRestorationFixedIndicesEvalDataloader(FixedIndicesEvalDataloader):
52 | """fixed_indices_eval_dataloader that returns a fixed set of indices.
53 |
54 | Args:
55 | input_dataset: Ground-truth images for evaluation
56 | degraded_dataset: Corresponding training images with degradation.
57 | image_indices: List of image indices to load data from. If None, then use all images.
58 | device: Device to load data to
59 | """
60 |
61 | def __init__(
62 | self,
63 | input_dataset: InputDataset,
64 | degraded_dataset: InputDataset,
65 | image_indices: Optional[Tuple[int]] = None,
66 | device: Union[torch.device, str] = "cpu",
67 | **kwargs,
68 | ):
69 | super().__init__(input_dataset, image_indices, device, **kwargs)
70 | self.degraded_dataset = degraded_dataset
71 |
72 | def get_camera(self, image_idx: int = 0) -> Tuple[Cameras, Dict]:
73 | """Returns the data for a specific image index.
74 |
75 | Args:
76 | image_idx: Camera image index
77 | """
78 | camera = self.cameras[image_idx : image_idx + 1]
79 | batch = self.input_dataset[image_idx]
80 | batch = get_dict_to_torch(batch, device=self.device, exclude=["image"])
81 | batch["degraded"] = self.degraded_dataset[image_idx]["image"]
82 | assert isinstance(batch, dict)
83 | if camera.metadata is None:
84 | camera.metadata = {}
85 | camera.metadata["cam_idx"] = image_idx
86 | return camera, batch
87 |
--------------------------------------------------------------------------------
/bad_gaussians/image_restoration_full_image_datamanager.py:
--------------------------------------------------------------------------------
1 | """
2 | Full image datamanager for image restoration.
3 | """
4 | from __future__ import annotations
5 |
6 | import random
7 | from copy import deepcopy
8 | from dataclasses import dataclass, field
9 | from typing import Any, Callable, Literal, Type, Union, cast, Tuple, Dict
10 |
11 | import torch
12 |
13 | from nerfstudio.cameras.cameras import Cameras
14 | from nerfstudio.data.datamanagers.base_datamanager import variable_res_collate
15 | from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager, FullImageDatamanagerConfig
16 | from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
17 | from nerfstudio.utils.rich_utils import CONSOLE
18 |
19 | from bad_gaussians.image_restoration_dataloader import ImageRestorationFixedIndicesEvalDataloader, ImageRestorationRandIndicesEvalDataloader
20 |
21 |
22 | @dataclass
23 | class ImageRestorationFullImageDataManagerConfig(FullImageDatamanagerConfig):
24 | """Datamanager for image restoration"""
25 |
26 | _target: Type = field(default_factory=lambda: ImageRestorationFullImageDataManager)
27 | """Target class to instantiate."""
28 | collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(nerfstudio_collate))
29 | """Specifies the collate function to use for the train and eval dataloaders."""
30 |
31 |
32 | class ImageRestorationFullImageDataManager(FullImageDatamanager): # pylint: disable=abstract-method
33 | """Data manager implementation for image restoration
34 | Args:
35 | config: the DataManagerConfig used to instantiate class
36 | """
37 |
38 | config: ImageRestorationFullImageDataManagerConfig
39 |
40 | def __init__(
41 | self,
42 | config: ImageRestorationFullImageDataManagerConfig,
43 | device: Union[torch.device, str] = "cpu",
44 | test_mode: Literal["test", "val", "inference"] = "val",
45 | world_size: int = 1,
46 | local_rank: int = 0,
47 | **kwargs,
48 | ):
49 | super().__init__(config, device, test_mode, world_size, local_rank, **kwargs)
50 | if self.train_dataparser_outputs is not None:
51 | cameras = self.train_dataparser_outputs.cameras
52 | if len(cameras) > 1:
53 | for i in range(1, len(cameras)):
54 | if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height:
55 | CONSOLE.print("Variable resolution, using variable_res_collate")
56 | self.config.collate_fn = variable_res_collate
57 | break
58 |
59 | self._fixed_indices_eval_dataloader = ImageRestorationFixedIndicesEvalDataloader(
60 | input_dataset=self.eval_dataset,
61 | degraded_dataset=self.train_dataset,
62 | device=self.device,
63 | num_workers=self.world_size * 4,
64 | )
65 | self.eval_dataloader = ImageRestorationRandIndicesEvalDataloader(
66 | input_dataset=self.eval_dataset,
67 | degraded_dataset=self.train_dataset,
68 | device=self.device,
69 | num_workers=self.world_size * 4,
70 | )
71 |
72 | @property
73 | def fixed_indices_eval_dataloader(self):
74 | """Returns the fixed indices eval dataloader"""
75 | return self._fixed_indices_eval_dataloader
76 |
77 | def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
78 | """Returns the next evaluation batch. Returns a Camera instead of raybundle"""
79 | image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1))
80 | # Make sure to re-populate the unseen cameras list if we have exhausted it
81 | if len(self.eval_unseen_cameras) == 0:
82 | self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))]
83 | data = deepcopy(self.cached_eval[image_idx])
84 | data["image"] = data["image"].to(self.device)
85 | assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension"
86 | camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
87 | # BAD-Gaussians: pass camera index to BadNerfCameraOptimizer
88 | if camera.metadata is None:
89 | camera.metadata = {}
90 | camera.metadata["cam_idx"] = image_idx
91 | return camera, data
92 |
93 | def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
94 | for camera, batch in self.eval_dataloader:
95 | assert camera.shape[0] == 1
96 | return camera, batch
97 | raise ValueError("No more eval images")
98 |
--------------------------------------------------------------------------------
/bad_gaussians/image_restoration_pipeline.py:
--------------------------------------------------------------------------------
1 | """Image restoration pipeline."""
2 | from __future__ import annotations
3 |
4 | import os
5 | from pathlib import Path
6 | from time import time
7 | from typing import Optional, Type
8 |
9 | import torch
10 | from dataclasses import dataclass, field
11 | from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeElapsedColumn
12 |
13 | from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager
14 | from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager
15 | from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
16 | from nerfstudio.pipelines.base_pipeline import VanillaPipeline, VanillaPipelineConfig
17 | from nerfstudio.utils import profiler
18 | from nerfstudio.utils.writer import to8b
19 |
20 | from bad_gaussians.bad_gaussians import BadGaussiansModel
21 |
22 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
23 | import cv2
24 |
25 |
26 | @dataclass
27 | class ImageRestorationPipelineConfig(VanillaPipelineConfig):
28 | """Image restoration pipeline config"""
29 |
30 | _target: Type = field(default_factory=lambda: ImageRestorationPipeline)
31 | """The target class to be instantiated."""
32 |
33 | eval_render_start_end: bool = False
34 | """whether to render and save the starting and ending virtual sharp images in eval"""
35 |
36 | eval_render_estimated: bool = False
37 | """whether to render and save the estimated degraded images with learned trajectory in eval.
38 | Note: Slow & VRAM hungry! Reduce VRAM consumption by passing argument
39 | `--pipeline.model.eval_num_rays_per_chunk=16384` or less.
40 | """
41 |
42 |
43 | class ImageRestorationPipeline(VanillaPipeline):
44 | """Image restoration pipeline"""
45 |
46 | config: ImageRestorationPipelineConfig
47 |
48 | @profiler.time_function
49 | def get_average_eval_image_metrics(
50 | self, step: Optional[int] = None, output_path: Optional[Path] = None, get_std: bool = False
51 | ):
52 | """Iterate over all the images in the eval dataset and get the average.
53 | Also saves the rendered images to disk if output_path is provided.
54 |
55 | Args:
56 | step: current training step
57 | output_path: optional path to save rendered images to
58 | get_std: Set True if you want to return std with the mean metric.
59 |
60 | Returns:
61 | metrics_dict: dictionary of metrics
62 | """
63 | self.eval()
64 | metrics_dict_list = []
65 | render_list = ["mid"]
66 | if self.config.eval_render_start_end:
67 | render_list += ["start", "end"]
68 | if self.config.eval_render_estimated:
69 | render_list += ["uniform"]
70 | assert isinstance(self.datamanager, (VanillaDataManager, ParallelDataManager, FullImageDatamanager))
71 | num_images = len(self.datamanager.fixed_indices_eval_dataloader)
72 | with Progress(
73 | TextColumn("[progress.description]{task.description}"),
74 | BarColumn(),
75 | TimeElapsedColumn(),
76 | MofNCompleteColumn(),
77 | transient=True,
78 | ) as progress:
79 | task = progress.add_task("[green]Evaluating all eval images...", total=num_images)
80 | for camera, batch in self.datamanager.fixed_indices_eval_dataloader:
81 | # time this the following line
82 | inner_start = time()
83 | image_idx = batch['image_idx']
84 | images_dict = {
85 | f"{image_idx:04}_input": batch["degraded"][:, :, :3],
86 | f"{image_idx:04}_gt": batch["image"][:, :, :3],
87 | }
88 | if isinstance(self.model, BadGaussiansModel):
89 | for mode in render_list:
90 | outputs = self.model.get_outputs_for_camera(camera, mode=mode)
91 | for key, value in outputs.items():
92 | if "uniform" == mode:
93 | filename = f"{image_idx:04}_estimated"
94 | else:
95 | filename = f"{image_idx:04}_{key}_{mode}"
96 | if "rgb" in key:
97 | images_dict[filename] = value
98 | if "depth" in key and "uniform" != mode:
99 | images_dict[filename] = value
100 | if "mid" == mode:
101 | metrics_dict, _ = self.model.get_image_metrics_and_images(outputs, batch)
102 | else:
103 | outputs = self.model.get_outputs_for_camera(camera)
104 | metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
105 | if output_path is not None:
106 | image_dir = output_path / f"{step:06}"
107 | if not image_dir.exists():
108 | image_dir.mkdir(parents=True)
109 | for filename, data in images_dict.items():
110 | data = data.detach().cpu()
111 | is_u8_image = False
112 | for tag in ["rgb", "input", "gt", "estimated", "mask"]:
113 | if tag in filename:
114 | is_u8_image = True
115 | if is_u8_image:
116 | path = str((image_dir / f"{filename}.png").resolve())
117 | cv2.imwrite(path, cv2.cvtColor(to8b(data).numpy(), cv2.COLOR_RGB2BGR))
118 | else:
119 | path = str((image_dir / f"{filename}.exr").resolve())
120 | cv2.imwrite(path, data.numpy())
121 |
122 | assert "num_rays_per_sec" not in metrics_dict
123 | height, width = camera.height, camera.width
124 | num_rays = height * width
125 | metrics_dict["num_rays_per_sec"] = (num_rays / (time() - inner_start)).item()
126 | fps_str = "fps"
127 | assert fps_str not in metrics_dict
128 | metrics_dict[fps_str] = (metrics_dict["num_rays_per_sec"] / (height * width)).item()
129 | metrics_dict_list.append(metrics_dict)
130 | progress.advance(task)
131 | # average the metrics list
132 | metrics_dict = {}
133 | for key in metrics_dict_list[0].keys():
134 | if get_std:
135 | key_std, key_mean = torch.std_mean(
136 | torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list])
137 | )
138 | metrics_dict[key] = float(key_mean)
139 | metrics_dict[f"{key}_std"] = float(key_std)
140 | else:
141 | metrics_dict[key] = float(
142 | torch.mean(torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list]))
143 | )
144 | self.train()
145 | return metrics_dict
146 |
--------------------------------------------------------------------------------
/bad_gaussians/image_restoration_trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Image restoration trainer.
3 | """
4 | from __future__ import annotations
5 |
6 | import dataclasses
7 | import functools
8 | from typing import Literal, Type
9 | from typing_extensions import assert_never
10 |
11 | import torch
12 | from dataclasses import dataclass, field
13 | from nerfstudio.engine.callbacks import TrainingCallbackAttributes
14 | from nerfstudio.engine.trainer import Trainer, TrainerConfig
15 | from nerfstudio.utils import profiler, writer
16 | from nerfstudio.utils.decorators import check_eval_enabled
17 | from nerfstudio.utils.misc import step_check
18 | from nerfstudio.utils.writer import EventName, TimeWriter
19 |
20 | from bad_gaussians.image_restoration_pipeline import ImageRestorationPipeline, ImageRestorationPipelineConfig
21 | from bad_gaussians.bad_viewer import BadViewer
22 |
23 |
24 | @dataclass
25 | class ImageRestorationTrainerConfig(TrainerConfig):
26 | """Configuration for image restoration training"""
27 | _target: Type = field(default_factory=lambda: ImageRestorationTrainer)
28 | """The target class to be instantiated."""
29 |
30 | pipeline: ImageRestorationPipelineConfig = field(default_factory=ImageRestorationPipelineConfig)
31 | """Image restoration pipeline configuration"""
32 |
33 |
34 | class ImageRestorationTrainer(Trainer):
35 | """Image restoration Trainer class"""
36 | config: ImageRestorationTrainerConfig
37 | pipeline: ImageRestorationPipeline
38 |
39 | def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
40 | """Set up the trainer.
41 |
42 | Args:
43 | test_mode: The test mode to use.
44 | """
45 | # BAD-Gaussians: Overriding original setup since we want to use our BadNerfViewer
46 | self.pipeline = self.config.pipeline.setup(
47 | device=self.device,
48 | test_mode=test_mode,
49 | world_size=self.world_size,
50 | local_rank=self.local_rank,
51 | grad_scaler=self.grad_scaler,
52 | )
53 | self.optimizers = self.setup_optimizers()
54 |
55 | # set up viewer if enabled
56 | viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename
57 | self.viewer_state, banner_messages = None, None
58 | if self.config.is_viewer_legacy_enabled() and self.local_rank == 0:
59 | assert_never(self.config.vis)
60 | if self.config.is_viewer_enabled() and self.local_rank == 0:
61 | datapath = self.config.data
62 | if datapath is None:
63 | datapath = self.base_dir
64 | self.viewer_state = BadViewer(
65 | self.config.viewer,
66 | log_filename=viewer_log_path,
67 | datapath=datapath,
68 | pipeline=self.pipeline,
69 | trainer=self,
70 | train_lock=self.train_lock,
71 | share=self.config.viewer.make_share_url,
72 | )
73 | banner_messages = self.viewer_state.viewer_info
74 | self._check_viewer_warnings()
75 |
76 | self._load_checkpoint()
77 |
78 | self.callbacks = self.pipeline.get_training_callbacks(
79 | TrainingCallbackAttributes(
80 | optimizers=self.optimizers, grad_scaler=self.grad_scaler, pipeline=self.pipeline, trainer=self
81 | )
82 | )
83 |
84 | # set up writers/profilers if enabled
85 | writer_log_path = self.base_dir / self.config.logging.relative_log_dir
86 | writer.setup_event_writer(
87 | self.config.is_wandb_enabled(),
88 | self.config.is_tensorboard_enabled(),
89 | self.config.is_comet_enabled(),
90 | log_dir=writer_log_path,
91 | experiment_name=self.config.experiment_name,
92 | project_name=self.config.project_name,
93 | )
94 | writer.setup_local_writer(
95 | self.config.logging, max_iter=self.config.max_num_iterations, banner_messages=banner_messages
96 | )
97 | writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0)
98 | profiler.setup_profiler(self.config.logging, writer_log_path)
99 |
100 | # BAD-Gaussians: disable eval if no eval images
101 | if self.pipeline.datamanager.eval_dataset.cameras is None:
102 | self.config.steps_per_eval_all_images = int(9e9)
103 | self.config.steps_per_eval_batch = int(9e9)
104 | self.config.steps_per_eval_image = int(9e9)
105 |
106 | @check_eval_enabled
107 | @profiler.time_function
108 | def eval_iteration(self, step: int) -> None:
109 | """Run one iteration with different batch/image/all image evaluations depending on step size.
110 | Args:
111 | step: Current training step.
112 | """
113 | # a batch of eval rays
114 | if step_check(step, self.config.steps_per_eval_batch):
115 | _, eval_loss_dict, eval_metrics_dict = self.pipeline.get_eval_loss_dict(step=step)
116 | eval_loss = functools.reduce(torch.add, eval_loss_dict.values())
117 | writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step)
118 | writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step)
119 | writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step)
120 | # one eval image
121 | if step_check(step, self.config.steps_per_eval_image):
122 | with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t:
123 | metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
124 | writer.put_time(
125 | name=EventName.TEST_RAYS_PER_SEC,
126 | duration=metrics_dict["num_rays"] / test_t.duration,
127 | step=step,
128 | avg_over_steps=True,
129 | )
130 | writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step)
131 | group = "Eval Images"
132 | for image_name, image in images_dict.items():
133 | writer.put_image(name=group + "/" + image_name, image=image, step=step)
134 |
135 | # all eval images
136 | if step_check(step, self.config.steps_per_eval_all_images):
137 | # BAD-Gaussians: pass output_path to save rendered images
138 | metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step, output_path=self.base_dir)
139 | writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step)
140 |
--------------------------------------------------------------------------------
/bad_gaussians/spline.py:
--------------------------------------------------------------------------------
1 | """
2 | SE(3) B-spline trajectory
3 |
4 | Created by lzzhao on 2023.09.29
5 | """
6 | from __future__ import annotations
7 |
8 | from dataclasses import dataclass, field
9 | from typing import Tuple, Type
10 |
11 | import pypose as pp
12 | import torch
13 | from jaxtyping import Float
14 | from pypose import LieTensor
15 | from torch import nn, Tensor
16 | from typing_extensions import assert_never
17 |
18 | from nerfstudio.configs.base_config import InstantiateConfig
19 |
20 | from bad_gaussians.spline_functor import linear_interpolation, cubic_bspline_interpolation
21 |
22 |
23 | @dataclass
24 | class SplineConfig(InstantiateConfig):
25 | """Configuration for spline instantiation."""
26 |
27 | _target: Type = field(default_factory=lambda: Spline)
28 | """Target class to instantiate."""
29 |
30 | degree: int = 1
31 | """Degree of the spline. 1 for linear spline, 3 for cubic spline."""
32 |
33 | sampling_interval: float = 0.1
34 | """Sampling interval of the control knots."""
35 |
36 | start_time: float = 0.0
37 | """Starting timestamp of the spline."""
38 |
39 |
40 | class Spline(nn.Module):
41 | """SE(3) spline trajectory.
42 |
43 | Args:
44 | config: the SplineConfig used to instantiate class
45 | """
46 |
47 | config: SplineConfig
48 | data: Float[LieTensor, "num_knots 7"]
49 | start_time: float
50 | end_time: float
51 | t_lower_bound: float
52 | t_upper_bound: float
53 |
54 | def __init__(self, config: SplineConfig):
55 | super().__init__()
56 | self.config = config
57 | self.data = pp.identity_SE3(0)
58 | self.order = self.config.degree + 1
59 | """Order of the spline, i.e. control knots per segment, 2 for linear, 4 for cubic"""
60 |
61 | self.set_start_time(config.start_time)
62 | self.update_end_time()
63 |
64 | def __len__(self):
65 | return self.data.shape[0]
66 |
67 | def forward(self, timestamps: Float[Tensor, "*batch_size"]) -> Float[LieTensor, "*batch_size 7"]:
68 | """Interpolate the spline at the given timestamps.
69 |
70 | Args:
71 | timestamps: Timestamps to interpolate the spline at. Range: [t_lower_bound, t_upper_bound].
72 |
73 | Returns:
74 | poses: The interpolated pose.
75 | """
76 | segment, u = self.get_segment(timestamps)
77 | u = u[..., None] # (*batch_size) to (*batch_size, interpolations=1)
78 | if self.config.degree == 1:
79 | poses = linear_interpolation(segment, u)
80 | elif self.config.degree == 3:
81 | poses = cubic_bspline_interpolation(segment, u)
82 | else:
83 | assert_never(self.config.degree)
84 | return poses.squeeze()
85 |
86 | def get_segment(
87 | self,
88 | timestamps: Float[Tensor, "*batch_size"]
89 | ) -> Tuple[
90 | Float[LieTensor, "*batch_size self.order 7"],
91 | Float[Tensor, "*batch_size"]
92 | ]:
93 | """Get the spline segment and normalized position on segment at the given timestamp.
94 |
95 | Args:
96 | timestamps: Timestamps to get the spline segment and normalized position at.
97 |
98 | Returns:
99 | segment: The spline segment.
100 | u: The normalized position on the segment.
101 | """
102 | assert torch.all(timestamps >= self.t_lower_bound)
103 | assert torch.all(timestamps <= self.t_upper_bound)
104 | batch_size = timestamps.shape
105 | relative_time = timestamps - self.start_time
106 | normalized_time = relative_time / self.config.sampling_interval
107 | start_index = torch.floor(normalized_time).int()
108 | u = normalized_time - start_index
109 | if self.config.degree == 3:
110 | start_index -= 1
111 |
112 | indices = (start_index.tile((self.order, 1)).T +
113 | torch.arange(self.order).tile((*batch_size, 1)).to(start_index.device))
114 | indices = indices[..., None].tile(7)
115 | segment = pp.SE3(torch.gather(self.data.expand(*batch_size, -1, -1), 1, indices))
116 |
117 | return segment, u
118 |
119 | def insert(self, pose: Float[LieTensor, "1 7"]):
120 | """Insert a control knot"""
121 | self.data = pp.SE3(torch.cat([self.data, pose]))
122 | self.update_end_time()
123 |
124 | def set_data(self, data: Float[LieTensor, "num_knots 7"] | pp.Parameter):
125 | """Set the spline data."""
126 | self.data = data
127 | self.update_end_time()
128 |
129 | def set_start_time(self, start_time: float):
130 | """Set the starting timestamp of the spline."""
131 | self.start_time = start_time
132 | if self.config.degree == 1:
133 | self.t_lower_bound = self.start_time
134 | elif self.config.degree == 3:
135 | self.t_lower_bound = self.start_time + self.config.sampling_interval
136 | else:
137 | assert_never(self.config.degree)
138 |
139 | def update_end_time(self):
140 | """Update the ending timestamp of the spline."""
141 | self.end_time = self.start_time + self.config.sampling_interval * (len(self) - 1)
142 | if self.config.degree == 1:
143 | self.t_upper_bound = self.end_time
144 | elif self.config.degree == 3:
145 | self.t_upper_bound = self.end_time - self.config.sampling_interval
146 | else:
147 | assert_never(self.config.degree)
148 |
--------------------------------------------------------------------------------
/bad_gaussians/spline_functor.py:
--------------------------------------------------------------------------------
1 | """
2 | SE(3) B-spline trajectory library
3 |
4 | Created by lzzhao on 2023.09.19
5 | """
6 | from __future__ import annotations
7 |
8 | import pypose as pp
9 | import scipy
10 | import torch
11 | from jaxtyping import Float
12 | from pypose import LieTensor
13 | from torch import Tensor
14 |
15 | _EPS = 1e-6
16 |
17 |
18 | def linear_interpolation_mid(
19 | ctrl_knots: Float[LieTensor, "*batch_size 2 7"],
20 | ) -> Float[LieTensor, "*batch_size 7"]:
21 | """Get the midpoint between batches of two SE(3) poses by linear interpolation.
22 |
23 | Args:
24 | ctrl_knots: The control knots.
25 |
26 | Returns:
27 | The midpoint poses.
28 | """
29 | start_pose, end_pose = ctrl_knots[..., 0, :], ctrl_knots[..., 1, :]
30 | t_start, q_start = start_pose.translation(), start_pose.rotation()
31 | t_end, q_end = end_pose.translation(), end_pose.rotation()
32 |
33 | t = (t_start + t_end) * 0.5
34 |
35 | q_tau_0 = q_start.Inv() @ q_end
36 | q_t_0 = pp.Exp(pp.so3(q_tau_0.Log() * 0.5))
37 | q = q_start @ q_t_0
38 |
39 | ret = pp.SE3(torch.cat([t, q], dim=-1))
40 | return ret
41 |
42 |
43 | def linear_interpolation(
44 | ctrl_knots: Float[LieTensor, "*batch_size 2 7"],
45 | u: Float[Tensor, "interpolations"] | Float[Tensor, "*batch_size interpolations"],
46 | enable_eps: bool = False,
47 | ) -> Float[LieTensor, "*batch_size interpolations 7"]:
48 | """Linear interpolation between batches of two SE(3) poses.
49 |
50 | Args:
51 | ctrl_knots: The control knots.
52 | u: Normalized positions between two SE(3) poses. Range: [0, 1].
53 | enable_eps: Whether to clip the normalized position with a small epsilon to avoid possible numerical issues.
54 |
55 | Returns:
56 | The interpolated poses.
57 | """
58 | start_pose, end_pose = ctrl_knots[..., 0, :], ctrl_knots[..., 1, :]
59 | batch_size = start_pose.shape[:-1]
60 | interpolations = u.shape[-1]
61 |
62 | t_start, q_start = start_pose.translation(), start_pose.rotation()
63 | t_end, q_end = end_pose.translation(), end_pose.rotation()
64 |
65 | # If u only has one dim, broadcast it to all batches. This means same interpolations for all batches.
66 | # Otherwise, u should have the same batch size as the control knots (*batch_size, interpolations).
67 | if u.dim() == 1:
68 | u = u.tile((*batch_size, 1)) # (*batch_size, interpolations)
69 | if enable_eps:
70 | u = torch.clip(u, _EPS, 1.0 - _EPS)
71 |
72 | t = pp.bvv(1 - u, t_start) + pp.bvv(u, t_end)
73 |
74 | q_tau_0 = q_start.Inv() @ q_end
75 | r_tau_0 = q_tau_0.Log()
76 | q_t_0 = pp.Exp(pp.so3(pp.bvv(u, r_tau_0)))
77 | q = q_start.unsqueeze(-2).tile((interpolations, 1)) @ q_t_0
78 |
79 | ret = pp.SE3(torch.cat([t, q], dim=-1))
80 | return ret
81 |
82 |
83 | def cubic_bspline_interpolation(
84 | ctrl_knots: Float[LieTensor, "*batch_size 4 7"],
85 | u: Float[Tensor, "interpolations"] | Float[Tensor, "*batch_size interpolations"],
86 | enable_eps: bool = False,
87 | ) -> Float[LieTensor, "*batch_size interpolations 7"]:
88 | """Cubic B-spline interpolation with batches of four SE(3) control knots.
89 |
90 | Args:
91 | ctrl_knots: The control knots.
92 | u: Normalized positions on the trajectory segments. Range: [0, 1].
93 | enable_eps: Whether to clip the normalized position with a small epsilon to avoid possible numerical issues.
94 |
95 | Returns:
96 | The interpolated poses.
97 | """
98 | batch_size = ctrl_knots.shape[:-2]
99 | interpolations = u.shape[-1]
100 |
101 | # If u only has one dim, broadcast it to all batches. This means same interpolations for all batches.
102 | # Otherwise, u should have the same batch size as the control knots (*batch_size, interpolations).
103 | if u.dim() == 1:
104 | u = u.tile((*batch_size, 1)) # (*batch_size, interpolations)
105 | if enable_eps:
106 | u = torch.clip(u, _EPS, 1.0 - _EPS)
107 |
108 | uu = u * u
109 | uuu = uu * u
110 | oos = 1.0 / 6.0 # one over six
111 |
112 | # t coefficients
113 | coeffs_t = torch.stack([
114 | oos - 0.5 * u + 0.5 * uu - oos * uuu,
115 | 4.0 * oos - uu + 0.5 * uuu,
116 | oos + 0.5 * u + 0.5 * uu - 0.5 * uuu,
117 | oos * uuu
118 | ], dim=-2)
119 |
120 | # spline t
121 | t_t = torch.sum(pp.bvv(coeffs_t, ctrl_knots.translation()), dim=-3)
122 |
123 | # q coefficients
124 | coeffs_r = torch.stack([
125 | 5.0 * oos + 0.5 * u - 0.5 * uu + oos * uuu,
126 | oos + 0.5 * u + 0.5 * uu - 2 * oos * uuu,
127 | oos * uuu
128 | ], dim=-2)
129 |
130 | # spline q
131 | q_adjacent = ctrl_knots[..., :-1, :].rotation().Inv() @ ctrl_knots[..., 1:, :].rotation()
132 | r_adjacent = q_adjacent.Log()
133 | q_ts = pp.Exp(pp.so3(pp.bvv(coeffs_r, r_adjacent)))
134 | q0 = ctrl_knots[..., 0, :].rotation() # (*batch_size, 4)
135 | q_ts = torch.cat([
136 | q0.unsqueeze(-2).tile((interpolations, 1)).unsqueeze(-3),
137 | q_ts
138 | ], dim=-3) # (*batch_size, num_ctrl_knots=4, interpolations, 4)
139 | q_t = pp.cumprod(q_ts, dim=-3, left=False)[..., -1, :, :]
140 |
141 | ret = pp.SE3(torch.cat([t_t, q_t], dim=-1))
142 | return ret
143 |
144 | def bezier_interpolation(
145 | ctrl_knots: Float[LieTensor, "*batch_size order 7"],
146 | u: Float[Tensor, "interpolations"] | Float[Tensor, "*batch_size interpolations"],
147 | enable_eps: bool = False,
148 | ) -> Float[LieTensor, "*batch_size interpolations 7"]:
149 | """Bezier interpolation with batches of SE(3) control knots.
150 |
151 | Args:
152 | ctrl_knots: The control knots.
153 | u: Normalized positions on the trajectory segments. Range: [0, 1].
154 | enable_eps: Whether to clip the normalized position with a small epsilon to avoid possible numerical issues.
155 |
156 | Returns:
157 | The interpolated poses.
158 | """
159 | batch_size = ctrl_knots.shape[:-2]
160 | order = ctrl_knots.shape[-2]
161 | degree = order - 1
162 | interpolations = u.shape[-1]
163 | binomial_coeffs = [scipy.special.binom(degree, k) for k in range(order)]
164 |
165 | # If u only has one dim, broadcast it to all batches. This means same interpolations for all batches.
166 | # Otherwise, u should have the same batch size as the control knots (*batch_size, interpolations).
167 | if u.dim() == 1:
168 | u = u.tile((*batch_size, 1)) # (*batch_size, interpolations)
169 | if enable_eps:
170 | u = torch.clip(u, _EPS, 1.0 - _EPS)
171 |
172 | # Build coefficient matrix. TODO: precompute the coefficients.
173 | bezier_coeffs = []
174 | for i in range(order):
175 | coeff_i = binomial_coeffs[i] * pow(1 - u, degree - i) * pow(u, i)
176 | bezier_coeffs.append(coeff_i)
177 | bezier_coeffs = torch.stack(bezier_coeffs, dim=1).float().to(ctrl_knots.device) # (*batch_size, order, interpolations)
178 |
179 | # (*batch_size, order, interpolations, 7)
180 | weighted_ctrl_knots = pp.se3(bezier_coeffs.unsqueeze(-1) * ctrl_knots.Log().unsqueeze(-2)).Exp()
181 | ret = pp.cumprod(weighted_ctrl_knots, dim=-3, left=False)[..., -1, :, :]
182 |
183 | return ret
184 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "bad_gaussians"
3 | version = "1.0.3"
4 |
5 | dependencies=[
6 | "nerfstudio>=1.0.3",
7 | "pypose",
8 | "gsplat>=0.1.11,<1.0.0"
9 | ]
10 |
11 | # black
12 | [tool.black]
13 | line-length = 120
14 |
15 | # pylint
16 | [tool.pylint.messages_control]
17 | max-line-length = 120
18 | generated-members = ["numpy.*", "torch.*", "cv2.*", "cv.*"]
19 |
20 | [tool.setuptools.packages.find]
21 | include = ["bad_gaussians"]
22 |
23 | [project.entry-points.'nerfstudio.dataparser_configs']
24 | deblur-nerf-data = 'bad_gaussians.bad_config_dataparser:DeblurNerfDataParser'
25 |
26 | [project.entry-points.'nerfstudio.method_configs']
27 | bad_gaussians = 'bad_gaussians.bad_config_method:bad_gaussians'
28 |
--------------------------------------------------------------------------------
/scripts/tools/export_poses_from_ckpt.py:
--------------------------------------------------------------------------------
1 | """
2 | Export optimized poses from a checkpoint.
3 | """
4 | import argparse
5 | from pathlib import Path
6 |
7 | import pypose as pp
8 | import torch
9 | from typing_extensions import assert_never
10 |
11 | from bad_gaussians.bad_utils import TrajectoryIO
12 | from bad_gaussians.deblur_nerf_dataparser import DeblurNerfDataParserConfig
13 | from bad_gaussians.spline_functor import linear_interpolation_mid, cubic_bspline_interpolation
14 |
15 | # DEVICE = 'cuda:0'
16 | DEVICE = 'cpu'
17 |
18 | def main():
19 | parser = argparse.ArgumentParser(description="Export optimized poses from a checkpoint.")
20 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint.")
21 | parser.add_argument("--data_dir", type=str, required=True, help="Path to the dataset.")
22 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.")
23 | args = parser.parse_args()
24 |
25 | ckpt_path = Path(args.ckpt_path)
26 | data_dir = Path(args.data_dir)
27 | output_path = Path(args.output)
28 | if not ckpt_path.exists():
29 | raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
30 | if not data_dir.exists():
31 | raise FileNotFoundError(f"Data directory not found: {data_dir}")
32 |
33 | print(f"Exporting optimized poses from {ckpt_path} using data from {data_dir}")
34 | export_poses(ckpt_path, data_dir, output_path)
35 |
36 |
37 | def export_poses(ckpt_path, data_dir, output_path):
38 | ckpt = torch.load(ckpt_path, map_location=DEVICE)
39 | pose_adjustment = ckpt['pipeline']['_model.camera_optimizer.pose_adjustment']
40 |
41 | parser = DeblurNerfDataParserConfig(
42 | data=data_dir,
43 | downscale_factor=1,
44 | ).setup()
45 | parser_outputs = parser.get_dataparser_outputs(split="train")
46 |
47 | print(parser_outputs)
48 |
49 | print(parser_outputs.cameras.camera_to_worlds.shape)
50 |
51 | poses = pp.mat2SE3(parser_outputs.cameras.camera_to_worlds.to(DEVICE))
52 |
53 | num_cameras, num_ctrl_knots, _ = pose_adjustment.shape
54 |
55 | if num_ctrl_knots == 2:
56 | poses_delta = linear_interpolation_mid(pose_adjustment)
57 | elif num_ctrl_knots == 4:
58 | poses_delta = cubic_bspline_interpolation(
59 | pose_adjustment,
60 | torch.tensor([0.5], device=pose_adjustment.device)
61 | ).squeeze(1)
62 | else:
63 | assert_never(num_ctrl_knots)
64 |
65 | print(poses.shape)
66 | print(poses_delta.shape)
67 | poses_optimized = poses @ poses_delta
68 |
69 | timestamps = [float(x) for x in range(num_cameras)]
70 |
71 | if output_path.is_dir():
72 | output_path = output_path / "BAD-Gaussians.txt"
73 |
74 | if output_path.exists():
75 | raise FileExistsError(f"File already exists: {output_path}")
76 |
77 | TrajectoryIO.write_tum_trajectory(output_path, torch.tensor(timestamps), poses_optimized)
78 |
79 | print(f"Exported optimized poses to {output_path}")
80 |
81 |
82 | main()
83 |
--------------------------------------------------------------------------------
/scripts/tools/export_poses_from_colmap.py:
--------------------------------------------------------------------------------
1 | """
2 | Export poses from colmap images.txt
3 | """
4 |
5 | import argparse
6 | from pathlib import Path
7 |
8 | import pypose as pp
9 | import torch
10 |
11 | from bad_gaussians.bad_utils import TrajectoryIO
12 | from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig
13 |
14 | DEVICE = 'cpu'
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser(description="Export poses from colmap results to TUM trajectory.")
19 | parser.add_argument("--input", type=str, required=True, help="Path to colmap files. E.g. ./sparse/0")
20 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.")
21 | args = parser.parse_args()
22 |
23 | input_path = Path(args.input)
24 | if not input_path.exists():
25 | raise FileNotFoundError(f"File not found: {input_path}")
26 |
27 | output_path = Path(args.output)
28 | if output_path.exists():
29 | raise FileExistsError(f"File already exists: {output_path}")
30 | if not output_path.parent.exists():
31 | output_path.parent.mkdir(parents=True)
32 |
33 | print(f"Exporting poses from {input_path} to {output_path}")
34 |
35 | dataparser = ColmapDataParserConfig(data=input_path, colmap_path=".").setup()
36 | frames = dataparser._get_all_images_and_cameras(input_path)["frames"]
37 |
38 | names = [frame["file_path"] for frame in frames]
39 | poses = [frame["transform_matrix"] for frame in frames]
40 | poses = [x for _, x in sorted(zip(names, poses))]
41 |
42 | poses = pp.mat2SE3(torch.tensor(poses))
43 | num_cameras = poses.shape[0]
44 | timestamps = [float(x) for x in range(num_cameras)]
45 | TrajectoryIO.write_tum_trajectory(output_path, torch.tensor(timestamps), poses)
46 |
47 |
48 | main()
49 |
--------------------------------------------------------------------------------
/scripts/tools/export_poses_from_npy.py:
--------------------------------------------------------------------------------
1 | """
2 | Export poses from poses_bounds.npy
3 | """
4 | import argparse
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import pypose as pp
9 | import torch
10 |
11 | from bad_gaussians.bad_utils import TrajectoryIO
12 |
13 | DEVICE = 'cpu'
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser(description="Export poses from poses_bounds.npy.")
18 | parser.add_argument("--input", type=str, required=True, help="Path to the npy file.")
19 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.")
20 | args = parser.parse_args()
21 |
22 | npy_path = Path(args.input)
23 | if not npy_path.exists():
24 | raise FileNotFoundError(f"File not found: {npy_path}")
25 |
26 | output_path = Path(args.output)
27 | if output_path.exists():
28 | raise FileExistsError(f"File already exists: {output_path}")
29 | if not output_path.parent.exists():
30 | output_path.parent.mkdir(parents=True)
31 |
32 | print(f"Exporting poses from {npy_path} to {output_path}")
33 |
34 | # Load data from npy file, shape -1, 17
35 | pose_bounds = np.load(npy_path)
36 | # (N, 17)
37 | assert pose_bounds.shape[-1] == 17
38 | # extract (N, 15) from (N, 17), reshape to (N, 3, 5)
39 | matrices = np.reshape(pose_bounds[:, :-2], (-1, 3, 5))
40 |
41 | # pop every 8th pose
42 | matrices = np.delete(matrices, np.arange(0, matrices.shape[0], 8), axis=0)
43 |
44 | poses = pp.from_matrix(torch.tensor(matrices[:, :, :4]).to(DEVICE), check=False, ltype=pp.SE3_type)
45 |
46 | num_cameras = poses.shape[0]
47 | timestamps = [float(x) for x in range(num_cameras)]
48 | TrajectoryIO.write_tum_trajectory(output_path, torch.tensor(timestamps), poses)
49 |
50 |
51 | main()
52 |
--------------------------------------------------------------------------------
/scripts/tools/interpolate_traj.py:
--------------------------------------------------------------------------------
1 | """
2 | Interpolate TUM trajectory with cubic B-spline given timestamps.
3 | """
4 |
5 | import argparse
6 | from pathlib import Path
7 |
8 | import torch
9 |
10 | from bad_gaussians.bad_utils import TrajectoryIO
11 | from bad_gaussians.spline import SplineConfig
12 |
13 | DEVICE = 'cpu'
14 | torch.set_default_dtype(torch.float64)
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser(description="Interpolate TUM trajectory with cubic B-spline given timestamps.")
19 | parser.add_argument("--input", type=str, required=True, help="Path to the TUM trajectory.")
20 | parser.add_argument("--times", type=str, required=True, help="Path to the timestamps.")
21 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.")
22 | args = parser.parse_args()
23 |
24 | input_path = Path(args.input)
25 | if not input_path.exists():
26 | raise FileNotFoundError(f"File not found: {input_path}")
27 |
28 | times_path = Path(args.times)
29 | if not times_path.exists():
30 | raise FileNotFoundError(f"File not found: {times_path}")
31 |
32 | output_path = Path(args.output)
33 | if output_path.exists():
34 | raise FileExistsError(f"File already exists: {output_path}")
35 | if not output_path.parent.exists():
36 | output_path.parent.mkdir(parents=True)
37 |
38 | print(f"Interpolating TUM trajectory from {input_path} to {output_path} using timestamps from {times_path}")
39 |
40 | timestamps, tum_trajectory = TrajectoryIO.load_tum_trajectory(input_path)
41 |
42 | linear_spline_config = SplineConfig(
43 | degree=1,
44 | sampling_interval=(timestamps[1] - timestamps[0]),
45 | start_time=timestamps[0]
46 | )
47 | cubic_spline_config = SplineConfig(
48 | degree=3,
49 | sampling_interval=(timestamps[1] - timestamps[0]),
50 | start_time=timestamps[0]
51 | )
52 | linear_spline = linear_spline_config.setup()
53 | cubic_spline = cubic_spline_config.setup()
54 |
55 | linear_spline.set_data(tum_trajectory)
56 | cubic_spline.set_data(tum_trajectory)
57 |
58 | # read timestamps from times_path
59 | timestamps = []
60 | with open(times_path, 'r') as f:
61 | for line in f:
62 | timestamps.append(float(line.strip()))
63 |
64 | timestamps = torch.tensor(timestamps, dtype=torch.float64)
65 | linear_poses = linear_spline(timestamps)
66 | cubic_poses = cubic_spline(timestamps)
67 |
68 | TrajectoryIO.write_tum_trajectory(output_path, timestamps, cubic_poses)
69 |
70 |
71 | main()
72 |
--------------------------------------------------------------------------------
/scripts/tools/kitti_to_tum.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert KITTI trajectory to TUM format.
3 |
4 | Usage:
5 | python tools/kitti_to_tum.py \
6 | --input=data/Replica/office0/traj.txt \
7 | --output=data/Replica/office0/traj_tum.txt
8 | """
9 | from __future__ import annotations
10 |
11 | import argparse
12 |
13 | import torch
14 |
15 | from bad_gaussians.bad_utils import TrajectoryIO
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--input", type=str, help="KITTI trajectory file")
20 | parser.add_argument("--output", type=str, help="TUM trajectory file")
21 | args = parser.parse_args()
22 |
23 | poses = TrajectoryIO.load_kitti_trajectory(args.input)
24 | timestamps = torch.arange(0, len(poses))
25 | TrajectoryIO.write_tum_trajectory(args.output, timestamps, poses)
26 |
--------------------------------------------------------------------------------
/scripts/tools/tum_to_kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert TUM trajectory to KITTI format.
3 |
4 | Usage:
5 | python tools/tum_to_kitti.py \
6 | --input=data/MBA-VO/archviz_sharp1/groundtruth_synced.txt \
7 | --output=data/MBA-VO/archviz_sharp1/traj.txt
8 | """
9 | from __future__ import annotations
10 |
11 | import argparse
12 |
13 | import pypose as pp
14 | import torch
15 |
16 | from bad_gaussians.bad_utils import TrajectoryIO
17 |
18 | EXTRINSICS_C2B = pp.SE3(torch.tensor([0, 0, 0, -0.5, 0.5, -0.5, 0.5]))
19 |
20 | if __name__ == "__main__":
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--input", type=str, help="TUM trajectory file")
23 | parser.add_argument("--output", type=str, help="KITTI trajectory file")
24 | args = parser.parse_args()
25 |
26 | timestamps, poses = TrajectoryIO.load_tum_trajectory(args.input)
27 | poses = poses @ EXTRINSICS_C2B
28 | TrajectoryIO.write_kitti_trajectory(args.output, poses)
29 |
--------------------------------------------------------------------------------
/tests/data/traj.txt:
--------------------------------------------------------------------------------
1 | # timestamp tx ty tz qx qy qz qw
2 | 4090.80000000 -0.18243881 0.03254267 1.40185679 -0.00173178 0.00025674 0.13497136 -0.99084795
3 | 4090.81000000 -0.16016920 0.02855653 1.40362307 -0.00146991 0.01213789 0.14717870 -0.98903435
4 | 4090.82000000 -0.14011959 0.02443513 1.40735599 -0.00198271 0.02332832 0.15841676 -0.98709472
5 | 4090.83000000 -0.12291090 0.02040767 1.41288924 -0.00327176 0.03378584 0.16837958 -0.98513762
6 | 4090.84000000 -0.10899440 0.01671207 1.41993526 -0.00527953 0.04342881 0.17683406 -0.98326791
7 | 4090.85000000 -0.09858330 0.01356689 1.42803953 -0.00785705 0.05211664 0.18365110 -0.98157750
8 | 4090.86000000 -0.09180161 0.01116808 1.43669946 -0.01082509 0.05969303 0.18873273 -0.98015280
9 | 4090.87000000 -0.08875377 0.00969074 1.44541554 -0.01399851 0.06600539 0.19198015 -0.97907658
10 | 4090.88000000 -0.08952103 0.00927787 1.45371707 -0.01719510 0.07091593 0.19329913 -0.97842256
11 | 4090.89000000 -0.09408323 0.00999614 1.46123216 -0.02026520 0.07434388 0.19264398 -0.97823852
12 | 4090.90000000 -0.10228215 0.01181107 1.46772354 -0.02310640 0.07628709 0.19004132 -0.97853496
13 | 4090.91000000 -0.11393439 0.01463452 1.47301262 -0.02562829 0.07676846 0.18553636 -0.97929876
14 | 4090.92000000 -0.12887693 0.01834777 1.47694328 -0.02773746 0.07581525 0.17917102 -0.98050009
15 | 4090.93000000 -0.14695062 0.02279857 1.47939470 -0.02934404 0.07346718 0.17099498 -0.98209074
16 | 4090.94000000 -0.16790675 0.02777456 1.48036611 -0.03039753 0.06982359 0.16112294 -0.98399190
17 | 4090.95000000 -0.19135118 0.03300418 1.48001450 -0.03090727 0.06506945 0.14976596 -0.98609374
18 | 4090.96000000 -0.21684238 0.03821645 1.47853498 -0.03090178 0.05941785 0.13716523 -0.98828149
19 | 4090.97000000 -0.24392885 0.04317235 1.47610302 -0.03041002 0.05308818 0.12356516 -0.99044865
20 | 4090.98000000 -0.27215072 0.04767324 1.47286446 -0.02946022 0.04629492 0.10921146 -0.99250276
21 | 4090.99000000 -0.30104959 0.05155234 1.46895543 -0.02808955 0.03920172 0.09435056 -0.99437024
22 | 4091.00000000 -0.33017403 0.05467048 1.46452261 -0.02635045 0.03189638 0.07922914 -0.99599750
23 | 4091.01000000 -0.35906556 0.05693766 1.45970096 -0.02429888 0.02444497 0.06409508 -0.99734840
24 | 4091.02000000 -0.38725823 0.05832269 1.45460346 -0.02199011 0.01691482 0.04919677 -0.99840373
25 | 4091.03000000 -0.41428787 0.05885385 1.44930689 -0.01947043 0.00937571 0.03478069 -0.99916130
26 | 4091.04000000 -0.43971627 0.05861918 1.44378425 -0.01673474 0.00189726 0.02108405 -0.99963584
27 | 4091.05000000 -0.46314950 0.05775755 1.43786593 -0.01370147 -0.00545027 0.00832970 -0.99985658
28 | 4091.06000000 -0.48422138 0.05643706 1.43132212 -0.01026296 -0.01259155 -0.00326694 -0.99986272
29 | 4091.07000000 -0.50258397 0.05484005 1.42389867 -0.00630679 -0.01944668 -0.01349425 -0.99969993
30 | 4091.08000000 -0.51792112 0.05314779 1.41536639 -0.00174138 -0.02593295 -0.02215351 -0.99941667
31 | 4091.09000000 -0.53003169 0.05151765 1.40571642 0.00339966 -0.03195767 -0.02911037 -0.99905943
32 | 4091.10000000 -0.53886722 0.05007196 1.39526624 0.00890261 -0.03741195 -0.03432306 -0.99867063
33 | 4091.11000000 -0.54441339 0.04891248 1.38443232 0.01450070 -0.04218010 -0.03777889 -0.99829020
34 | 4091.12000000 -0.54664754 0.04812773 1.37363738 0.01992689 -0.04614062 -0.03947000 -0.99795595
35 | 4091.13000000 -0.54555868 0.04778838 1.36328536 0.02493200 -0.04917457 -0.03940206 -0.99770123
36 | 4091.14000000 -0.54124719 0.04792123 1.35366642 -0.02934762 0.05120863 0.03764078 0.99754677
37 | 4091.15000000 -0.53398271 0.04849199 1.34490758 -0.03311942 0.05223923 0.03433627 0.99749446
38 | 4091.16000000 -0.52409072 0.04943453 1.33709817 -0.03622871 0.05228240 0.02966474 0.99753398
39 | 4091.17000000 -0.51190864 0.05066986 1.33033471 -0.03866195 0.05135183 0.02380575 0.99764800
40 | 4091.18000000 -0.49778911 0.05211288 1.32470566 -0.04041808 0.04946564 0.01694362 0.99781383
41 | 4091.19000000 -0.48212413 0.05366856 1.32021547 -0.04155425 0.04668800 0.00928221 0.99800166
42 | 4091.20000000 -0.46535643 0.05523369 1.31673746 -0.04221047 0.04315405 0.00105214 0.99817578
--------------------------------------------------------------------------------
/tests/test_spline.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from pathlib import Path
3 |
4 | import pypose as pp
5 | import torch
6 |
7 | from bad_gaussians.bad_utils import TrajectoryIO
8 | from bad_gaussians.spline import SplineConfig
9 |
10 | torch.set_default_dtype(torch.float64)
11 | torch.set_printoptions(precision=5, sci_mode=False)
12 |
13 |
14 | class TestSpline(unittest.TestCase):
15 |
16 | def test_spline_basic(self):
17 | print(f"\n{'='*10}{self.id()}{'='*10}")
18 | linear_spline_config = SplineConfig(degree=1)
19 | cubic_spline_config = SplineConfig(degree=3)
20 | linear_spline = linear_spline_config.setup()
21 | cubic_spline = cubic_spline_config.setup()
22 |
23 | p0 = pp.identity_SE3(1)
24 | p1 = pp.randn_SE3(1)
25 | p2 = pp.randn_SE3(1)
26 | p3 = pp.randn_SE3(1)
27 |
28 | p0.translation()[0] = 1
29 | p1.translation()[0] = 2
30 | p2.translation()[0] = 3
31 | p3.translation()[0] = 4
32 |
33 | for spline in [linear_spline, cubic_spline]:
34 | spline.insert(p0)
35 | spline.insert(p1)
36 | spline.insert(p2)
37 | spline.insert(p3)
38 |
39 | pose0 = linear_spline(torch.tensor([0.15001]))
40 | print(f"linear: {pose0}")
41 | self.assertTrue(torch.isclose(pose0.translation(), torch.tensor([2.5001, 2.5001, 2.5001]), atol=1e-3).all())
42 | pose1 = cubic_spline(torch.tensor([0.15001]))
43 | print(f"cubic: {pose1}")
44 | self.assertTrue(torch.isclose(pose1.translation(), torch.tensor([2.5001, 2.5001, 2.5001]), atol=1e-3).all())
45 |
46 | def test_spline_tum(self):
47 | print(f"\n{'='*10}{self.id()}{'='*10}")
48 | timestamps, tum_trajectory = TrajectoryIO.load_tum_trajectory(Path("data/traj.txt"))
49 |
50 | linear_spline_config = SplineConfig(
51 | degree=1,
52 | sampling_interval=(timestamps[1] - timestamps[0]),
53 | start_time=timestamps[0]
54 | )
55 | cubic_spline_config = SplineConfig(
56 | degree=3,
57 | sampling_interval=(timestamps[1] - timestamps[0]),
58 | start_time=timestamps[0]
59 | )
60 | linear_spline = linear_spline_config.setup()
61 | cubic_spline = cubic_spline_config.setup()
62 |
63 | linear_spline.set_data(tum_trajectory)
64 | cubic_spline.set_data(tum_trajectory)
65 |
66 | poses0 = linear_spline(torch.tensor([timestamps[20], timestamps[30]]))
67 | print(f"linear: {poses0}")
68 | self.assertTrue(torch.isclose(poses0[0], tum_trajectory[20], atol=1e-3).all())
69 | self.assertTrue(torch.isclose(poses0[1], tum_trajectory[30], atol=1e-3).all())
70 | poses1 = cubic_spline(torch.tensor([timestamps[20], timestamps[30]]))
71 | print(f"cubic: {poses1}")
72 | self.assertTrue(torch.isclose(poses1[0], tum_trajectory[20], atol=1e-3).all())
73 | self.assertTrue(torch.isclose(poses1[1], tum_trajectory[30], atol=1e-3).all())
74 |
75 |
76 | if __name__ == "__main__":
77 | unittest.main()
78 |
--------------------------------------------------------------------------------