├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml └── splat_viewer ├── __init__.py ├── camera ├── __init__.py ├── fov.py ├── transforms.py └── visibility.py ├── gaussians ├── __init__.py ├── data_types.py ├── loading.py ├── sh_utils.py └── workspace.py ├── renderer ├── __init__.py ├── arguments.py └── taichi_splatting.py ├── scripts ├── __init__.py ├── compare_clouds.py ├── crop_foreground.py ├── debug_tiles.py ├── depth_fusion.py ├── export_rgb_cloud.py ├── export_workspace.py ├── label_foreground.py └── splat_viewer.py └── viewer ├── __init__.py ├── interaction.py ├── interactions ├── __init__.py ├── animate.py ├── fly_control.py └── scribble.py ├── keyboard.py ├── mesh.py ├── renderer.py ├── scene_camera.py ├── scene_widget.py ├── settings.py └── viewer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Pytorch models 10 | *.pth 11 | *.pt 12 | 13 | # Distribution / packaging 14 | .Python 15 | *.nfs* 16 | 17 | wandb/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | 148 | # IDE settings 149 | .vscode 150 | .idea 151 | 152 | # pixi environments 153 | .pixi 154 | *.egg-info 155 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | A viewer and some tools to work with guassian splatting reconstructions. Currently will open .ply files and gaussian splatting workspaces from the original guassian-splatting implementation. Intended primarily for testing [taichi-splatting](https://github.com/uc-vision/taichi-splatting) 4 | 5 | 6 | # Example data 7 | 8 | Some example scenes can be found from the official [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) page. Under [Pre-trained models](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/pretrained/models.zip). 9 | 10 | 11 | ## Installing 12 | 13 | ### External dependencies 14 | Create an environment (for example conda with mambaforge) with the following dependencies: 15 | 16 | * python >= 3.10 17 | * pytorch - from either conda Follow instructions [https://pytorch.org/](here). 18 | * taichi-nightly `pip install --upgrade -i https://pypi.taichi.graphics/simple/ taichi-nightly` 19 | 20 | ### Install 21 | 22 | One of: 23 | * `pip install splat-viewer` 24 | * Clone down with `git clone` and install with `pip install ./splat-viewer` 25 | 26 | 27 | # splat-viewer 28 | 29 | A gaussian splatting viewer. An example of some visualizations produced by this viewer can be seen: 30 | [![Watch the video](https://img.youtube.com/vi/4ysMY5lti7c/hqdefault.jpg)](https://www.youtube.com/embed/4ysMY5lti7c) 31 | 32 | 33 | ## Arguments 34 | 35 | ``` 36 | usage: splat-viewer [-h] [--model MODEL] [--device DEVICE] [--debug] model_path 37 | 38 | positional arguments: 39 | model_path workspace folder containing cameras.json, input.ply and point_cloud folder with .ply models 40 | 41 | options: 42 | -h, --help show this help message and exit 43 | --model MODEL load model from point_clouds folder, default is latest iteration 44 | --device DEVICE torch device to use 45 | --debug enable taichi kernels in debug mode 46 | ``` 47 | 48 | ## Keyboard Controls 49 | 50 | 51 | ### Switch View mode 52 | * 1: normal rendering 53 | * 2: render gaussian centers as points 54 | * 3: render depth map 55 | 56 | ### Show/hide 57 | * 0 : cropped foreground 58 | * 9 : initial points 59 | * 8 : camera markers 60 | 61 | ### Misc 62 | * prntsc: save high-res snapshot into workspace directory 63 | * shift return: toggle fullscreen 64 | 65 | ### Camera 66 | * '[' : Prev camera 67 | * ']' : Next camera 68 | 69 | * '=' : zoom in 70 | * '-' : zoom out 71 | 72 | * w/s a/d q/e : forward/backward left/right up/down 73 | * keypad plus/minus: navigate faster/slower 74 | 75 | 76 | ### Animation 77 | * space: add current viewpoint to animaiton sequence 78 | * control-space: save current animation sequence to workspace folder 79 | * return: animate current sequence 80 | * shift plus/minus: animation speed faster/slower 81 | 82 | 83 | # splat-benchmark 84 | 85 | A benchmark to test forward and backward passes of differentiable renderers. 86 | Example `splat-benchmark models/garden --sh_degree 1 --image_size 1920` 87 | 88 | ## Arguments 89 | 90 | ``` 91 | usage: splat-benchmark [-h] [--device DEVICE] [--model MODEL] [--profile] [--debug] [-n N] [--tile_size TILE_SIZE] [--backward] [--sh_degree SH_DEGREE] [--no_sort] [--depth] 92 | [--image_size RESIZE_IMAGE] [--taichi] 93 | model_path 94 | 95 | positional arguments: 96 | model_path workspace folder containing cameras.json, input.ply and point_cloud folder with .ply models 97 | 98 | options: 99 | -h, --help show this help message and exit 100 | --device DEVICE torch device to use 101 | --model MODEL model iteration to load from point_clouds folder 102 | --profile enable profiling 103 | --debug enable taichi kernels in debug mode 104 | -n N number of iterations to render 105 | --tile_size TILE_SIZE 106 | tile size for rasterizer 107 | --backward benchmark backward pass 108 | --sh_degree SH_DEGREE 109 | modify spherical harmonics degree 110 | --no_sort disable sorting by scale (sorting makes tilemapping faster) 111 | --depth render depth maps 112 | --image_size RESIZE_IMAGE 113 | resize longest edge of camera image sizes 114 | --taichi use taichi renderer 115 | 116 | ``` 117 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "splat-viewer" 3 | version = "0.21.0" 4 | description = "A viewer for gaussian-splatting reconstructions" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | license = {file = "LICENSE"} 8 | 9 | maintainers = [ 10 | {name = "Oliver Batchelor", email = "oliver.batchelor@canterbury.ac.nz" } 11 | ] 12 | 13 | dependencies = [ 14 | "beartype", 15 | "taichi-splatting >= 0.21.2", 16 | "tqdm", 17 | "tensordict", 18 | "plyfile", 19 | "pyside6", 20 | "natsort", 21 | "opencv_python", 22 | "pyrender", 23 | "roma" 24 | 25 | ] 26 | 27 | 28 | [project.urls] 29 | "Homepage" = "https://github.com/uc-vision/splat-viewer" 30 | 31 | [build-system] 32 | # Hatching 33 | requires = ["hatchling>=1.5.0"] 34 | build-backend = "hatchling.build" 35 | 36 | [project.scripts] # Optional 37 | splat-viewer = "splat_viewer.scripts.splat_viewer:main" 38 | 39 | debug-tiles = "splat_viewer.scripts.debug_tiles:main" 40 | label-foreground = "splat_viewer.scripts.label_foreground:main" 41 | export-workspace = "splat_viewer.scripts.export_workspace:main" 42 | export-rgb-cloud = "splat_viewer.scripts.export_rgb_cloud:main" 43 | 44 | crop-foreground = "splat_viewer.scripts.crop_foreground:main" 45 | compare-clouds = "splat_viewer.scripts.compare_clouds:main" 46 | fuse-depths = "splat_viewer.scripts.depth_fusion:main" 47 | 48 | 49 | [tool.ruff] 50 | indent-width = 2 51 | 52 | [tool.pytest.ini_options] 53 | filterwarnings = [ 54 | # disable "UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor" 55 | "ignore::UserWarning" 56 | ] 57 | 58 | 59 | -------------------------------------------------------------------------------- /splat_viewer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /splat_viewer/camera/__init__.py: -------------------------------------------------------------------------------- 1 | from .fov import FOVCamera 2 | 3 | __all__ = ["FOVCamera"] -------------------------------------------------------------------------------- /splat_viewer/camera/fov.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, replace 2 | from numbers import Number, Integral 3 | 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import json 8 | 9 | from beartype import beartype 10 | from beartype.typing import Tuple 11 | 12 | from splat_viewer.camera.transforms import project_points, unproject_pixels 13 | 14 | 15 | NumberPair = np.ndarray | Tuple[Number, Number] 16 | IntPair = np.ndarray | Tuple[int, int] 17 | 18 | Box = Tuple[int, int, int, int] 19 | 20 | 21 | def resize_shortest(image_size:Tuple[Integral, Integral], 22 | min_size, max_size=None) -> Tuple[Tuple[int, int], float]: 23 | 24 | if max_size is None: 25 | max_size = min_size 26 | 27 | shortest = min(image_size) 28 | 29 | scale = (min_size / shortest if shortest < min_size 30 | else max_size / shortest) 31 | 32 | new_size = tuple(np.round(np.array(image_size) * scale).astype(np.int32)) 33 | return new_size, scale 34 | 35 | 36 | 37 | @beartype 38 | @dataclass 39 | class FOVCamera: 40 | 41 | position: np.ndarray 42 | rotation: np.ndarray 43 | focal_length : np.ndarray # 2 44 | image_size : np.ndarray # 2 45 | 46 | image_name: str 47 | principal_point : np.ndarray = field(default_factory=lambda: np.array([0., 0.])) 48 | 49 | near:float = 0.01 50 | far :float = 1000.0 51 | 52 | @staticmethod 53 | def from_json(json_dict) -> 'FOVCamera': 54 | return from_json(json_dict) 55 | 56 | def to_json(self): 57 | return to_json(self) 58 | 59 | 60 | @property 61 | def aspect(self): 62 | width, height = self.image_size 63 | return width / height 64 | 65 | @property 66 | def width(self): 67 | return self.image_size[0] 68 | 69 | @property 70 | def height(self): 71 | return self.image_size[1] 72 | 73 | def scale_size(self, scale_factor) -> 'FOVCamera': 74 | 75 | return replace(self, 76 | image_size=np.round(self.image_size * scale_factor).astype(np.int32), 77 | focal_length=self.focal_length * scale_factor, 78 | principal_point=self.principal_point * scale_factor 79 | ) 80 | 81 | def scale_to(self, new_size:NumberPair, scale_factor:float) -> 'FOVCamera': 82 | return replace(self, 83 | image_size=np.array(new_size).astype(np.int32), 84 | focal_length=self.focal_length * scale_factor, 85 | principal_point=self.principal_point * scale_factor 86 | ) 87 | 88 | def crop_offset_size(self, offset:NumberPair, size:NumberPair) -> 'FOVCamera': 89 | offset, size = np.array(offset), np.array(size) 90 | return replace(self, 91 | image_size=size.astype(np.int32), 92 | principal_point=self.principal_point - offset 93 | ) 94 | 95 | def crop_extent(self, centre:NumberPair, size:NumberPair) -> 'FOVCamera': 96 | centre, size = np.array(centre), np.array(size) 97 | return self.crop_offset_size(centre - size / 2, size) 98 | 99 | def crop_box(self, box:Box) -> 'FOVCamera': 100 | x_min, x_max, y_min, y_max = box 101 | return self.crop_offset_size( 102 | np.array([x_min, y_min]), 103 | np.array([x_max - x_min, y_max - y_min]) 104 | ) 105 | 106 | def pad_to(self, image_size:NumberPair) -> 'FOVCamera': 107 | image_size = np.array(image_size) 108 | return replace(self, 109 | image_size=image_size.astype(np.int32), 110 | principal_point=self.principal_point + (image_size - self.image_size) / 2 111 | ) 112 | 113 | def pad_bottom_right(self, image_size:NumberPair) -> 'FOVCamera': 114 | image_size = np.array(image_size) 115 | return replace(self, 116 | image_size=image_size.astype(np.int32), 117 | principal_point=self.principal_point 118 | ) 119 | 120 | def resize_shortest(self, min_size, max_size=None) -> 'FOVCamera': 121 | new_size, scale = resize_shortest(self.image_size, min_size, max_size) 122 | return self.scale_to(new_size, scale) 123 | 124 | def resize_longest(self, size) -> 'FOVCamera': 125 | longest = max(self.image_size) 126 | return self.scale_size(size / longest) 127 | 128 | 129 | def resize_to(self, size:NumberPair) -> 'FOVCamera': 130 | size = np.array(size) 131 | return self.scale_size(size / self.image_size) 132 | 133 | 134 | def zoom(self, zoom_factor) -> 'FOVCamera': 135 | return replace(self, focal_length=self.focal_length * zoom_factor) 136 | 137 | 138 | @property 139 | def world_t_camera(self): 140 | return join_rt(self.rotation, self.position) 141 | 142 | @property 143 | def camera_t_world(self): 144 | return np.linalg.inv(self.world_t_camera) 145 | 146 | def __repr__(self): 147 | w, h = self.image_size 148 | fx, fy = self.focal_length 149 | cx, cy = self.principal_point 150 | return f"FOVCamera(name={self.image_name}@{w}x{h} pos={self.position}, z={self.forward}, fx={fx} fy={fy}, cx={cx} cy={cy})" 151 | 152 | def __str__(self): 153 | return repr(self) 154 | 155 | 156 | @property 157 | def right(self): 158 | return self.rotation[0] 159 | 160 | @property 161 | def up(self): 162 | return -self.rotation[1] 163 | 164 | @property 165 | def forward(self): 166 | return self.rotation[2] 167 | 168 | 169 | @property 170 | def fov(self): 171 | return np.arctan2(self.image_size, self.focal_length * 2) * 2 172 | 173 | @property 174 | def intrinsic(self): 175 | 176 | cx, cy = self.principal_point 177 | fx, fy = self.focal_length 178 | 179 | return np.array( 180 | [[fx, 0, cx], 181 | [0, fy, cy], 182 | [0, 0, 1]] 183 | ) 184 | 185 | def unproject_pixels(self, xy:np.ndarray, depth:np.ndarray): 186 | return unproject_pixels(self.world_t_image, xy, depth) 187 | 188 | def project_points(self, points:np.ndarray): 189 | return project_points(self.image_t_world, points) 190 | 191 | 192 | def unproject_pixel(self, x, y, depth): 193 | points = self.unproject_pixels(np.array([[x, y]]), np.array([[depth]])) 194 | return tuple(points[0]) 195 | 196 | def project_point(self, x, y, z): 197 | xy, depth = self.project_points(np.array([[x, y, z]])) 198 | return tuple([*xy[0], *depth[0]]) 199 | 200 | @property 201 | def image_t_camera(self): 202 | m44 = np.eye(4) 203 | m44[:3, :3] = self.intrinsic 204 | 205 | return m44 206 | 207 | @property 208 | def image_t_world(self): 209 | return self.image_t_camera @ self.camera_t_world 210 | 211 | @property 212 | def world_t_image(self): 213 | return np.linalg.inv(self.image_t_world) 214 | 215 | @property 216 | def projection(self): 217 | return self.image_t_world 218 | 219 | @property 220 | def ndc_t_camera(self): 221 | """ OpenGL projection - Camera to Normalised Device Coordinates (NDC) 222 | """ 223 | w, h = self.image_size 224 | 225 | cx, cy = self.principal_point 226 | fx, fy = self.focal_length 227 | n, f = self.near, self.far 228 | 229 | return np.array([ 230 | [2.0 * fx / w, 0, 1.0 - 2.0 * cx / w, 0], 231 | [0, 2.0 * fy / h, 2.0 * cy / h - 1.0, 0], 232 | [0, 0, (f + n) / (n - f), (2 * f * n) / (n - f)], 233 | [0, 0, 1.0, 0] 234 | ], dtype=np.float32) 235 | 236 | 237 | 238 | @property 239 | def gl_camera_t_image(self): 240 | return np.linalg.inv(self.ndc_t_camera) 241 | 242 | @property 243 | def gl_camera_t_world(self): 244 | 245 | flip_yz = np.array([ 246 | [1, 0, 0], 247 | [0, -1, 0], 248 | [0, 0, -1] 249 | ]) 250 | 251 | rotation = self.rotation @ flip_yz 252 | return join_rt(rotation, self.position) 253 | 254 | @property 255 | def ndc_t_world(self): 256 | 257 | return self.ndc_t_camera @ self.gl_camera_t_world 258 | 259 | 260 | def join_rt(R, T): 261 | Rt = np.zeros((4, 4)) 262 | Rt[:3, :3] = R 263 | Rt[:3, 3] = T 264 | Rt[3, 3] = 1.0 265 | 266 | return Rt 267 | 268 | 269 | def split_rt(Rt): 270 | R = Rt[:3, :3] 271 | T = Rt[:3, 3] 272 | return R, T 273 | 274 | 275 | 276 | def from_json(camera_info) -> FOVCamera: 277 | pos = np.array(camera_info['position']) 278 | rotation = np.array(camera_info['rotation']).reshape(3, 3) 279 | w, h = (camera_info['width'], camera_info['height']) 280 | cx, cy = (camera_info.get('cx', w/2.), camera_info.get('cy', h/2.)) 281 | 282 | 283 | return FOVCamera( 284 | position=pos, 285 | rotation=rotation, 286 | image_size=np.array([w, h], dtype=np.int32), 287 | focal_length=np.array([camera_info['fx'], camera_info['fy']]), 288 | principal_point=np.array([cx, cy]), 289 | image_name=camera_info['img_name'] 290 | ) 291 | 292 | def to_json(camera:FOVCamera): 293 | cx, cy = camera.principal_point 294 | fx, fy = camera.focal_length 295 | w, h = camera.image_size 296 | return { 297 | 'id': camera.image_name, 298 | 'img_name': camera.image_name, 299 | 'width': int(w), 300 | 'height': int(h), 301 | 'fx': fx, 302 | 'fy': fy, 303 | 'cx': cx, 304 | 'cy': cy, 305 | 'position': camera.position.tolist(), 306 | 'rotation': camera.rotation.tolist(), 307 | } 308 | 309 | 310 | 311 | def load_camera_json(filename:Path): 312 | filename = Path(filename) 313 | cameras = sorted(json.loads(filename.read_text()), key=lambda x: x['id']) 314 | 315 | return {camera_info['id']: from_json(camera_info) for camera_info in cameras} 316 | 317 | 318 | -------------------------------------------------------------------------------- /splat_viewer/camera/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def expand_identity(m, shape=(4, 4)): 5 | expanded = np.eye(*shape) 6 | expanded[0:m.shape[0], 0:m.shape[1]] = m 7 | return expanded 8 | 9 | 10 | def join_rt(r, t): 11 | assert t.ndim == r.ndim - 1 and t.shape[-1] == 3 and r.shape[-2:] == (3, 3),\ 12 | 'got r.shape:' + str(r.shape) + ' and t.shape:' + str(t.shape) 13 | 14 | d = t.ndim 15 | m_34 = np.concatenate([r, np.expand_dims(t, d)], axis=d) 16 | row = np.broadcast_to(np.array([[0, 0, 0, 1]]), (*r.shape[:-2], 1, 4)) 17 | return np.concatenate([m_34, row], axis=d - 1) 18 | 19 | 20 | def split_rt(m): 21 | assert m.shape[-2:] == (4, 4) 22 | return m[..., :3, :3], m[..., :3, 3] 23 | 24 | 25 | def translation(m): 26 | assert m.shape[-2:] == (4, 4) 27 | return m[..., :3, 3] 28 | 29 | 30 | def rotation(m): 31 | assert m.shape[-2:] == (4, 4) 32 | return m[..., :3, :3] 33 | 34 | 35 | def cross_matrix(vec): 36 | return np.cross(vec, np.identity(vec.shape[0]) * -1) 37 | 38 | 39 | def make_grid(width, height): 40 | return np.meshgrid(np.arange(0, width, dtype=np.float32), 41 | np.arange(0, height, dtype=np.float32)) 42 | 43 | 44 | def make_grid_centred(width, height): 45 | return np.meshgrid(np.arange(0.5, width, dtype=np.float32), 46 | np.arange(0.5, height, dtype=np.float32)) 47 | 48 | 49 | def transform_grid_homog(x, y, z, w, transform): 50 | points = np.stack([x, y, z, w], axis=2) 51 | transformed = points.reshape( 52 | (-1, 4)) @ np.transpose(transform.astype(np.float32)) 53 | return transformed.reshape(*z.shape, 4) 54 | 55 | 56 | def make_homog(points): 57 | shape = list(points.shape) 58 | shape[-1] = 1 59 | return np.concatenate([points, np.ones(shape, dtype=np.float32)], axis=-1) 60 | 61 | 62 | def transform_grid(x, y, z, transform): 63 | """ transform points of (x, y, z) by 4x4 matrix """ 64 | return transform_grid_homog(x, y, z, np.ones(z.shape, dtype=np.float32), 65 | transform) 66 | 67 | 68 | def uproject_invdepth(invdepth: np.ndarray, depth_t_disparity: np.ndarray): 69 | """ perspective transform points of (x, y, 1/depth) by 4x4 matrix """ 70 | x, y = make_grid(invdepth.shape[1], invdepth.shape[0]) 71 | return transform_grid_homog(x, y, np.ones(x.shape, dtype=x.dtype), invdepth, 72 | depth_t_disparity) 73 | 74 | 75 | def uproject_depth(depth, transform): 76 | x, y = make_grid(depth.shape[1], depth.shape[0]) 77 | return transform_grid_homog(x * depth, y * depth, depth, 78 | np.ones(x.shape, dtype=x.dtype), transform) 79 | 80 | 81 | def uproject_disparity(disparity, transform): 82 | x, y = make_grid(disparity.shape[1], disparity.shape[0]) 83 | return transform_grid_homog(x, y, disparity, np.ones(x.shape, dtype=x.dtype), 84 | transform) 85 | 86 | 87 | def transform_invdepth(invdepth, depth_t_disparity): 88 | """ transform image grid of inverse-depth image by 4x4 matrix 89 | returns: inverse-depth image in new coordinate system 90 | """ 91 | points = uproject_invdepth(invdepth, depth_t_disparity) 92 | return points[:, :, 3] / points[:, :, 2] 93 | 94 | 95 | def transform_depth(depth, depth_t_disparity): 96 | """ transform depth image by 4x4 matrix 97 | returns: image in new coordinate system 98 | """ 99 | points = uproject_invdepth(1 / depth, depth_t_disparity) 100 | return points[:, :, 2] / points[:, :, 3] 101 | 102 | 103 | def _batch_transform(transforms, points): 104 | assert points.shape[ 105 | -1] == 3 and points.ndim == 2, 'transform_points: expected 3d points of Nx3, got:' + str( 106 | points.shape) 107 | assert transforms.shape[-2:] == ( 108 | 4, 4 109 | ) and transforms.ndim == 3, 'transform_points: expected Mx4x4, got:' + str( 110 | transforms.shape) 111 | 112 | homog = make_homog(points) 113 | transformed = transforms.reshape(transforms.shape[0], 1, 4, 4) @ homog.reshape( 114 | 1, *homog.shape, 1) 115 | 116 | return transformed[..., 0].reshape([transforms.shape[0], -1, 4]) 117 | 118 | 119 | def batch_transform_points(transforms, points): 120 | return _batch_transform(transforms, points)[..., 0:3] 121 | 122 | 123 | def batch_project_points(transforms, points): 124 | homog = _batch_transform(transforms, points) 125 | return homog[..., 0:3] / homog[..., 3:4] 126 | 127 | 128 | def _transform_points(transform, points): 129 | assert points.shape[ 130 | -1] == 3, 'transform_points: expected 3d points of ...x3, got:' + str( 131 | points.shape) 132 | 133 | homog = make_homog(points).reshape([-1, 4, 1]) 134 | transformed = transform.reshape([1, 4, 4]) @ homog 135 | return transformed[..., 0].reshape(-1, 4) 136 | 137 | 138 | def transform_points(transform, points): 139 | return _transform_points(transform, points)[..., 0:3] 140 | 141 | def unproject_pixels(transform, xy, depth): 142 | points = np.concatenate([xy * depth, depth], axis=-1) 143 | 144 | homog = _transform_points(transform, points) 145 | return homog[..., 0:3] / homog[..., 3:4] 146 | 147 | def project_points(transform, xyz): 148 | homog = _transform_points(transform, xyz) 149 | depth = homog[..., 2:3] 150 | xy = homog[..., 0:2] 151 | return (xy / depth), depth 152 | 153 | def affine_transform_points(transform, points): 154 | assert points.shape[ 155 | -1] == 3, 'affine_transform_points: expected 3d points of ...x3, got:' + str( 156 | points.shape) 157 | 158 | r, t = split_rt(transform) 159 | transformed = (r.reshape(1, 3, 3) @ points.reshape(-1, 3, 1)).reshape(-1, 160 | 3) + t 161 | 162 | return transformed.reshape(points.shape) 163 | 164 | 165 | def translate_33(tx, ty): 166 | return np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) 167 | 168 | 169 | def scale_33(sx, sy): 170 | return np.diag([sx, sy, 1]) 171 | 172 | 173 | def translate_44(tx, ty, tz): 174 | return np.array([ 175 | [1, 0, 0, tx], 176 | [0, 1, 0, ty], 177 | [0, 0, 1, tz], 178 | [0, 0, 0, 1], 179 | ]) 180 | 181 | 182 | def scale_44(sx, sy, sz): 183 | return np.diag([sx, sy, sz, 1]) 184 | 185 | 186 | def dict_subset(dict, keys): 187 | return {k: dict[k] for k in keys} 188 | 189 | 190 | def check_size(image, expected_size): 191 | size = image.shape[1], image.shape[0] 192 | assert size == tuple( 193 | expected_size), 'got: ' + str(size) + ', should be: ' + str(expected_size) 194 | 195 | 196 | def estimate_rigid_transform(ref_points, points): 197 | assert ref_points.shape == points.shape and points.shape[0] >= 3,\ 198 | 'estimate_transform: expected at least 3 points, got:' + str(points.shape) 199 | 200 | centroid_ref = np.mean(ref_points, axis=0) 201 | centroid_points = np.mean(points, axis=0) 202 | 203 | centered_ref = ref_points - centroid_ref 204 | centered_points = points - centroid_points 205 | 206 | s = centered_ref.T @ centered_points 207 | 208 | u, s, vh = np.linalg.svd(s) 209 | r = np.dot(vh.T, u.T) 210 | 211 | if np.linalg.det(r) < 0: 212 | vh[-1, :] *= -1 213 | r = np.dot(vh.T, u.T) 214 | 215 | t = centroid_points - np.dot(r, centroid_ref) 216 | transform = join_rt(r, t) 217 | 218 | err = np.linalg.norm(transform_points(transform, ref_points) - points, axis=1) 219 | return transform, err 220 | 221 | 222 | 223 | def normalize(v): 224 | return v / np.linalg.norm(v, axis=-1) 225 | -------------------------------------------------------------------------------- /splat_viewer/camera/visibility.py: -------------------------------------------------------------------------------- 1 | 2 | from beartype.typing import List 3 | from beartype import beartype 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from .fov import FOVCamera 10 | 11 | 12 | 13 | 14 | def make_homog(points): 15 | shape = list(points.shape) 16 | shape[-1] = 1 17 | return torch.concatenate([points, torch.ones(shape, dtype=torch.float32, device=points.device)], axis=-1) 18 | 19 | def _transform_points(transform, points): 20 | assert points.shape[ 21 | -1] == 3, 'transform_points: expected 3d points of ...x3, got:' + str( 22 | points.shape) 23 | 24 | homog = make_homog(points).reshape([-1, 4, 1]) 25 | transformed = transform.reshape([1, 4, 4]) @ homog 26 | return transformed[..., 0].reshape(-1, 4) 27 | 28 | def project_points(transform, xyz): 29 | homog = _transform_points(transform, xyz) 30 | depth = homog[..., 2:3] 31 | xy = homog[..., 0:2] 32 | return (xy / depth), depth 33 | 34 | 35 | 36 | @beartype 37 | def visibility(cameras:List[FOVCamera], points:torch.Tensor, near=0.1, far=torch.inf): 38 | vis_counts = torch.zeros(len(points), dtype=torch.int32, device=points.device) 39 | 40 | projections = np.array([camera.projection for camera in cameras]) 41 | torch_projections = torch.from_numpy(projections).to(dtype=torch.float32, device=points.device) 42 | 43 | min_distance = torch.full((len(points), ), fill_value=far, dtype=torch.float32, device=points.device) 44 | 45 | for camera, proj in tqdm(zip(cameras, torch_projections), total=len(cameras), desc="Evaluating visibility"): 46 | 47 | proj, depth = project_points(proj, points) 48 | width, height = camera.image_size 49 | 50 | is_valid = ((proj[:, 0] >= 0) & (proj[:, 0] < width) & 51 | (proj[:, 1] >= 0) & (proj[:, 1] < height) 52 | & (depth[:, 0] > near) & (depth[:, 0] < far) 53 | ) 54 | 55 | min_distance[is_valid] = torch.minimum(depth[is_valid, 0], min_distance[is_valid]) 56 | vis_counts[is_valid] += 1 57 | 58 | 59 | return vis_counts, min_distance 60 | 61 | 62 | -------------------------------------------------------------------------------- /splat_viewer/gaussians/__init__.py: -------------------------------------------------------------------------------- 1 | from .loading import read_gaussians, write_gaussians 2 | from .workspace import Workspace, load_workspace, load_camera_json 3 | 4 | from .data_types import Rendering, Gaussians 5 | 6 | __all__ = [ 7 | "Gaussians", 8 | "read_gaussians", 9 | "write_gaussians", 10 | "Rendering", 11 | 12 | "Workspace", 13 | "load_workspace", 14 | "load_camera_json" 15 | ] -------------------------------------------------------------------------------- /splat_viewer/gaussians/data_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import replace 2 | import math 3 | from beartype.typing import Optional 4 | from tensordict import TensorClass 5 | import torch 6 | 7 | from .sh_utils import check_sh_degree, num_sh_features, rgb_to_sh, sh_to_rgb 8 | 9 | from dataclasses import dataclass 10 | from splat_viewer.camera.fov import FOVCamera 11 | 12 | from taichi_splatting import Gaussians3D 13 | import roma 14 | 15 | 16 | 17 | @dataclass 18 | class Rendering: 19 | image : torch.Tensor 20 | camera : FOVCamera 21 | 22 | depth : torch.Tensor 23 | 24 | @property 25 | def image_size(self): 26 | y, x = self.image.shape[1:] 27 | return x, y 28 | 29 | class Gaussians(TensorClass): 30 | position : torch.Tensor # 3 - xyz 31 | log_scaling : torch.Tensor # 3 - scale = exp(log_scalining) 32 | rotation : torch.Tensor # 4 - quaternion wxyz 33 | alpha_logit : torch.Tensor # 1 - alpha = sigmoid(alpha_logit) 34 | 35 | sh_feature : torch.Tensor # (spherical harmonics (3, deg + 1)**2)) 36 | 37 | foreground : Optional[torch.Tensor] = None # 1 (bool) 38 | label : Optional[torch.Tensor] = None # 1 (int) 39 | instance_label : Optional[torch.Tensor] = None # 1 (int) 40 | 41 | def __post_init__(self): 42 | assert self.position.shape[1] == 3, f"Expected shape (N, 3), got {self.position.shape}" 43 | assert self.log_scaling.shape[1] == 3, f"Expected shape (N, 3), got {self.log_scaling.shape}" 44 | assert self.rotation.shape[1] == 4, f"Expected shape (N, 4), got {self.rotation.shape}" 45 | assert self.alpha_logit.shape[1] == 1, f"Expected shape (N, 1), got {self.alpha_logit.shape}" 46 | 47 | check_sh_degree(self.sh_feature) 48 | assert self.foreground is None or self.foreground.shape[1] == 1, f"Expected shape (N, 1), got {self.foreground.shape}" 49 | assert self.label is None or self.label.shape[1] == 1, f"Expected shape (N, 1), got {self.label.shape}" 50 | assert self.instance_label is None or self.instance_label.shape[1] == 1, f"Expected shape (N, 1), got {self.instance_label.shape}" 51 | 52 | def n(self): 53 | return self.batch_size[0] 54 | 55 | def __repr__(self): 56 | return f"Gaussians with {self.batch_shape[0]} points, sh_degree={self.sh_degree}" 57 | 58 | def __str__(self): 59 | return repr(self) 60 | 61 | def packed(self): 62 | return torch.cat([self.position, self.log_scaling, self.rotation, self.alpha_logit], dim=-1) 63 | 64 | @property 65 | def device(self): 66 | return self.position.device 67 | 68 | def to_gaussians3d(self): 69 | return Gaussians3D( 70 | position=self.position, 71 | log_scaling=self.log_scaling, 72 | rotation=self.rotation, 73 | alpha_logit=self.alpha_logit, 74 | feature=self.sh_feature, 75 | batch_size=self.batch_size 76 | ) 77 | 78 | @staticmethod 79 | def from_gaussians3d(g:Gaussians3D): 80 | return Gaussians( 81 | position=g.position, 82 | log_scaling=g.log_scaling, 83 | rotation=g.rotation, 84 | alpha_logit=g.alpha_logit, 85 | sh_feature=g.feature, 86 | batch_size=g.batch_size 87 | ) 88 | 89 | def crop_foreground(self): 90 | if self.foreground is None: 91 | return self 92 | else: 93 | return self[self.foreground[:, 0]] 94 | 95 | 96 | 97 | def scale(self): 98 | return torch.exp(self.log_scaling) 99 | 100 | def alpha(self): 101 | return torch.sigmoid(self.alpha_logit) 102 | 103 | def mul_alpha(self, factor) -> 'Gaussians': 104 | return self.replace(alpha_logit=inverse_sigmoid(self.alpha() * factor)) 105 | 106 | def split_sh(self): 107 | return self.sh_feature[:, :, 0], self.sh_feature[:, :, 1:] 108 | 109 | def sh_degree(self): 110 | return check_sh_degree(self.sh_feature) 111 | 112 | def with_fixed_scale(self, scale:float): 113 | return replace(self, 114 | log_scaling=torch.full_like(self.log_scaling, math.log(scale)), 115 | batch_size=self.batch_size) 116 | 117 | def get_colors(self): 118 | return sh_to_rgb(self.sh_feature[:, :, 0]) 119 | 120 | def get_rotation_matrix(self): 121 | return roma.unitquat_to_rotmat(self.rotation) 122 | 123 | def set_colors(self, color: tuple[float, float, float], indexes: Optional[torch.Tensor]): 124 | colors = torch.tensor(color, device=self.device).expand(indexes.shape[0], -1) 125 | return self.with_colors(colors, indexes) 126 | 127 | 128 | def with_colors(self, colors, index=None): 129 | sh_feature = self.sh_feature.clone() 130 | if index is None: 131 | sh_feature[:, :, 0] = rgb_to_sh(colors) 132 | else: 133 | sh_feature[index, :, 0] = rgb_to_sh(colors) 134 | return self.replace(sh_feature=sh_feature) 135 | 136 | def with_labels(self, labels, index=None): 137 | if index is None: 138 | return self.replace(label=labels) 139 | else: 140 | return self.replace(label=self.label[index]) 141 | 142 | def with_sh_degree(self, sh_degree:int): 143 | assert sh_degree >= 0 144 | 145 | if sh_degree <= self.sh_degree(): 146 | return self.replace(sh_feature = self.sh_feature[:, :, :num_sh_features(sh_degree)]) 147 | else: 148 | num_extra = num_sh_features(sh_degree) - num_sh_features(self.sh_degree) 149 | extra_features = torch.zeros((self.batch_shape[0], 150 | 3, num_extra), device=self.device) 151 | 152 | return self.replace(sh_feature = torch.cat( 153 | [self.sh_feature, extra_features], dim=2)) 154 | 155 | 156 | def replace(self, **kwargs): 157 | return replace(self, **kwargs, batch_size=self.batch_size) 158 | 159 | 160 | def sorted(self): 161 | max_axis = torch.max(self.log_scaling, dim=1).values 162 | indices = torch.argsort(max_axis, descending=False) 163 | 164 | return self[indices] 165 | 166 | 167 | 168 | def inverse_sigmoid(x, eps=1e-6): 169 | return -torch.log((1 / (x + eps)) - 1) 170 | -------------------------------------------------------------------------------- /splat_viewer/gaussians/loading.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import tempfile 4 | import plyfile 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .data_types import Gaussians 11 | 12 | 13 | def to_plydata(gaussians:Gaussians) -> plyfile.PlyData: 14 | gaussians = gaussians.to('cpu') 15 | 16 | num_sh = gaussians.sh_feature.shape[2] * gaussians.sh_feature.shape[1] 17 | 18 | dtype = [ 19 | ('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 20 | ('opacity', 'f4'), 21 | ('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'), 22 | ('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4'), 23 | ('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'), 24 | ] + [('f_rest_' + str(i), 'f4') for i in range(num_sh - 3)] 25 | 26 | if gaussians.foreground is not None: 27 | dtype.append(('foreground', 'u1')) 28 | 29 | if gaussians.label is not None: 30 | dtype.append(('label', 'f4')) 31 | 32 | if gaussians.instance_label is not None: 33 | dtype.append(('instance_label', 'int16')) 34 | 35 | vertex = np.zeros(gaussians.batch_size[0], dtype=dtype ) 36 | 37 | for i, name in enumerate(['x', 'y', 'z']): 38 | vertex[name] = gaussians.position[:, i].numpy() 39 | 40 | for i in range(3): 41 | vertex['scale_' + str(i)] = gaussians.log_scaling[:, i].numpy() 42 | 43 | rotation = torch.roll(F.normalize(gaussians.rotation), 1, dims=(1,)) 44 | for i in range(4): 45 | vertex['rot_' + str(i)] = rotation[:, i].numpy() 46 | 47 | vertex['opacity'] = gaussians.alpha_logit[:, 0].numpy() 48 | 49 | sh_dc, sh_rest = gaussians.split_sh() 50 | 51 | sh_dc = sh_dc.view(-1, 3) 52 | sh_rest = sh_rest.reshape(sh_rest.shape[0], sh_rest.shape[1] * sh_rest.shape[2]) 53 | 54 | for i in range(3): 55 | vertex['f_dc_' + str(i)] = sh_dc[:, i].numpy() 56 | 57 | for i in range(sh_rest.shape[1]): 58 | vertex['f_rest_' + str(i)] = sh_rest[:, i].numpy() 59 | 60 | if gaussians.foreground is not None: 61 | vertex['foreground'] = gaussians.foreground[:, 0].numpy() 62 | 63 | if gaussians.label is not None: 64 | vertex['label'] = gaussians.label[:, 0].numpy() 65 | 66 | if gaussians.instance_label is not None: 67 | vertex['instance_label'] = gaussians.instance_label[:, 0].numpy() 68 | 69 | el = plyfile.PlyElement.describe(vertex, 'vertex') 70 | return plyfile.PlyData([el]) 71 | 72 | 73 | def from_plydata(plydata:plyfile.PlyData) -> Gaussians: 74 | 75 | vertex = plydata['vertex'] 76 | 77 | def get_keys(ks): 78 | values = [torch.from_numpy(vertex[k].copy()) for k in ks] 79 | return torch.stack(values, dim=-1) 80 | 81 | 82 | positions = torch.stack( 83 | [ torch.from_numpy(vertex[i].copy()) for i in ['x', 'y', 'z']], dim=-1) 84 | 85 | attrs = sorted(plydata['vertex'].data.dtype.names) 86 | sh_attrs = [k for k in attrs if k.startswith('f_rest_') or k.startswith('f_dc_')] 87 | 88 | n_sh = len(sh_attrs) // 3 89 | deg = int(np.sqrt(n_sh)) 90 | 91 | assert deg * deg == n_sh, f"SH feature count must be square ({deg} * {deg} != {n_sh}), got {len(sh_attrs)}" 92 | log_scaling = get_keys([f'scale_{k}' for k in range(3)]) 93 | 94 | 95 | sh_dc = get_keys([f'f_dc_{k}' for k in range(3)]).view(positions.shape[0], 3, 1) 96 | 97 | 98 | if deg > 1: 99 | sh_rest = get_keys([f'f_rest_{k}' for k in range(3 * (deg * deg - 1))]) 100 | sh_rest = sh_rest.view(positions.shape[0], 3, n_sh - 1) 101 | 102 | sh_features = torch.cat([sh_dc, sh_rest], dim=2) 103 | else: 104 | sh_features = sh_dc 105 | 106 | rotation = get_keys([f'rot_{k}' for k in range(4)]) 107 | # convert from wxyz to xyzw quaternion and normalize 108 | rotation = torch.roll(F.normalize(rotation), -1, dims=(1,)) 109 | 110 | alpha_logit = get_keys(['opacity']) 111 | 112 | foreground = (get_keys(['foreground']).to(torch.bool) 113 | if 'foreground' in plydata['vertex'].data.dtype.names else None) 114 | 115 | label = get_keys(['label']) if 'label' in vertex.data.dtype.names else None 116 | instance_label = get_keys(['instance_label']) if 'instance_label' in vertex.data.dtype.names else None 117 | 118 | return Gaussians( 119 | position = positions, 120 | rotation = rotation, 121 | alpha_logit = alpha_logit, 122 | log_scaling = log_scaling, 123 | sh_feature = sh_features, 124 | 125 | foreground = foreground, 126 | label = label, 127 | instance_label = instance_label, 128 | 129 | batch_size = (positions.shape[0],) 130 | ) 131 | 132 | def write_gaussians(filename:Path | str, gaussians:Gaussians): 133 | filename = Path(filename) 134 | 135 | plydata = to_plydata(gaussians) 136 | plydata.write(str(filename)) 137 | 138 | 139 | 140 | def read_gaussians(filename:Path | str) -> Gaussians: 141 | filename = Path(filename) 142 | 143 | plydata = plyfile.PlyData.read(str(filename)) 144 | return from_plydata(plydata) 145 | 146 | 147 | 148 | 149 | def random_gaussians(n:int, sh_degree:int): 150 | points = torch.randn(n, 3) 151 | 152 | return Gaussians( 153 | position = points, 154 | rotation = F.normalize(torch.randn(n, 4), dim=1), 155 | alpha_logit = torch.randn(n, 1), 156 | log_scaling = torch.randn(n, 3) * 4, 157 | sh_feature = torch.randn(n, 3, (sh_degree + 1)**2), 158 | 159 | batch_size = (n,) 160 | ) 161 | 162 | def test_read_write(): 163 | temp_path = Path(tempfile.mkdtemp()) 164 | 165 | print("Testing write/read") 166 | for i in range(10): 167 | g = random_gaussians((i + 1) * 1000, 3) 168 | write_gaussians(temp_path / f'gaussians_{i}.ply', g) 169 | g2 = read_gaussians(temp_path / f'gaussians_{i}.ply') 170 | 171 | assert torch.allclose(g.position, g2.position) 172 | assert torch.allclose(g.rotation, g2.rotation) 173 | assert torch.allclose(g.alpha(), g2.alpha()) 174 | assert torch.allclose(g.log_scaling, g2.log_scaling) 175 | assert torch.allclose(g.sh_feature, g2.sh_feature) 176 | 177 | 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | test_read_write() 183 | 184 | 185 | -------------------------------------------------------------------------------- /splat_viewer/gaussians/sh_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def check_sh_degree(sh_features): 4 | assert len(sh_features.shape) == 3, f"SH features must have 3 dimensions, got {sh_features.shape}" 5 | 6 | n_sh = sh_features.shape[2] 7 | n = int(math.sqrt(n_sh)) 8 | 9 | assert n * n == n_sh, f"SH feature count must be square, got {n_sh} ({sh_features.shape})" 10 | return (n - 1) 11 | 12 | def num_sh_features(deg): 13 | return (deg + 1) ** 2 14 | 15 | 16 | sh0 = 0.282094791773878 17 | 18 | def rgb_to_sh(rgb): 19 | return (rgb - 0.5) / sh0 20 | 21 | def sh_to_rgb(sh): 22 | return sh * sh0 + 0.5 -------------------------------------------------------------------------------- /splat_viewer/gaussians/workspace.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from pathlib import Path 3 | import re 4 | from beartype import beartype 5 | from beartype.typing import Optional 6 | from natsort import natsorted 7 | from dataclasses import dataclass 8 | 9 | import numpy as np 10 | 11 | from splat_viewer.camera.fov import FOVCamera, load_camera_json 12 | 13 | from .loading import read_gaussians 14 | from .data_types import Gaussians 15 | 16 | import plyfile 17 | 18 | @beartype 19 | @dataclass 20 | class Workspace: 21 | model_path:Path 22 | 23 | cloud_files : dict[str, Path] 24 | cameras:list[FOVCamera] 25 | 26 | @staticmethod 27 | def load(model_path:Path | str) -> 'Workspace': 28 | return load_workspace(model_path) 29 | 30 | @cached_property 31 | def camera_extent(self): 32 | camera_positions = np.array([c.position for c in self.cameras]) 33 | 34 | scene_diagonal = np.linalg.norm( 35 | camera_positions.max(axis=0) - camera_positions.min(axis=0)) 36 | 37 | return scene_diagonal / 2.0 38 | 39 | 40 | def latest_iteration(self) -> str: 41 | paths = [(m.group(1), name) for name in self.cloud_files.keys() 42 | if (m:=re.search("iteration_(\d+)", name))] 43 | 44 | if len(paths) == 0: 45 | raise FileNotFoundError(f"No point clouds named iteration_(\d+) in {str(self.model_path)}") 46 | 47 | paths = sorted(paths, key=lambda x: int(x[0])) 48 | return paths[-1][1] 49 | 50 | def model_filename(self, model:str) -> Path: 51 | if model not in self.cloud_files: 52 | options = list(self.cloud_files.keys()) 53 | raise LookupError(f"Model {model} not found in {self.model_path} options are: {options}") 54 | 55 | return self.cloud_files[model] 56 | 57 | def load_model(self, model:Optional[str]=None) -> Gaussians: 58 | if model is None: 59 | model = self.latest_iteration() 60 | return read_gaussians(self.model_filename(model)) 61 | 62 | def load_seed_points(self) -> plyfile.PlyData: 63 | return plyfile.PlyData.read(str(self.model_path / "input.ply")) 64 | 65 | 66 | def find_clouds(p:Path): 67 | clouds = {d.name : file for d in p.iterdir() 68 | if d.is_dir() and (file :=d / "point_cloud.ply").exists() 69 | } 70 | 71 | 72 | if len(clouds) == 0: 73 | raise FileNotFoundError(f"No point clouds found in {str(p)}") 74 | 75 | return clouds 76 | 77 | 78 | def load_workspace(model_path:Path | str) -> Workspace: 79 | model_path = Path(model_path) 80 | cloud_path = model_path / "point_cloud" 81 | 82 | if not cloud_path.exists(): 83 | raise FileNotFoundError(f"Could not find point cloud directory at {str(cloud_path)}") 84 | 85 | cloud_files = find_clouds(cloud_path) 86 | 87 | transfer = model_path / "transfer" / "clustered.ply" 88 | if transfer.exists(): 89 | cloud_files["transfer"] = transfer 90 | 91 | cameras = load_camera_json(model_path / "cameras.json") 92 | cameras = natsorted(cameras.values(), key=lambda x: x.image_name) 93 | 94 | return Workspace( 95 | model_path = model_path, 96 | cloud_files = cloud_files, 97 | cameras = cameras 98 | ) 99 | 100 | -------------------------------------------------------------------------------- /splat_viewer/renderer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/splat-viewer/c2fd37c42c3671b5719be28863e3ea6cf14770ea/splat_viewer/renderer/__init__.py -------------------------------------------------------------------------------- /splat_viewer/renderer/arguments.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from dataclasses import dataclass 5 | from enum import Enum 6 | from typing import Tuple 7 | 8 | 9 | def add_render_arguments(parser): 10 | parser.add_argument("--tile_size", type=int, default=16, help="tile size for rasterizer") 11 | parser.add_argument("--antialias", action="store_true", help="enable analytic antialiasing") 12 | parser.add_argument("--depth16", action="store_true", help="use 16 bit depth in sorting (default is 32 bit)") 13 | 14 | parser.add_argument("--blur_cov", type=float, default=0.3, help="add isotropic gaussian blur with given covariance") 15 | parser.add_argument("--pixel_stride", type=str, default="2,2", help="pixel tile size for rasterizer, e.g. 2,2") 16 | return parser 17 | 18 | 19 | 20 | 21 | 22 | @dataclass(frozen=True) 23 | class RendererArgs: 24 | tile_size: int = 16 25 | pixel_stride: Tuple[int, int] = (2, 2) 26 | antialias: bool = False 27 | depth16: bool = False 28 | blur_cov: float = 0.3 29 | 30 | def renderer_from_args(args:RendererArgs): 31 | from splat_viewer.renderer.taichi_splatting import GaussianRenderer 32 | 33 | return GaussianRenderer(tile_size=args.tile_size, 34 | antialias=args.antialias, 35 | use_depth16=args.depth16, 36 | pixel_stride=args.pixel_stride, 37 | blur_cov=args.blur_cov if not args.antialias else 0.0) 38 | 39 | def make_renderer_args(args): 40 | 41 | 42 | pixel_stride = tuple(map(int, args.pixel_stride.split(','))) 43 | 44 | return RendererArgs( 45 | tile_size=args.tile_size, 46 | antialias=args.antialias, 47 | depth16=args.depth16, 48 | pixel_stride=pixel_stride, 49 | blur_cov=args.blur_cov 50 | ) 51 | 52 | -------------------------------------------------------------------------------- /splat_viewer/renderer/taichi_splatting.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass, replace 3 | from typing import Tuple 4 | from beartype import beartype 5 | import torch 6 | from splat_viewer.camera.fov import FOVCamera 7 | 8 | from taichi_splatting import Gaussians3D, renderer 9 | from taichi_splatting.perspective import CameraParams 10 | 11 | from splat_viewer.gaussians.data_types import Gaussians, Rendering 12 | 13 | 14 | def to_camera_params(camera:FOVCamera, device=torch.device("cuda:0")): 15 | 16 | params = CameraParams( 17 | T_camera_world=torch.from_numpy(camera.camera_t_world), 18 | projection=torch.tensor([*camera.focal_length, *camera.principal_point]), 19 | 20 | image_size=tuple(int(x) for x in camera.image_size), 21 | near_plane=camera.near, 22 | far_plane=camera.far 23 | ) 24 | 25 | return params.to(device=device, dtype=torch.float32) 26 | 27 | 28 | 29 | class GaussianRenderer: 30 | @dataclass 31 | class Config: 32 | tile_size : int = 16 33 | antialias : bool = False 34 | use_depth16 : bool = False 35 | pixel_stride : Tuple[int, int] = (2, 2) 36 | blur_cov: float = 0.3 37 | 38 | def __init__(self, **kwargs): 39 | 40 | 41 | self.config = GaussianRenderer.Config(**kwargs) 42 | 43 | @beartype 44 | def pack_inputs(self, gaussians:Gaussians, requires_grad=False): 45 | return gaussians.to_gaussians3d().requires_grad_(requires_grad) 46 | 47 | 48 | def update_settings(self, **kwargs): 49 | self.config = replace(self.config, **kwargs) 50 | 51 | @beartype 52 | def render(self, inputs:Gaussians3D, camera:FOVCamera): 53 | device = inputs.position.device 54 | 55 | config = renderer.RasterConfig( 56 | tile_size=self.config.tile_size, 57 | antialias=self.config.antialias, 58 | pixel_stride=self.config.pixel_stride, 59 | blur_cov=self.config.blur_cov 60 | ) 61 | 62 | rendering = renderer.render_gaussians( 63 | gaussians=inputs, 64 | camera_params=to_camera_params(camera, device), 65 | config=config, 66 | use_sh=True, 67 | use_depth16=self.config.use_depth16, 68 | render_median_depth=True) 69 | 70 | 71 | return Rendering(image=rendering.image, 72 | depth=rendering.median_depth_image, 73 | camera=camera) 74 | 75 | -------------------------------------------------------------------------------- /splat_viewer/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/splat-viewer/c2fd37c42c3671b5719be28863e3ea6cf14770ea/splat_viewer/scripts/__init__.py -------------------------------------------------------------------------------- /splat_viewer/scripts/compare_clouds.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from splat_viewer.gaussians import Workspace 3 | from splat_trainer.util.pointcloud import PointCloud 4 | import open3d as o3d 5 | import argparse 6 | from torch.utils.dlpack import to_dlpack 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms.functional as F 10 | 11 | from pykeops.torch import LazyTensor 12 | 13 | def rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: 14 | r, g, _ = image.unbind(dim=-1) 15 | 16 | # Implementation is based on 17 | # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330 18 | minc, maxc = torch.aminmax(image, dim=-1) 19 | 20 | # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN 21 | # from happening in the results, because 22 | # + S channel has division by `maxc`, which is zero only if `maxc = minc` 23 | # + H channel has division by `(maxc - minc)`. 24 | # 25 | # Instead of overwriting NaN afterwards, we just prevent it from occurring so 26 | # we don't need to deal with it in case we save the NaN in a buffer in 27 | # backprop, if it is ever supported, but it doesn't hurt to do so. 28 | eqc = maxc == minc 29 | 30 | channels_range = maxc - minc 31 | # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine. 32 | ones = torch.ones_like(maxc) 33 | s = channels_range / torch.where(eqc, ones, maxc) 34 | # Note that `eqc => maxc = minc = r = g = b`. So the following calculation 35 | # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it 36 | # would not matter what values `rc`, `gc`, and `bc` have here, and thus 37 | # replacing denominator with 1 when `eqc` is fine. 38 | channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-1) 39 | rc, gc, bc = ((maxc.unsqueeze(dim=-1) - image) / channels_range_divisor).unbind(dim=-1) 40 | 41 | mask_maxc_neq_r = maxc != r 42 | mask_maxc_eq_g = maxc == g 43 | 44 | hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r) 45 | hr = bc.sub_(gc).mul_(~mask_maxc_neq_r) 46 | hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_())) 47 | 48 | h = hr.add_(hg).add_(hb) 49 | h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0) 50 | return torch.stack((h, s, maxc), dim=-1) 51 | 52 | 53 | 54 | 55 | def smooth_colors_spatial(points: torch.Tensor, colors: torch.Tensor, k: int = 16) -> torch.Tensor: 56 | """Smooth colors by averaging with k nearest spatial neighbors""" 57 | 58 | N, D = points.shape 59 | x_i = LazyTensor(points.view(N, 1, D)) # (N, 1, D) samples 60 | x_j = LazyTensor(points.view(1, N, D)) # (1, N, D) samples 61 | 62 | # Compute pairwise squared distances 63 | D_ij = ((x_i - x_j) ** 2).sum(-1) # (N, N) symbolic squared distances 64 | 65 | # Find k nearest neighbors using PyKeOps 66 | knn_indices = D_ij.argKmin(K=k, dim=1) # (N, k) tensor of neighbor indices 67 | 68 | # Gather colors of k nearest neighbors 69 | knn_colors = colors[knn_indices] # (N, k, 3) 70 | 71 | # Average the colors 72 | smoothed_colors = knn_colors.mean(dim=1) # (N, 3) 73 | return smoothed_colors 74 | 75 | def is_vegetation(colors: torch.Tensor, green_hue_tolerance: float = 0.1) -> torch.Tensor: 76 | hsv = rgb_to_hsv(colors) 77 | h, s, v = hsv.unbind(dim=-1) 78 | 79 | # Debug: show hue distribution 80 | q05, q25, q50, q75, q95 = torch.quantile(h, torch.tensor([0.05, 0.25, 0.5, 0.75, 0.95])) 81 | print(f"Hue quantiles: 5%={q05:.3f}, 25%={q25:.3f}, 50%={q50:.3f}, 75%={q75:.3f}, 95%={q95:.3f}") 82 | print(f"Hue mean: {h.mean():.3f}, std: {h.std():.3f}") 83 | 84 | green_center = 0.2 # Around 72 degrees - more yellow-green 85 | hue_min = green_center - green_hue_tolerance 86 | hue_max = green_center + green_hue_tolerance 87 | 88 | print(f"Looking for hues between {hue_min:.3f} and {hue_max:.3f}") 89 | 90 | mask = (h > hue_min) & (h < hue_max) & (s > 0.2) 91 | 92 | return mask 93 | 94 | 95 | def load_cloud(path:str): 96 | workspace = Workspace.load(model_path=path) 97 | gaussians3d = workspace.load_model() 98 | 99 | gaussians3d = gaussians3d.crop_foreground() 100 | gaussians3d = gaussians3d[gaussians3d.alpha().squeeze() > 0.5] 101 | 102 | return PointCloud( 103 | points=gaussians3d.position, 104 | colors=gaussians3d.get_colors(), 105 | batch_size=(gaussians3d.batch_size[0])) 106 | 107 | 108 | def to_o3d(cloud:PointCloud, o3d_device:str = "CUDA:0"): 109 | device = o3d.core.Device(o3d_device) 110 | pcd = o3d.t.geometry.PointCloud(device=device) 111 | 112 | assert torch.all(cloud.points.isfinite()) and torch.all(cloud.colors.isfinite()) 113 | 114 | # Move tensors to CUDA device matching o3d_device 115 | torch_device = f"cuda:{o3d_device.split(':')[1]}" 116 | points_tensor = cloud.points.to(torch_device).contiguous() 117 | colors_tensor = cloud.colors.to(torch_device).contiguous() 118 | 119 | pcd.point["positions"] = o3d.core.Tensor.from_dlpack(to_dlpack(points_tensor)) 120 | pcd.point["colors"] = o3d.core.Tensor.from_dlpack(to_dlpack(colors_tensor)) 121 | return pcd 122 | 123 | 124 | 125 | def vis_clouds(clouds:List[PointCloud], o3d_device:str = "CUDA:0"): 126 | o3d_clouds = [to_o3d(cloud, o3d_device) for cloud in clouds] 127 | 128 | for i in range(1, len(o3d_clouds)): 129 | 130 | result = icp(o3d_clouds[i], o3d_clouds[0]) 131 | print(result) 132 | print("Transformation matrix:") 133 | print(result.transformation.cpu().numpy()) 134 | 135 | # Apply transformation to align first cloud to second 136 | o3d_clouds[i] = o3d_clouds[i].transform(result.transformation) 137 | 138 | print("Visualizing aligned clouds...") 139 | o3d.visualization.draw(o3d_clouds) 140 | 141 | 142 | def icp(cloud1, cloud2, max_distance=0.05, max_iterations=300, init_transform=np.identity(4)): 143 | # Convert init_transform to tensor (Float32 like their example) 144 | init_transform_tensor = o3d.core.Tensor.eye(4, o3d.core.Dtype.Float32) 145 | 146 | # Use multi-scale ICP matching their example 147 | voxel_sizes = o3d.utility.DoubleVector([0.01, 0.005, 0.001]) 148 | 149 | # Use stricter convergence criteria to actually converge 150 | criteria_list = [ 151 | o3d.t.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-4, max_iteration=200), 152 | o3d.t.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-5, max_iteration=100), 153 | o3d.t.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6, max_iteration=50) 154 | ] 155 | 156 | # Use their max correspondence distances 157 | max_correspondence_distances = o3d.utility.DoubleVector([0.1, 0.05, 0.01]) 158 | 159 | icp_result = o3d.t.pipelines.registration.multi_scale_icp( 160 | source=cloud1, 161 | target=cloud2, 162 | voxel_sizes=voxel_sizes, 163 | criteria_list=criteria_list, 164 | max_correspondence_distances=max_correspondence_distances, 165 | init_source_to_target=init_transform_tensor, 166 | estimation_method=o3d.t.pipelines.registration.TransformationEstimationPointToPoint() 167 | ) 168 | 169 | return icp_result 170 | 171 | def filter_vegetation(cloud:PointCloud, green_hue_tolerance: float = 0.1, use_smoothing: bool = True) -> PointCloud: 172 | colors = cloud.colors 173 | if use_smoothing: 174 | colors = smooth_colors_spatial(cloud.points, cloud.colors, k=16) 175 | 176 | is_leaf = is_vegetation(colors, green_hue_tolerance) 177 | 178 | # # color the vegetation red 179 | # cloud.colors[is_leaf] = torch.tensor([1., 0., 0.]) 180 | 181 | # return cloud 182 | return cloud[~is_leaf] 183 | 184 | def main(): 185 | 186 | 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("workspace_paths", type=str, nargs="+") 189 | parser.add_argument("--o3d_device", type=str, default="CUDA:0") 190 | parser.add_argument("--green_hue_tolerance", type=float, default=0.1) 191 | args = parser.parse_args() 192 | 193 | clouds = [load_cloud(workspace_path) for workspace_path in args.workspace_paths] 194 | print(clouds) 195 | 196 | clouds = [filter_vegetation(cloud, args.green_hue_tolerance) for cloud in clouds] 197 | print(clouds) 198 | 199 | 200 | vis_clouds(clouds, args.o3d_device) 201 | 202 | if __name__ == "__main__": 203 | main() -------------------------------------------------------------------------------- /splat_viewer/scripts/crop_foreground.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import List 4 | 5 | from splat_viewer.camera.fov import FOVCamera 6 | from splat_viewer.camera.visibility import visibility 7 | import torch 8 | 9 | from splat_viewer.gaussians.loading import read_gaussians, write_gaussians 10 | from splat_viewer.gaussians.workspace import load_workspace 11 | 12 | 13 | 14 | def crop_model(model, cameras:List[FOVCamera], args): 15 | num_visible, min_distance = visibility(cameras, model.position, near = args.near) 16 | 17 | min_views = max(1, len(cameras) * args.min_percent / 100) 18 | 19 | is_near = (min_distance < args.far) & (num_visible > min_views) 20 | n_near = is_near.sum(dtype=torch.int32) 21 | 22 | print(f"Cropped model to {n_near} points from {model.batch_size[0]} points") 23 | return model[is_near].to(model.device) 24 | 25 | 26 | def main(): 27 | 28 | parser = argparse.ArgumentParser(description="Add a 'foreground' annotation to a .ply gaussian splatting file") 29 | parser.add_argument("model_path", type=Path, help="Path to the gaussian splatting workspace") 30 | 31 | parser.add_argument("--far", default=torch.inf, type=float, help="Max depth to determine the visible ROI") 32 | parser.add_argument("--near", default=0.01, type=float, help="Min depth to determine the visible ROI") 33 | 34 | parser.add_argument("--min_percent", type=float, default=0, help="Minimum percent of views to be included") 35 | parser.add_argument("--device", default='cuda:0') 36 | 37 | parser.add_argument("--write_to", type=Path, help="Write the model to a ply file") 38 | parser.add_argument("--show", action="store_true") 39 | 40 | args = parser.parse_args() 41 | 42 | assert args.show or args.write_to, "Nothing to do. Please specify --show or --write_to" 43 | 44 | 45 | workspace = load_workspace(args.model_path) 46 | 47 | with torch.inference_mode(): 48 | workspace = load_workspace(args.model_path) 49 | 50 | model_name = workspace.latest_iteration() 51 | model_file = workspace.model_filename(model_name) 52 | model = read_gaussians(model_file) 53 | 54 | model = model.to(args.device) 55 | model = crop_model(model, workspace.cameras, args) 56 | 57 | 58 | if args.write_to: 59 | write_gaussians(args.write_to, model) 60 | print(f"Wrote {model} to {args.write_to}") 61 | 62 | if args.show: 63 | from splat_viewer.viewer.viewer import show_workspace 64 | show_workspace(workspace, model) 65 | 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /splat_viewer/scripts/debug_tiles.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from pathlib import Path 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from splat_viewer.gaussians.data_types import Gaussians 8 | from splat_viewer.gaussians.workspace import load_workspace 9 | 10 | from taichi_splatting import RasterConfig, map_to_tiles, pad_to_tile, perspective, tile_mapper, TaichiQueue 11 | 12 | import taichi as ti 13 | from splat_viewer.renderer.taichi_splatting import to_camera_params 14 | 15 | 16 | def overlap_stats(overlap_counts, tile_counts, max_hist = 8): 17 | large_thresh = (overlap_counts.std() * 3).item() 18 | very_large = (overlap_counts > large_thresh).sum() 19 | 20 | hist = torch.histc(overlap_counts.clamp(0, max_hist + 1), bins=max_hist + 1, min=0, max=max_hist + 1) 21 | 22 | n = overlap_counts.shape[0] 23 | print(f"Overlap mean: {overlap_counts.mean():.2f} very large (> {large_thresh:.2f}) %: {100.0 * very_large / n:.2f}") 24 | print(f"Overlap histogram %: {100 * hist[:max_hist] / n} > 10: {100 * hist[max_hist] / n}") 25 | 26 | 27 | print(f"Mean tile count: {tile_counts.mean():.2f}, max tile count: {tile_counts.max()}") 28 | 29 | 30 | def main(): 31 | 32 | torch.set_printoptions(precision=5, sci_mode=False, linewidth=120) 33 | 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument("model_path", type=Path, help="workspace folder containing cameras.json, input.ply and point_cloud folder with .ply models") 37 | parser.add_argument("--device", type=str, default="cuda:0", help="torch device to use") 38 | parser.add_argument("--model", type=str, default=None, help="model to load") 39 | 40 | parser.add_argument("--tile_size", type=int, default=16, help="tile size for rasterizer") 41 | parser.add_argument("--image_size", type=int, default=None, help="resize longest edge of camera image sizes") 42 | parser.add_argument("--antialias", action="store_true", help="enable analytic antialiasing") 43 | parser.add_argument("--debug", action="store_true", help="enable taichi debug mode") 44 | 45 | args = parser.parse_args() 46 | 47 | TaichiQueue.init(arch=ti.cuda, offline_cache=True, log_level=ti.INFO, 48 | debug=args.debug, device_memory_GB=0.1) 49 | 50 | 51 | workspace = load_workspace(args.model_path) 52 | if args.model is None: 53 | args.model = workspace.latest_iteration() 54 | 55 | gaussians:Gaussians = workspace.load_model(args.model).to(args.device) 56 | gaussians = gaussians.sorted() 57 | 58 | 59 | print(f"Using {args.model_path} with {gaussians.batch_size[0]} points") 60 | 61 | image_sizes = set([tuple(camera.image_size) for camera in workspace.cameras]) 62 | print(f"Cameras: {len(workspace.cameras)}, Image sizes: {image_sizes}") 63 | 64 | config = RasterConfig( 65 | tile_size=args.tile_size, 66 | antialias=args.antialias) 67 | 68 | packed = gaussians.packed() 69 | 70 | overlaps = [] 71 | tile_counts = [] 72 | 73 | for camera in tqdm(workspace.cameras): 74 | camera_params = to_camera_params(camera, device=args.device) 75 | 76 | mask = perspective.frustum_culling(packed, camera_params, margin_pixels=50) 77 | gaussians2d, depth = perspective.project_to_image(packed[mask], camera_params) 78 | 79 | overlap_to_point, tile_ranges = map_to_tiles( 80 | gaussians2d, depth, camera_params.image_size, config) 81 | 82 | image_size = pad_to_tile(camera.image_size, config.tile_size) 83 | 84 | mapper = tile_mapper.tile_mapper(config) 85 | overlap_offsets, total_overlap = mapper.generate_tile_overlaps( 86 | gaussians2d, image_size) 87 | 88 | cum_overlap_counts = torch.cat([overlap_offsets.cpu(), torch.tensor([total_overlap])]) 89 | 90 | 91 | overlaps.append((cum_overlap_counts[1:] - cum_overlap_counts[:-1]).float()) 92 | tile_counts.append((tile_ranges[:, 1] - tile_ranges[:, 0]).float()) 93 | 94 | 95 | overlap_stats(torch.cat(overlaps), torch.cat(tile_counts)) 96 | max_counts = [torch.max(overlap) for overlap in overlaps] 97 | print(f"Max overlap (mean): {sum(max_counts) / len(max_counts):.2f}") 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | 103 | -------------------------------------------------------------------------------- /splat_viewer/scripts/depth_fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from splat_viewer.gaussians import read_gaussians 4 | import argparse 5 | from pathlib import Path 6 | import open3d as o3d 7 | from tqdm import tqdm 8 | from dataclasses import dataclass 9 | 10 | from splat_viewer.gaussians.data_types import Gaussians 11 | from splat_viewer.gaussians.workspace import load_workspace 12 | from splat_viewer.camera import FOVCamera 13 | 14 | from splat_viewer.renderer.taichi_splatting import GaussianRenderer, Rendering 15 | import open3d.core as o3c 16 | 17 | from taichi_splatting import TaichiQueue 18 | import taichi as ti 19 | 20 | 21 | @dataclass 22 | class TSDFConfig: 23 | voxel_size: float = 0.0005 24 | block_resolution: int = 8 25 | depth_max: float = 1.5 26 | device: str = "CUDA:0" 27 | 28 | 29 | def torch_to_o3d(tensor:torch.Tensor) -> o3d.core.Tensor: 30 | return o3d.core.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) 31 | 32 | def o3d_to_torch(tensor:o3c.Tensor) -> torch.Tensor: 33 | return torch.from_dlpack(o3d.core.Tensor.to_dlpack(tensor)) 34 | 35 | 36 | def to_o3d_rgbd(rendering:Rendering) -> o3d.t.geometry.RGBDImage: 37 | return o3d.t.geometry.RGBDImage.create_from_color_and_depth( 38 | torch_to_o3d(rendering.image), 39 | torch_to_o3d(rendering.depth) 40 | ) 41 | 42 | 43 | def create_voxel_block_grid(config: TSDFConfig) -> o3d.t.geometry.VoxelBlockGrid: 44 | """Create and initialize a voxel block grid for TSDF integration""" 45 | device = o3d.core.Device(config.device) 46 | 47 | vbg = o3d.t.geometry.VoxelBlockGrid( 48 | attr_names=('tsdf', 'weight'), 49 | attr_dtypes=(o3c.float32, o3c.float32), 50 | attr_channels=((1), (1)), 51 | voxel_size=config.voxel_size, 52 | block_resolution=config.block_resolution, 53 | block_count=10000, 54 | device=device 55 | ) 56 | 57 | return vbg 58 | 59 | 60 | def integrate_tsdf(vbg: o3d.t.geometry.VoxelBlockGrid, 61 | rendering: Rendering, 62 | camera: FOVCamera, 63 | config: TSDFConfig) -> None: 64 | """Integrate a single depth/color image into the TSDF volume""" 65 | 66 | # Convert rendering to Open3D format 67 | depth_tensor = torch_to_o3d(rendering.depth.squeeze()) 68 | depth = o3d.t.geometry.Image(depth_tensor) 69 | 70 | K = camera.intrinsic 71 | intrinsic = o3d.core.Tensor(K, dtype=o3d.core.float64) 72 | 73 | # Create extrinsic matrix (world-to-camera transform) 74 | extrinsic = o3d.core.Tensor(camera.camera_t_world, dtype=o3d.core.float64) 75 | 76 | # Compute unique block coordinates in current viewing frustum 77 | frustum_block_coords = vbg.compute_unique_block_coordinates( 78 | depth, intrinsic, extrinsic, 1.0, config.depth_max 79 | ) 80 | 81 | # Integrate depth only (no color to save memory) 82 | vbg.integrate(frustum_block_coords, depth, intrinsic, extrinsic, 1.0, config.depth_max) 83 | 84 | 85 | def main(): 86 | 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('input', type=Path) 89 | parser.add_argument('--write', type=Path) 90 | parser.add_argument('--show', action='store_true') 91 | 92 | parser.add_argument('--densify', default=1, type=int) 93 | parser.add_argument('--device', default='cuda:0') 94 | parser.add_argument('--sample', default=None, type=float) 95 | 96 | # TSDF parameters 97 | parser.add_argument('--voxel_size', default=0.001, type=float) 98 | parser.add_argument('--depth_max', default=1.0, type=float) 99 | parser.add_argument('--image_scale', default=1.0, type=float, help='Scale factor for images to reduce memory usage') 100 | 101 | args = parser.parse_args() 102 | 103 | if args.write is None and not args.show: 104 | raise ValueError("Must specify --output and/or --show") 105 | 106 | input:Path = args.input 107 | 108 | TaichiQueue.init(ti.gpu, offline_cache=True, debug=False, device_memory_GB=0.1) 109 | 110 | torch.set_grad_enabled(False) 111 | 112 | assert input.is_dir() 113 | workspace = load_workspace(input) 114 | gaussians:Gaussians = workspace.load_model() 115 | 116 | gaussians = gaussians.to(device=args.device) 117 | 118 | renderer = GaussianRenderer() 119 | inputs = renderer.pack_inputs(gaussians) 120 | 121 | # Create TSDF configuration 122 | config = TSDFConfig( 123 | voxel_size=args.voxel_size, 124 | depth_max=args.depth_max, 125 | device="CUDA:0" if args.device.startswith('cuda') else "CPU:0" 126 | ) 127 | 128 | # Initialize voxel block grid 129 | vbg = create_voxel_block_grid(config) 130 | 131 | print(f"Starting TSDF integration with {len(workspace.cameras)} views...") 132 | 133 | for camera in tqdm(workspace.cameras, desc="Integrating views"): 134 | # Scale down camera to reduce memory usage 135 | scaled_camera = camera.scale_size(args.image_scale) if args.image_scale != 1.0 else camera 136 | rendering = renderer.render(inputs, scaled_camera) 137 | integrate_tsdf(vbg, rendering, scaled_camera, config) 138 | 139 | torch.cuda.empty_cache() 140 | 141 | print("TSDF integration complete, extracting geometry...") 142 | 143 | # Extract point cloud 144 | pcd = vbg.extract_point_cloud() 145 | 146 | if args.write: 147 | o3d.io.write_point_cloud(str(args.write), pcd) 148 | print(f"Saved point cloud to {args.write}") 149 | 150 | if args.show: 151 | o3d.visualization.draw([pcd]) 152 | 153 | 154 | if __name__ == "__main__": 155 | main() -------------------------------------------------------------------------------- /splat_viewer/scripts/export_rgb_cloud.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from splat_viewer.gaussians import read_gaussians 4 | import argparse 5 | from pathlib import Path 6 | import open3d as o3d 7 | 8 | from splat_viewer.gaussians.data_types import Gaussians 9 | from splat_viewer.gaussians.workspace import load_workspace 10 | import open3d.core as o3c 11 | 12 | 13 | 14 | def torch_to_o3d(tensor:torch.Tensor) -> o3d.core.Tensor: 15 | return o3d.core.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) 16 | 17 | def o3d_to_torch(tensor:o3c.Tensor) -> torch.Tensor: 18 | return torch.from_dlpack(o3d.core.Tensor.to_dlpack(tensor)) 19 | 20 | 21 | 22 | def sample_points(gaussians:Gaussians, n:int): 23 | m = gaussians.batch_size[0] 24 | basis = gaussians.get_rotation_matrix() # (M, 3, 3) 25 | 26 | samples = (torch.randn((m, n, 3), device=gaussians.device) 27 | * gaussians.scale()[:, None, :]) # (M, N, 3) 28 | 29 | samples = torch.einsum('mij,mnj->mni', basis, samples) # (M, N, 3) 30 | return samples + gaussians.position[:, None, :] # (M, N, 3) 31 | 32 | 33 | 34 | 35 | def to_rgb(gaussians:Gaussians, densify=1) -> o3d.t.geometry.PointCloud: 36 | colors = gaussians.get_colors() 37 | 38 | if densify > 1: 39 | positions = sample_points(gaussians, densify).reshape(-1, 3) 40 | colors = colors.repeat_interleave(densify, dim=0) 41 | 42 | else: 43 | positions = gaussians.position 44 | 45 | 46 | positions, colors = [torch_to_o3d(t) for t in (positions, colors)] 47 | 48 | cloud = o3d.t.geometry.PointCloud(positions.device) 49 | 50 | cloud.point['positions'] = positions 51 | cloud.point['colors'] = colors 52 | 53 | return cloud 54 | 55 | 56 | def main(): 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('input', type=Path) 60 | parser.add_argument('--write', type=Path) 61 | parser.add_argument('--show', action='store_true') 62 | 63 | parser.add_argument('--densify', default=1, type=int) 64 | parser.add_argument('--device', default='cuda:0') 65 | parser.add_argument('--sample', default=None, type=float) 66 | args = parser.parse_args() 67 | 68 | if args.write is None and not args.show: 69 | raise ValueError("Must specify --output or --show") 70 | 71 | input:Path = args.input 72 | 73 | if input.is_dir(): 74 | workspace = load_workspace(input) 75 | gaussians:Gaussians = workspace.load_model() 76 | else: 77 | gaussians = read_gaussians(args.input) 78 | 79 | gaussians = gaussians.to(device=args.device) 80 | 81 | 82 | print("Loaded:", gaussians) 83 | 84 | if gaussians.foreground is not None: 85 | gaussians = gaussians.crop_foreground() 86 | 87 | pcd:o3d.t.geometry.PointCloud = to_rgb(gaussians, densify=args.densify) 88 | 89 | print(pcd) 90 | 91 | if args.sample is not None: 92 | pcd = pcd.voxel_down_sample(args.sample) 93 | print(f"After sampling to {args.sample}", pcd) 94 | 95 | 96 | 97 | if args.show: 98 | o3d.visualization.draw([pcd]) 99 | 100 | if args.write: 101 | o3d.t.io.write_point_cloud(str(args.write), pcd) 102 | 103 | print(f"Wrote {pcd} to {args.write}") -------------------------------------------------------------------------------- /splat_viewer/scripts/export_workspace.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from natsort import natsorted 3 | from pathlib import Path 4 | from beartype.typing import List 5 | import shutil 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description="Export a trained workspace (only the minimum files required)") 10 | parser.add_argument("model_path", type=Path) 11 | parser.add_argument("output", type=Path) 12 | 13 | args = parser.parse_args() 14 | 15 | clouds = natsorted([file for file in args.model_path.glob("point_cloud/iteration_*/point_cloud.ply")]) 16 | 17 | if len(clouds) == 0: 18 | raise Exception("No point clouds found in {}".format(args.model_path)) 19 | 20 | cloud_file = clouds[-1] 21 | print("Using point cloud {}".format(cloud_file)) 22 | 23 | camera_file = args.model_path/"cameras.json" 24 | input_file = args.model_path/"input.ply" 25 | 26 | scene_file = args.model_path/"scene.json" 27 | cfg_file = args.model_path/"cfg_args" 28 | 29 | files:List[Path] = [camera_file, input_file, cloud_file, cfg_file, scene_file] 30 | 31 | for file in [camera_file, input_file]: 32 | if not file.exists(): 33 | raise Exception("Missing file {}".format(file)) 34 | 35 | 36 | args.output.mkdir(parents=True, exist_ok=True) 37 | for file in files: 38 | if not file.exists(): 39 | continue 40 | 41 | filename = file.relative_to(args.model_path) 42 | out_filename = args.output/filename 43 | 44 | print(f"Copying {file} to {out_filename}") 45 | out_filename.parent.mkdir(parents=True, exist_ok=True) 46 | 47 | shutil.copyfile(file, out_filename) 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | if __name__=="__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /splat_viewer/scripts/label_foreground.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import replace 3 | from pathlib import Path 4 | from beartype.typing import List 5 | 6 | import torch 7 | from splat_viewer.camera.fov import FOVCamera 8 | from splat_viewer.camera.visibility import visibility 9 | 10 | from splat_viewer.gaussians.loading import read_gaussians, write_gaussians 11 | from splat_viewer.gaussians.workspace import load_workspace 12 | 13 | 14 | def label_model(model, cameras:List[FOVCamera], args): 15 | num_visible, min_distance = visibility(cameras, model.position, near = args.near) 16 | 17 | min_views = max(1, len(cameras) * args.min_percent / 100) 18 | 19 | is_near = (min_distance < args.far) & (num_visible > min_views) 20 | n_near = is_near.sum(dtype=torch.int32) 21 | 22 | print(f"Labelled {n_near} points as near ({100.0 * n_near / model.batch_size[0]:.2f}%)") 23 | model = model.replace(foreground=is_near.reshape(-1, 1)) 24 | 25 | return model 26 | 27 | def main(): 28 | 29 | parser = argparse.ArgumentParser(description="Add a 'foreground' annotation to a .ply gaussian splatting file") 30 | parser.add_argument("model_path", type=Path, help="Path to the gaussian splatting workspace") 31 | 32 | parser.add_argument("--far", default=torch.inf, type=float, help="Max depth to determine the visible ROI") 33 | parser.add_argument("--near", default=0.01, type=float, help="Min depth to determine the visible ROI") 34 | 35 | parser.add_argument("--min_percent", type=float, default=0, help="Minimum percent of views to be included") 36 | parser.add_argument("--device", default='cuda:0') 37 | 38 | parser.add_argument("--write", action="store_true", help="Write the labelled moel back to the file") 39 | parser.add_argument("--show", action="store_true") 40 | 41 | args = parser.parse_args() 42 | 43 | assert args.show or args.write, "Nothing to do. Please specify --show or --write" 44 | 45 | 46 | workspace = load_workspace(args.model_path) 47 | 48 | with torch.inference_mode(): 49 | workspace = load_workspace(args.model_path) 50 | 51 | model_name = workspace.latest_iteration() 52 | model_file = workspace.model_filename(model_name) 53 | model = read_gaussians(model_file) 54 | 55 | 56 | model = model.to(args.device) 57 | model = label_model(model, workspace.cameras, args) 58 | 59 | 60 | if args.write: 61 | write_gaussians(model_file, model) 62 | print(f"Wrote {model} to {model_file}") 63 | 64 | if args.show: 65 | from splat_viewer.viewer.viewer import show_workspace 66 | show_workspace(workspace, model) 67 | 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /splat_viewer/scripts/splat_viewer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from PySide6 import QtWidgets 5 | from PySide6.QtWidgets import QApplication 6 | 7 | from splat_viewer.gaussians.workspace import load_workspace 8 | from splat_viewer.renderer.arguments import add_render_arguments, make_renderer_args, renderer_from_args 9 | 10 | from splat_viewer.viewer.scene_widget import SceneWidget, Settings 11 | from taichi_splatting import TaichiQueue 12 | 13 | import signal 14 | import taichi as ti 15 | import torch 16 | 17 | def process_cl_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('model_path', help="workspace folder containing cameras.json, input.ply and point_cloud folder with .ply models") # positional argument 20 | parser.add_argument('--model', default=None, help="load model from point_clouds folder, default is latest iteration") 21 | parser.add_argument('--device', default='cuda:0', help="torch device to use") 22 | parser.add_argument('--debug', action='store_true', help="enable taichi kernels in debug mode") 23 | 24 | add_render_arguments(parser) 25 | 26 | parsed_args, unparsed_args = parser.parse_known_args() 27 | return parsed_args, unparsed_args 28 | 29 | def sigint_handler(*args): 30 | QApplication.quit() 31 | 32 | 33 | 34 | def main(): 35 | signal.signal(signal.SIGINT, sigint_handler) 36 | torch.set_printoptions(precision=5, sci_mode=False, linewidth=120) 37 | 38 | 39 | parsed_args, unparsed_args = process_cl_args() 40 | workspace = load_workspace(parsed_args.model_path) 41 | 42 | if parsed_args.model is None: 43 | parsed_args.model = workspace.latest_iteration() 44 | 45 | gaussians = workspace.load_model(parsed_args.model) 46 | print(f"Loaded model {parsed_args.model}: {gaussians}") 47 | 48 | TaichiQueue.init(ti.gpu, offline_cache=True, debug=parsed_args.debug, device_memory_GB=0.1) 49 | 50 | 51 | qt_args = sys.argv[:1] + unparsed_args 52 | app = QApplication(qt_args) 53 | 54 | 55 | window = QtWidgets.QMainWindow() 56 | 57 | renderer = renderer_from_args(make_renderer_args(parsed_args)) 58 | print(renderer) 59 | scene_widget = SceneWidget( 60 | settings=Settings(device=parsed_args.device), 61 | renderer = renderer 62 | ) 63 | 64 | scene_widget.load_workspace(workspace, gaussians) 65 | 66 | window.setCentralWidget(scene_widget) 67 | 68 | window.show() 69 | sys.exit(app.exec()) 70 | 71 | -------------------------------------------------------------------------------- /splat_viewer/viewer/__init__.py: -------------------------------------------------------------------------------- 1 | from .viewer import (init_viewer, show_workspace, 2 | ViewerProcess, Viewer) 3 | 4 | from .settings import (Settings, Show) 5 | 6 | __all__ = ["init_viewer", "show_workspace", 7 | "ViewerProcess", "Settings", "Show", "Viewer"] -------------------------------------------------------------------------------- /splat_viewer/viewer/interaction.py: -------------------------------------------------------------------------------- 1 | 2 | from beartype.typing import Optional, Set, Tuple 3 | from PySide6 import QtGui 4 | from PySide6.QtCore import QEvent 5 | 6 | from beartype import beartype 7 | import numpy as np 8 | import torch 9 | from splat_viewer.gaussians.data_types import Gaussians, Rendering 10 | 11 | from splat_viewer.viewer.settings import Settings 12 | 13 | 14 | class Interaction(): 15 | def __init__(self): 16 | super(Interaction, self).__init__() 17 | self._child = None 18 | self.active = False 19 | 20 | def transition(self, interaction:Optional['Interaction']): 21 | self.pop() 22 | if interaction is not None: 23 | self.push(interaction) 24 | 25 | def push(self, interaction:'Interaction'): 26 | self._child = interaction 27 | if self.active: 28 | self._child._activate() 29 | 30 | 31 | def pop(self): 32 | if self._child is not None: 33 | child = self._child 34 | self._child = None 35 | child._deactivate() 36 | 37 | 38 | def _activate(self): 39 | self.on_activate() 40 | 41 | if self._child is not None: 42 | self._child._activate() 43 | 44 | self.active = True 45 | 46 | 47 | def _deactivate(self): 48 | if self._child is not None: 49 | child = self._child 50 | self._child = None 51 | child._deactivate() 52 | 53 | self.on_deactivate() 54 | 55 | 56 | def trigger_event(self, event: QEvent) -> bool: 57 | if self._child is not None: 58 | if self._child.trigger_event(event): 59 | return True 60 | 61 | return self.event(event) 62 | 63 | def _update(self, dt:float) -> bool: 64 | if self._child is not None: 65 | if self._child._update(dt): 66 | return True 67 | 68 | return self.update(dt) 69 | 70 | @beartype 71 | def event(self, event: QEvent) -> bool: 72 | event_callbacks = { 73 | QEvent.KeyPress: self.keyPressEvent, 74 | QEvent.KeyRelease: self.keyReleaseEvent, 75 | QEvent.MouseButtonPress: self.mousePressEvent, 76 | QEvent.MouseButtonRelease: self.mouseReleaseEvent, 77 | QEvent.MouseMove: self.mouseMoveEvent, 78 | QEvent.Wheel: self.wheelEvent, 79 | QEvent.FocusIn: self.focusInEvent, 80 | QEvent.FocusOut: self.focusOutEvent, 81 | } 82 | 83 | if event.type() in event_callbacks: 84 | return event_callbacks[event.type()](event) or False 85 | 86 | return False 87 | 88 | def trigger_paint(self, event: QtGui.QPaintEvent, view_changed:bool) -> bool: 89 | if self._child is not None: 90 | if self._child.trigger_paint(event, view_changed): 91 | return True 92 | 93 | return self.paintEvent(event, view_changed) 94 | 95 | def keyPressEvent(self, event: QtGui.QKeyEvent): 96 | return False 97 | 98 | def keyReleaseEvent(self, event: QtGui.QKeyEvent): 99 | return False 100 | 101 | def mousePressEvent(self, event: QtGui.QMouseEvent): 102 | return False 103 | 104 | def mouseReleaseEvent(self, event: QtGui.QMouseEvent): 105 | return False 106 | 107 | def mouseMoveEvent(self, event: QtGui.QMouseEvent): 108 | return False 109 | 110 | def wheelEvent(self, event: QtGui.QWheelEvent): 111 | return False 112 | 113 | def focusInEvent(self, event: QtGui.QFocusEvent): 114 | return False 115 | 116 | def focusOutEvent(self, event: QtGui.QFocusEvent): 117 | return False 118 | 119 | def paintEvent(self, event: QtGui.QPaintEvent, view_changed:bool): 120 | return False 121 | 122 | @beartype 123 | def update(self, dt) -> bool: 124 | return False 125 | 126 | def on_activate(self): 127 | pass 128 | 129 | def on_deactivate(self): 130 | pass 131 | 132 | @property 133 | def scene_widget(self): 134 | from .scene_widget import SceneWidget 135 | return SceneWidget.instance 136 | 137 | 138 | 139 | @property 140 | def settings(self) -> Settings: 141 | return self.scene_widget.settings 142 | 143 | @property 144 | def modifiers(self) -> QtGui.Qt.KeyboardModifier: 145 | return self.scene_widget.modifiers 146 | 147 | @property 148 | def keys_down(self) -> Set[QtGui.Qt.Key]: 149 | return self.scene_widget.keys_down 150 | 151 | @property 152 | def cursor_pos(self) -> Tuple[int, int]: 153 | return self.scene_widget.cursor_pos 154 | 155 | @property 156 | def current_point(self) -> np.ndarray: 157 | return self.scene_widget.current_point_3d 158 | 159 | def lookup_point_3d(self, p:Tuple[int, int]) -> np.ndarray: 160 | return self.scene_widget.lookup_point_3d(p) 161 | 162 | def lookup_depth(self, p:Tuple[int, int]) -> np.ndarray: 163 | return self.scene_widget.lookup_depth(p) 164 | 165 | 166 | def lookup_depths(self, p:np.ndarray) -> np.ndarray: 167 | return self.scene_widget.lookup_depths(p) 168 | 169 | def test_depths(self, p:np.ndarray, depth:np.ndarray) -> np.ndarray: 170 | return self.scene_widget.test_depths(p, depth) 171 | 172 | def unproject_point(self, p:Tuple[int, int], depth:float) -> np.ndarray: 173 | return self.scene_widget.unproject_point(p, depth) 174 | 175 | def unproject_radius(self, p:Tuple[int, int], depth:float, radius:float 176 | ) -> Tuple[np.ndarray, float]: 177 | return self.scene_widget.unproject_radius(p, depth, radius) 178 | 179 | def set_dirty(self): 180 | self.scene_widget.set_dirty() 181 | 182 | @property 183 | def depth_map(self): 184 | return self.scene_widget.depth_map 185 | 186 | def from_numpy(self, a:np.ndarray): 187 | return torch.from_numpy(a).to(device=self.settings.device) 188 | 189 | @property 190 | def rendering(self) -> Rendering: 191 | return self.scene_widget.renderer.rendering 192 | 193 | 194 | @property 195 | def renderer(self): 196 | return self.scene_widget.renderer 197 | 198 | @property 199 | def gaussians(self) -> Gaussians: 200 | return self.scene_widget.gaussians 201 | 202 | def update_gaussians(self, gaussians:Gaussians): 203 | return self.scene_widget.update_gaussians(gaussians) 204 | 205 | def update_setting(self, **kwargs): 206 | self.scene_widget.update_setting(**kwargs) 207 | 208 | 209 | -------------------------------------------------------------------------------- /splat_viewer/viewer/interactions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uc-vision/splat-viewer/c2fd37c42c3671b5719be28863e3ea6cf14770ea/splat_viewer/viewer/interactions/__init__.py -------------------------------------------------------------------------------- /splat_viewer/viewer/interactions/animate.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from beartype.typing import Callable, List, Optional 4 | from PySide6 import QtGui 5 | from PySide6.QtCore import Qt 6 | from beartype import beartype 7 | 8 | import numpy as np 9 | import scipy 10 | 11 | from splat_viewer.camera.fov import split_rt 12 | 13 | from ..interaction import Interaction 14 | from scipy.spatial.transform import Rotation as R 15 | from scipy.spatial.transform import Slerp 16 | 17 | 18 | 19 | def generalized_sigmoid(x, smooth, eps=1e-6): 20 | if x < eps: 21 | return 0.0 22 | if (1 - x) < eps: 23 | return 1.0 24 | else: 25 | return 1/(1 + (x / (1 - x)) ** -smooth) 26 | 27 | 28 | def animate_to(state, current_pose, dest_pose, on_finish=None): 29 | to = AnimateCamera([current_pose, dest_pose], loop=False, on_finish=on_finish) 30 | state.transition(to) 31 | 32 | def animate_to_loop(state, current_pose, loop_poses, on_finish=None): 33 | loop = AnimateCamera(loop_poses, loop=True, on_finish=on_finish) 34 | to = AnimateCamera([current_pose, loop_poses[0]], loop=False, 35 | on_finish=lambda: state.transition(loop)) 36 | 37 | state.transition(to) 38 | 39 | class AnimateCamera(Interaction): 40 | @beartype 41 | def __init__(self, motion_path:List[np.ndarray], loop=True, 42 | on_finish:Optional[Callable]=None): 43 | super(AnimateCamera, self).__init__() 44 | 45 | bc_type = 'not-a-knot' 46 | 47 | if loop: 48 | motion_path = [*motion_path, motion_path[0]] 49 | bc_type = 'periodic' 50 | 51 | self.loop = loop 52 | self.on_finish = on_finish 53 | 54 | self.total = len(motion_path) - 1 55 | times = np.arange(len(motion_path)) 56 | 57 | r, t = zip(*[split_rt(m) for m in motion_path]) 58 | 59 | self.rots, self.pos = np.array(r), np.array(t) 60 | 61 | self.slerp = Slerp(times, R.from_matrix(self.rots)) 62 | self.interp = scipy.interpolate.CubicSpline(times, self.pos, axis=0, bc_type=bc_type) 63 | 64 | self.t = 0.0 65 | 66 | self.speed_controls = { 67 | Qt.Key_Plus : 1, 68 | Qt.Key_Minus : -1, 69 | } 70 | 71 | 72 | def keyPressEvent(self, event: QtGui.QKeyEvent): 73 | if event.key() in self.speed_controls: 74 | modifier = self.speed_controls[event.key()] 75 | 76 | if event.modifiers() & Qt.ShiftModifier: 77 | self.update_setting(animate_pausing = np.clip(self.settings.animate_pausing + (0.2 * modifier), 0, 2)) 78 | else: 79 | self.update_setting(animate_speed = self.settings.animate_speed * (2 ** modifier)) 80 | 81 | return True 82 | 83 | 84 | 85 | def update(self, dt:float): 86 | scene = self.scene_widget 87 | 88 | inc = dt * self.settings.animate_speed 89 | 90 | if self.t + inc >= self.total: 91 | if self.on_finish is not None: 92 | self.on_finish() 93 | 94 | if not self.loop: 95 | self.t = min(self.total, self.t + inc) 96 | else: 97 | self.t = (self.t + inc) % self.total 98 | 99 | frac = math.fmod(self.t, 1) 100 | t = math.floor(self.t) + generalized_sigmoid(frac, self.settings.animate_pausing + 1) 101 | 102 | r = self.slerp(np.array([t])).as_matrix()[0] 103 | 104 | pos = self.interp(t) 105 | 106 | scene.set_camera_pose(r, pos) 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /splat_viewer/viewer/interactions/fly_control.py: -------------------------------------------------------------------------------- 1 | 2 | from PySide6 import QtGui 3 | from PySide6.QtCore import Qt, QEvent 4 | 5 | import numpy as np 6 | 7 | from ..interaction import Interaction 8 | 9 | 10 | class FlyControl(Interaction): 11 | def __init__(self): 12 | super(FlyControl, self).__init__() 13 | self.drag_mouse_pos = None 14 | 15 | self.directions = { 16 | Qt.Key_Q : np.array([0., -1., 0.]), 17 | Qt.Key_E : np.array([0., 1., 0.]), 18 | 19 | Qt.Key_W : np.array([0., 0., 1.]), 20 | Qt.Key_S : np.array([0., 0., -1.]), 21 | 22 | Qt.Key_A : np.array([-1., 0., 0.]), 23 | Qt.Key_D : np.array([1., 0., 0.]) 24 | } 25 | 26 | self.rotations = { 27 | Qt.Key_Z : np.array([0., 0., 1.]), 28 | Qt.Key_C : np.array([0., 0., -1.]), 29 | 30 | Qt.Key_Up : np.array([0., 1., 0.]), 31 | Qt.Key_Down : np.array([0., -1., 0.]), 32 | 33 | Qt.Key_Left : np.array([-1., 0., 0.]), 34 | Qt.Key_Right : np.array([1., 0., 0.]), 35 | } 36 | 37 | 38 | self.speed_controls = { 39 | Qt.Key_Plus : 2.0, 40 | Qt.Key_Minus : 0.5, 41 | } 42 | 43 | 44 | self.held_keys = set(self.directions.keys()) | set(self.rotations.keys()) 45 | 46 | 47 | def keyPressEvent(self, event: QtGui.QKeyEvent): 48 | if event.key() in self.held_keys and not event.isAutoRepeat(): 49 | self.transition(None) 50 | 51 | if event.key() in self.speed_controls and event.modifiers() & Qt.KeypadModifier: 52 | self.update_setting(move_speed = self.settings.move_speed * self.speed_controls[event.key()]) 53 | return True 54 | 55 | 56 | def update(self, dt:float): 57 | scene = self.scene_widget 58 | mod = 0.1 if Qt.Key_Shift in self.keys_down else 1.0 59 | 60 | for key in self.keys_down: 61 | if key in self.rotations: 62 | scene.rotate_camera(mod * self.rotations[key] * dt * self.settings.rotate_speed) 63 | 64 | elif key in self.directions: 65 | scene.move_camera(mod * self.directions[key] * dt * self.settings.move_speed) 66 | 67 | 68 | def mousePressEvent(self, event: QtGui.QMouseEvent): 69 | if event.buttons() & Qt.RightButton: 70 | self.drag_mouse_pos = event.localPos() 71 | return True 72 | 73 | 74 | def mouseReleaseEvent(self, event: QtGui.QMouseEvent): 75 | if event.button() & Qt.RightButton: 76 | self.drag_mouse_pos = None 77 | return True 78 | 79 | def mouseMoveEvent(self, event: QtGui.QMouseEvent): 80 | if event.buttons() & Qt.RightButton and self.drag_mouse_pos is not None: 81 | delta = event.localPos() - self.drag_mouse_pos 82 | 83 | sz = self.scene_widget.size() 84 | 85 | speed = self.settings.drag_speed 86 | self.scene_widget.rotate_camera([delta.x() / sz.width() * speed, 87 | -delta.y() / sz.height() * speed, 88 | 0]) 89 | 90 | self.drag_mouse_pos = event.localPos() 91 | -------------------------------------------------------------------------------- /splat_viewer/viewer/interactions/scribble.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Tuple 3 | 4 | from PySide6 import QtGui, QtCore 5 | from PySide6.QtCore import Qt 6 | import numpy as np 7 | import torch 8 | 9 | from splat_viewer.gaussians.data_types import Gaussians 10 | from splat_viewer.viewer.interaction import Interaction 11 | 12 | 13 | 14 | 15 | def in_sphere(positions:torch.Tensor, center:torch.Tensor, radius:float): 16 | idx = in_box(positions, center - radius, center + radius) 17 | return idx[torch.linalg.norm(positions[idx] - center, dim=-1) <= radius] 18 | 19 | def in_box(positions:torch.Tensor, lower:torch.Tensor, upper:np.array): 20 | mask = ((positions >= lower) & (positions <= upper)).all(dim=-1) 21 | return torch.nonzero(mask, as_tuple=True)[0] 22 | 23 | 24 | class ScribbleGeometric(Interaction): 25 | def __init__(self): 26 | super(ScribbleGeometric, self).__init__() 27 | 28 | self.drawing = False 29 | 30 | self.current_label = 0 31 | self.current_points = None 32 | 33 | self.color = torch.tensor([1, 0, 0], dtype=torch.float32) 34 | 35 | @property 36 | def ready(self): 37 | return bool(self.modifiers & Qt.ControlModifier) 38 | 39 | 40 | def mousePressEvent(self, event: QtGui.QMouseEvent): 41 | if event.button() == Qt.LeftButton and event.modifiers() & Qt.ControlModifier: 42 | self.drawing = True 43 | self.draw((event.x(), event.y())) 44 | return True 45 | 46 | def mouseReleaseEvent(self, event: QtGui.QMouseEvent): 47 | if event.button() == Qt.LeftButton and self.drawing: 48 | self.drawing = False 49 | return True 50 | 51 | def draw(self, cursor_pos:Tuple[int, int]): 52 | depth = self.lookup_depth(cursor_pos) 53 | 54 | p, r = self.unproject_radius(cursor_pos, depth, self.settings.brush_size) 55 | idx = in_sphere(self.gaussians.position, self.from_numpy(p), r) 56 | 57 | if self.current_points is None: 58 | self.current_points = idx 59 | else: 60 | self.current_points = torch.cat([self.current_points, idx]).unique() 61 | 62 | self.update_gaussians(self.gaussians.set_colors(self.color, self.current_points)) 63 | 64 | self.set_dirty() 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | def mouseMoveEvent(self, event: QtGui.QMouseEvent): 73 | if self.drawing: 74 | self.draw((event.x(), event.y())) 75 | 76 | def wheelEvent(self, event: QtGui.QWheelEvent): 77 | if self.ready: 78 | dy = event.pixelDelta().y() 79 | factor = math.pow(1.0015, dy) 80 | 81 | self.update_setting(brush_size = np.clip(self.settings.brush_size * factor, 1, 100)) 82 | return True 83 | 84 | def keyPressEvent(self, event: QtGui.QKeyEvent): 85 | 86 | 87 | return super().keyPressEvent(event) 88 | 89 | def paintEvent(self, event: QtGui.QPaintEvent, dirty:bool): 90 | if self.ready: 91 | painter = QtGui.QPainter(self.scene_widget) 92 | painter.setRenderHint(QtGui.QPainter.Antialiasing) 93 | painter.setPen(QtGui.QPen(Qt.red, 1, Qt.DashLine)) 94 | 95 | point = QtCore.QPointF(*self.cursor_pos) 96 | painter.drawEllipse(point, 97 | self.settings.brush_size, self.settings.brush_size) 98 | painter.end() 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /splat_viewer/viewer/keyboard.py: -------------------------------------------------------------------------------- 1 | from PySide6.QtCore import Qt, QEvent 2 | 3 | 4 | keymap = {} 5 | for key, value in vars(Qt).items(): 6 | if isinstance(value, Qt.Key): 7 | keymap[value] = key 8 | 9 | def keyevent_to_string(event): 10 | sequence = [] 11 | 12 | key = keymap.get(event.key(), event.text()) 13 | if key not in sequence: 14 | sequence.append(key) 15 | return '+'.join(sequence) -------------------------------------------------------------------------------- /splat_viewer/viewer/mesh.py: -------------------------------------------------------------------------------- 1 | import math 2 | from beartype.typing import List 3 | import trimesh 4 | 5 | import numpy as np 6 | import pyrender 7 | 8 | import torch 9 | 10 | from splat_viewer.camera.fov import FOVCamera 11 | from splat_viewer.camera.transforms import batch_transform_points 12 | from splat_viewer.gaussians import Gaussians 13 | 14 | 15 | def instance_meshes(mesh:trimesh.Trimesh, transforms:np.array): 16 | vertices = batch_transform_points(transforms, mesh.vertices) 17 | n = transforms.shape[0] 18 | 19 | offsets = np.arange(n).reshape(n, 1, 1) * mesh.vertices.shape[0] 20 | faces = mesh.faces.reshape(1, -1, 3) + offsets 21 | 22 | return trimesh.Trimesh(vertices=vertices.reshape(-1, 3), 23 | faces=faces.reshape(-1, 3)) 24 | 25 | def camera_marker(camera:FOVCamera, scale): 26 | fov = camera.fov 27 | 28 | x = math.tan(fov[0] / 2) 29 | y = math.tan(fov[1] / 2) 30 | 31 | points = np.array([ 32 | [0, 0, 0], 33 | [-x, y, 1], 34 | [x, y, 1], 35 | [x, -y, 1], 36 | [-x, -y, 1] 37 | ]) 38 | 39 | 40 | triangles = np.array([ 41 | [0, 1, 2], 42 | [0, 2, 3], 43 | [0, 3, 4], 44 | [0, 4, 1], 45 | 46 | [1, 2, 3], 47 | [1, 3, 4] 48 | ], dtype=np.int32) 49 | 50 | return trimesh.Trimesh(vertices=points * scale, faces=triangles, process=False) 51 | 52 | 53 | def make_camera_markers(cameras:List[FOVCamera], scale:float): 54 | mesh = camera_marker(cameras[0], scale) 55 | 56 | markers = instance_meshes(mesh, np.array([cam.world_t_camera for cam in cameras], dtype=np.float32)) 57 | markers = pyrender.Mesh.from_trimesh(markers, wireframe=True, smooth=False, 58 | material=pyrender.MetallicRoughnessMaterial 59 | (doubleSided=True, wireframe=True, smooth=False, baseColorFactor=(255, 0, 0, 255))) 60 | 61 | 62 | return markers 63 | 64 | 65 | def extract_instance_corner_points(gaussians: Gaussians): 66 | 67 | corner_points = [] 68 | 69 | mask = gaussians.instance_label != -1 70 | valid_labels = gaussians.instance_label[mask] 71 | unique_labels = torch.unique(valid_labels) 72 | 73 | for label in unique_labels: 74 | positions = gaussians.position[(gaussians.instance_label == label).squeeze()] 75 | corner_points.append((torch.min(positions, dim=0)[0], torch.max(positions, dim=0)[0])) 76 | 77 | return corner_points 78 | 79 | 80 | def make_bounding_box(gaussians: Gaussians): 81 | assert gaussians.instance_label is not None 82 | 83 | all_vertices = [] 84 | all_indices = [] 85 | current_vertex_count = 0 86 | 87 | for (min_coords, max_coords) in extract_instance_corner_points(gaussians): 88 | min_x, min_y, min_z = min_coords.cpu().numpy() 89 | max_x, max_y, max_z = max_coords.cpu().numpy() 90 | 91 | vertices = np.array([ 92 | [min_x, min_y, min_z], 93 | [max_x, min_y, min_z], 94 | [max_x, max_y, min_z], 95 | [min_x, max_y, min_z], 96 | [min_x, min_y, max_z], 97 | [max_x, min_y, max_z], 98 | [max_x, max_y, max_z], 99 | [min_x, max_y, max_z] 100 | ]) 101 | 102 | edges = np.array([ 103 | [0, 1], [1, 2], [2, 3], [3, 0], 104 | [4, 5], [5, 6], [6, 7], [7, 4], 105 | [0, 4], [1, 5], [2, 6], [3, 7] 106 | ], dtype=np.uint32) + current_vertex_count 107 | 108 | current_vertex_count += len(vertices) 109 | 110 | all_vertices.append(vertices) 111 | all_indices.append(edges) 112 | 113 | if not all_vertices: 114 | all_vertices = np.empty((0, 3)) 115 | all_indices = np.empty((0, 2), dtype=np.uint32) 116 | 117 | else: 118 | all_vertices = np.vstack(all_vertices) 119 | all_indices = np.vstack(all_indices) 120 | 121 | primitive = pyrender.Primitive( 122 | positions=all_vertices, 123 | indices=all_indices, 124 | mode=1, 125 | material=pyrender.MetallicRoughnessMaterial 126 | (doubleSided=True, wireframe=True, smooth=False, baseColorFactor=(255, 255, 0, 255)) 127 | ) 128 | 129 | mesh = pyrender.Mesh([primitive]) 130 | 131 | return mesh 132 | 133 | 134 | 135 | def make_sphere(radius=1.0, subdivisions=3, color=(0.0, 0.0, 1.0)): 136 | sphere = trimesh.creation.icosphere(radius=radius, subdivisions=subdivisions) 137 | 138 | material = pyrender.MetallicRoughnessMaterial( 139 | metallicFactor=0.0, 140 | baseColorFactor=[*color, 1.0], 141 | ) 142 | return pyrender.Mesh.from_trimesh(sphere, smooth=True, material=material) -------------------------------------------------------------------------------- /splat_viewer/viewer/renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, replace 2 | import cv2 3 | 4 | import numpy as np 5 | import pyrender 6 | import torch 7 | from splat_viewer.camera.fov import FOVCamera 8 | 9 | from splat_viewer.gaussians.workspace import Workspace 10 | from splat_viewer.gaussians import Gaussians, Rendering 11 | from splat_viewer.viewer.scene_camera import to_pyrender_camera 12 | 13 | 14 | from .mesh import make_camera_markers, make_bounding_box 15 | from .settings import Settings, ViewMode 16 | 17 | import plyfile 18 | 19 | 20 | def plyfile_to_mesh(plydata: plyfile.PlyData): 21 | vertex = plydata['vertex'] 22 | positions = torch.stack( 23 | [torch.from_numpy(vertex[i].copy()) for i in ['x', 'y', 'z']], dim=-1) 24 | 25 | colors = torch.stack( 26 | [torch.from_numpy(vertex[i]) for i in ['red', 'green', 'blue']], dim=-1) 27 | 28 | return pyrender.Mesh.from_points(positions, colors) 29 | 30 | 31 | def get_cv_colormap(cmap): 32 | colormap = np.arange(256, dtype=np.uint8).reshape(1, -1) 33 | colormap = cv2.applyColorMap(colormap, cmap) 34 | return colormap.astype(np.uint8) 35 | 36 | 37 | class PyrenderScene: 38 | 39 | def __init__(self, workspace: Workspace, gaussians: Gaussians): 40 | self.bbox_node = None 41 | self.renderer = None 42 | 43 | self.seed_points = workspace.load_seed_points() 44 | self.initial_scene = pyrender.Scene() 45 | 46 | self.points = plyfile_to_mesh(self.seed_points) 47 | self.initial_node = self.initial_scene.add(self.points, pose=np.eye(4)) 48 | 49 | self.initial_scene.ambient_light = np.array([1.0, 1.0, 1.0, 1.0]) 50 | 51 | self.cameras = make_camera_markers(workspace.cameras, workspace.camera_extent / 50.) 52 | self.initial_scene.add(self.cameras) 53 | 54 | self.update_gaussians(gaussians) 55 | 56 | def update_gaussians(self, gaussians: Gaussians): 57 | if self.bbox_node is not None: 58 | self.initial_scene.remove_node(self.bbox_node) 59 | 60 | if gaussians.instance_label is not None: 61 | bounding_boxes = make_bounding_box(gaussians) 62 | self.bbox_node = self.initial_scene.add(bounding_boxes) 63 | else: 64 | self.bbox_node = None 65 | 66 | def create_renderer(self, camera, settings: Settings): 67 | if self.renderer is None: 68 | self.renderer = pyrender.OffscreenRenderer(camera.image_size[0], camera.image_size[1], point_size=settings.point_size) 69 | else: 70 | self.renderer.viewport_width = camera.image_size[0] 71 | self.renderer.viewport_height = camera.image_size[1] 72 | self.renderer.point_size = settings.point_size 73 | 74 | return self.renderer 75 | 76 | def render(self, camera, settings: Settings): 77 | 78 | renderer = self.create_renderer(camera, settings) 79 | 80 | self.cameras.is_visible = settings.show.cameras 81 | self.points.is_visible = settings.show.initial_points 82 | 83 | if self.bbox_node is not None: 84 | self.bbox_node.mesh.is_visible = settings.show.bounding_boxes 85 | 86 | node = to_pyrender_camera(camera) 87 | scene = self.initial_scene 88 | scene.add_node(node) 89 | 90 | image, depth = renderer.render(scene) 91 | scene.remove_node(node) 92 | 93 | return image, depth 94 | 95 | 96 | @dataclass(frozen=True) 97 | class RenderState: 98 | as_points: bool = False 99 | cropped: bool = False 100 | filtered_points: bool = False 101 | color_instances: bool = False 102 | 103 | def update_setting(self, settings: Settings): 104 | return replace(self, 105 | as_points=settings.view_mode == ViewMode.Points, 106 | cropped=settings.show.cropped, 107 | filtered_points=settings.show.filtered_points, 108 | color_instances=settings.show.color_instances) 109 | 110 | def updated(self, gaussians: Gaussians) -> Gaussians: 111 | 112 | if self.as_points: 113 | 114 | alpha_logit = torch.full_like(gaussians.alpha_logit, 10.0) 115 | gaussians = gaussians.with_fixed_scale(0.001).replace(alpha_logit=alpha_logit) 116 | 117 | if self.cropped and gaussians.foreground is not None: 118 | gaussians = gaussians[gaussians.foreground.squeeze()] 119 | 120 | if self.color_instances and gaussians.instance_label is not None: 121 | 122 | instance_mask = (gaussians.instance_label >= 0).squeeze() 123 | valid_instances = gaussians.instance_label[instance_mask].squeeze().long() 124 | 125 | unique_instance_labels = torch.unique(valid_instances) 126 | color_space = (torch.randn(unique_instance_labels.shape[0], 3, device=instance_mask.device) * 2).sigmoid() 127 | 128 | gaussians = gaussians.with_colors(color_space[valid_instances], instance_mask) 129 | 130 | if self.filtered_points and gaussians.label is not None: 131 | gaussians = gaussians[gaussians.label.squeeze() > 0.3] 132 | 133 | return gaussians 134 | 135 | 136 | class WorkspaceRenderer: 137 | def __init__(self, workspace: Workspace, gaussians: Gaussians, gaussian_renderer): 138 | self.workspace = workspace 139 | 140 | self.gaussians = gaussians 141 | self.packed_gaussians = None 142 | self.render_state = RenderState() 143 | 144 | self.gaussian_renderer = gaussian_renderer 145 | 146 | self.pyrender_scene = PyrenderScene(workspace, gaussians) 147 | 148 | self.rendering = None 149 | self.color_map = torch.from_numpy(get_cv_colormap(cv2.COLORMAP_TURBO) 150 | ).squeeze(0).to(device=self.gaussians.device) 151 | 152 | def render_gaussians(self, camera, settings: Settings) -> Rendering: 153 | render_state = self.render_state.update_setting(settings) 154 | 155 | if self.packed_gaussians is None or self.render_state != render_state: 156 | self.packed_gaussians = self.gaussian_renderer.pack_inputs( 157 | render_state.updated(self.gaussians)) 158 | 159 | self.render_state = render_state 160 | return self.gaussian_renderer.render(self.packed_gaussians, camera) 161 | 162 | def update_gaussians(self, gaussians: Gaussians): 163 | self.gaussians = gaussians 164 | self.packed_gaussians = None 165 | 166 | self.pyrender_scene.update_gaussians(gaussians) 167 | 168 | def unproject_mask(self, camera: FOVCamera, 169 | mask: torch.Tensor, alpha_multiplier=1.0, threshold=1.0): 170 | return self.gaussian_renderer.unproject_mask(self.gaussians, 171 | camera, mask, alpha_multiplier, threshold) 172 | 173 | def colormap_torch(self, depth, near=0.2, far=2.0): 174 | depth = (depth - near) / (far - near) 175 | depth = 1 - depth.clamp(0, 1).sqrt() 176 | 177 | return (self.color_map[(255 * depth).to(torch.int)]) 178 | 179 | def colormap_np(self, depth, near_point=0.2): 180 | 181 | inv_depth = (near_point / depth) 182 | inv_depth = (255 * inv_depth).astype(np.uint8) 183 | return cv2.applyColorMap(inv_depth, cv2.COLORMAP_TURBO) 184 | 185 | def render(self, camera, settings: Settings): 186 | show = settings.show 187 | 188 | with torch.inference_mode(): 189 | self.rendering = self.render_gaussians(camera, settings) 190 | 191 | depth = self.rendering.depth 192 | 193 | if settings.view_mode == ViewMode.Depth: 194 | image_gaussian = self.colormap_torch(depth, near=settings.depth_near, far=settings.depth_far).to(torch.uint8).cpu().numpy() 195 | else: 196 | image_gaussian = (self.rendering.image.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() 197 | 198 | if any([show.initial_points, show.cameras, show.bounding_boxes]): 199 | 200 | image, depth = self.pyrender_scene.render(camera, settings) 201 | depth_gaussian = self.rendering.depth.cpu().numpy() 202 | depth_gaussian[depth_gaussian == 0] = np.inf 203 | 204 | mask = np.bitwise_and(depth_gaussian > depth, depth > 0) 205 | return np.where(np.expand_dims(mask, [-1]), image, image_gaussian) 206 | 207 | else: 208 | return image_gaussian 209 | -------------------------------------------------------------------------------- /splat_viewer/viewer/scene_camera.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import replace 3 | from beartype.typing import Tuple 4 | import trimesh 5 | import pyrender 6 | 7 | import numpy as np 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | from splat_viewer.camera.fov import FOVCamera, join_rt 11 | 12 | 13 | def normalize(v): 14 | return v / np.linalg.norm(v) 15 | 16 | 17 | def to_pyrender_camera(camera:FOVCamera): 18 | fx, fy = camera.focal_length 19 | cx, cy = camera.principal_point 20 | 21 | pr_camera = pyrender.IntrinsicsCamera(fx, fy, cx, cy, 22 | znear=camera.near, zfar=camera.far) 23 | 24 | rotation = camera.rotation @ flip_yz 25 | m = join_rt(rotation, camera.position) 26 | 27 | return pyrender.Node(camera=pr_camera, matrix=m) 28 | 29 | 30 | def look_at(eye, target, up=np.array([0., 0., 1.])): 31 | forward = normalize(target - eye) 32 | left = normalize(np.cross(up, forward)) 33 | true_up = np.cross(forward, left) 34 | return np.stack([left, true_up, forward]) 35 | 36 | 37 | def look_at_pose(eye, target, up=np.array([0., 0., 1.])): 38 | pose = np.eye(4) 39 | pose[:3, :3] = look_at(eye, target, up) 40 | pose[:3, 3] = eye 41 | return pose 42 | 43 | def make_sphere(pos, color, radius): 44 | sphere = trimesh.creation.icosphere(radius=radius) 45 | sphere.visual.vertex_colors = color 46 | sphere_mesh = pyrender.Mesh.from_trimesh(sphere) 47 | node = pyrender.Node(mesh=sphere_mesh, translation=pos) 48 | return node 49 | 50 | def fov_to_focal(fov, image_size): 51 | return image_size / (2 * np.tan(fov / 2)) 52 | 53 | flip_yz = np.array([ 54 | [1, 0, 0], 55 | [0, -1, 0], 56 | [0, 0, -1] 57 | ]) 58 | 59 | 60 | class SceneCamera: 61 | def __init__(self): 62 | 63 | self._camera = FOVCamera(position=np.array([0, 0, 0]), 64 | rotation=np.eye(3), 65 | focal_length=fov_to_focal(60, (640, 480)), 66 | image_size = np.array((640, 480)), 67 | image_name="viewport") 68 | 69 | 70 | def set_camera(self, camera): 71 | self._camera = deepcopy(camera) 72 | 73 | 74 | def look_at(self, pos:np.array, target:np.ndarray, up:np.ndarray=np.array([0, 0, 1])): 75 | self._camera = self._camera.with_pose(look_at_pose(pos, target, up)) 76 | 77 | 78 | def resized(self, size:Tuple[int, int]): 79 | scale_factor = max(size[1] / self._camera.image_size[1], size[0] / self._camera.image_size[0]) 80 | return self._camera.scale_size(scale_factor).pad_to(np.array(size)) 81 | 82 | @property 83 | def view_matrix(self): 84 | return self._camera.world_t_camera 85 | 86 | @property 87 | def rotation(self): 88 | return self._camera.rotation 89 | 90 | @rotation.setter 91 | def rotation(self, value): 92 | return self.set_pose(value, self.pos) 93 | 94 | @property 95 | def pos(self): 96 | return self._camera.position 97 | 98 | @pos.setter 99 | def pos(self, value): 100 | return self.set_pose(self.rotation, value) 101 | 102 | 103 | def set_pose(self, r, t): 104 | self._camera = replace(self._camera, rotation=r, position=t) 105 | 106 | 107 | def move(self, delta:np.ndarray): 108 | self.pos += self.rotation @ delta 109 | 110 | def rotate(self, ypr): 111 | m = self.view_matrix 112 | m[:3, :3] = m[:3, :3] @ R.from_euler('yxz', ypr).as_matrix() 113 | 114 | self.rotation = m[:3, :3] 115 | 116 | 117 | def zoom(self, factor): 118 | self._camera = self._camera.zoom(factor) -------------------------------------------------------------------------------- /splat_viewer/viewer/scene_widget.py: -------------------------------------------------------------------------------- 1 | from dataclasses import replace 2 | from typing import List 3 | from beartype.typing import Tuple, Optional 4 | 5 | from PySide6 import QtGui, QtCore, QtWidgets 6 | from PySide6.QtCore import Qt, QEvent 7 | from beartype import beartype 8 | import cv2 9 | 10 | import math 11 | 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | from splat_viewer.camera.visibility import visibility 16 | import torch 17 | from splat_viewer.camera.fov import FOVCamera 18 | 19 | from splat_viewer.gaussians.workspace import Workspace 20 | from splat_viewer.gaussians import Gaussians 21 | from splat_viewer.viewer.interactions.animate import animate_to_loop 22 | from splat_viewer.viewer.interaction import Interaction 23 | from splat_viewer.viewer.interactions.scribble import ScribbleGeometric 24 | from splat_viewer.viewer.renderer import WorkspaceRenderer 25 | 26 | 27 | from .interactions.fly_control import FlyControl 28 | from .scene_camera import SceneCamera 29 | from .settings import Settings, ViewMode 30 | 31 | 32 | 33 | class SceneWidget(QtWidgets.QWidget): 34 | def __init__(self, settings:Settings = Settings(), renderer=None, parent=None): 35 | super(SceneWidget, self).__init__(parent=parent) 36 | 37 | SceneWidget.instance = self 38 | 39 | self.camera_state = Interaction() 40 | self.interaction = ScribbleGeometric() 41 | 42 | self.camera = SceneCamera() 43 | self.settings = settings 44 | self.renderer = renderer 45 | 46 | self.setFocusPolicy(Qt.StrongFocus) 47 | self.setMouseTracking(True) 48 | 49 | 50 | self.cursor_pos = (0, 0) 51 | self.modifiers = Qt.NoModifier 52 | self.keys_down = set() 53 | 54 | self.dirty = True 55 | 56 | self.timer = QtCore.QTimer(self) 57 | self.timer.timeout.connect(self.update) 58 | self.timer.start(1000 / Settings.update_rate) 59 | 60 | 61 | 62 | def update_setting(self, **kwargs): 63 | self.settings = replace(self.settings, **kwargs) 64 | self.dirty = True 65 | 66 | @property 67 | def gaussians(self) -> Gaussians: 68 | return self.workspace_renderer.gaussians 69 | 70 | 71 | def median_point(self, points: List[np.ndarray]) -> int: 72 | 73 | stacked = np.vstack(points) # shape: (n, 3) 74 | median_coords = np.median(stacked, axis=0) 75 | # Compute the Euclidean distances from each point to the median coordinates 76 | dists = np.linalg.norm(stacked - median_coords, axis=1) 77 | return int(np.argmin(dists)) 78 | 79 | def load_workspace(self, workspace:Workspace, gaussians:Gaussians): 80 | self.workspace = workspace 81 | 82 | gaussians = gaussians.to(self.settings.device) 83 | if gaussians.foreground is None: 84 | foreground, depths = visibility(workspace.cameras, gaussians.position) 85 | 86 | q = torch.quantile(depths, 0.75) 87 | mask = (foreground > 0.05 * len(workspace.cameras)) & (depths < q) 88 | gaussians = gaussians.replace(foreground=mask.unsqueeze(1)) 89 | 90 | 91 | self.workspace_renderer = WorkspaceRenderer(workspace, gaussians, self.renderer) 92 | self.keypoints = self.read_keypoints() 93 | 94 | centers = [c.position for c in workspace.cameras] 95 | self.set_camera_index(self.median_point(centers)) 96 | self.camera_state = FlyControl() 97 | 98 | 99 | 100 | def update_workspace(self, gaussians:Gaussians, index:Optional[int]=None): 101 | self.load_workspace(self.workspace, gaussians) 102 | if index is not None: 103 | self.set_camera_index(index) 104 | self.show() 105 | 106 | def update_gaussians(self, gaussians:Gaussians): 107 | self.workspace_renderer.update_gaussians(gaussians.to(self.settings.device)) 108 | self.dirty = True 109 | 110 | def set_dirty(self): 111 | self.dirty = True 112 | 113 | @property 114 | def camera_path_file(self): 115 | return self.workspace.model_path / "camera_path.npy" 116 | 117 | def write_keypoints(self): 118 | np.save(self.camera_path_file, np.array(self.keypoints)) 119 | print(f"Saved {len(self.keypoints)} keypoints to {self.camera_path_file}") 120 | 121 | def read_keypoints(self): 122 | if self.camera_path_file.exists(): 123 | kp = list(np.load(self.camera_path_file)) 124 | print(f"Loaded {len(kp)} keypoints from {self.camera_path_file}") 125 | return kp 126 | 127 | return [] 128 | 129 | 130 | 131 | def set_camera_index(self, index:int): 132 | self.camera_state.transition(None) 133 | 134 | camera = self.workspace.cameras[index] 135 | print(f'Showing view from camera {index}, {camera.image_name}') 136 | self.zoom = 1.0 137 | 138 | print(camera) 139 | 140 | self.camera.set_camera(camera) 141 | self.camera_index = index 142 | self.dirty = True 143 | 144 | 145 | @property 146 | def image_size(self): 147 | w, h = self.size().width(), self.size().height() 148 | 149 | return w, h 150 | 151 | def sizeHint(self): 152 | return QtCore.QSize(1024, 768) 153 | 154 | def event(self, event: QEvent): 155 | 156 | if (self.interaction.trigger_event(event) or 157 | self.camera_state.trigger_event(event)): 158 | return True 159 | 160 | return super(SceneWidget, self).event(event) 161 | 162 | 163 | 164 | def keyReleaseEvent(self, event: QtGui.QKeyEvent) -> bool: 165 | self.modifiers = event.modifiers() 166 | self.keys_down.discard(event.key()) 167 | 168 | 169 | return super().keyPressEvent(event) 170 | 171 | def focusOutEvent(self, event: QtGui.QFocusEvent): 172 | self.keys_down.clear() 173 | return super().focusOutEvent(event) 174 | 175 | def keyPressEvent(self, event: QtGui.QKeyEvent) -> bool: 176 | self.modifiers = event.modifiers() 177 | self.keys_down.add(event.key()) 178 | 179 | view_modes = { 180 | Qt.Key_1 : ViewMode.Normal, 181 | Qt.Key_2 : ViewMode.Points, 182 | Qt.Key_3 : ViewMode.Depth, 183 | Qt.Key_4 : ViewMode.DepthVar, 184 | 185 | } 186 | 187 | enable_disable = { 188 | Qt.Key_0 : 'cropped', 189 | Qt.Key_9 : 'initial_points', 190 | Qt.Key_8 : 'cameras', 191 | Qt.Key_7 : 'bounding_boxes', 192 | Qt.Key_6 : 'filtered_points', 193 | Qt.Key_5 : 'color_instances' 194 | } 195 | 196 | if event.key() == Qt.Key_Print: 197 | self.save_snapshot() 198 | return True 199 | 200 | elif event.key() == Qt.Key_BraceLeft: 201 | self.set_camera_index((self.camera_index - 1) % len(self.workspace.cameras)) 202 | return True 203 | elif event.key() == Qt.Key_BraceRight: 204 | self.set_camera_index((self.camera_index + 1) % len(self.workspace.cameras)) 205 | return True 206 | 207 | 208 | elif event.key() == Qt.Key_Equal: 209 | self.camera.zoom(self.settings.zoom_discrete) 210 | self.dirty = True 211 | return True 212 | elif event.key() == Qt.Key_Minus: 213 | self.camera.zoom(1/self.settings.zoom_discrete) 214 | self.dirty = True 215 | return True 216 | 217 | 218 | elif event.key() == Qt.Key_O: 219 | shift = event.modifiers() & Qt.ShiftModifier 220 | self.update_setting(depth_near = self.settings.depth_near * (0.9 if shift else 1/0.9)) 221 | self.dirty = True 222 | return True 223 | elif event.key() == Qt.Key_P: 224 | shift = event.modifiers() & Qt.ShiftModifier 225 | self.update_setting(depth_far = self.settings.depth_far * (0.9 if shift else 1/0.9)) 226 | self.dirty = True 227 | return True 228 | 229 | 230 | 231 | elif event.key() in enable_disable: 232 | k = enable_disable[event.key()] 233 | update = {k: not getattr(self.settings.show, k)} 234 | 235 | self.update_setting(show = replace(self.settings.show, **update)) 236 | return True 237 | 238 | elif event.key() in view_modes: 239 | k = view_modes[event.key()] 240 | self.update_setting(view_mode = k) 241 | return True 242 | 243 | 244 | elif event.key() == Qt.Key_Space: 245 | self.keypoints.append(self.camera.view_matrix) 246 | 247 | if event.key() == Qt.Key_Space and event.modifiers() & Qt.ControlModifier: 248 | self.write_keypoints() 249 | 250 | elif event.key() == Qt.Key_Return: 251 | if event.modifiers() & Qt.ShiftModifier: 252 | if self.window().isFullScreen(): 253 | self.window().showNormal() 254 | else: 255 | self.window().showFullScreen() 256 | elif len(self.keypoints) > 0: 257 | animate_to_loop(self.camera_state, 258 | self.camera.view_matrix, self.keypoints) 259 | 260 | return super().keyPressEvent(event) 261 | 262 | 263 | def update(self): 264 | self.camera_state._update(1 / self.settings.update_rate) 265 | self.repaint() 266 | 267 | def resizeEvent(self, event: QtGui.QResizeEvent): 268 | self.dirty = True 269 | return super().resizeEvent(event) 270 | 271 | def mouseMoveEvent(self, event: QtGui.QMouseEvent): 272 | p = event.localPos() 273 | self.cursor_pos = (p.x(), p.y()) 274 | return super().mouseMoveEvent(event) 275 | 276 | def current_point_3d(self) -> np.ndarray: 277 | return self.lookup_point_3d(self.cursor_pos) 278 | 279 | @property 280 | def rendering(self): 281 | if self.workspace_renderer.rendering is None: 282 | raise ValueError("No depth render available") 283 | 284 | return self.workspace_renderer.rendering 285 | 286 | def unproject_point(self, p:np.ndarray, depth:float) -> np.ndarray: 287 | render = self.rendering 288 | camera:FOVCamera = render.camera 289 | scene_point = camera.unproject_pixel(*p, depth) 290 | return np.array([scene_point], dtype=np.float32) 291 | 292 | 293 | def unproject_radius(self, p:np.ndarray, depth:float, radius:float 294 | ) -> Tuple[np.ndarray, float]: 295 | p = np.array(p) 296 | 297 | p1 = self.unproject_point(p, depth) 298 | p2 = self.unproject_point(p + np.array([radius, 0]), depth) 299 | 300 | return p1, np.linalg.norm(p2 - p1) 301 | 302 | 303 | def lookup_depth(self, p:Tuple[int, int]) -> np.ndarray: 304 | render = self.rendering 305 | 306 | p = np.round(p).astype(np.int32) 307 | x = np.clip(p[0], 0, render.depth.shape[1] - 1) 308 | y = np.clip(p[1], 0, render.depth.shape[0] - 1) 309 | 310 | return render.depth[y, x].item() 311 | 312 | def from_numpy(self, a:np.ndarray): 313 | return torch.from_numpy(a).to(device=self.settings.device) 314 | 315 | @beartype 316 | def lookup_depths(self, p:np.ndarray) -> np.ndarray: 317 | render = self.rendering 318 | p = np.round(p).astype(np.int32) 319 | 320 | x = np.clip(p[:, 0], 0, render.depth.shape[1] - 1) 321 | y = np.clip(p[:, 1], 0, render.depth.shape[0] - 1) 322 | 323 | x, y = self.from_numpy(x), self.from_numpy(y) 324 | return render.depth[y, x].cpu().numpy() 325 | 326 | @beartype 327 | def test_depths(self, p:np.ndarray, depth:np.ndarray, tol=0.98) -> np.ndarray: 328 | 329 | return ((depth * tol <= self.lookup_depths(p)) & 330 | (p[:, 0] >= 0) & (p[:, 0] < self.image_size[0] - 1) & 331 | (p[:, 1] >= 0) & (p[:, 1] < self.image_size[1] - 1)) 332 | 333 | 334 | @property 335 | def depth_map(self) -> torch.Tensor: 336 | render = self.rendering 337 | return render.depth 338 | 339 | def lookup_point_3d(self, p:np.ndarray) -> np.ndarray: 340 | render = self.rendering 341 | scene_point = render.camera.unproject_pixel(*p, self.lookup_depth(p)) 342 | return np.array([scene_point], dtype=np.float32) 343 | 344 | 345 | def render_camera(self) -> FOVCamera: 346 | return self.camera.resized(self.image_size) 347 | 348 | def render(self): 349 | camera = self.render_camera() 350 | 351 | self.view_image = np.ascontiguousarray( 352 | self.workspace_renderer.render(camera, self.settings)) 353 | 354 | self.dirty = False 355 | return self.view_image 356 | 357 | 358 | def paintEvent(self, event: QtGui.QPaintEvent): 359 | with QtGui.QPainter(self) as painter: 360 | dirty = self.dirty 361 | if dirty: 362 | self.render() 363 | 364 | image = QtGui.QImage(self.view_image.data, 365 | self.view_image.shape[1], self.view_image.shape[0], 366 | self.view_image.strides[0], 367 | QtGui.QImage.Format_RGB888) 368 | 369 | 370 | painter.drawImage(0, 0, image) 371 | 372 | self.interaction.paintEvent(event, dirty) 373 | 374 | 375 | def snapshot_file(self): 376 | pictures = Path.home() / "Pictures" 377 | filename = pictures / "snapshot.jpg" 378 | 379 | i = 0 380 | while filename.exists(): 381 | i += 1 382 | filename = pictures / f"snapshot_{i}.jpg" 383 | 384 | return filename 385 | 386 | 387 | def render_tiled(self, camera:FOVCamera): 388 | tile_size = self.settings.snapshot_tile 389 | nw, nh = [int(math.ceil(x / tile_size)) 390 | for x in camera.image_size] 391 | 392 | full_image = np.zeros((nh * tile_size, nw * tile_size, 3), dtype=np.uint8) 393 | 394 | for x in range(0, nw): 395 | for y in range(0, nh): 396 | tile_camera = camera.crop_offset_size(np.array([x * tile_size, y * tile_size]), 397 | np.array([tile_size, tile_size])) 398 | 399 | image = self.workspace_renderer.render(tile_camera, self.settings) 400 | tile = full_image[y * tile_size:(y + 1) * tile_size, 401 | x * tile_size:(x + 1) * tile_size, :] 402 | 403 | print(tile.shape, image.shape, x, y) 404 | tile[:] = image 405 | 406 | return full_image[:camera.image_size[1], :camera.image_size[0]] 407 | 408 | def save_snapshot(self): 409 | camera = self.camera.resized(self.settings.snapshot_size) 410 | filename = self.snapshot_file() 411 | 412 | w, h = camera.image_size 413 | print(f"Rendering snapshot ({w}x{h})...") 414 | print(camera) 415 | 416 | image = self.render_tiled(camera) 417 | 418 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 419 | cv2.imwrite(str(filename), image, [cv2.IMWRITE_JPEG_QUALITY, 92]) 420 | 421 | print(f"Saved to {filename}") 422 | 423 | 424 | 425 | def move_camera(self, delta:np.ndarray): 426 | self.camera.move(delta) 427 | self.dirty = True 428 | 429 | def rotate_camera(self, delta:np.ndarray): 430 | self.camera.rotate(delta) 431 | self.dirty = True 432 | 433 | def set_camera_pose(self, r:np.ndarray, t:np.ndarray): 434 | self.camera.set_pose(r, t) 435 | self.dirty = True 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | -------------------------------------------------------------------------------- /splat_viewer/viewer/settings.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from beartype.typing import Tuple 4 | 5 | 6 | class ViewMode(Enum): 7 | Normal = 0 8 | Depth = 1 9 | Points = 2 10 | Hidden = 3 11 | DepthVar = 4 12 | 13 | @dataclass(frozen=True) 14 | class Show: 15 | initial_points: bool = False 16 | cameras: bool = False 17 | cropped : bool = False 18 | bounding_boxes: bool = False 19 | filtered_points: bool = False 20 | color_instances: bool = False 21 | 22 | 23 | 24 | @dataclass(frozen=True) 25 | class Settings: 26 | update_rate : int = 20 27 | move_speed : float = 1.0 28 | 29 | transition_time : float = 0.5 30 | rotate_speed : float = 1.0 31 | 32 | animate_speed : float = 0.5 33 | animate_pausing: float = 0.4 34 | 35 | zoom_discrete : float = 1.2 36 | zoom_continuous : float = 0.1 37 | 38 | drag_speed : float = 1.0 39 | point_size : float = 2.0 40 | 41 | snapshot_size: Tuple[int, int] = (8192, 6144) 42 | snapshot_tile: int = 1024 43 | 44 | 45 | depth_near: float = 0.2 46 | depth_far: float = 2.0 47 | 48 | 49 | device : str = 'cuda:0' 50 | bg_color : Tuple[float, float, float] = (1, 1, 1) 51 | 52 | show : Show = Show() 53 | view_mode : ViewMode = ViewMode.Normal 54 | 55 | brush_size : int = 10 56 | 57 | 58 | -------------------------------------------------------------------------------- /splat_viewer/viewer/viewer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import signal 4 | 5 | from PySide6 import QtWidgets 6 | from PySide6 import QtCore 7 | from torch.multiprocessing import Process, Queue, get_start_method 8 | 9 | from splat_viewer.gaussians import Gaussians 10 | from splat_viewer.gaussians.workspace import Workspace 11 | from splat_viewer.renderer.taichi_splatting import GaussianRenderer 12 | from splat_viewer.viewer.scene_widget import SceneWidget, Settings 13 | 14 | from taichi_splatting import TaichiQueue 15 | 16 | 17 | def init_viewer(workspace:Workspace, 18 | gaussians:Gaussians = None, 19 | settings:Settings=Settings()): 20 | 21 | app = QtWidgets.QApplication.instance() 22 | if app is None: 23 | app = QtWidgets.QApplication(["viewer"]) 24 | 25 | 26 | widget = SceneWidget(settings=settings, 27 | renderer = GaussianRenderer()) 28 | 29 | if gaussians is None: 30 | gaussians = workspace.load_model(workspace.latest_iteration()) 31 | widget.load_workspace(workspace, gaussians) 32 | 33 | 34 | widget.show() 35 | return app, widget 36 | 37 | def show_workspace(workspace:Workspace, 38 | gaussians:Gaussians = None, 39 | settings:Settings=Settings()): 40 | 41 | import taichi as ti 42 | TaichiQueue.init(ti.gpu, offline_cache=True, device_memory_GB=0.1) 43 | 44 | 45 | from splat_viewer.viewer.viewer import sigint_handler 46 | signal.signal(signal.SIGINT, sigint_handler) 47 | 48 | print(f"Showing model from {workspace.model_path}: {gaussians}") 49 | app, _ = init_viewer(workspace, gaussians, settings) 50 | app.exec() 51 | 52 | 53 | def sigint_handler(*args): 54 | QtWidgets.QApplication.quit() 55 | 56 | 57 | def run_process(workspace:Workspace, 58 | update_queue:Queue, 59 | 60 | gaussians:Gaussians = None, 61 | settings:Settings=Settings()): 62 | 63 | import taichi as ti 64 | TaichiQueue.init(ti.gpu, offline_cache=True, device_memory_GB=0.1) 65 | 66 | from splat_viewer.viewer.viewer import sigint_handler 67 | signal.signal(signal.SIGINT, sigint_handler) 68 | 69 | app, widget = init_viewer(workspace, gaussians, settings) 70 | 71 | def on_timer(): 72 | if not update_queue.empty(): 73 | update = update_queue.get() 74 | 75 | if update is None: 76 | app.quit() 77 | return 78 | 79 | if isinstance(update, Gaussians): 80 | widget.update_gaussians(update) 81 | elif isinstance(update, dict): 82 | widget.update_gaussians(widget.gaussians.replace(**update)) 83 | else: 84 | raise TypeError(f"Unknown type of update: {type(update)}") 85 | 86 | timer = QtCore.QTimer(widget) 87 | timer.timeout.connect(on_timer) 88 | 89 | timer.start(10) 90 | 91 | app.exec() 92 | 93 | 94 | class Viewer: 95 | def __init__(self): 96 | pass 97 | 98 | def quit(self): 99 | pass 100 | 101 | def update_gaussians(self, gaussians:Gaussians): 102 | pass 103 | 104 | def __enter__(self): 105 | return self 106 | 107 | def __exit__(self, exc_type, exc_value, traceback): 108 | pass 109 | 110 | def start(self): 111 | pass 112 | 113 | def close(self): 114 | pass 115 | 116 | class ViewerProcess(Viewer): 117 | def __init__(self, workspace:Workspace, 118 | gaussians:Gaussians, 119 | settings:Settings, 120 | queue_size=1): 121 | 122 | assert get_start_method() == "spawn", "For ViewerProcess, torch multiprocessing must be started with spawn" 123 | self.update_queue = Queue(queue_size) 124 | self.view_process = Process(target=run_process, 125 | args=(workspace, self.update_queue, gaussians, settings)) 126 | 127 | 128 | def quit(self): 129 | self.update_queue.put(None) 130 | self.join() 131 | 132 | def update_gaussians(self, gaussians:Gaussians): 133 | self.update_queue.put(gaussians) 134 | 135 | def __enter__(self): 136 | self.start() 137 | return self 138 | 139 | def __exit__(self, exc_type, exc_value, traceback): 140 | self.join() 141 | 142 | def start(self): 143 | self.view_process.start() 144 | 145 | def join(self): 146 | while not self.update_queue.empty(): 147 | pass 148 | 149 | self.view_process.join() 150 | self.update_queue.close() 151 | 152 | 153 | 154 | --------------------------------------------------------------------------------