├── .gitignore ├── LICENSE ├── README.md ├── bad_gaussians ├── bad_camera_optimizer.py ├── bad_config_dataparser.py ├── bad_config_method.py ├── bad_gaussians.py ├── bad_losses.py ├── bad_utils.py ├── bad_viewer.py ├── deblur_nerf_dataparser.py ├── image_restoration_dataloader.py ├── image_restoration_full_image_datamanager.py ├── image_restoration_pipeline.py ├── image_restoration_trainer.py ├── spline.py └── spline_functor.py ├── pyproject.toml ├── scripts └── tools │ ├── export_poses_from_ckpt.py │ ├── export_poses_from_colmap.py │ ├── export_poses_from_npy.py │ ├── interpolate_traj.py │ ├── kitti_to_tum.py │ └── tum_to_kitti.py └── tests ├── data └── traj.txt └── test_spline.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | !scripts/downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .envrc 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Experiments and outputs 165 | outputs/ 166 | exports/ 167 | renders/ 168 | # tensorboard log files 169 | events.out.* 170 | 171 | # Data 172 | data 173 | !*/data 174 | 175 | # Misc 176 | .vscode/ 177 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

😈BAD-Gaussians: Bundle-Adjusted Deblur Gaussian Splatting

2 | 3 | 4 | 5 | 6 | This as an official implementation of our arXiv 2024 paper 7 | [**BAD-Gaussians**: Bundle Adjusted Deblur Gaussian Splatting](https://lingzhezhao.github.io/BAD-Gaussians/), based on the [nerfstudio](https://github.com/nerfstudio-project/nerfstudio) framework. 8 | 9 | ## Demo 10 | 11 | Deblurring & novel-view synthesis results on [Deblur-NeRF](https://github.com/limacv/Deblur-NeRF/)'s real-world motion-blurred data: 12 | 13 | 14 | 15 | > Left: BAD-Gaussians deblured novel-view renderings; 16 | > 17 | > Right: Input images. 18 | 19 | 20 | ## Quickstart 21 | 22 | ### 1. Installation 23 | 24 | You may check out the original [`nerfstudio`](https://github.com/nerfstudio-project/nerfstudio) repo for prerequisites and dependencies. 25 | Currently, our codebase is tested with nerfstudio v1.0.3. 26 | 27 | TL;DR: You can install `nerfstudio` with: 28 | 29 | ```bash 30 | # (Optional) create a fresh conda env 31 | conda create --name nerfstudio -y "python<3.11" 32 | conda activate nerfstudio 33 | 34 | # install dependencies 35 | pip install --upgrade pip setuptools 36 | pip install "torch==2.1.2+cu118" "torchvision==0.16.2+cu118" --extra-index-url https://download.pytorch.org/whl/cu118 37 | 38 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit 39 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 40 | 41 | # install nerfstudio! 42 | pip install nerfstudio==1.0.3 43 | ``` 44 | 45 | Then you can install this repo as a Python package with: 46 | 47 | ```bash 48 | pip install git+https://github.com/WU-CVGL/BAD-Gaussians 49 | ``` 50 | 51 | ### 2. Prepare the dataset 52 | 53 | #### Deblur-NeRF Synthetic Dataset (Re-rendered) 54 | 55 | As described in the previous BAD-NeRF paper, we re-rendered Deblur-NeRF's synthetic dataset with 51 interpolations per blurry image. 56 | 57 | Additionally, in the previous BAD-NeRF paper, we directly run COLMAP on blurry images only, with neither ground-truth 58 | camera intrinsics nor sharp novel-view images. We find this is quite challenging for COLMAP - it may fail to 59 | reconstruct the scene and we need to re-run COLMAP for serval times. To this end, we provided a new set of data, 60 | where we ran COLMAP with ground-truth camera intrinsics over both blurry and sharp novel-view images, 61 | named `bad-nerf-gtK-colmap-nvs`: 62 | 63 | [Download link](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EoCe3vaC9V5Fl74DjbGriwcBKj1nbB0HQFSWnVTLX7qT9A) 64 | 65 | #### Deblur-NeRF Real Dataset 66 | 67 | You can directly download the `real_camera_motion_blur` folder from [Deblur-NeRF](https://limacv.github.io/deblurnerf/). 68 | 69 | #### Your Custom Dataset 70 | 71 | 1. Use the [`ns-process-data` tool from Nerfstudio](https://docs.nerf.studio/reference/cli/ns_process_data.html) 72 | to process deblur-nerf training images. 73 | 74 | For example, if the 75 | [dataset from BAD-NeRF](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EsgdW2cRic5JqerhNbTsxtkBqy9m6cbnb2ugYZtvaib3qA?e=bjK7op) 76 | is in `llff_data`, execute: 77 | 78 | ``` 79 | ns-process-data images \ 80 | --data llff_data/blurtanabata/images \ 81 | --output-dir data/my_data/blurtanabata 82 | ``` 83 | 84 | 2. The folder `data/my_data/blurtanabata` is ready. 85 | 86 | > Note: Although nerfstudio does not model the NDC scene contraction for LLFF data, 87 | > we found that `scale_factor = 0.25` works well on LLFF datasets. 88 | > If your data is captured in a [LLFF fashion](https://github.com/Fyusion/LLFF#using-your-own-input-images-for-view-synthesis) (i.e. forward-facing), 89 | > instead of object-centric like Mip-NeRF 360, 90 | > you can pass the `scale_factor = 0.25` parameter to the nerfstudio dataparser (which is already set to default in our `DeblurNerfDataParser`), 91 | > e.g., `ns-train bad-gaussians --data data/my_data/my_seq --vis viewer+tensorboard nerfstudio-data --scale_factor 0.25` 92 | 93 | ### 3. Training 94 | 95 | 1. For `Deblur-NeRF synthetic` dataset, train with: 96 | 97 | ```bash 98 | ns-train bad-gaussians \ 99 | --data data/bad-nerf-gtK-colmap-nvs/blurtanabata \ 100 | --pipeline.model.camera-optimizer.mode "linear" \ 101 | --vis viewer+tensorboard \ 102 | deblur-nerf-data 103 | ``` 104 | where 105 | - `--data data/bad-nerf-gtK-colmap-nvs/blurtanabata` is the relative path of the data sequence; 106 | - `--pipeline.model.camera-optimizer.mode "linear"` enables linear camera pose interpolation 107 | - `--vis viewer+tensorboard` enables both the viewer and the tensorboard metrics saving 108 | - `deblur-nerf-data` chooses the DeblurNerfDataparser 109 | 110 | 2. For `Deblur-NeRF real` dataset with `downscale_factor=4`, train with: 111 | ```bash 112 | ns-train bad-gaussians \ 113 | --data data/real_camera_motion_blur/blurdecoration \ 114 | --pipeline.model.camera-optimizer.mode "cubic" \ 115 | --vis viewer+tensorboard \ 116 | deblur-nerf-data \ 117 | --downscale_factor 4 118 | ``` 119 | where 120 | - `--pipeline.model.camera-optimizer.mode "cubic"` enables cubic B-spline; 121 | - `--downscale_factor 4` after the `deblur-nerf-data` tells the DeblurNerfDataparser to downscale the images' width and height to `1/4` of its originals. 122 | 123 | 3. For `Deblur-NeRF real` dataset with *full resolution*, train with: 124 | ```bash 125 | ns-train bad-gaussians \ 126 | --data data/real_camera_motion_blur/blurdecoration \ 127 | --pipeline.model.camera-optimizer.mode "cubic" \ 128 | --pipeline.model.camera-optimizer.num_virtual_views 15 \ 129 | --pipeline.model.num_downscales 2 \ 130 | --pipeline.model.resolution_schedule 3000 \ 131 | --vis viewer+tensorboard \ 132 | deblur-nerf-data 133 | ``` 134 | where 135 | - `--pipeline.model.camera-optimizer.mode "cubic"` enables cubic B-spline; 136 | - `--pipeline.model.camera-optimizer.num_virtual_views 15` increases the number of virtual cameras to 15; 137 | - `--pipeline.model.num_downscales 2` and `--pipeline.model.resolution_schedule 3000` enables coarse-to-fine training. 138 | 139 | 4. For custom data processed with `ns-process-data`, train with: 140 | 141 | ```bash 142 | ns-train bad-gaussians \ 143 | --data data/my_data/blurtanabata \ 144 | --vis viewer+tensorboard \ 145 | nerfstudio-data --eval_mode "all" 146 | ``` 147 | 148 | > Note: To improve reconstruction quality on your custom dataset, you may need to add 149 | some of the parameters to enable *cubic B-spline*, *more virtual cameras* and 150 | *coarse-to-fine training*, as shown in the examples above. 151 | 152 | ### 4. Render videos 153 | 154 | This command will generate a trajectory with the camera poses of the training images, keeping their original order, interplate 10 frames between adjacent images with a frame rate of 30. It will load the `config.yml` and save the video to `renders/.mp4`. 155 | 156 | ```bash 157 | ns-render interpolate \ 158 | --load-config outputs/blurtanabata/bad-gaussians//config.yml \ 159 | --pose-source train \ 160 | --frame-rate 30 \ 161 | --interpolation-steps 10 \ 162 | --output-path renders/.mp4 163 | ``` 164 | 165 | > Note1: You can add the `--render-nearest-camera True` option to compare with the blurry inputs, but it will slow down the rendering process significantly. 166 | > 167 | > Note2: The working directory when executing this command must be the parent of `outputs`, i.e. the same directory when training. 168 | > 169 | > Note3: You can find more information of this command in the [nerfstudio docs](https://docs.nerf.studio/reference/cli/ns_render.html#ns-render). 170 | 171 | ### 5. Export the 3D Gaussians 172 | 173 | This command will load the `config.yml` and export a `splat.ply` into the same folder: 174 | 175 | ```bash 176 | ns-export gaussian-splat \ 177 | --load-config outputs/blurtanabata/bad-gaussians//config.yml \ 178 | --output-dir outputs/blurtanabata/bad-gaussians/ 179 | ``` 180 | 181 | > Note1: We use `rasterize_mode = antialiased` by default. However, if you want to export the 3D gaussians, since the `antialiased` mode (i.e. *Mip-Splatting*) is not supported by most 3D-GS viewers, it is better to turn if off during training using: `--pipeline.model.rasterize_mode "classic"` 182 | > 183 | > Note2: The working directory when executing this command must be the parent of `outputs`, i.e. the same directory when training. 184 | 185 | Then you can visualize this file with any viewer, for example the [WebGL Viewer](https://antimatter15.com/splat/). 186 | 187 | ### 6. Debug with your IDE 188 | 189 | Open this repo with your IDE, create a configuration, and set the executing python script path to 190 | `/nerfstudio/scripts/train.py`, with the parameters above. 191 | 192 | 193 | ## Citation 194 | 195 | If you find this useful, please consider citing: 196 | 197 | ```bibtex 198 | @inproceedings{zhao2024badgaussians, 199 | author = {Zhao, Lingzhe and Wang, Peng and Liu, Peidong}, 200 | title = {Bad-gaussians: Bundle adjusted deblur gaussian splatting}, 201 | booktitle = {European Conference on Computer Vision (ECCV)}, 202 | year = {2024} 203 | } 204 | ``` 205 | 206 | ## Acknowledgments 207 | 208 | - Kudos to the [Nerfstudio](https://github.com/nerfstudio-project/) and [gsplat](https://github.com/nerfstudio-project/gsplat) contributors for their amazing works: 209 | 210 | ```bibtex 211 | @inproceedings{nerfstudio, 212 | title = {Nerfstudio: A Modular Framework for Neural Radiance Field Development}, 213 | author = { 214 | Tancik, Matthew and Weber, Ethan and Ng, Evonne and Li, Ruilong and Yi, Brent 215 | and Kerr, Justin and Wang, Terrance and Kristoffersen, Alexander and Austin, 216 | Jake and Salahi, Kamyar and Ahuja, Abhik and McAllister, David and Kanazawa, 217 | Angjoo 218 | }, 219 | year = 2023, 220 | booktitle = {ACM SIGGRAPH 2023 Conference Proceedings}, 221 | series = {SIGGRAPH '23} 222 | } 223 | 224 | @software{Ye_gsplat, 225 | author = {Ye, Vickie and Turkulainen, Matias, and the Nerfstudio team}, 226 | title = {{gsplat}}, 227 | url = {https://github.com/nerfstudio-project/gsplat} 228 | } 229 | 230 | @misc{ye2023mathematical, 231 | title={Mathematical Supplement for the $\texttt{gsplat}$ Library}, 232 | author={Vickie Ye and Angjoo Kanazawa}, 233 | year={2023}, 234 | eprint={2312.02121}, 235 | archivePrefix={arXiv}, 236 | primaryClass={cs.MS} 237 | } 238 | ``` 239 | 240 | - Kudos to the [pypose](https://github.com/pypose/pypose) contributors for their amazing library: 241 | 242 | ```bibtex 243 | @inproceedings{wang2023pypose, 244 | title = {{PyPose}: A Library for Robot Learning with Physics-based Optimization}, 245 | author = {Wang, Chen and Gao, Dasong and Xu, Kuan and Geng, Junyi and Hu, Yaoyu and Qiu, Yuheng and Li, Bowen and Yang, Fan and Moon, Brady and Pandey, Abhinav and Aryan and Xu, Jiahe and Wu, Tianhao and He, Haonan and Huang, Daning and Ren, Zhongqiang and Zhao, Shibo and Fu, Taimeng and Reddy, Pranay and Lin, Xiao and Wang, Wenshan and Shi, Jingnan and Talak, Rajat and Cao, Kun and Du, Yi and Wang, Han and Yu, Huai and Wang, Shanzhao and Chen, Siyu and Kashyap, Ananth and Bandaru, Rohan and Dantu, Karthik and Wu, Jiajun and Xie, Lihua and Carlone, Luca and Hutter, Marco and Scherer, Sebastian}, 246 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 247 | year = {2023} 248 | } 249 | ``` 250 | -------------------------------------------------------------------------------- /bad_gaussians/bad_camera_optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pose and Intrinsics Optimizers 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import functools 8 | from copy import deepcopy 9 | from typing import List, Literal, Optional, Type, Union 10 | 11 | import pypose as pp 12 | import torch 13 | from dataclasses import dataclass, field 14 | from jaxtyping import Float, Int 15 | from pypose import LieTensor 16 | from torch import Tensor 17 | from typing_extensions import assert_never 18 | 19 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig 20 | from nerfstudio.cameras.cameras import Cameras 21 | from nerfstudio.cameras.rays import RayBundle 22 | 23 | from bad_gaussians.spline_functor import ( 24 | bezier_interpolation, 25 | cubic_bspline_interpolation, 26 | linear_interpolation, 27 | linear_interpolation_mid, 28 | ) 29 | 30 | 31 | TrajSamplingMode = Literal["uniform", "start", "mid", "end"] 32 | """How to sample the camera trajectory""" 33 | 34 | 35 | @dataclass 36 | class BadCameraOptimizerConfig(CameraOptimizerConfig): 37 | """Configuration of BAD-Gaussians camera optimizer.""" 38 | 39 | _target: Type = field(default_factory=lambda: BadCameraOptimizer) 40 | """The target class to be instantiated.""" 41 | 42 | mode: Literal["off", "linear", "cubic", "bezier"] = "linear" 43 | """Pose optimization strategy to use. 44 | linear: linear interpolation on SE(3); 45 | cubic: cubic b-spline interpolation on SE(3). 46 | bezier: Bezier curve interpolation on SE(3). 47 | """ 48 | 49 | bezier_degree: int = 9 50 | """Degree of the Bezier curve. Only used when mode is bezier.""" 51 | 52 | trans_l2_penalty: float = 0.0 53 | """L2 penalty on translation parameters.""" 54 | 55 | rot_l2_penalty: float = 0.0 56 | """L2 penalty on rotation parameters.""" 57 | 58 | num_virtual_views: int = 10 59 | """The number of samples used to model the motion-blurring.""" 60 | 61 | initial_noise_se3_std: float = 1e-5 62 | """Initial perturbation to pose delta on se(3). Must be non-zero to prevent NaNs.""" 63 | 64 | 65 | class BadCameraOptimizer(CameraOptimizer): 66 | """Optimization for BAD-Gaussians virtual camera trajectories.""" 67 | 68 | config: BadCameraOptimizerConfig 69 | 70 | def __init__( 71 | self, 72 | config: BadCameraOptimizerConfig, 73 | num_cameras: int, 74 | device: Union[torch.device, str], 75 | non_trainable_camera_indices: Optional[Int[Tensor, "num_non_trainable_cameras"]] = None, 76 | **kwargs, 77 | ) -> None: 78 | super().__init__(CameraOptimizerConfig(), num_cameras, device) 79 | self.config = config 80 | self.num_cameras = num_cameras 81 | self.device = device 82 | self.non_trainable_camera_indices = non_trainable_camera_indices 83 | self.dof = 6 84 | """Degrees of freedom of manifold, i.e. number of dimensions of the tangent space""" 85 | self.dim = 7 86 | """Dimentions of pose parameterization. Three for translation, 4-tuple for quaternion""" 87 | 88 | # Initialize learnable parameters. 89 | if self.config.mode == "off": 90 | return 91 | elif self.config.mode == "linear": 92 | self.num_control_knots = 2 93 | elif self.config.mode == "cubic": 94 | self.num_control_knots = 4 95 | elif self.config.mode == "bezier": 96 | self.num_control_knots = self.config.bezier_degree 97 | else: 98 | assert_never(self.config.mode) 99 | 100 | self.pose_adjustment = pp.Parameter( 101 | pp.randn_se3( 102 | (num_cameras, self.num_control_knots), 103 | sigma=self.config.initial_noise_se3_std, 104 | device=device, 105 | ), 106 | ) 107 | 108 | def forward( 109 | self, 110 | indices: Int[Tensor, "camera_indices"], 111 | mode: TrajSamplingMode = "mid", 112 | ) -> Float[LieTensor, "camera_indices self.num_control_knots self.dof"]: 113 | """Indexing into camera adjustments. 114 | 115 | Args: 116 | indices: indices of Cameras to optimize. 117 | mode: interpolate between start and end, or return start / mid / end. 118 | 119 | Returns: 120 | Transformation matrices from optimized camera coordinates 121 | to given camera coordinates. 122 | """ 123 | outputs = [] 124 | 125 | # Apply learned transformation delta. 126 | if self.config.mode == "off": 127 | pass 128 | else: 129 | indices = indices.int() 130 | unique_indices, lut = torch.unique(indices, return_inverse=True) 131 | camera_opt = self.pose_adjustment[unique_indices].Exp() 132 | outputs.append(self._interpolate(camera_opt, mode)[lut]) 133 | 134 | # Detach non-trainable indices by setting to identity transform 135 | if ( 136 | torch.is_grad_enabled() 137 | and self.non_trainable_camera_indices is not None 138 | and len(indices) > len(self.non_trainable_camera_indices) 139 | ): 140 | if self.non_trainable_camera_indices.device != self.pose_adjustment.device: 141 | self.non_trainable_camera_indices = self.non_trainable_camera_indices.to(self.pose_adjustment.device) 142 | nt = self.non_trainable_camera_indices 143 | outputs[0][nt] = outputs[0][nt].clone().detach() 144 | 145 | # Return: identity if no transforms are needed, otherwise composite transforms together. 146 | if len(outputs) == 0: 147 | return pp.identity_SE3(*indices.shape, device=self.device) 148 | return functools.reduce(pp.mul, outputs) 149 | 150 | def _interpolate( 151 | self, 152 | camera_opt: Float[LieTensor, "*batch_size self.num_control_knots self.dof"], 153 | mode: TrajSamplingMode 154 | ) -> Float[Tensor, "*batch_size interpolations self.dof"]: 155 | if mode == "uniform": 156 | u = torch.linspace( 157 | start=0, 158 | end=1, 159 | steps=self.config.num_virtual_views, 160 | device=camera_opt.device, 161 | ) 162 | if self.config.mode == "linear": 163 | return linear_interpolation(camera_opt, u) 164 | elif self.config.mode == "cubic": 165 | return cubic_bspline_interpolation(camera_opt, u) 166 | elif self.config.mode == "bezier": 167 | return bezier_interpolation(camera_opt, u) 168 | else: 169 | assert_never(self.config.mode) 170 | elif mode == "mid": 171 | if self.config.mode == "linear": 172 | return linear_interpolation_mid(camera_opt) 173 | elif self.config.mode == "cubic": 174 | return cubic_bspline_interpolation( 175 | camera_opt, 176 | torch.tensor([0.5], device=camera_opt.device) 177 | ).squeeze(1) 178 | elif self.config.mode == "bezier": 179 | return bezier_interpolation(camera_opt, torch.tensor([0.5], device=camera_opt.device)).squeeze(1) 180 | else: 181 | assert_never(self.config.mode) 182 | elif mode == "start": 183 | if self.config.mode == "linear": 184 | return camera_opt[..., 0, :] 185 | elif self.config.mode == "cubic": 186 | return cubic_bspline_interpolation( 187 | camera_opt, 188 | torch.tensor([0.0], device=camera_opt.device) 189 | ).squeeze(1) 190 | elif self.config.mode == "bezier": 191 | return bezier_interpolation(camera_opt, torch.tensor([0.0], device=camera_opt.device)).squeeze(1) 192 | else: 193 | assert_never(self.config.mode) 194 | elif mode == "end": 195 | if self.config.mode == "linear": 196 | return camera_opt[..., 1, :] 197 | elif self.config.mode == "cubic": 198 | return cubic_bspline_interpolation( 199 | camera_opt, 200 | torch.tensor([1.0], device=camera_opt.device) 201 | ).squeeze(1) 202 | elif self.config.mode == "bezier": 203 | return bezier_interpolation(camera_opt, torch.tensor([1.0], device=camera_opt.device)).squeeze(1) 204 | else: 205 | assert_never(self.config.mode) 206 | else: 207 | assert_never(mode) 208 | 209 | def apply_to_raybundle(self, *args, **kwargs): 210 | """Not implemented. Should not be called.""" 211 | raise NotImplementedError("Not implemented in BAD-Gaussians. Please checkout https://github.com/WU-CVGL/Bad-RFs") 212 | 213 | def apply_to_camera(self, camera: Cameras, mode: TrajSamplingMode) -> List[Cameras]: 214 | """Apply pose correction to the camera""" 215 | # assert camera.metadata is not None, "Must provide camera metadata" 216 | # assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata" 217 | if self.config.mode == "off" or camera.metadata is None or not ("cam_idx" in camera.metadata): 218 | # print("[WARN] Cannot get cam_idx in camera.metadata") 219 | return [deepcopy(camera)] 220 | 221 | camera_idx = camera.metadata["cam_idx"] 222 | c2w = camera.camera_to_worlds # shape: (1, 4, 4) 223 | if c2w.shape[1] == 3: 224 | c2w = torch.cat([c2w, torch.tensor([0, 0, 0, 1], device=c2w.device).view(1, 1, 4)], dim=1) 225 | 226 | poses_delta = self((torch.tensor([camera_idx])), mode) 227 | 228 | if mode == "uniform": 229 | c2ws = c2w.tile((self.config.num_virtual_views, 1, 1)) # shape: (num_virtual_views, 4, 4) 230 | c2ws_adjusted = torch.bmm(c2ws, poses_delta.matrix().squeeze()) 231 | cameras_list = [deepcopy(camera) for _ in range(self.config.num_virtual_views)] 232 | for i in range(self.config.num_virtual_views): 233 | cameras_list[i].camera_to_worlds = c2ws_adjusted[None, i, :, :] 234 | else: 235 | c2w_adjusted = torch.bmm(c2w, poses_delta.matrix()) 236 | cameras_list = [deepcopy(camera)] 237 | cameras_list[0].camera_to_worlds = c2w_adjusted 238 | 239 | assert len(cameras_list) 240 | return cameras_list 241 | 242 | def get_metrics_dict(self, metrics_dict: dict) -> None: 243 | """Get camera optimizer metrics""" 244 | if self.config.mode != "off": 245 | metrics_dict["camera_opt_trajectory_translation"] = ( 246 | self.pose_adjustment[:, 1, :3] - self.pose_adjustment[:, 0, :3]).norm() 247 | metrics_dict["camera_opt_trajectory_rotation"] = ( 248 | self.pose_adjustment[:, 1, 3:] - self.pose_adjustment[:, 0, 3:]).norm() 249 | metrics_dict["camera_opt_translation"] = 0 250 | metrics_dict["camera_opt_rotation"] = 0 251 | for i in range(self.num_control_knots): 252 | metrics_dict["camera_opt_translation"] += self.pose_adjustment[:, i, :3].norm() 253 | metrics_dict["camera_opt_rotation"] += self.pose_adjustment[:, i, 3:].norm() 254 | 255 | def get_loss_dict(self, loss_dict: dict) -> None: 256 | """Add regularization""" 257 | pass 258 | -------------------------------------------------------------------------------- /bad_gaussians/bad_config_dataparser.py: -------------------------------------------------------------------------------- 1 | """ 2 | BAD-Gaussians dataparser configs. 3 | """ 4 | 5 | from nerfstudio.plugins.registry_dataparser import DataParserSpecification 6 | 7 | from bad_gaussians.deblur_nerf_dataparser import DeblurNerfDataParserConfig 8 | 9 | DeblurNerfDataParser = DataParserSpecification(config=DeblurNerfDataParserConfig()) 10 | -------------------------------------------------------------------------------- /bad_gaussians/bad_config_method.py: -------------------------------------------------------------------------------- 1 | """ 2 | BAD-Gaussians configs. 3 | """ 4 | 5 | from nerfstudio.configs.base_config import ViewerConfig 6 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 7 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 8 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 9 | from nerfstudio.plugins.types import MethodSpecification 10 | 11 | from bad_gaussians.bad_camera_optimizer import BadCameraOptimizerConfig 12 | from bad_gaussians.bad_gaussians import BadGaussiansModelConfig 13 | from bad_gaussians.image_restoration_full_image_datamanager import ImageRestorationFullImageDataManagerConfig 14 | from bad_gaussians.image_restoration_pipeline import ImageRestorationPipelineConfig 15 | from bad_gaussians.image_restoration_trainer import ImageRestorationTrainerConfig 16 | 17 | 18 | bad_gaussians = MethodSpecification( 19 | config=ImageRestorationTrainerConfig( 20 | method_name="bad-gaussians", 21 | steps_per_eval_image=500, 22 | steps_per_eval_batch=500, 23 | steps_per_save=2000, 24 | steps_per_eval_all_images=500, 25 | max_num_iterations=30001, 26 | mixed_precision=False, 27 | use_grad_scaler=False, 28 | gradient_accumulation_steps={"camera_opt": 25}, 29 | pipeline=ImageRestorationPipelineConfig( 30 | eval_render_start_end=True, 31 | eval_render_estimated=True, 32 | datamanager=ImageRestorationFullImageDataManagerConfig( 33 | # cache_images="gpu", # reduce CPU usage, caused by pin_memory()? 34 | dataparser=NerfstudioDataParserConfig( 35 | load_3D_points=True, 36 | eval_mode="interval", 37 | eval_interval=8, 38 | ), 39 | ), 40 | model=BadGaussiansModelConfig( 41 | camera_optimizer=BadCameraOptimizerConfig(mode="linear", num_virtual_views=10), 42 | use_scale_regularization=True, 43 | continue_cull_post_densification=False, 44 | cull_alpha_thresh=5e-3, 45 | densify_grad_thresh=4e-4, 46 | num_downscales=0, 47 | resolution_schedule=250, 48 | tv_loss_lambda=None, 49 | ), 50 | ), 51 | optimizers={ 52 | "means": { 53 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 54 | "scheduler": ExponentialDecaySchedulerConfig( 55 | lr_final=1.6e-6, 56 | max_steps=30000, 57 | ), 58 | }, 59 | "features_dc": { 60 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 61 | "scheduler": None, 62 | }, 63 | "features_rest": { 64 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 65 | "scheduler": None, 66 | }, 67 | "opacities": { 68 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 69 | "scheduler": None, 70 | }, 71 | "scales": { 72 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 73 | "scheduler": None, 74 | }, 75 | "quats": { 76 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), 77 | "scheduler": None 78 | }, 79 | "camera_opt": { 80 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 81 | "scheduler": ExponentialDecaySchedulerConfig( 82 | lr_final=1e-5, 83 | max_steps=30000, 84 | ), 85 | }, 86 | }, 87 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 88 | vis="viewer", 89 | ), 90 | description="Implementation of BAD-Gaussians", 91 | ) 92 | -------------------------------------------------------------------------------- /bad_gaussians/bad_gaussians.py: -------------------------------------------------------------------------------- 1 | """ 2 | BAD-Gaussians model. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from dataclasses import dataclass, field 8 | from typing import Dict, List, Literal, Optional, Tuple, Type, Union 9 | 10 | import torch 11 | 12 | from gsplat.project_gaussians import project_gaussians 13 | from gsplat.rasterize import rasterize_gaussians 14 | from gsplat.sh import spherical_harmonics 15 | 16 | from nerfstudio.cameras.cameras import Cameras 17 | from nerfstudio.data.scene_box import OrientedBox 18 | from nerfstudio.models.splatfacto import SplatfactoModel, SplatfactoModelConfig 19 | from nerfstudio.model_components import renderers 20 | 21 | from bad_gaussians.bad_camera_optimizer import ( 22 | BadCameraOptimizer, 23 | BadCameraOptimizerConfig, 24 | TrajSamplingMode, 25 | ) 26 | from bad_gaussians.bad_losses import EdgeAwareVariationLoss 27 | 28 | 29 | @dataclass 30 | class BadGaussiansModelConfig(SplatfactoModelConfig): 31 | """BAD-Gaussians Model config""" 32 | 33 | _target: Type = field(default_factory=lambda: BadGaussiansModel) 34 | """The target class to be instantiated.""" 35 | 36 | rasterize_mode: Literal["classic", "antialiased"] = "antialiased" 37 | """ 38 | Classic mode of rendering will use the EWA volume splatting with a [0.3, 0.3] screen space blurring kernel. This 39 | approach is however not suitable to render tiny gaussians at higher or lower resolution than the captured, which 40 | results "aliasing-like" artifacts. The antialiased mode overcomes this limitation by calculating compensation factors 41 | and apply them to the opacities of gaussians to preserve the total integrated density of splats. 42 | 43 | However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that 44 | were implemented for classic mode can not render antialiased mode PLY properly without modifications. 45 | Refs: 46 | 1. https://github.com/nerfstudio-project/gsplat/pull/117 47 | 2. https://github.com/nerfstudio-project/nerfstudio/pull/2888 48 | 3. Yu, Zehao, et al. "Mip-Splatting: Alias-free 3D Gaussian Splatting." arXiv preprint arXiv:2311.16493 (2023). 49 | """ 50 | 51 | camera_optimizer: BadCameraOptimizerConfig = field(default_factory=BadCameraOptimizerConfig) 52 | """Config of the camera optimizer to use""" 53 | 54 | cull_alpha_thresh: float = 0.005 55 | """Threshold for alpha to cull gaussians. Default: 0.1 in splatfacto, 0.005 in splatfacto-big.""" 56 | 57 | densify_grad_thresh: float = 4e-4 58 | """[IMPORTANT] Threshold for gradient to densify gaussians. Default: 4e-4. Tune it smaller with complex scenes.""" 59 | 60 | continue_cull_post_densification: bool = False 61 | """Whether to continue culling after densification. Default: True in splatfacto, False in splatfacto-big.""" 62 | 63 | resolution_schedule: int = 250 64 | """training starts at 1/d resolution, every n steps this is doubled. 65 | Default: 250. Use 3000 with high resolution images (e.g. higher than 1920x1080). 66 | """ 67 | 68 | num_downscales: int = 0 69 | """at the beginning, resolution is 1/2^d, where d is this number. Default: 0. Use 2 with high resolution images.""" 70 | 71 | enable_absgrad: bool = False 72 | """Whether to enable absgrad for gaussians. (It affects param tuning of densify_grad_thresh) 73 | Default: False. Ref: (https://github.com/nerfstudio-project/nerfstudio/pull/3113) 74 | """ 75 | 76 | tv_loss_lambda: Optional[float] = None 77 | """weight of total variation loss""" 78 | 79 | 80 | class BadGaussiansModel(SplatfactoModel): 81 | """BAD-Gaussians Model 82 | 83 | Args: 84 | config: configuration to instantiate model 85 | """ 86 | 87 | config: BadGaussiansModelConfig 88 | camera_optimizer: BadCameraOptimizer 89 | 90 | def __init__(self, config: BadGaussiansModelConfig, **kwargs) -> None: 91 | super().__init__(config=config, **kwargs) 92 | # Scale densify_grad_thresh by the number of virtual views 93 | self.config.densify_grad_thresh /= self.config.camera_optimizer.num_virtual_views 94 | # (Experimental) Total variation loss 95 | self.tv_loss = EdgeAwareVariationLoss(in1_nc=3) 96 | 97 | def populate_modules(self) -> None: 98 | super().populate_modules() 99 | self.camera_optimizer: BadCameraOptimizer = self.config.camera_optimizer.setup( 100 | num_cameras=self.num_train_data, device="cpu" 101 | ) 102 | 103 | def forward( 104 | self, 105 | camera: Cameras, 106 | mode: TrajSamplingMode = "uniform", 107 | ) -> Dict[str, Union[torch.Tensor, List]]: 108 | return self.get_outputs(camera, mode) 109 | 110 | def get_outputs( 111 | self, camera: Cameras, 112 | mode: TrajSamplingMode = "uniform", 113 | ) -> Dict[str, Union[torch.Tensor, List]]: 114 | """Takes in a Camera and returns a dictionary of outputs. 115 | 116 | Args: 117 | camera: Input camera. This camera should have all the needed information to compute the outputs. 118 | 119 | Returns: 120 | Outputs of model. (ie. rendered colors) 121 | """ 122 | if not isinstance(camera, Cameras): 123 | print("Called get_outputs with not a camera") 124 | return {} 125 | assert camera.shape[0] == 1, "Only one camera at a time" 126 | 127 | is_training = self.training and torch.is_grad_enabled() 128 | 129 | # BAD-Gaussians: get virtual cameras 130 | virtual_cameras = self.camera_optimizer.apply_to_camera(camera, mode) 131 | 132 | if is_training: 133 | if self.config.background_color == "random": 134 | background = torch.rand(3, device=self.device) 135 | elif self.config.background_color == "white": 136 | background = torch.ones(3, device=self.device) 137 | elif self.config.background_color == "black": 138 | background = torch.zeros(3, device=self.device) 139 | else: 140 | background = self.background_color.to(self.device) 141 | else: 142 | # logic for setting the background of the scene 143 | if renderers.BACKGROUND_COLOR_OVERRIDE is not None: 144 | background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device) 145 | else: 146 | background = self.background_color.to(self.device) 147 | if self.crop_box is not None and not is_training: 148 | crop_ids = self.crop_box.within(self.means).squeeze() 149 | if crop_ids.sum() == 0: 150 | rgb = background.repeat(int(camera.height.item()), int(camera.width.item()), 1) 151 | depth = background.new_ones(*rgb.shape[:2], 1) * 10 152 | accumulation = background.new_zeros(*rgb.shape[:2], 1) 153 | return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background} 154 | else: 155 | crop_ids = None 156 | 157 | camera_downscale = self._get_downscale_factor() 158 | 159 | for cam in virtual_cameras: 160 | cam.rescale_output_resolution(1 / camera_downscale) 161 | 162 | # BAD-Gaussians: render virtual views 163 | virtual_views_rgb = [] 164 | virtual_views_alpha = [] 165 | for cam in virtual_cameras: 166 | # shift the camera to center of scene looking at center 167 | R = cam.camera_to_worlds[0, :3, :3] # 3 x 3 168 | T = cam.camera_to_worlds[0, :3, 3:4] # 3 x 1 169 | # flip the z axis to align with gsplat conventions 170 | R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype)) 171 | R = R @ R_edit 172 | # analytic matrix inverse to get world2camera matrix 173 | R_inv = R.T 174 | T_inv = -R_inv @ T 175 | viewmat = torch.eye(4, device=R.device, dtype=R.dtype) 176 | viewmat[:3, :3] = R_inv 177 | viewmat[:3, 3:4] = T_inv 178 | # update last_size 179 | W, H = int(cam.width.item()), int(cam.height.item()) 180 | self.last_size = (H, W) 181 | 182 | if crop_ids is not None: 183 | opacities_crop = self.opacities[crop_ids] 184 | means_crop = self.means[crop_ids] 185 | features_dc_crop = self.features_dc[crop_ids] 186 | features_rest_crop = self.features_rest[crop_ids] 187 | scales_crop = self.scales[crop_ids] 188 | quats_crop = self.quats[crop_ids] 189 | else: 190 | opacities_crop = self.opacities 191 | means_crop = self.means 192 | features_dc_crop = self.features_dc 193 | features_rest_crop = self.features_rest 194 | scales_crop = self.scales 195 | quats_crop = self.quats 196 | 197 | colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) 198 | BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default 199 | self.xys, depths, self.radii, conics, comp, num_tiles_hit, cov3d = project_gaussians( 200 | means_crop, 201 | torch.exp(scales_crop), 202 | 1, 203 | quats_crop / quats_crop.norm(dim=-1, keepdim=True), 204 | viewmat.squeeze()[:3, :], 205 | cam.fx.item(), 206 | cam.fy.item(), 207 | cam.cx.item(), 208 | cam.cy.item(), 209 | H, 210 | W, 211 | BLOCK_WIDTH, 212 | ) # type: ignore 213 | 214 | # rescale the camera back to original dimensions before returning 215 | cam.rescale_output_resolution(camera_downscale) 216 | 217 | if (self.radii).sum() == 0: 218 | rgb = background.repeat(H, W, 1) 219 | depth = background.new_ones(*rgb.shape[:2], 1) * 10 220 | accumulation = background.new_zeros(*rgb.shape[:2], 1) 221 | 222 | return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background} 223 | 224 | # Important to allow xys grads to populate properly 225 | if is_training: 226 | self.xys.retain_grad() 227 | 228 | if self.config.sh_degree > 0: 229 | viewdirs = means_crop.detach() - cam.camera_to_worlds.detach()[..., :3, 3] # (N, 3) 230 | viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) 231 | n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) 232 | rgbs = spherical_harmonics(n, viewdirs, colors_crop) 233 | rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore 234 | else: 235 | rgbs = torch.sigmoid(colors_crop[:, 0, :]) 236 | 237 | # rescale the camera back to original dimensions 238 | # cam.rescale_output_resolution(camera_downscale) 239 | assert (num_tiles_hit > 0).any() # type: ignore 240 | 241 | # apply the compensation of screen space blurring to gaussians 242 | if self.config.rasterize_mode == "antialiased": 243 | alphas = torch.sigmoid(opacities_crop) * comp[:, None] 244 | elif self.config.rasterize_mode == "classic": 245 | alphas = torch.sigmoid(opacities_crop) 246 | rgb, alpha = rasterize_gaussians( # type: ignore 247 | self.xys, 248 | depths, 249 | self.radii, 250 | conics, 251 | num_tiles_hit, # type: ignore 252 | rgbs, 253 | alphas, 254 | H, 255 | W, 256 | BLOCK_WIDTH, 257 | background=background, 258 | return_alpha=True, 259 | ) # type: ignore 260 | alpha = alpha[..., None] 261 | rgb = torch.clamp(rgb, max=1.0) # type: ignore 262 | virtual_views_rgb.append(rgb) 263 | virtual_views_alpha.append(alpha) 264 | depth_im = None 265 | rgb = torch.stack(virtual_views_rgb, dim=0).mean(dim=0) 266 | alpha = torch.stack(virtual_views_alpha, dim=0).mean(dim=0) 267 | 268 | # eval 269 | if not is_training: 270 | depth_im = rasterize_gaussians( # type: ignore 271 | self.xys, 272 | depths, 273 | self.radii, 274 | conics, 275 | num_tiles_hit, # type: ignore 276 | depths[:, None].repeat(1, 3), 277 | torch.sigmoid(opacities_crop), 278 | H, 279 | W, 280 | BLOCK_WIDTH, 281 | background=torch.zeros(3, device=self.device), 282 | )[..., 0:1] # type: ignore 283 | depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max()) 284 | 285 | return {"rgb": rgb, "depth": depth_im, "accumulation": alpha, "background": background} # type: ignore 286 | 287 | def after_train(self, step: int): 288 | assert step == self.step 289 | # to save some training time, we no longer need to update those stats post refinement 290 | if self.step >= self.config.stop_split_at: 291 | return 292 | with torch.no_grad(): 293 | # keep track of a moving average of grad norms 294 | visible_mask = (self.radii > 0).flatten() 295 | # BAD-Gaussians: use absgrad if enabled 296 | if self.config.enable_absgrad: 297 | assert self.xys.absgrad is not None # type: ignore 298 | grads = self.xys.absgrad.detach().norm(dim=-1) # type: ignore 299 | else: 300 | assert self.xys.grad is not None 301 | grads = self.xys.grad.detach().norm(dim=-1) 302 | # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") 303 | if self.xys_grad_norm is None: 304 | self.xys_grad_norm = grads 305 | self.vis_counts = torch.ones_like(self.xys_grad_norm) 306 | else: 307 | assert self.vis_counts is not None 308 | self.vis_counts[visible_mask] = self.vis_counts[visible_mask] + 1 309 | self.xys_grad_norm[visible_mask] = grads[visible_mask] + self.xys_grad_norm[visible_mask] 310 | # update the max screen size, as a ratio of number of pixels 311 | if self.max_2Dsize is None: 312 | self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) 313 | newradii = self.radii.detach()[visible_mask] 314 | self.max_2Dsize[visible_mask] = torch.maximum( 315 | self.max_2Dsize[visible_mask], 316 | newradii / float(max(self.last_size[0], self.last_size[1])), 317 | ) 318 | 319 | @torch.no_grad() 320 | def get_outputs_for_camera( 321 | self, 322 | camera: Cameras, 323 | obb_box: Optional[OrientedBox] = None, 324 | mode: TrajSamplingMode = "mid", 325 | ) -> Dict[str, torch.Tensor]: 326 | """Takes in a camera, generates the raybundle, and computes the output of the model. 327 | Overridden for a camera-based gaussian model. 328 | """ 329 | assert camera is not None, "must provide camera to gaussian model" 330 | self.set_crop(obb_box) 331 | # BAD-Gaussians: camera.to(device) will drop metadata 332 | metadata = camera.metadata 333 | camera = camera.to(self.device) 334 | camera.metadata = metadata 335 | outs = self.get_outputs(camera, mode=mode) 336 | return outs # type: ignore 337 | 338 | def get_loss_dict(self, outputs, batch, metrics_dict=None): 339 | loss_dict = super().get_loss_dict(outputs, batch, metrics_dict) 340 | # Add total variation loss 341 | rgb = outputs["rgb"].permute(2, 0, 1).unsqueeze(0) # H, W, 3 to 1, 3, H, W 342 | if self.config.tv_loss_lambda is not None: 343 | loss_dict["tv_loss"] = self.tv_loss(rgb) * self.config.tv_loss_lambda 344 | # Add loss from camera optimizer 345 | self.camera_optimizer.get_loss_dict(loss_dict) 346 | return loss_dict 347 | 348 | def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: 349 | metrics_dict = super().get_metrics_dict(outputs, batch) 350 | # Add metrics from camera optimizer 351 | self.camera_optimizer.get_metrics_dict(metrics_dict) 352 | return metrics_dict 353 | 354 | def get_param_groups(self) -> Dict[str, List[torch.nn.Parameter]]: 355 | param_groups = super().get_param_groups() 356 | self.camera_optimizer.get_param_groups(param_groups=param_groups) 357 | return param_groups 358 | -------------------------------------------------------------------------------- /bad_gaussians/bad_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class GridGradientCentralDiff: 6 | def __init__(self, nc, padding=True, diagonal=False): 7 | self.conv_x = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False) 8 | self.conv_y = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False) 9 | self.conv_xy = None 10 | if diagonal: 11 | self.conv_xy = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False) 12 | 13 | self.padding = None 14 | if padding: 15 | self.padding = nn.ReplicationPad2d([0, 1, 0, 1]) 16 | 17 | fx = torch.zeros(nc, nc, 2, 2).float().cuda() 18 | fy = torch.zeros(nc, nc, 2, 2).float().cuda() 19 | if diagonal: 20 | fxy = torch.zeros(nc, nc, 2, 2).float().cuda() 21 | 22 | fx_ = torch.tensor([[1, -1], [0, 0]]).cuda() 23 | fy_ = torch.tensor([[1, 0], [-1, 0]]).cuda() 24 | if diagonal: 25 | fxy_ = torch.tensor([[1, 0], [0, -1]]).cuda() 26 | 27 | for i in range(nc): 28 | fx[i, i, :, :] = fx_ 29 | fy[i, i, :, :] = fy_ 30 | if diagonal: 31 | fxy[i, i, :, :] = fxy_ 32 | 33 | self.conv_x.weight = nn.Parameter(fx) 34 | self.conv_y.weight = nn.Parameter(fy) 35 | if diagonal: 36 | self.conv_xy.weight = nn.Parameter(fxy) 37 | 38 | def __call__(self, grid_2d): 39 | _image = grid_2d 40 | if self.padding is not None: 41 | _image = self.padding(_image) 42 | dx = self.conv_x(_image) 43 | dy = self.conv_y(_image) 44 | 45 | if self.conv_xy is not None: 46 | dxy = self.conv_xy(_image) 47 | return dx, dy, dxy 48 | return dx, dy 49 | 50 | 51 | class EdgeAwareVariationLoss(nn.Module): 52 | def __init__(self, in1_nc, grad_fn=GridGradientCentralDiff): 53 | super(EdgeAwareVariationLoss, self).__init__() 54 | self.in1_grad_fn = grad_fn(in1_nc) 55 | # self.in2_grad_fn = grad_fn(in2_nc) 56 | 57 | def forward(self, in1, mean=False): 58 | in1_dx, in1_dy = self.in1_grad_fn(in1) 59 | # in2_dx, in2_dy = self.in2_grad_fn(in2) 60 | 61 | abs_in1_dx, abs_in1_dy = in1_dx.abs().sum(dim=1, keepdim=True), in1_dy.abs().sum(dim=1, keepdim=True) 62 | # abs_in2_dx, abs_in2_dy = in2_dx.abs().sum(dim=1,keepdim=True), in2_dy.abs().sum(dim=1,keepdim=True) 63 | 64 | weight_dx, weight_dy = torch.exp(-abs_in1_dx), torch.exp(-abs_in1_dy) 65 | 66 | variation = weight_dx * abs_in1_dx + weight_dy * abs_in1_dy 67 | 68 | if mean != False: 69 | return variation.mean() 70 | return variation.sum() 71 | 72 | 73 | class GrayEdgeAwareVariationLoss(nn.Module): 74 | def __init__(self, in1_nc, in2_nc, grad_fn=GridGradientCentralDiff): 75 | super(GrayEdgeAwareVariationLoss, self).__init__() 76 | self.in1_grad_fn = grad_fn(in1_nc) # Gray 77 | self.in2_grad_fn = grad_fn(in2_nc) # Sharp 78 | 79 | def forward(self, in1, in2, mean=False): 80 | in1_dx, in1_dy = self.in1_grad_fn(in1) 81 | in2_dx, in2_dy = self.in2_grad_fn(in2) 82 | 83 | abs_in1_dx, abs_in1_dy = in1_dx.abs().sum(dim=1, keepdim=True), in1_dy.abs().sum(dim=1, keepdim=True) 84 | abs_in2_dx, abs_in2_dy = in2_dx.abs().sum(dim=1, keepdim=True), in2_dy.abs().sum(dim=1, keepdim=True) 85 | 86 | weight_dx, weight_dy = torch.exp(-abs_in2_dx), torch.exp(-abs_in2_dy) 87 | 88 | variation = weight_dx * abs_in1_dx + weight_dy * abs_in1_dy 89 | 90 | if mean != False: 91 | return variation.mean() 92 | return variation.sum() 93 | -------------------------------------------------------------------------------- /bad_gaussians/bad_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | BAD-Gaussians utils. 3 | """ 4 | from __future__ import annotations 5 | 6 | from pathlib import Path 7 | from typing import Tuple 8 | 9 | import pypose as pp 10 | import torch 11 | from jaxtyping import Float 12 | from pypose import LieTensor 13 | from torch import Tensor 14 | 15 | 16 | class TrajectoryIO: 17 | @staticmethod 18 | def load_tum_trajectory( 19 | filename: Path 20 | ) -> Tuple[ 21 | Float[Tensor, "num_poses"], 22 | Float[LieTensor, "num_poses 7"] 23 | ]: 24 | """Load TUM trajectory from file""" 25 | with open(filename, 'r', encoding="UTF-8") as f: 26 | lines = f.read().splitlines() 27 | if lines[0].startswith('#'): 28 | lines.pop(0) 29 | lines = [line.split() for line in lines] 30 | lines = [[float(val) for val in line] for line in lines] 31 | timestamps = [] 32 | poses = [] 33 | for line in lines: 34 | timestamps.append(torch.tensor(line[0])) 35 | poses.append(torch.tensor(line[1:])) 36 | timestamps = torch.stack(timestamps) 37 | poses = pp.SE3(torch.stack(poses)) 38 | return timestamps, poses 39 | 40 | @staticmethod 41 | def load_kitti_trajectory(filename: Path) -> Float[LieTensor, "num_poses 7"]: 42 | """Load KITTI trajectory from file""" 43 | with open(filename, 'r', encoding="UTF-8") as f: 44 | lines = f.read().splitlines() 45 | lines = [line.split() for line in lines] 46 | lines = [[float(val) for val in line] for line in lines] 47 | poses = [] 48 | for line in lines: 49 | poses.append(torch.tensor(line).reshape(4, 4)) 50 | poses = pp.mat2SE3(torch.stack(poses).cuda()) 51 | return poses 52 | 53 | @staticmethod 54 | def write_tum_trajectory( 55 | filename: Path, 56 | timestamps: Float[Tensor, "num_poses"], 57 | poses: Float[LieTensor, "num_poses 7"] | Float[Tensor, "num_poses 7"] 58 | ): 59 | """Write TUM trajectory to file""" 60 | with open(filename, 'w', encoding="UTF-8") as f: 61 | if pp.is_lietensor(poses): 62 | poses = poses.tensor() 63 | for timestamp, pose in zip(timestamps, poses): 64 | f.write(f'{timestamp.item()} {pose[0]} {pose[1]} {pose[2]} {pose[3]} {pose[4]} {pose[5]} {pose[6]}\n') 65 | 66 | @staticmethod 67 | def write_kitti_trajectory( 68 | filename: Path, 69 | poses: Float[LieTensor, "num_poses 7"] | Float[Tensor, "num_poses 7"] 70 | ): 71 | """Write KITTI trajectory to file""" 72 | with open(filename, 'w', encoding="UTF-8") as f: 73 | poses = pp.SE3(poses) 74 | poses = poses.matrix() # 4x4 matrix 75 | for pose in poses: 76 | f.write(f"{' '.join([str(p.item()) for p in pose.flatten()])}\n") 77 | -------------------------------------------------------------------------------- /bad_gaussians/bad_viewer.py: -------------------------------------------------------------------------------- 1 | '''Viewer of BAD-Gaussians''' 2 | import numpy as np 3 | import torch 4 | import viser.transforms as vtf 5 | 6 | from nerfstudio.viewer.viewer import Viewer, VISER_NERFSTUDIO_SCALE_RATIO 7 | 8 | from bad_gaussians.bad_camera_optimizer import BadCameraOptimizer 9 | 10 | 11 | class BadViewer(Viewer): 12 | # BAD-Gaussians: Overriding original update_camera_poses because BadNerfCameraOptimizer returns LieTensor 13 | def update_camera_poses(self): 14 | # TODO this fn accounts for like ~5% of total train time 15 | # Update the train camera locations based on optimization 16 | assert self.camera_handles is not None 17 | if hasattr(self.pipeline.datamanager, "train_camera_optimizer"): 18 | camera_optimizer = self.pipeline.datamanager.train_camera_optimizer 19 | elif hasattr(self.pipeline.model, "camera_optimizer"): 20 | camera_optimizer = self.pipeline.model.camera_optimizer 21 | else: 22 | return 23 | idxs = list(self.camera_handles.keys()) 24 | with torch.no_grad(): 25 | assert isinstance(camera_optimizer, BadCameraOptimizer) 26 | c2ws_delta = camera_optimizer(torch.tensor(idxs, device=camera_optimizer.device)) 27 | for i, key in enumerate(idxs): 28 | # both are numpy arrays 29 | c2w_orig = self.original_c2w[key] 30 | c2w_delta = c2ws_delta[i, ...] 31 | c2w = c2w_orig @ c2w_delta.matrix().cpu().numpy() 32 | R = vtf.SO3.from_matrix(c2w[:3, :3]) # type: ignore 33 | R = R @ vtf.SO3.from_x_radians(np.pi) 34 | self.camera_handles[key].position = c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO 35 | self.camera_handles[key].wxyz = R.wxyz 36 | -------------------------------------------------------------------------------- /bad_gaussians/deblur_nerf_dataparser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data parser for Deblur-NeRF COLMAP datasets. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import glob 8 | import os 9 | from dataclasses import dataclass, field 10 | from pathlib import Path 11 | from typing import List, Literal, Optional, Type 12 | 13 | import cv2 14 | import torch 15 | from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParser, ColmapDataParserConfig 16 | 17 | 18 | def _find_files(directory: Path, exts: List[str]) -> List[Path]: 19 | """Find all files in a directory that have a certain file extension. 20 | 21 | Args: 22 | directory : The directory to search for files. 23 | exts : A list of file extensions to search for. Each file extension should be in the form '*.ext'. 24 | 25 | Returns: 26 | A list of file paths for all the files that were found. The list is sorted alphabetically. 27 | """ 28 | assert directory.exists() 29 | if os.path.isdir(directory): 30 | # types should be ['*.png', '*.jpg', '*.JPG', '*.PNG'] 31 | files_grabbed = [] 32 | for ext in exts: 33 | files_grabbed.extend(glob.glob(os.path.join(directory, ext))) 34 | if len(files_grabbed) > 0: 35 | files_grabbed = sorted(list(set(files_grabbed))) 36 | files_grabbed = [Path(f) for f in files_grabbed] 37 | return files_grabbed 38 | return [] 39 | 40 | 41 | @dataclass 42 | class DeblurNerfDataParserConfig(ColmapDataParserConfig): 43 | """Deblur-NeRF dataset config""" 44 | 45 | _target: Type = field(default_factory=lambda: DeblurNerfDataParser) 46 | """target class to instantiate""" 47 | eval_mode: Literal["fraction", "filename", "interval", "all"] = "interval" 48 | """ 49 | The method to use for splitting the dataset into train and eval. 50 | Fraction splits based on a percentage for train and the remaining for eval. 51 | Filename splits based on filenames containing train/eval. 52 | Interval uses every nth frame for eval (used by most academic papers, e.g. MipNerf360, GSplat). 53 | All uses all the images for any split. 54 | """ 55 | eval_interval: int = 8 56 | """The interval between frames to use for eval. Only used when eval_mode is eval-interval.""" 57 | images_path: Path = Path("images") 58 | """Path to images directory relative to the data path.""" 59 | downscale_factor: Optional[int] = 1 60 | """The downscale factor for the images. Default: 1.""" 61 | poses_bounds_path: Path = Path("poses_bounds.npy") 62 | """Path to the poses bounds file relative to the data path.""" 63 | colmap_path: Path = Path("sparse/0") 64 | """Path to the colmap reconstruction directory relative to the data path.""" 65 | drop_distortion: bool = False 66 | """Whether to drop the camera distortion parameters. Default: False.""" 67 | scale_factor: float = 0.25 68 | """[IMPORTANT] How much to scale the camera origins by. 69 | Default: 0.25 suggested for LLFF datasets with COLMAP. 70 | """ 71 | 72 | 73 | @dataclass 74 | class DeblurNerfDataParser(ColmapDataParser): 75 | """Deblur-NeRF COLMAP dataset parser""" 76 | 77 | config: DeblurNerfDataParserConfig 78 | _downscale_factor: Optional[int] = None 79 | 80 | def _get_all_images_and_cameras(self, recon_dir: Path): 81 | out = super()._get_all_images_and_cameras(recon_dir) 82 | out["frames"] = sorted(out["frames"], key=lambda x: x["file_path"]) 83 | return out 84 | 85 | def _check_outputs(self, outputs): 86 | """ 87 | Check if the colmap outputs are estimated on downscaled data. If so, correct the camera parameters. 88 | """ 89 | # load the first image to get the image size 90 | image = cv2.imread(str(outputs.image_filenames[0])) 91 | # get the image size 92 | h, w = image.shape[:2] 93 | # check if the cx and cy are in the correct range 94 | cx = outputs.cameras.cx[0] 95 | cy = outputs.cameras.cy[0] 96 | ideal_cx = torch.tensor(w / 2) 97 | ideal_cy = torch.tensor(h / 2) 98 | if not torch.allclose(cx, ideal_cx, rtol=0.3): 99 | x_scale = cx / ideal_cx 100 | print(f"[WARN] cx is not at the center of the image, correcting... cx scale: {x_scale}") 101 | if x_scale < 1: 102 | outputs.cameras.fx *= round(1 / x_scale.item()) 103 | outputs.cameras.cx *= round(1 / x_scale.item()) 104 | outputs.cameras.width *= round(1 / x_scale.item()) 105 | else: 106 | outputs.cameras.fx /= round(x_scale.item()) 107 | outputs.cameras.cx /= round(x_scale.item()) 108 | outputs.cameras.width //= round(x_scale.item()) 109 | 110 | if not torch.allclose(cy, ideal_cy, rtol=0.3): 111 | y_scale = cy / ideal_cy 112 | print(f"[WARN] cy is not at the center of the image, correcting... cy scale: {y_scale}") 113 | if y_scale < 1: 114 | outputs.cameras.fy *= round(1 / y_scale.item()) 115 | outputs.cameras.cy *= round(1 / y_scale.item()) 116 | outputs.cameras.height *= round(1 / y_scale.item()) 117 | else: 118 | outputs.cameras.fy /= round(y_scale.item()) 119 | outputs.cameras.cy /= round(y_scale.item()) 120 | outputs.cameras.height //= round(y_scale.item()) 121 | 122 | return outputs 123 | 124 | def _check_suffixes(self, filenames): 125 | """ 126 | Check if the file path exists. if not, check if the file path with the correct suffix exists. 127 | """ 128 | for i, filename in enumerate(filenames): 129 | if not filename.exists(): 130 | flag_found = False 131 | exts = [".png", ".PNG", ".jpg", ".JPG"] 132 | for ext in exts: 133 | new_filename = filename.with_suffix(ext) 134 | if new_filename.exists(): 135 | filenames[i] = new_filename 136 | flag_found = True 137 | break 138 | if not flag_found: 139 | print(f"[WARN] {filename} not found in the images directory.") 140 | 141 | return filenames 142 | 143 | def _generate_dataparser_outputs(self, split="train"): 144 | assert self.config.data.exists(), f"Data directory {self.config.data} does not exist." 145 | 146 | if self.config.eval_mode == "interval": 147 | # find the file named `hold=n` , n is the eval_interval to be recognized 148 | hold_file = [f for f in os.listdir(self.config.data) if f.startswith('hold=')] 149 | if len(hold_file) == 0: 150 | print(f"[INFO] defaulting hold={self.config.eval_interval}") 151 | else: 152 | self.config.eval_interval = int(hold_file[0].split('=')[-1]) 153 | if self.config.eval_interval < 1: 154 | self.config.eval_mode = "all" 155 | 156 | gt_folder_path = self.config.data / "images_test" 157 | if gt_folder_path.exists(): 158 | outputs = super()._generate_dataparser_outputs("train") 159 | if split != "train": 160 | gt_image_filenames = _find_files(gt_folder_path, exts=["*.png", "*.jpg", "*.JPG", "*.PNG"]) 161 | num_gt_images = len(gt_image_filenames) 162 | print(f"[INFO] Found {num_gt_images} ground truth sharp images.") 163 | # number of GT sharp testing images should be equal to the number of degraded training images 164 | assert num_gt_images == len(outputs.image_filenames) 165 | outputs.image_filenames = gt_image_filenames 166 | else: 167 | print("[INFO] No ground truth sharp images found.") 168 | outputs = super()._generate_dataparser_outputs(split) 169 | 170 | if self.config.drop_distortion: 171 | for camera in outputs.cameras: 172 | camera.distortion_params = None 173 | outputs.image_filenames = self._check_suffixes(outputs.image_filenames) 174 | outputs = self._check_outputs(outputs) 175 | 176 | return outputs 177 | -------------------------------------------------------------------------------- /bad_gaussians/image_restoration_dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image Restoration Dataloaders. 3 | """ 4 | 5 | from typing import Dict, Optional, Tuple, Union 6 | 7 | import torch 8 | 9 | from nerfstudio.cameras.cameras import Cameras 10 | from nerfstudio.data.datasets.base_dataset import InputDataset 11 | from nerfstudio.data.utils.dataloaders import RandIndicesEvalDataloader, FixedIndicesEvalDataloader 12 | from nerfstudio.utils.misc import get_dict_to_torch 13 | 14 | 15 | class ImageRestorationRandIndicesEvalDataloader(RandIndicesEvalDataloader): 16 | """eval_dataloader that returns random images. 17 | 18 | Args: 19 | input_dataset: Ground-truth images for evaluation 20 | degraded_dataset: Corresponding training images with degradation. 21 | device: Device to load data to. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | input_dataset: InputDataset, 27 | degraded_dataset: InputDataset, 28 | device: Union[torch.device, str] = "cpu", 29 | **kwargs, 30 | ): 31 | super().__init__(input_dataset, device, **kwargs) 32 | self.degraded_dataset = degraded_dataset 33 | 34 | def get_camera(self, image_idx: int = 0) -> Tuple[Cameras, Dict]: 35 | """Returns the data for a specific image index. 36 | 37 | Args: 38 | image_idx: Camera image index 39 | """ 40 | camera = self.cameras[image_idx : image_idx + 1] 41 | batch = self.input_dataset[image_idx] 42 | batch = get_dict_to_torch(batch, device=self.device, exclude=["image"]) 43 | batch["degraded"] = self.degraded_dataset[image_idx]["image"] 44 | assert isinstance(batch, dict) 45 | if camera.metadata is None: 46 | camera.metadata = {} 47 | camera.metadata["cam_idx"] = image_idx 48 | return camera, batch 49 | 50 | 51 | class ImageRestorationFixedIndicesEvalDataloader(FixedIndicesEvalDataloader): 52 | """fixed_indices_eval_dataloader that returns a fixed set of indices. 53 | 54 | Args: 55 | input_dataset: Ground-truth images for evaluation 56 | degraded_dataset: Corresponding training images with degradation. 57 | image_indices: List of image indices to load data from. If None, then use all images. 58 | device: Device to load data to 59 | """ 60 | 61 | def __init__( 62 | self, 63 | input_dataset: InputDataset, 64 | degraded_dataset: InputDataset, 65 | image_indices: Optional[Tuple[int]] = None, 66 | device: Union[torch.device, str] = "cpu", 67 | **kwargs, 68 | ): 69 | super().__init__(input_dataset, image_indices, device, **kwargs) 70 | self.degraded_dataset = degraded_dataset 71 | 72 | def get_camera(self, image_idx: int = 0) -> Tuple[Cameras, Dict]: 73 | """Returns the data for a specific image index. 74 | 75 | Args: 76 | image_idx: Camera image index 77 | """ 78 | camera = self.cameras[image_idx : image_idx + 1] 79 | batch = self.input_dataset[image_idx] 80 | batch = get_dict_to_torch(batch, device=self.device, exclude=["image"]) 81 | batch["degraded"] = self.degraded_dataset[image_idx]["image"] 82 | assert isinstance(batch, dict) 83 | if camera.metadata is None: 84 | camera.metadata = {} 85 | camera.metadata["cam_idx"] = image_idx 86 | return camera, batch 87 | -------------------------------------------------------------------------------- /bad_gaussians/image_restoration_full_image_datamanager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full image datamanager for image restoration. 3 | """ 4 | from __future__ import annotations 5 | 6 | import random 7 | from copy import deepcopy 8 | from dataclasses import dataclass, field 9 | from typing import Any, Callable, Literal, Type, Union, cast, Tuple, Dict 10 | 11 | import torch 12 | 13 | from nerfstudio.cameras.cameras import Cameras 14 | from nerfstudio.data.datamanagers.base_datamanager import variable_res_collate 15 | from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager, FullImageDatamanagerConfig 16 | from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate 17 | from nerfstudio.utils.rich_utils import CONSOLE 18 | 19 | from bad_gaussians.image_restoration_dataloader import ImageRestorationFixedIndicesEvalDataloader, ImageRestorationRandIndicesEvalDataloader 20 | 21 | 22 | @dataclass 23 | class ImageRestorationFullImageDataManagerConfig(FullImageDatamanagerConfig): 24 | """Datamanager for image restoration""" 25 | 26 | _target: Type = field(default_factory=lambda: ImageRestorationFullImageDataManager) 27 | """Target class to instantiate.""" 28 | collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(nerfstudio_collate)) 29 | """Specifies the collate function to use for the train and eval dataloaders.""" 30 | 31 | 32 | class ImageRestorationFullImageDataManager(FullImageDatamanager): # pylint: disable=abstract-method 33 | """Data manager implementation for image restoration 34 | Args: 35 | config: the DataManagerConfig used to instantiate class 36 | """ 37 | 38 | config: ImageRestorationFullImageDataManagerConfig 39 | 40 | def __init__( 41 | self, 42 | config: ImageRestorationFullImageDataManagerConfig, 43 | device: Union[torch.device, str] = "cpu", 44 | test_mode: Literal["test", "val", "inference"] = "val", 45 | world_size: int = 1, 46 | local_rank: int = 0, 47 | **kwargs, 48 | ): 49 | super().__init__(config, device, test_mode, world_size, local_rank, **kwargs) 50 | if self.train_dataparser_outputs is not None: 51 | cameras = self.train_dataparser_outputs.cameras 52 | if len(cameras) > 1: 53 | for i in range(1, len(cameras)): 54 | if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height: 55 | CONSOLE.print("Variable resolution, using variable_res_collate") 56 | self.config.collate_fn = variable_res_collate 57 | break 58 | 59 | self._fixed_indices_eval_dataloader = ImageRestorationFixedIndicesEvalDataloader( 60 | input_dataset=self.eval_dataset, 61 | degraded_dataset=self.train_dataset, 62 | device=self.device, 63 | num_workers=self.world_size * 4, 64 | ) 65 | self.eval_dataloader = ImageRestorationRandIndicesEvalDataloader( 66 | input_dataset=self.eval_dataset, 67 | degraded_dataset=self.train_dataset, 68 | device=self.device, 69 | num_workers=self.world_size * 4, 70 | ) 71 | 72 | @property 73 | def fixed_indices_eval_dataloader(self): 74 | """Returns the fixed indices eval dataloader""" 75 | return self._fixed_indices_eval_dataloader 76 | 77 | def next_eval(self, step: int) -> Tuple[Cameras, Dict]: 78 | """Returns the next evaluation batch. Returns a Camera instead of raybundle""" 79 | image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1)) 80 | # Make sure to re-populate the unseen cameras list if we have exhausted it 81 | if len(self.eval_unseen_cameras) == 0: 82 | self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] 83 | data = deepcopy(self.cached_eval[image_idx]) 84 | data["image"] = data["image"].to(self.device) 85 | assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension" 86 | camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device) 87 | # BAD-Gaussians: pass camera index to BadNerfCameraOptimizer 88 | if camera.metadata is None: 89 | camera.metadata = {} 90 | camera.metadata["cam_idx"] = image_idx 91 | return camera, data 92 | 93 | def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: 94 | for camera, batch in self.eval_dataloader: 95 | assert camera.shape[0] == 1 96 | return camera, batch 97 | raise ValueError("No more eval images") 98 | -------------------------------------------------------------------------------- /bad_gaussians/image_restoration_pipeline.py: -------------------------------------------------------------------------------- 1 | """Image restoration pipeline.""" 2 | from __future__ import annotations 3 | 4 | import os 5 | from pathlib import Path 6 | from time import time 7 | from typing import Optional, Type 8 | 9 | import torch 10 | from dataclasses import dataclass, field 11 | from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeElapsedColumn 12 | 13 | from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager 14 | from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager 15 | from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager 16 | from nerfstudio.pipelines.base_pipeline import VanillaPipeline, VanillaPipelineConfig 17 | from nerfstudio.utils import profiler 18 | from nerfstudio.utils.writer import to8b 19 | 20 | from bad_gaussians.bad_gaussians import BadGaussiansModel 21 | 22 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 23 | import cv2 24 | 25 | 26 | @dataclass 27 | class ImageRestorationPipelineConfig(VanillaPipelineConfig): 28 | """Image restoration pipeline config""" 29 | 30 | _target: Type = field(default_factory=lambda: ImageRestorationPipeline) 31 | """The target class to be instantiated.""" 32 | 33 | eval_render_start_end: bool = False 34 | """whether to render and save the starting and ending virtual sharp images in eval""" 35 | 36 | eval_render_estimated: bool = False 37 | """whether to render and save the estimated degraded images with learned trajectory in eval. 38 | Note: Slow & VRAM hungry! Reduce VRAM consumption by passing argument 39 | `--pipeline.model.eval_num_rays_per_chunk=16384` or less. 40 | """ 41 | 42 | 43 | class ImageRestorationPipeline(VanillaPipeline): 44 | """Image restoration pipeline""" 45 | 46 | config: ImageRestorationPipelineConfig 47 | 48 | @profiler.time_function 49 | def get_average_eval_image_metrics( 50 | self, step: Optional[int] = None, output_path: Optional[Path] = None, get_std: bool = False 51 | ): 52 | """Iterate over all the images in the eval dataset and get the average. 53 | Also saves the rendered images to disk if output_path is provided. 54 | 55 | Args: 56 | step: current training step 57 | output_path: optional path to save rendered images to 58 | get_std: Set True if you want to return std with the mean metric. 59 | 60 | Returns: 61 | metrics_dict: dictionary of metrics 62 | """ 63 | self.eval() 64 | metrics_dict_list = [] 65 | render_list = ["mid"] 66 | if self.config.eval_render_start_end: 67 | render_list += ["start", "end"] 68 | if self.config.eval_render_estimated: 69 | render_list += ["uniform"] 70 | assert isinstance(self.datamanager, (VanillaDataManager, ParallelDataManager, FullImageDatamanager)) 71 | num_images = len(self.datamanager.fixed_indices_eval_dataloader) 72 | with Progress( 73 | TextColumn("[progress.description]{task.description}"), 74 | BarColumn(), 75 | TimeElapsedColumn(), 76 | MofNCompleteColumn(), 77 | transient=True, 78 | ) as progress: 79 | task = progress.add_task("[green]Evaluating all eval images...", total=num_images) 80 | for camera, batch in self.datamanager.fixed_indices_eval_dataloader: 81 | # time this the following line 82 | inner_start = time() 83 | image_idx = batch['image_idx'] 84 | images_dict = { 85 | f"{image_idx:04}_input": batch["degraded"][:, :, :3], 86 | f"{image_idx:04}_gt": batch["image"][:, :, :3], 87 | } 88 | if isinstance(self.model, BadGaussiansModel): 89 | for mode in render_list: 90 | outputs = self.model.get_outputs_for_camera(camera, mode=mode) 91 | for key, value in outputs.items(): 92 | if "uniform" == mode: 93 | filename = f"{image_idx:04}_estimated" 94 | else: 95 | filename = f"{image_idx:04}_{key}_{mode}" 96 | if "rgb" in key: 97 | images_dict[filename] = value 98 | if "depth" in key and "uniform" != mode: 99 | images_dict[filename] = value 100 | if "mid" == mode: 101 | metrics_dict, _ = self.model.get_image_metrics_and_images(outputs, batch) 102 | else: 103 | outputs = self.model.get_outputs_for_camera(camera) 104 | metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch) 105 | if output_path is not None: 106 | image_dir = output_path / f"{step:06}" 107 | if not image_dir.exists(): 108 | image_dir.mkdir(parents=True) 109 | for filename, data in images_dict.items(): 110 | data = data.detach().cpu() 111 | is_u8_image = False 112 | for tag in ["rgb", "input", "gt", "estimated", "mask"]: 113 | if tag in filename: 114 | is_u8_image = True 115 | if is_u8_image: 116 | path = str((image_dir / f"{filename}.png").resolve()) 117 | cv2.imwrite(path, cv2.cvtColor(to8b(data).numpy(), cv2.COLOR_RGB2BGR)) 118 | else: 119 | path = str((image_dir / f"{filename}.exr").resolve()) 120 | cv2.imwrite(path, data.numpy()) 121 | 122 | assert "num_rays_per_sec" not in metrics_dict 123 | height, width = camera.height, camera.width 124 | num_rays = height * width 125 | metrics_dict["num_rays_per_sec"] = (num_rays / (time() - inner_start)).item() 126 | fps_str = "fps" 127 | assert fps_str not in metrics_dict 128 | metrics_dict[fps_str] = (metrics_dict["num_rays_per_sec"] / (height * width)).item() 129 | metrics_dict_list.append(metrics_dict) 130 | progress.advance(task) 131 | # average the metrics list 132 | metrics_dict = {} 133 | for key in metrics_dict_list[0].keys(): 134 | if get_std: 135 | key_std, key_mean = torch.std_mean( 136 | torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list]) 137 | ) 138 | metrics_dict[key] = float(key_mean) 139 | metrics_dict[f"{key}_std"] = float(key_std) 140 | else: 141 | metrics_dict[key] = float( 142 | torch.mean(torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list])) 143 | ) 144 | self.train() 145 | return metrics_dict 146 | -------------------------------------------------------------------------------- /bad_gaussians/image_restoration_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image restoration trainer. 3 | """ 4 | from __future__ import annotations 5 | 6 | import dataclasses 7 | import functools 8 | from typing import Literal, Type 9 | from typing_extensions import assert_never 10 | 11 | import torch 12 | from dataclasses import dataclass, field 13 | from nerfstudio.engine.callbacks import TrainingCallbackAttributes 14 | from nerfstudio.engine.trainer import Trainer, TrainerConfig 15 | from nerfstudio.utils import profiler, writer 16 | from nerfstudio.utils.decorators import check_eval_enabled 17 | from nerfstudio.utils.misc import step_check 18 | from nerfstudio.utils.writer import EventName, TimeWriter 19 | 20 | from bad_gaussians.image_restoration_pipeline import ImageRestorationPipeline, ImageRestorationPipelineConfig 21 | from bad_gaussians.bad_viewer import BadViewer 22 | 23 | 24 | @dataclass 25 | class ImageRestorationTrainerConfig(TrainerConfig): 26 | """Configuration for image restoration training""" 27 | _target: Type = field(default_factory=lambda: ImageRestorationTrainer) 28 | """The target class to be instantiated.""" 29 | 30 | pipeline: ImageRestorationPipelineConfig = field(default_factory=ImageRestorationPipelineConfig) 31 | """Image restoration pipeline configuration""" 32 | 33 | 34 | class ImageRestorationTrainer(Trainer): 35 | """Image restoration Trainer class""" 36 | config: ImageRestorationTrainerConfig 37 | pipeline: ImageRestorationPipeline 38 | 39 | def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None: 40 | """Set up the trainer. 41 | 42 | Args: 43 | test_mode: The test mode to use. 44 | """ 45 | # BAD-Gaussians: Overriding original setup since we want to use our BadNerfViewer 46 | self.pipeline = self.config.pipeline.setup( 47 | device=self.device, 48 | test_mode=test_mode, 49 | world_size=self.world_size, 50 | local_rank=self.local_rank, 51 | grad_scaler=self.grad_scaler, 52 | ) 53 | self.optimizers = self.setup_optimizers() 54 | 55 | # set up viewer if enabled 56 | viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename 57 | self.viewer_state, banner_messages = None, None 58 | if self.config.is_viewer_legacy_enabled() and self.local_rank == 0: 59 | assert_never(self.config.vis) 60 | if self.config.is_viewer_enabled() and self.local_rank == 0: 61 | datapath = self.config.data 62 | if datapath is None: 63 | datapath = self.base_dir 64 | self.viewer_state = BadViewer( 65 | self.config.viewer, 66 | log_filename=viewer_log_path, 67 | datapath=datapath, 68 | pipeline=self.pipeline, 69 | trainer=self, 70 | train_lock=self.train_lock, 71 | share=self.config.viewer.make_share_url, 72 | ) 73 | banner_messages = self.viewer_state.viewer_info 74 | self._check_viewer_warnings() 75 | 76 | self._load_checkpoint() 77 | 78 | self.callbacks = self.pipeline.get_training_callbacks( 79 | TrainingCallbackAttributes( 80 | optimizers=self.optimizers, grad_scaler=self.grad_scaler, pipeline=self.pipeline, trainer=self 81 | ) 82 | ) 83 | 84 | # set up writers/profilers if enabled 85 | writer_log_path = self.base_dir / self.config.logging.relative_log_dir 86 | writer.setup_event_writer( 87 | self.config.is_wandb_enabled(), 88 | self.config.is_tensorboard_enabled(), 89 | self.config.is_comet_enabled(), 90 | log_dir=writer_log_path, 91 | experiment_name=self.config.experiment_name, 92 | project_name=self.config.project_name, 93 | ) 94 | writer.setup_local_writer( 95 | self.config.logging, max_iter=self.config.max_num_iterations, banner_messages=banner_messages 96 | ) 97 | writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0) 98 | profiler.setup_profiler(self.config.logging, writer_log_path) 99 | 100 | # BAD-Gaussians: disable eval if no eval images 101 | if self.pipeline.datamanager.eval_dataset.cameras is None: 102 | self.config.steps_per_eval_all_images = int(9e9) 103 | self.config.steps_per_eval_batch = int(9e9) 104 | self.config.steps_per_eval_image = int(9e9) 105 | 106 | @check_eval_enabled 107 | @profiler.time_function 108 | def eval_iteration(self, step: int) -> None: 109 | """Run one iteration with different batch/image/all image evaluations depending on step size. 110 | Args: 111 | step: Current training step. 112 | """ 113 | # a batch of eval rays 114 | if step_check(step, self.config.steps_per_eval_batch): 115 | _, eval_loss_dict, eval_metrics_dict = self.pipeline.get_eval_loss_dict(step=step) 116 | eval_loss = functools.reduce(torch.add, eval_loss_dict.values()) 117 | writer.put_scalar(name="Eval Loss", scalar=eval_loss, step=step) 118 | writer.put_dict(name="Eval Loss Dict", scalar_dict=eval_loss_dict, step=step) 119 | writer.put_dict(name="Eval Metrics Dict", scalar_dict=eval_metrics_dict, step=step) 120 | # one eval image 121 | if step_check(step, self.config.steps_per_eval_image): 122 | with TimeWriter(writer, EventName.TEST_RAYS_PER_SEC, write=False) as test_t: 123 | metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step) 124 | writer.put_time( 125 | name=EventName.TEST_RAYS_PER_SEC, 126 | duration=metrics_dict["num_rays"] / test_t.duration, 127 | step=step, 128 | avg_over_steps=True, 129 | ) 130 | writer.put_dict(name="Eval Images Metrics", scalar_dict=metrics_dict, step=step) 131 | group = "Eval Images" 132 | for image_name, image in images_dict.items(): 133 | writer.put_image(name=group + "/" + image_name, image=image, step=step) 134 | 135 | # all eval images 136 | if step_check(step, self.config.steps_per_eval_all_images): 137 | # BAD-Gaussians: pass output_path to save rendered images 138 | metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step, output_path=self.base_dir) 139 | writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step) 140 | -------------------------------------------------------------------------------- /bad_gaussians/spline.py: -------------------------------------------------------------------------------- 1 | """ 2 | SE(3) B-spline trajectory 3 | 4 | Created by lzzhao on 2023.09.29 5 | """ 6 | from __future__ import annotations 7 | 8 | from dataclasses import dataclass, field 9 | from typing import Tuple, Type 10 | 11 | import pypose as pp 12 | import torch 13 | from jaxtyping import Float 14 | from pypose import LieTensor 15 | from torch import nn, Tensor 16 | from typing_extensions import assert_never 17 | 18 | from nerfstudio.configs.base_config import InstantiateConfig 19 | 20 | from bad_gaussians.spline_functor import linear_interpolation, cubic_bspline_interpolation 21 | 22 | 23 | @dataclass 24 | class SplineConfig(InstantiateConfig): 25 | """Configuration for spline instantiation.""" 26 | 27 | _target: Type = field(default_factory=lambda: Spline) 28 | """Target class to instantiate.""" 29 | 30 | degree: int = 1 31 | """Degree of the spline. 1 for linear spline, 3 for cubic spline.""" 32 | 33 | sampling_interval: float = 0.1 34 | """Sampling interval of the control knots.""" 35 | 36 | start_time: float = 0.0 37 | """Starting timestamp of the spline.""" 38 | 39 | 40 | class Spline(nn.Module): 41 | """SE(3) spline trajectory. 42 | 43 | Args: 44 | config: the SplineConfig used to instantiate class 45 | """ 46 | 47 | config: SplineConfig 48 | data: Float[LieTensor, "num_knots 7"] 49 | start_time: float 50 | end_time: float 51 | t_lower_bound: float 52 | t_upper_bound: float 53 | 54 | def __init__(self, config: SplineConfig): 55 | super().__init__() 56 | self.config = config 57 | self.data = pp.identity_SE3(0) 58 | self.order = self.config.degree + 1 59 | """Order of the spline, i.e. control knots per segment, 2 for linear, 4 for cubic""" 60 | 61 | self.set_start_time(config.start_time) 62 | self.update_end_time() 63 | 64 | def __len__(self): 65 | return self.data.shape[0] 66 | 67 | def forward(self, timestamps: Float[Tensor, "*batch_size"]) -> Float[LieTensor, "*batch_size 7"]: 68 | """Interpolate the spline at the given timestamps. 69 | 70 | Args: 71 | timestamps: Timestamps to interpolate the spline at. Range: [t_lower_bound, t_upper_bound]. 72 | 73 | Returns: 74 | poses: The interpolated pose. 75 | """ 76 | segment, u = self.get_segment(timestamps) 77 | u = u[..., None] # (*batch_size) to (*batch_size, interpolations=1) 78 | if self.config.degree == 1: 79 | poses = linear_interpolation(segment, u) 80 | elif self.config.degree == 3: 81 | poses = cubic_bspline_interpolation(segment, u) 82 | else: 83 | assert_never(self.config.degree) 84 | return poses.squeeze() 85 | 86 | def get_segment( 87 | self, 88 | timestamps: Float[Tensor, "*batch_size"] 89 | ) -> Tuple[ 90 | Float[LieTensor, "*batch_size self.order 7"], 91 | Float[Tensor, "*batch_size"] 92 | ]: 93 | """Get the spline segment and normalized position on segment at the given timestamp. 94 | 95 | Args: 96 | timestamps: Timestamps to get the spline segment and normalized position at. 97 | 98 | Returns: 99 | segment: The spline segment. 100 | u: The normalized position on the segment. 101 | """ 102 | assert torch.all(timestamps >= self.t_lower_bound) 103 | assert torch.all(timestamps <= self.t_upper_bound) 104 | batch_size = timestamps.shape 105 | relative_time = timestamps - self.start_time 106 | normalized_time = relative_time / self.config.sampling_interval 107 | start_index = torch.floor(normalized_time).int() 108 | u = normalized_time - start_index 109 | if self.config.degree == 3: 110 | start_index -= 1 111 | 112 | indices = (start_index.tile((self.order, 1)).T + 113 | torch.arange(self.order).tile((*batch_size, 1)).to(start_index.device)) 114 | indices = indices[..., None].tile(7) 115 | segment = pp.SE3(torch.gather(self.data.expand(*batch_size, -1, -1), 1, indices)) 116 | 117 | return segment, u 118 | 119 | def insert(self, pose: Float[LieTensor, "1 7"]): 120 | """Insert a control knot""" 121 | self.data = pp.SE3(torch.cat([self.data, pose])) 122 | self.update_end_time() 123 | 124 | def set_data(self, data: Float[LieTensor, "num_knots 7"] | pp.Parameter): 125 | """Set the spline data.""" 126 | self.data = data 127 | self.update_end_time() 128 | 129 | def set_start_time(self, start_time: float): 130 | """Set the starting timestamp of the spline.""" 131 | self.start_time = start_time 132 | if self.config.degree == 1: 133 | self.t_lower_bound = self.start_time 134 | elif self.config.degree == 3: 135 | self.t_lower_bound = self.start_time + self.config.sampling_interval 136 | else: 137 | assert_never(self.config.degree) 138 | 139 | def update_end_time(self): 140 | """Update the ending timestamp of the spline.""" 141 | self.end_time = self.start_time + self.config.sampling_interval * (len(self) - 1) 142 | if self.config.degree == 1: 143 | self.t_upper_bound = self.end_time 144 | elif self.config.degree == 3: 145 | self.t_upper_bound = self.end_time - self.config.sampling_interval 146 | else: 147 | assert_never(self.config.degree) 148 | -------------------------------------------------------------------------------- /bad_gaussians/spline_functor.py: -------------------------------------------------------------------------------- 1 | """ 2 | SE(3) B-spline trajectory library 3 | 4 | Created by lzzhao on 2023.09.19 5 | """ 6 | from __future__ import annotations 7 | 8 | import pypose as pp 9 | import scipy 10 | import torch 11 | from jaxtyping import Float 12 | from pypose import LieTensor 13 | from torch import Tensor 14 | 15 | _EPS = 1e-6 16 | 17 | 18 | def linear_interpolation_mid( 19 | ctrl_knots: Float[LieTensor, "*batch_size 2 7"], 20 | ) -> Float[LieTensor, "*batch_size 7"]: 21 | """Get the midpoint between batches of two SE(3) poses by linear interpolation. 22 | 23 | Args: 24 | ctrl_knots: The control knots. 25 | 26 | Returns: 27 | The midpoint poses. 28 | """ 29 | start_pose, end_pose = ctrl_knots[..., 0, :], ctrl_knots[..., 1, :] 30 | t_start, q_start = start_pose.translation(), start_pose.rotation() 31 | t_end, q_end = end_pose.translation(), end_pose.rotation() 32 | 33 | t = (t_start + t_end) * 0.5 34 | 35 | q_tau_0 = q_start.Inv() @ q_end 36 | q_t_0 = pp.Exp(pp.so3(q_tau_0.Log() * 0.5)) 37 | q = q_start @ q_t_0 38 | 39 | ret = pp.SE3(torch.cat([t, q], dim=-1)) 40 | return ret 41 | 42 | 43 | def linear_interpolation( 44 | ctrl_knots: Float[LieTensor, "*batch_size 2 7"], 45 | u: Float[Tensor, "interpolations"] | Float[Tensor, "*batch_size interpolations"], 46 | enable_eps: bool = False, 47 | ) -> Float[LieTensor, "*batch_size interpolations 7"]: 48 | """Linear interpolation between batches of two SE(3) poses. 49 | 50 | Args: 51 | ctrl_knots: The control knots. 52 | u: Normalized positions between two SE(3) poses. Range: [0, 1]. 53 | enable_eps: Whether to clip the normalized position with a small epsilon to avoid possible numerical issues. 54 | 55 | Returns: 56 | The interpolated poses. 57 | """ 58 | start_pose, end_pose = ctrl_knots[..., 0, :], ctrl_knots[..., 1, :] 59 | batch_size = start_pose.shape[:-1] 60 | interpolations = u.shape[-1] 61 | 62 | t_start, q_start = start_pose.translation(), start_pose.rotation() 63 | t_end, q_end = end_pose.translation(), end_pose.rotation() 64 | 65 | # If u only has one dim, broadcast it to all batches. This means same interpolations for all batches. 66 | # Otherwise, u should have the same batch size as the control knots (*batch_size, interpolations). 67 | if u.dim() == 1: 68 | u = u.tile((*batch_size, 1)) # (*batch_size, interpolations) 69 | if enable_eps: 70 | u = torch.clip(u, _EPS, 1.0 - _EPS) 71 | 72 | t = pp.bvv(1 - u, t_start) + pp.bvv(u, t_end) 73 | 74 | q_tau_0 = q_start.Inv() @ q_end 75 | r_tau_0 = q_tau_0.Log() 76 | q_t_0 = pp.Exp(pp.so3(pp.bvv(u, r_tau_0))) 77 | q = q_start.unsqueeze(-2).tile((interpolations, 1)) @ q_t_0 78 | 79 | ret = pp.SE3(torch.cat([t, q], dim=-1)) 80 | return ret 81 | 82 | 83 | def cubic_bspline_interpolation( 84 | ctrl_knots: Float[LieTensor, "*batch_size 4 7"], 85 | u: Float[Tensor, "interpolations"] | Float[Tensor, "*batch_size interpolations"], 86 | enable_eps: bool = False, 87 | ) -> Float[LieTensor, "*batch_size interpolations 7"]: 88 | """Cubic B-spline interpolation with batches of four SE(3) control knots. 89 | 90 | Args: 91 | ctrl_knots: The control knots. 92 | u: Normalized positions on the trajectory segments. Range: [0, 1]. 93 | enable_eps: Whether to clip the normalized position with a small epsilon to avoid possible numerical issues. 94 | 95 | Returns: 96 | The interpolated poses. 97 | """ 98 | batch_size = ctrl_knots.shape[:-2] 99 | interpolations = u.shape[-1] 100 | 101 | # If u only has one dim, broadcast it to all batches. This means same interpolations for all batches. 102 | # Otherwise, u should have the same batch size as the control knots (*batch_size, interpolations). 103 | if u.dim() == 1: 104 | u = u.tile((*batch_size, 1)) # (*batch_size, interpolations) 105 | if enable_eps: 106 | u = torch.clip(u, _EPS, 1.0 - _EPS) 107 | 108 | uu = u * u 109 | uuu = uu * u 110 | oos = 1.0 / 6.0 # one over six 111 | 112 | # t coefficients 113 | coeffs_t = torch.stack([ 114 | oos - 0.5 * u + 0.5 * uu - oos * uuu, 115 | 4.0 * oos - uu + 0.5 * uuu, 116 | oos + 0.5 * u + 0.5 * uu - 0.5 * uuu, 117 | oos * uuu 118 | ], dim=-2) 119 | 120 | # spline t 121 | t_t = torch.sum(pp.bvv(coeffs_t, ctrl_knots.translation()), dim=-3) 122 | 123 | # q coefficients 124 | coeffs_r = torch.stack([ 125 | 5.0 * oos + 0.5 * u - 0.5 * uu + oos * uuu, 126 | oos + 0.5 * u + 0.5 * uu - 2 * oos * uuu, 127 | oos * uuu 128 | ], dim=-2) 129 | 130 | # spline q 131 | q_adjacent = ctrl_knots[..., :-1, :].rotation().Inv() @ ctrl_knots[..., 1:, :].rotation() 132 | r_adjacent = q_adjacent.Log() 133 | q_ts = pp.Exp(pp.so3(pp.bvv(coeffs_r, r_adjacent))) 134 | q0 = ctrl_knots[..., 0, :].rotation() # (*batch_size, 4) 135 | q_ts = torch.cat([ 136 | q0.unsqueeze(-2).tile((interpolations, 1)).unsqueeze(-3), 137 | q_ts 138 | ], dim=-3) # (*batch_size, num_ctrl_knots=4, interpolations, 4) 139 | q_t = pp.cumprod(q_ts, dim=-3, left=False)[..., -1, :, :] 140 | 141 | ret = pp.SE3(torch.cat([t_t, q_t], dim=-1)) 142 | return ret 143 | 144 | def bezier_interpolation( 145 | ctrl_knots: Float[LieTensor, "*batch_size order 7"], 146 | u: Float[Tensor, "interpolations"] | Float[Tensor, "*batch_size interpolations"], 147 | enable_eps: bool = False, 148 | ) -> Float[LieTensor, "*batch_size interpolations 7"]: 149 | """Bezier interpolation with batches of SE(3) control knots. 150 | 151 | Args: 152 | ctrl_knots: The control knots. 153 | u: Normalized positions on the trajectory segments. Range: [0, 1]. 154 | enable_eps: Whether to clip the normalized position with a small epsilon to avoid possible numerical issues. 155 | 156 | Returns: 157 | The interpolated poses. 158 | """ 159 | batch_size = ctrl_knots.shape[:-2] 160 | order = ctrl_knots.shape[-2] 161 | degree = order - 1 162 | interpolations = u.shape[-1] 163 | binomial_coeffs = [scipy.special.binom(degree, k) for k in range(order)] 164 | 165 | # If u only has one dim, broadcast it to all batches. This means same interpolations for all batches. 166 | # Otherwise, u should have the same batch size as the control knots (*batch_size, interpolations). 167 | if u.dim() == 1: 168 | u = u.tile((*batch_size, 1)) # (*batch_size, interpolations) 169 | if enable_eps: 170 | u = torch.clip(u, _EPS, 1.0 - _EPS) 171 | 172 | # Build coefficient matrix. TODO: precompute the coefficients. 173 | bezier_coeffs = [] 174 | for i in range(order): 175 | coeff_i = binomial_coeffs[i] * pow(1 - u, degree - i) * pow(u, i) 176 | bezier_coeffs.append(coeff_i) 177 | bezier_coeffs = torch.stack(bezier_coeffs, dim=1).float().to(ctrl_knots.device) # (*batch_size, order, interpolations) 178 | 179 | # (*batch_size, order, interpolations, 7) 180 | weighted_ctrl_knots = pp.se3(bezier_coeffs.unsqueeze(-1) * ctrl_knots.Log().unsqueeze(-2)).Exp() 181 | ret = pp.cumprod(weighted_ctrl_knots, dim=-3, left=False)[..., -1, :, :] 182 | 183 | return ret 184 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "bad_gaussians" 3 | version = "1.0.3" 4 | 5 | dependencies=[ 6 | "nerfstudio>=1.0.3", 7 | "pypose", 8 | "gsplat>=0.1.11,<1.0.0" 9 | ] 10 | 11 | # black 12 | [tool.black] 13 | line-length = 120 14 | 15 | # pylint 16 | [tool.pylint.messages_control] 17 | max-line-length = 120 18 | generated-members = ["numpy.*", "torch.*", "cv2.*", "cv.*"] 19 | 20 | [tool.setuptools.packages.find] 21 | include = ["bad_gaussians"] 22 | 23 | [project.entry-points.'nerfstudio.dataparser_configs'] 24 | deblur-nerf-data = 'bad_gaussians.bad_config_dataparser:DeblurNerfDataParser' 25 | 26 | [project.entry-points.'nerfstudio.method_configs'] 27 | bad_gaussians = 'bad_gaussians.bad_config_method:bad_gaussians' 28 | -------------------------------------------------------------------------------- /scripts/tools/export_poses_from_ckpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export optimized poses from a checkpoint. 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import pypose as pp 8 | import torch 9 | from typing_extensions import assert_never 10 | 11 | from bad_gaussians.bad_utils import TrajectoryIO 12 | from bad_gaussians.deblur_nerf_dataparser import DeblurNerfDataParserConfig 13 | from bad_gaussians.spline_functor import linear_interpolation_mid, cubic_bspline_interpolation 14 | 15 | # DEVICE = 'cuda:0' 16 | DEVICE = 'cpu' 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description="Export optimized poses from a checkpoint.") 20 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint.") 21 | parser.add_argument("--data_dir", type=str, required=True, help="Path to the dataset.") 22 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.") 23 | args = parser.parse_args() 24 | 25 | ckpt_path = Path(args.ckpt_path) 26 | data_dir = Path(args.data_dir) 27 | output_path = Path(args.output) 28 | if not ckpt_path.exists(): 29 | raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") 30 | if not data_dir.exists(): 31 | raise FileNotFoundError(f"Data directory not found: {data_dir}") 32 | 33 | print(f"Exporting optimized poses from {ckpt_path} using data from {data_dir}") 34 | export_poses(ckpt_path, data_dir, output_path) 35 | 36 | 37 | def export_poses(ckpt_path, data_dir, output_path): 38 | ckpt = torch.load(ckpt_path, map_location=DEVICE) 39 | pose_adjustment = ckpt['pipeline']['_model.camera_optimizer.pose_adjustment'] 40 | 41 | parser = DeblurNerfDataParserConfig( 42 | data=data_dir, 43 | downscale_factor=1, 44 | ).setup() 45 | parser_outputs = parser.get_dataparser_outputs(split="train") 46 | 47 | print(parser_outputs) 48 | 49 | print(parser_outputs.cameras.camera_to_worlds.shape) 50 | 51 | poses = pp.mat2SE3(parser_outputs.cameras.camera_to_worlds.to(DEVICE)) 52 | 53 | num_cameras, num_ctrl_knots, _ = pose_adjustment.shape 54 | 55 | if num_ctrl_knots == 2: 56 | poses_delta = linear_interpolation_mid(pose_adjustment) 57 | elif num_ctrl_knots == 4: 58 | poses_delta = cubic_bspline_interpolation( 59 | pose_adjustment, 60 | torch.tensor([0.5], device=pose_adjustment.device) 61 | ).squeeze(1) 62 | else: 63 | assert_never(num_ctrl_knots) 64 | 65 | print(poses.shape) 66 | print(poses_delta.shape) 67 | poses_optimized = poses @ poses_delta 68 | 69 | timestamps = [float(x) for x in range(num_cameras)] 70 | 71 | if output_path.is_dir(): 72 | output_path = output_path / "BAD-Gaussians.txt" 73 | 74 | if output_path.exists(): 75 | raise FileExistsError(f"File already exists: {output_path}") 76 | 77 | TrajectoryIO.write_tum_trajectory(output_path, torch.tensor(timestamps), poses_optimized) 78 | 79 | print(f"Exported optimized poses to {output_path}") 80 | 81 | 82 | main() 83 | -------------------------------------------------------------------------------- /scripts/tools/export_poses_from_colmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export poses from colmap images.txt 3 | """ 4 | 5 | import argparse 6 | from pathlib import Path 7 | 8 | import pypose as pp 9 | import torch 10 | 11 | from bad_gaussians.bad_utils import TrajectoryIO 12 | from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig 13 | 14 | DEVICE = 'cpu' 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="Export poses from colmap results to TUM trajectory.") 19 | parser.add_argument("--input", type=str, required=True, help="Path to colmap files. E.g. ./sparse/0") 20 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.") 21 | args = parser.parse_args() 22 | 23 | input_path = Path(args.input) 24 | if not input_path.exists(): 25 | raise FileNotFoundError(f"File not found: {input_path}") 26 | 27 | output_path = Path(args.output) 28 | if output_path.exists(): 29 | raise FileExistsError(f"File already exists: {output_path}") 30 | if not output_path.parent.exists(): 31 | output_path.parent.mkdir(parents=True) 32 | 33 | print(f"Exporting poses from {input_path} to {output_path}") 34 | 35 | dataparser = ColmapDataParserConfig(data=input_path, colmap_path=".").setup() 36 | frames = dataparser._get_all_images_and_cameras(input_path)["frames"] 37 | 38 | names = [frame["file_path"] for frame in frames] 39 | poses = [frame["transform_matrix"] for frame in frames] 40 | poses = [x for _, x in sorted(zip(names, poses))] 41 | 42 | poses = pp.mat2SE3(torch.tensor(poses)) 43 | num_cameras = poses.shape[0] 44 | timestamps = [float(x) for x in range(num_cameras)] 45 | TrajectoryIO.write_tum_trajectory(output_path, torch.tensor(timestamps), poses) 46 | 47 | 48 | main() 49 | -------------------------------------------------------------------------------- /scripts/tools/export_poses_from_npy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export poses from poses_bounds.npy 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import pypose as pp 9 | import torch 10 | 11 | from bad_gaussians.bad_utils import TrajectoryIO 12 | 13 | DEVICE = 'cpu' 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description="Export poses from poses_bounds.npy.") 18 | parser.add_argument("--input", type=str, required=True, help="Path to the npy file.") 19 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.") 20 | args = parser.parse_args() 21 | 22 | npy_path = Path(args.input) 23 | if not npy_path.exists(): 24 | raise FileNotFoundError(f"File not found: {npy_path}") 25 | 26 | output_path = Path(args.output) 27 | if output_path.exists(): 28 | raise FileExistsError(f"File already exists: {output_path}") 29 | if not output_path.parent.exists(): 30 | output_path.parent.mkdir(parents=True) 31 | 32 | print(f"Exporting poses from {npy_path} to {output_path}") 33 | 34 | # Load data from npy file, shape -1, 17 35 | pose_bounds = np.load(npy_path) 36 | # (N, 17) 37 | assert pose_bounds.shape[-1] == 17 38 | # extract (N, 15) from (N, 17), reshape to (N, 3, 5) 39 | matrices = np.reshape(pose_bounds[:, :-2], (-1, 3, 5)) 40 | 41 | # pop every 8th pose 42 | matrices = np.delete(matrices, np.arange(0, matrices.shape[0], 8), axis=0) 43 | 44 | poses = pp.from_matrix(torch.tensor(matrices[:, :, :4]).to(DEVICE), check=False, ltype=pp.SE3_type) 45 | 46 | num_cameras = poses.shape[0] 47 | timestamps = [float(x) for x in range(num_cameras)] 48 | TrajectoryIO.write_tum_trajectory(output_path, torch.tensor(timestamps), poses) 49 | 50 | 51 | main() 52 | -------------------------------------------------------------------------------- /scripts/tools/interpolate_traj.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interpolate TUM trajectory with cubic B-spline given timestamps. 3 | """ 4 | 5 | import argparse 6 | from pathlib import Path 7 | 8 | import torch 9 | 10 | from bad_gaussians.bad_utils import TrajectoryIO 11 | from bad_gaussians.spline import SplineConfig 12 | 13 | DEVICE = 'cpu' 14 | torch.set_default_dtype(torch.float64) 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="Interpolate TUM trajectory with cubic B-spline given timestamps.") 19 | parser.add_argument("--input", type=str, required=True, help="Path to the TUM trajectory.") 20 | parser.add_argument("--times", type=str, required=True, help="Path to the timestamps.") 21 | parser.add_argument("--output", type=str, required=True, help="Path to the output TUM trajectory.") 22 | args = parser.parse_args() 23 | 24 | input_path = Path(args.input) 25 | if not input_path.exists(): 26 | raise FileNotFoundError(f"File not found: {input_path}") 27 | 28 | times_path = Path(args.times) 29 | if not times_path.exists(): 30 | raise FileNotFoundError(f"File not found: {times_path}") 31 | 32 | output_path = Path(args.output) 33 | if output_path.exists(): 34 | raise FileExistsError(f"File already exists: {output_path}") 35 | if not output_path.parent.exists(): 36 | output_path.parent.mkdir(parents=True) 37 | 38 | print(f"Interpolating TUM trajectory from {input_path} to {output_path} using timestamps from {times_path}") 39 | 40 | timestamps, tum_trajectory = TrajectoryIO.load_tum_trajectory(input_path) 41 | 42 | linear_spline_config = SplineConfig( 43 | degree=1, 44 | sampling_interval=(timestamps[1] - timestamps[0]), 45 | start_time=timestamps[0] 46 | ) 47 | cubic_spline_config = SplineConfig( 48 | degree=3, 49 | sampling_interval=(timestamps[1] - timestamps[0]), 50 | start_time=timestamps[0] 51 | ) 52 | linear_spline = linear_spline_config.setup() 53 | cubic_spline = cubic_spline_config.setup() 54 | 55 | linear_spline.set_data(tum_trajectory) 56 | cubic_spline.set_data(tum_trajectory) 57 | 58 | # read timestamps from times_path 59 | timestamps = [] 60 | with open(times_path, 'r') as f: 61 | for line in f: 62 | timestamps.append(float(line.strip())) 63 | 64 | timestamps = torch.tensor(timestamps, dtype=torch.float64) 65 | linear_poses = linear_spline(timestamps) 66 | cubic_poses = cubic_spline(timestamps) 67 | 68 | TrajectoryIO.write_tum_trajectory(output_path, timestamps, cubic_poses) 69 | 70 | 71 | main() 72 | -------------------------------------------------------------------------------- /scripts/tools/kitti_to_tum.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert KITTI trajectory to TUM format. 3 | 4 | Usage: 5 | python tools/kitti_to_tum.py \ 6 | --input=data/Replica/office0/traj.txt \ 7 | --output=data/Replica/office0/traj_tum.txt 8 | """ 9 | from __future__ import annotations 10 | 11 | import argparse 12 | 13 | import torch 14 | 15 | from bad_gaussians.bad_utils import TrajectoryIO 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--input", type=str, help="KITTI trajectory file") 20 | parser.add_argument("--output", type=str, help="TUM trajectory file") 21 | args = parser.parse_args() 22 | 23 | poses = TrajectoryIO.load_kitti_trajectory(args.input) 24 | timestamps = torch.arange(0, len(poses)) 25 | TrajectoryIO.write_tum_trajectory(args.output, timestamps, poses) 26 | -------------------------------------------------------------------------------- /scripts/tools/tum_to_kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert TUM trajectory to KITTI format. 3 | 4 | Usage: 5 | python tools/tum_to_kitti.py \ 6 | --input=data/MBA-VO/archviz_sharp1/groundtruth_synced.txt \ 7 | --output=data/MBA-VO/archviz_sharp1/traj.txt 8 | """ 9 | from __future__ import annotations 10 | 11 | import argparse 12 | 13 | import pypose as pp 14 | import torch 15 | 16 | from bad_gaussians.bad_utils import TrajectoryIO 17 | 18 | EXTRINSICS_C2B = pp.SE3(torch.tensor([0, 0, 0, -0.5, 0.5, -0.5, 0.5])) 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--input", type=str, help="TUM trajectory file") 23 | parser.add_argument("--output", type=str, help="KITTI trajectory file") 24 | args = parser.parse_args() 25 | 26 | timestamps, poses = TrajectoryIO.load_tum_trajectory(args.input) 27 | poses = poses @ EXTRINSICS_C2B 28 | TrajectoryIO.write_kitti_trajectory(args.output, poses) 29 | -------------------------------------------------------------------------------- /tests/data/traj.txt: -------------------------------------------------------------------------------- 1 | # timestamp tx ty tz qx qy qz qw 2 | 4090.80000000 -0.18243881 0.03254267 1.40185679 -0.00173178 0.00025674 0.13497136 -0.99084795 3 | 4090.81000000 -0.16016920 0.02855653 1.40362307 -0.00146991 0.01213789 0.14717870 -0.98903435 4 | 4090.82000000 -0.14011959 0.02443513 1.40735599 -0.00198271 0.02332832 0.15841676 -0.98709472 5 | 4090.83000000 -0.12291090 0.02040767 1.41288924 -0.00327176 0.03378584 0.16837958 -0.98513762 6 | 4090.84000000 -0.10899440 0.01671207 1.41993526 -0.00527953 0.04342881 0.17683406 -0.98326791 7 | 4090.85000000 -0.09858330 0.01356689 1.42803953 -0.00785705 0.05211664 0.18365110 -0.98157750 8 | 4090.86000000 -0.09180161 0.01116808 1.43669946 -0.01082509 0.05969303 0.18873273 -0.98015280 9 | 4090.87000000 -0.08875377 0.00969074 1.44541554 -0.01399851 0.06600539 0.19198015 -0.97907658 10 | 4090.88000000 -0.08952103 0.00927787 1.45371707 -0.01719510 0.07091593 0.19329913 -0.97842256 11 | 4090.89000000 -0.09408323 0.00999614 1.46123216 -0.02026520 0.07434388 0.19264398 -0.97823852 12 | 4090.90000000 -0.10228215 0.01181107 1.46772354 -0.02310640 0.07628709 0.19004132 -0.97853496 13 | 4090.91000000 -0.11393439 0.01463452 1.47301262 -0.02562829 0.07676846 0.18553636 -0.97929876 14 | 4090.92000000 -0.12887693 0.01834777 1.47694328 -0.02773746 0.07581525 0.17917102 -0.98050009 15 | 4090.93000000 -0.14695062 0.02279857 1.47939470 -0.02934404 0.07346718 0.17099498 -0.98209074 16 | 4090.94000000 -0.16790675 0.02777456 1.48036611 -0.03039753 0.06982359 0.16112294 -0.98399190 17 | 4090.95000000 -0.19135118 0.03300418 1.48001450 -0.03090727 0.06506945 0.14976596 -0.98609374 18 | 4090.96000000 -0.21684238 0.03821645 1.47853498 -0.03090178 0.05941785 0.13716523 -0.98828149 19 | 4090.97000000 -0.24392885 0.04317235 1.47610302 -0.03041002 0.05308818 0.12356516 -0.99044865 20 | 4090.98000000 -0.27215072 0.04767324 1.47286446 -0.02946022 0.04629492 0.10921146 -0.99250276 21 | 4090.99000000 -0.30104959 0.05155234 1.46895543 -0.02808955 0.03920172 0.09435056 -0.99437024 22 | 4091.00000000 -0.33017403 0.05467048 1.46452261 -0.02635045 0.03189638 0.07922914 -0.99599750 23 | 4091.01000000 -0.35906556 0.05693766 1.45970096 -0.02429888 0.02444497 0.06409508 -0.99734840 24 | 4091.02000000 -0.38725823 0.05832269 1.45460346 -0.02199011 0.01691482 0.04919677 -0.99840373 25 | 4091.03000000 -0.41428787 0.05885385 1.44930689 -0.01947043 0.00937571 0.03478069 -0.99916130 26 | 4091.04000000 -0.43971627 0.05861918 1.44378425 -0.01673474 0.00189726 0.02108405 -0.99963584 27 | 4091.05000000 -0.46314950 0.05775755 1.43786593 -0.01370147 -0.00545027 0.00832970 -0.99985658 28 | 4091.06000000 -0.48422138 0.05643706 1.43132212 -0.01026296 -0.01259155 -0.00326694 -0.99986272 29 | 4091.07000000 -0.50258397 0.05484005 1.42389867 -0.00630679 -0.01944668 -0.01349425 -0.99969993 30 | 4091.08000000 -0.51792112 0.05314779 1.41536639 -0.00174138 -0.02593295 -0.02215351 -0.99941667 31 | 4091.09000000 -0.53003169 0.05151765 1.40571642 0.00339966 -0.03195767 -0.02911037 -0.99905943 32 | 4091.10000000 -0.53886722 0.05007196 1.39526624 0.00890261 -0.03741195 -0.03432306 -0.99867063 33 | 4091.11000000 -0.54441339 0.04891248 1.38443232 0.01450070 -0.04218010 -0.03777889 -0.99829020 34 | 4091.12000000 -0.54664754 0.04812773 1.37363738 0.01992689 -0.04614062 -0.03947000 -0.99795595 35 | 4091.13000000 -0.54555868 0.04778838 1.36328536 0.02493200 -0.04917457 -0.03940206 -0.99770123 36 | 4091.14000000 -0.54124719 0.04792123 1.35366642 -0.02934762 0.05120863 0.03764078 0.99754677 37 | 4091.15000000 -0.53398271 0.04849199 1.34490758 -0.03311942 0.05223923 0.03433627 0.99749446 38 | 4091.16000000 -0.52409072 0.04943453 1.33709817 -0.03622871 0.05228240 0.02966474 0.99753398 39 | 4091.17000000 -0.51190864 0.05066986 1.33033471 -0.03866195 0.05135183 0.02380575 0.99764800 40 | 4091.18000000 -0.49778911 0.05211288 1.32470566 -0.04041808 0.04946564 0.01694362 0.99781383 41 | 4091.19000000 -0.48212413 0.05366856 1.32021547 -0.04155425 0.04668800 0.00928221 0.99800166 42 | 4091.20000000 -0.46535643 0.05523369 1.31673746 -0.04221047 0.04315405 0.00105214 0.99817578 -------------------------------------------------------------------------------- /tests/test_spline.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | 4 | import pypose as pp 5 | import torch 6 | 7 | from bad_gaussians.bad_utils import TrajectoryIO 8 | from bad_gaussians.spline import SplineConfig 9 | 10 | torch.set_default_dtype(torch.float64) 11 | torch.set_printoptions(precision=5, sci_mode=False) 12 | 13 | 14 | class TestSpline(unittest.TestCase): 15 | 16 | def test_spline_basic(self): 17 | print(f"\n{'='*10}{self.id()}{'='*10}") 18 | linear_spline_config = SplineConfig(degree=1) 19 | cubic_spline_config = SplineConfig(degree=3) 20 | linear_spline = linear_spline_config.setup() 21 | cubic_spline = cubic_spline_config.setup() 22 | 23 | p0 = pp.identity_SE3(1) 24 | p1 = pp.randn_SE3(1) 25 | p2 = pp.randn_SE3(1) 26 | p3 = pp.randn_SE3(1) 27 | 28 | p0.translation()[0] = 1 29 | p1.translation()[0] = 2 30 | p2.translation()[0] = 3 31 | p3.translation()[0] = 4 32 | 33 | for spline in [linear_spline, cubic_spline]: 34 | spline.insert(p0) 35 | spline.insert(p1) 36 | spline.insert(p2) 37 | spline.insert(p3) 38 | 39 | pose0 = linear_spline(torch.tensor([0.15001])) 40 | print(f"linear: {pose0}") 41 | self.assertTrue(torch.isclose(pose0.translation(), torch.tensor([2.5001, 2.5001, 2.5001]), atol=1e-3).all()) 42 | pose1 = cubic_spline(torch.tensor([0.15001])) 43 | print(f"cubic: {pose1}") 44 | self.assertTrue(torch.isclose(pose1.translation(), torch.tensor([2.5001, 2.5001, 2.5001]), atol=1e-3).all()) 45 | 46 | def test_spline_tum(self): 47 | print(f"\n{'='*10}{self.id()}{'='*10}") 48 | timestamps, tum_trajectory = TrajectoryIO.load_tum_trajectory(Path("data/traj.txt")) 49 | 50 | linear_spline_config = SplineConfig( 51 | degree=1, 52 | sampling_interval=(timestamps[1] - timestamps[0]), 53 | start_time=timestamps[0] 54 | ) 55 | cubic_spline_config = SplineConfig( 56 | degree=3, 57 | sampling_interval=(timestamps[1] - timestamps[0]), 58 | start_time=timestamps[0] 59 | ) 60 | linear_spline = linear_spline_config.setup() 61 | cubic_spline = cubic_spline_config.setup() 62 | 63 | linear_spline.set_data(tum_trajectory) 64 | cubic_spline.set_data(tum_trajectory) 65 | 66 | poses0 = linear_spline(torch.tensor([timestamps[20], timestamps[30]])) 67 | print(f"linear: {poses0}") 68 | self.assertTrue(torch.isclose(poses0[0], tum_trajectory[20], atol=1e-3).all()) 69 | self.assertTrue(torch.isclose(poses0[1], tum_trajectory[30], atol=1e-3).all()) 70 | poses1 = cubic_spline(torch.tensor([timestamps[20], timestamps[30]])) 71 | print(f"cubic: {poses1}") 72 | self.assertTrue(torch.isclose(poses1[0], tum_trajectory[20], atol=1e-3).all()) 73 | self.assertTrue(torch.isclose(poses1[1], tum_trajectory[30], atol=1e-3).all()) 74 | 75 | 76 | if __name__ == "__main__": 77 | unittest.main() 78 | --------------------------------------------------------------------------------