├── .gitignore ├── LICENSE ├── README.md ├── docs └── demo.gif ├── helpers ├── PointModel.py └── SH_helper.py ├── interpolate.py ├── requirements.txt ├── requirements_visualizer.txt └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | # Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | # poetry.lock 109 | # poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | # pdm.lock 116 | # pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | # pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # Redis 135 | *.rdb 136 | *.aof 137 | *.pid 138 | 139 | # RabbitMQ 140 | mnesia/ 141 | rabbitmq/ 142 | rabbitmq-data/ 143 | 144 | # ActiveMQ 145 | activemq-data/ 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # Environments 151 | .env 152 | .envrc 153 | .venv 154 | env/ 155 | venv/ 156 | ENV/ 157 | env.bak/ 158 | venv.bak/ 159 | 160 | # Spyder project settings 161 | .spyderproject 162 | .spyproject 163 | 164 | # Rope project settings 165 | .ropeproject 166 | 167 | # mkdocs documentation 168 | /site 169 | 170 | # mypy 171 | .mypy_cache/ 172 | .dmypy.json 173 | dmypy.json 174 | 175 | # Pyre type checker 176 | .pyre/ 177 | 178 | # pytype static type analyzer 179 | .pytype/ 180 | 181 | # Cython debug symbols 182 | cython_debug/ 183 | 184 | # PyCharm 185 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 186 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 187 | # and can be added to the global gitignore or merged into this file. For a more nuclear 188 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 189 | # .idea/ 190 | 191 | # Abstra 192 | # Abstra is an AI-powered process automation framework. 193 | # Ignore directories containing user credentials, local state, and settings. 194 | # Learn more at https://abstra.io/docs 195 | .abstra/ 196 | 197 | # Visual Studio Code 198 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 199 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 200 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 201 | # you could uncomment the following to ignore the entire vscode folder 202 | # .vscode/ 203 | 204 | # Ruff stuff: 205 | .ruff_cache/ 206 | 207 | # PyPI configuration file 208 | .pypirc 209 | 210 | # Marimo 211 | marimo/_static/ 212 | marimo/_lsp/ 213 | __marimo__/ 214 | 215 | # Streamlit 216 | .streamlit/secrets.toml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 feel3x 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gaussian Splatting Morphing Tool 2 | 3 | ![Gaussian Morphing Demo](docs/demo.gif) 4 | 5 | A **CLI and visualization tool** for smoothly interpolating between two or more **Gaussian Splatting Models**. 6 | It builds intelligent point correspondences and generates intermediate morphs that transition seamlessly. 7 | Includes both command-line and real-time interactive visualizer modes. 8 | 9 | --- 10 | 11 | ## 🚀 Features 12 | 13 | - **Interpolate between multiple Gaussian Splatting models** 14 | - **Automatic one-to-one point correspondences** based on spatial and color similarity 15 | - **Spherical linear interpolation (SLERP)** for rotations 16 | - **Optional real-time visualizer** with interactive morphing slider 17 | - **GPU-accelerated (PyTorch)** processing 18 | - **Save interpolated frames** as `.ply` files for animation sequences 19 | 20 | --- 21 | 22 | ## 🧰 Installation 23 | 24 | Clone the repository and install the base dependencies: 25 | 26 | ```bash 27 | git clone https://github.com/feel3x/Gaussian_Splat_Morpher.git 28 | cd Gaussian_Splat_Morpher 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | If you want to use the **visualizer**, install the extra dependencies: 33 | 34 | ```bash 35 | pip install -r requirements_visualizer.txt 36 | ``` 37 | 38 | --- 39 | 40 | ## 🧑‍💻 Usage (CLI Mode) 41 | 42 | You can interpolate between `.ply` Gaussian Splat models directly from the command line. 43 | 44 | ### Basic Example 45 | ```bash 46 | python GaussianInterpolator.py -d ./models/ -o ./output/ 47 | ``` 48 | 49 | This searches the `./models/` folder for `.ply` files and generates intermediate models in `./output/`. 50 | 51 | --- 52 | 53 | ### Specify Individual Models 54 | ```bash 55 | python GaussianInterpolator.py -m model1.ply model2.ply model3.ply -o ./output/ 56 | ``` 57 | 58 | --- 59 | 60 | ### Adjust Interpolation Settings 61 | 62 | | Option | Description | Default | 63 | |--------|--------------|----------| 64 | | `--models_to_create` | Number of intermediate models per pair | `10` | 65 | | `--direct_interpolation_value` | Create a single interpolated model between two specific models (e.g. `1.3`) | `None` | 66 | | `--spatial_weight` | Weight for spatial distance during correspondence | `0.7` | 67 | | `--color_weight` | Weight for color difference during correspondence | `0.3` | 68 | | `--distance_threshold` | Max allowed distance for point matching | `None` | 69 | | `--batch_size` | Size of point batches for matching (reduce for lower VRAM) | `512` | 70 | | `--recenter_models` | Recenter all models before interpolation | `False` | 71 | | `--normalize_scales` | Normalize scales of models | `False` | 72 | 73 | --- 74 | 75 | ### Example: Generate 20 Intermediate Morphs 76 | 77 | ```bash 78 | python GaussianInterpolator.py -m bunny_A.ply bunny_B.ply -o ./morph_output --models_to_create 20 79 | ``` 80 | 81 | --- 82 | 83 | ### Example: Export a Single Interpolated Model at 0.4 84 | 85 | ```bash 86 | python GaussianInterpolator.py -m face_1.ply face_2.ply -o ./output --direct_interpolation_value 0.4 87 | ``` 88 | 89 | This saves one `.ply` at 40% interpolation between the first and second models. 90 | 91 | --- 92 | 93 | ## 🎨 Real-Time Visualizer (Optional) 94 | 95 | To explore morphs interactively: 96 | 97 | ```bash 98 | python visualizer.py -m model1.ply model2.ply 99 | ``` 100 | 101 | This launches a GUI with a **slider** that lets you morph smoothly between loaded Gaussian Splat models in real time. 102 | 103 | *(Visualizer requires extra dependencies from `requirements_visualizer.txt`.)* 104 | 105 | 106 | --- 107 | 108 | ## 🧠 How It Works 109 | 110 | 1. Loads two or more Gaussian Splat models (`.ply` format) 111 | 2. Builds **point correspondences** between consecutive models 112 | 3. Interpolates: 113 | - Positions, features, scales, and opacities (linearly) 114 | - Rotations using **SLERP** 115 | 4. Handles unmatched points via fade-in/out blending 116 | 5. Outputs intermediate `.ply` models or visualizes them in real time 117 | 118 | --- 119 | 120 | ## 📜 License 121 | 122 | This project is released under the **MIT License**. 123 | See the [LICENSE](./LICENSE) file for details. 124 | 125 | **Author:** Felix Hirt 126 | **Copyright © 2025** 127 | 128 | --- 129 | 130 | ## 🌟 Acknowledgements 131 | 132 | This project builds upon Gaussian Splatting concepts. 133 | 134 | The visualizer uses Nerfstudio's gSplat rasterizer: [GitHub](https://github.com/nerfstudio-project/gsplat) 135 | 136 | --- 137 | 138 | **Enjoy morphing! 🧩** 139 | -------------------------------------------------------------------------------- /docs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feel3x/Gaussian_Splat_Morpher/872fe19355116728f5f1a13f353e245a1eb5a5ba/docs/demo.gif -------------------------------------------------------------------------------- /helpers/PointModel.py: -------------------------------------------------------------------------------- 1 | # Gaussian Splat Decimation Tool 2 | # Author: Felix Hirt 3 | # License: MIT License (see LICENSE file for details) 4 | 5 | # Note: 6 | # This file contains original code by Felix Hirt, licensed under MIT. 7 | 8 | 9 | import torch 10 | import numpy as np 11 | from plyfile import PlyData, PlyElement 12 | import os 13 | 14 | 15 | class PointModel: 16 | def __init__(self, sh_degree: int = None): 17 | self.max_sh_degree = sh_degree 18 | self._xyz = torch.empty(0) 19 | self._features_dc = torch.empty(0) 20 | self._features_rest = torch.empty(0) 21 | self._scaling = torch.empty(0) 22 | self._rotation = torch.empty(0) 23 | self._opacity = torch.empty(0) 24 | 25 | # activation functions 26 | self.scaling_activation = torch.exp 27 | self.scaling_inverse_activation = torch.log 28 | self.rotation_activation = torch.nn.functional.normalize 29 | self.opacity_activation = torch.sigmoid 30 | def inverse_sigmoid(x): 31 | return torch.log(x/(1-x)) 32 | self.inverse_opacity_activation = inverse_sigmoid 33 | 34 | @property 35 | def get_scaling(self): 36 | return self.scaling_activation(self._scaling) 37 | 38 | @property 39 | def get_rotation(self): 40 | return self.rotation_activation(self._rotation) 41 | 42 | @property 43 | def get_opacity(self): 44 | return self.opacity_activation(self._opacity) 45 | 46 | def save_ply(self, path: str): 47 | """Save model parameters to a PLY file (matching original format).""" 48 | os.makedirs(os.path.dirname(path), exist_ok=True) 49 | 50 | xyz = self._xyz.detach().cpu().numpy() 51 | normals = np.zeros_like(xyz) 52 | f_dc = ( 53 | self._features_dc.detach() 54 | .transpose(1, 2) 55 | .flatten(start_dim=1) 56 | .contiguous() 57 | .cpu() 58 | .numpy() 59 | ) 60 | f_rest = ( 61 | self._features_rest.detach() 62 | .transpose(1, 2) 63 | .flatten(start_dim=1) 64 | .contiguous() 65 | .cpu() 66 | .numpy() 67 | ) 68 | opacities = self._opacity.detach().cpu().numpy() 69 | scale = self._scaling.detach().cpu().numpy() 70 | rotation = self._rotation.detach().cpu().numpy() 71 | 72 | dtype_full = [(attr, "f4") for attr in self.construct_save_list()] 73 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 74 | 75 | attributes = np.concatenate( 76 | (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1 77 | ) 78 | elements[:] = list(map(tuple, attributes)) 79 | el = PlyElement.describe(elements, "vertex") 80 | PlyData([el]).write(path) 81 | 82 | def construct_save_list(self): 83 | """Generate attribute names in the same order as save_ply concatenation.""" 84 | attributes = ["x", "y", "z"] 85 | attributes += ["nx", "ny", "nz"] # normals 86 | # f_dc and f_rest: flatten over channels 87 | n_dc = self._features_dc.shape[1] * self._features_dc.shape[2] 88 | n_rest = self._features_rest.shape[1] * self._features_rest.shape[2] 89 | attributes += [f"f_dc_{i}" for i in range(n_dc)] 90 | attributes += [f"f_rest_{i}" for i in range(n_rest)] 91 | attributes += ["opacity"] 92 | attributes += [f"scale_{i}" for i in range(self._scaling.shape[1])] 93 | attributes += [f"rot_{i}" for i in range(self._rotation.shape[1])] 94 | return attributes 95 | 96 | def load_ply(self, path): 97 | plydata = PlyData.read(path) 98 | 99 | if(self.max_sh_degree == None): 100 | self.max_sh_degree = self.get_sh_bands_from_plydata(plydata) 101 | 102 | vertex = plydata.elements[0] 103 | 104 | # xyz 105 | xyz = np.stack( 106 | [np.asarray(vertex["x"]), np.asarray(vertex["y"]), np.asarray(vertex["z"])], 107 | axis=1, 108 | ) 109 | 110 | # opacity 111 | opacities = np.asarray(vertex["opacity"])[..., np.newaxis] 112 | 113 | # features_dc (first 3 SH coefficients) 114 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 115 | features_dc[:, 0, 0] = np.asarray(vertex["f_dc_0"]) 116 | features_dc[:, 1, 0] = np.asarray(vertex["f_dc_1"]) 117 | features_dc[:, 2, 0] = np.asarray(vertex["f_dc_2"]) 118 | 119 | # features_rest 120 | extra_f_names = [p.name for p in vertex.properties if p.name.startswith("f_rest_")] 121 | extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) 122 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 123 | 124 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 125 | for idx, attr_name in enumerate(extra_f_names): 126 | features_extra[:, idx] = np.asarray(vertex[attr_name]) 127 | features_extra = features_extra.reshape( 128 | (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1) 129 | ) 130 | 131 | # scaling 132 | scale_names = [p.name for p in vertex.properties if p.name.startswith("scale_")] 133 | scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) 134 | scales = np.zeros((xyz.shape[0], len(scale_names))) 135 | for idx, attr_name in enumerate(scale_names): 136 | scales[:, idx] = np.asarray(vertex[attr_name]) 137 | 138 | # rotation 139 | rot_names = [p.name for p in vertex.properties if p.name.startswith("rot")] 140 | rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) 141 | rots = np.zeros((xyz.shape[0], len(rot_names))) 142 | for idx, attr_name in enumerate(rot_names): 143 | rots[:, idx] = np.asarray(vertex[attr_name]) 144 | 145 | # Convert to Parameters 146 | device = "cuda" if torch.cuda.is_available() else "cpu" 147 | self._xyz = torch.tensor(xyz, dtype=torch.float, device=device) 148 | self._features_dc = torch.tensor(features_dc, dtype=torch.float, device=device).transpose(1, 2).contiguous() 149 | self._features_rest = torch.tensor(features_extra, dtype=torch.float, device=device).transpose(1, 2).contiguous() 150 | self._opacity = torch.tensor(opacities, dtype=torch.float, device=device) 151 | self._scaling = torch.tensor(scales, dtype=torch.float, device=device) 152 | self._rotation = torch.tensor(rots, dtype=torch.float, device=device) 153 | 154 | self.active_sh_degree = self.max_sh_degree 155 | 156 | def get_sh_bands_from_plydata(self, plydata): 157 | """ 158 | Returns the number of spherical harmonics (SH) bands in a Gaussian Splatting .PLY file. 159 | """ 160 | try: 161 | #Get the vertex element 162 | if 'vertex' not in plydata: 163 | raise ValueError("PLY file does not contain vertex data") 164 | 165 | vertex = plydata['vertex'] 166 | 167 | #Count SH coefficient properties 168 | sh_properties = [] 169 | 170 | # Handle different PLY file structures 171 | if hasattr(vertex, 'dtype') and hasattr(vertex.dtype, 'names') and vertex.dtype.names: 172 | property_names = vertex.dtype.names 173 | elif hasattr(vertex, 'data') and len(vertex.data) > 0: 174 | # Try to get property names from the first data element 175 | property_names = vertex.data[0].dtype.names if hasattr(vertex.data[0], 'dtype') else [] 176 | elif hasattr(vertex, 'properties'): 177 | # Alternative: get from properties if available 178 | property_names = [prop.name for prop in vertex.properties] 179 | else: 180 | raise ValueError("Cannot determine property names from PLY vertex data") 181 | 182 | # Look for SH-related properties 183 | for prop_name in property_names: 184 | if prop_name.startswith('f_dc_') or prop_name.startswith('f_rest_'): 185 | sh_properties.append(prop_name) 186 | 187 | if not sh_properties: 188 | #No SH coefficients found 189 | return 0 190 | 191 | #Count DC components (band 0) 192 | dc_count = len([name for name in sh_properties if name.startswith('f_dc_')]) 193 | 194 | #Count rest components (bands 1+) 195 | rest_count = len([name for name in sh_properties if name.startswith('f_rest_')]) 196 | 197 | #Total SH coefficients 198 | total_sh_coeffs = dc_count + rest_count 199 | 200 | #Calculate number of bands 201 | #3 color channels (RGB) 202 | if total_sh_coeffs % 3 != 0: 203 | raise ValueError(f"Invalid number of SH coefficients: {total_sh_coeffs} (not divisible by 3)") 204 | 205 | coeffs_per_channel = total_sh_coeffs // 3 206 | 207 | #Find the number of bands 208 | #coeffs_per_channel = (max_band + 1)^2 209 | #So max_band = sqrt(coeffs_per_channel) - 1 210 | max_band = int(np.sqrt(coeffs_per_channel)) - 1 211 | 212 | # Verify 213 | expected_coeffs = (max_band + 1) ** 2 214 | if expected_coeffs != coeffs_per_channel: 215 | raise ValueError(f"Invalid SH coefficient count: {coeffs_per_channel} per channel doesn't match any valid band configuration") 216 | 217 | #print("SH Degree: "+ str(max_band)) 218 | return max_band # Return number of bands (0-indexed max_band + 1) 219 | 220 | except Exception as e: 221 | raise ValueError(f"Error reading PLY file: {str(e)}") 222 | 223 | def recenter_point_cloud(model): 224 | """ 225 | Recenter the point cloud to have its centroid at the origin. 226 | """ 227 | with torch.no_grad(): 228 | # Calculate centroid 229 | centroid = model._xyz.mean(dim=0, keepdim=True) 230 | 231 | # Recenter points 232 | model._xyz -= centroid 233 | 234 | return centroid.squeeze() 235 | 236 | 237 | def normalize_scale(model, target_scale=1.0): 238 | """ 239 | Normalize the scale of the point cloud based on its extent. 240 | """ 241 | with torch.no_grad(): 242 | # Calculate bounding box extent 243 | min_coords = model._xyz.min(dim=0)[0] 244 | max_coords = model._xyz.max(dim=0)[0] 245 | extent = max_coords - min_coords 246 | 247 | # Find maximum extent across all dimensions 248 | max_extent = extent.max() 249 | 250 | # Calculate scale factor 251 | scale_factor = target_scale / max_extent 252 | 253 | # Scale the positions 254 | model._xyz *= scale_factor 255 | 256 | # Scale the Gaussian splat scales 257 | model._scaling += torch.log(torch.tensor(scale_factor, 258 | device=model._scaling.device, 259 | dtype=model._scaling.dtype)) 260 | 261 | return scale_factor.item() 262 | 263 | -------------------------------------------------------------------------------- /helpers/SH_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from helpers.PointModel import PointModel 5 | 6 | def green_spill_in_SH(model: PointModel, threshold: float = 0.2, 7 | view_dirs: np.ndarray = None, num_sample_views: int = 8) -> np.ndarray: 8 | """ 9 | Detect points with green-dominant SH coefficients (view-dependent spill). 10 | 11 | Args: 12 | model: PointModel with SH coefficients 13 | threshold: Relative threshold for green dominance 14 | view_dirs: Optional specific viewing directions to check (N_views, 3) 15 | num_sample_views: Number of random viewing directions to sample if view_dirs not provided 16 | 17 | Returns: 18 | Boolean array indicating green spill points 19 | """ 20 | 21 | if model._features_rest.numel() == 0: 22 | return np.zeros(model._features_dc.shape[0], dtype=bool) 23 | 24 | with torch.no_grad(): 25 | # Get DC and SH coefficients 26 | dc = model._features_dc.cpu().numpy() # (N, 3) - base color 27 | rest = model._features_rest.cpu().numpy() # (N, SH_coeffs, 3) 28 | 29 | N = dc.shape[0] 30 | 31 | # Method 1: Analyze DC component for obvious green bias 32 | dc_green_mask = detect_dc_green_bias(dc, threshold) 33 | 34 | # Method 2: Sample multiple viewing directions to detect view-dependent green spill 35 | if view_dirs is None: 36 | # Generate sample viewing directions (roughly uniform on sphere) 37 | view_dirs = generate_sample_views(num_sample_views) 38 | 39 | view_dependent_mask = detect_view_dependent_green_spill( 40 | dc, rest, view_dirs, threshold 41 | ) 42 | 43 | # Method 3: Check for anomalous SH coefficient patterns 44 | sh_pattern_mask = detect_anomalous_sh_patterns(rest, threshold) 45 | 46 | # Combine all detection methods 47 | combined_mask = dc_green_mask | view_dependent_mask | sh_pattern_mask 48 | 49 | return combined_mask 50 | 51 | def detect_dc_green_bias(dc: np.ndarray, threshold: float) -> np.ndarray: 52 | """Detect obvious green bias in DC (base) color component.""" 53 | # Handle different DC shapes 54 | if dc.shape[1] == 1: 55 | # Single channel - can't detect green spill 56 | return np.zeros(dc.shape[0], dtype=bool) 57 | elif dc.shape[1] == 3: 58 | # RGB channels 59 | r, g, b = dc[:, 0], dc[:, 1], dc[:, 2] 60 | else: 61 | # Unexpected format - skip DC analysis 62 | return np.zeros(dc.shape[0], dtype=bool) 63 | 64 | # Method 1: Ratio-based for brighter colors 65 | brightness = r + g + b 66 | bright_mask = brightness > 0.1 # Only use ratios for reasonably bright pixels 67 | 68 | ratio_green_dominant = np.zeros_like(bright_mask) 69 | ratio_green_dominant[bright_mask] = ( 70 | (g[bright_mask] > r[bright_mask] * (1 + threshold)) & 71 | (g[bright_mask] > b[bright_mask] * (1 + threshold)) 72 | ) 73 | 74 | # Method 2: Absolute difference for darker colors 75 | dark_mask = brightness <= 0.1 76 | min_green_diff = threshold * 0.05 # Minimum absolute green difference 77 | 78 | abs_green_dominant = np.zeros_like(dark_mask) 79 | if np.any(dark_mask): 80 | g_diff_r = g[dark_mask] - r[dark_mask] 81 | g_diff_b = g[dark_mask] - b[dark_mask] 82 | abs_green_dominant[dark_mask] = ( 83 | (g_diff_r > min_green_diff) & (g_diff_b > min_green_diff) 84 | ) 85 | 86 | # Method 3: Statistical outlier detection for very subtle spill 87 | # Convert to LAB color space approximation for better perceptual analysis 88 | lab_green_bias = detect_perceptual_green_bias(r, g, b, threshold) 89 | 90 | # Method 4: Check for green shift in darker regions 91 | # Dark green spill often shows as green being the dominant channel even in low light 92 | dark_green_dominant = dark_mask & (g > r) & (g > b) & (g > threshold * 0.02) 93 | 94 | return ratio_green_dominant | abs_green_dominant | lab_green_bias | dark_green_dominant 95 | 96 | def detect_perceptual_green_bias(r: np.ndarray, g: np.ndarray, b: np.ndarray, threshold: float) -> np.ndarray: 97 | """Detect green bias using perceptual color differences.""" 98 | # Simple RGB to approximate LAB conversion for better perceptual analysis 99 | # This helps detect green tints that might not be obvious in RGB space 100 | 101 | # Normalize to prevent issues with very dark colors 102 | rgb_sum = r + g + b + 1e-8 103 | r_norm = r / rgb_sum 104 | g_norm = g / rgb_sum 105 | b_norm = b / rgb_sum 106 | 107 | # Expected neutral would be roughly equal proportions 108 | expected_prop = 1/3 109 | 110 | # Check if green proportion is significantly higher than expected 111 | green_excess = g_norm - expected_prop 112 | other_deficit = (expected_prop - r_norm) + (expected_prop - b_norm) 113 | 114 | # Green bias detected if green is significantly over-represented 115 | perceptual_bias = (green_excess > threshold * 0.1) & (other_deficit > threshold * 0.05) 116 | 117 | return perceptual_bias 118 | 119 | def detect_view_dependent_green_spill(dc: np.ndarray, rest: np.ndarray, 120 | view_dirs: np.ndarray, threshold: float) -> np.ndarray: 121 | """Detect green spill that appears in specific viewing directions.""" 122 | N = dc.shape[0] 123 | spill_detected = np.zeros(N, dtype=bool) 124 | 125 | # Skip if DC doesn't have RGB channels 126 | if dc.shape[1] != 3: 127 | return spill_detected 128 | 129 | for view_dir in view_dirs: 130 | # Evaluate SH for this viewing direction 131 | rendered_color = evaluate_sh_at_direction(dc, rest, view_dir) 132 | 133 | # Check for green spill in this view 134 | r, g, b = rendered_color[:, 0], rendered_color[:, 1], rendered_color[:, 2] 135 | 136 | # Detect green spill in this specific view 137 | view_green_spill = (g > r * (1 + threshold)) & (g > b * (1 + threshold)) 138 | 139 | # Also check for color shift towards green compared to base color 140 | base_r, base_g, base_b = dc[:, 0], dc[:, 1], dc[:, 2] 141 | green_shift = (g - base_g) > (np.maximum(r - base_r, b - base_b) + threshold) 142 | 143 | spill_detected |= (view_green_spill | green_shift) 144 | 145 | return spill_detected 146 | 147 | def detect_anomalous_sh_patterns(rest: np.ndarray, threshold: float) -> np.ndarray: 148 | """Detect anomalous patterns in SH coefficients that suggest green spill.""" 149 | if rest.shape[1] == 0: 150 | return np.zeros(rest.shape[0], dtype=bool) 151 | 152 | # Handle different rest shapes 153 | if rest.ndim != 3: 154 | return np.zeros(rest.shape[0], dtype=bool) 155 | 156 | # Check if we have RGB channels 157 | if rest.shape[2] < 3: 158 | return np.zeros(rest.shape[0], dtype=bool) 159 | 160 | # Analyze the energy distribution across SH bands 161 | # Green spill often shows up as unusual energy in higher-order SH coefficients 162 | 163 | # Calculate energy per color channel across all SH coefficients 164 | energy_per_channel = (rest ** 2).sum(axis=1) # (N, 3) 165 | r_energy, g_energy, b_energy = energy_per_channel[:, 0], energy_per_channel[:, 1], energy_per_channel[:, 2] 166 | 167 | # Check for disproportionate green energy 168 | total_energy = energy_per_channel.sum(axis=1) 169 | green_ratio = g_energy / (total_energy + 1e-8) 170 | 171 | # Detect points where green takes up too much of the total SH energy 172 | anomalous_green_energy = green_ratio > (1/3 + threshold) 173 | 174 | # Check for inconsistent signs in green SH coefficients (can indicate spill) 175 | if rest.shape[1] > 1: 176 | green_coeffs = rest[:, :, 1] # All green SH coefficients 177 | sign_changes = np.sum(np.diff(np.sign(green_coeffs), axis=1) != 0, axis=1) 178 | # Too many sign changes might indicate problematic coefficients 179 | inconsistent_signs = sign_changes > (rest.shape[1] * 0.6) 180 | else: 181 | inconsistent_signs = np.zeros(rest.shape[0], dtype=bool) 182 | 183 | return anomalous_green_energy | inconsistent_signs 184 | 185 | def generate_sample_views(num_views: int) -> np.ndarray: 186 | """Generate roughly uniform sample directions on unit sphere.""" 187 | # Use Fibonacci sphere for good uniform distribution 188 | indices = np.arange(0, num_views, dtype=float) + 0.5 189 | phi = np.arccos(1 - 2 * indices / num_views) # Inclination 190 | theta = np.pi * (1 + 5**0.5) * indices # Azimuth (golden angle) 191 | 192 | x = np.sin(phi) * np.cos(theta) 193 | y = np.sin(phi) * np.sin(theta) 194 | z = np.cos(phi) 195 | 196 | return np.column_stack([x, y, z]) 197 | 198 | def evaluate_sh_at_direction(dc: np.ndarray, rest: np.ndarray, direction: np.ndarray) -> np.ndarray: 199 | """Evaluate spherical harmonics at a specific viewing direction.""" 200 | # This is a simplified version - you'll need to implement proper SH evaluation 201 | # based on your specific SH basis functions and degree 202 | 203 | # For now, approximating with DC + first-order directional component 204 | result = dc.copy() 205 | 206 | if rest.shape[1] > 0: 207 | # Add contribution from first SH band (assuming it represents directional variation) 208 | # This is simplified - proper implementation needs full SH basis evaluation 209 | dir_contrib = rest[:, 0, :] * direction[0] # Simplified directional term 210 | if rest.shape[1] > 1: 211 | dir_contrib += rest[:, 1, :] * direction[1] 212 | if rest.shape[1] > 2: 213 | dir_contrib += rest[:, 2, :] * direction[2] 214 | 215 | result += dir_contrib 216 | 217 | # Clamp to reasonable color range 218 | result = np.clip(result, 0, 1) 219 | 220 | return result 221 | 222 | 223 | # Enhanced version with better dark green detection 224 | def green_spill_statistical_enhanced(model: PointModel, threshold: float = 1.5) -> np.ndarray: 225 | """ 226 | Enhanced statistical detection with better handling of dark green spill. 227 | """ 228 | if model._features_rest.numel() == 0: 229 | return np.zeros(model._features_dc.shape[0], dtype=bool) 230 | 231 | with torch.no_grad(): 232 | dc = model._features_dc.cpu().numpy() 233 | rest = model._features_rest.cpu().numpy() 234 | 235 | print(f"DC shape: {dc.shape}, Rest shape: {rest.shape}") 236 | 237 | if dc.shape[1] == 1 or dc.shape[1] != 3: 238 | return np.zeros(dc.shape[0], dtype=bool) 239 | 240 | r, g, b = dc[:, 0], dc[:, 1], dc[:, 2] 241 | 242 | # Separate analysis for different brightness levels 243 | brightness = r + g + b 244 | 245 | # Method 1: Ratio-based for bright pixels 246 | bright_mask = brightness > np.percentile(brightness, 25) # Top 75% brightness 247 | ratio_spill = np.zeros(len(r), dtype=bool) 248 | 249 | if np.any(bright_mask): 250 | gr_ratio = g[bright_mask] / (r[bright_mask] + 1e-8) 251 | gb_ratio = g[bright_mask] / (b[bright_mask] + 1e-8) 252 | 253 | gr_mean, gr_std = np.mean(gr_ratio), np.std(gr_ratio) 254 | gb_mean, gb_std = np.mean(gb_ratio), np.std(gb_ratio) 255 | 256 | gr_outliers = gr_ratio > (gr_mean + threshold * gr_std) 257 | gb_outliers = gb_ratio > (gb_mean + threshold * gb_std) 258 | 259 | ratio_spill[bright_mask] = gr_outliers | gb_outliers 260 | 261 | # Method 2: Absolute difference for dark pixels 262 | dark_mask = brightness <= np.percentile(brightness, 25) # Bottom 25% brightness 263 | abs_spill = np.zeros(len(r), dtype=bool) 264 | 265 | if np.any(dark_mask): 266 | # For dark pixels, look for absolute green excess 267 | g_excess_r = g[dark_mask] - r[dark_mask] 268 | g_excess_b = g[dark_mask] - b[dark_mask] 269 | 270 | # Statistical analysis of green excess in dark regions 271 | if len(g_excess_r) > 1: 272 | gr_excess_mean, gr_excess_std = np.mean(g_excess_r), np.std(g_excess_r) 273 | gb_excess_mean, gb_excess_std = np.mean(g_excess_b), np.std(g_excess_b) 274 | 275 | # Lower threshold for dark regions 276 | dark_threshold = threshold * 0.7 277 | 278 | gr_excess_outliers = g_excess_r > (gr_excess_mean + dark_threshold * gr_excess_std) 279 | gb_excess_outliers = g_excess_b > (gb_excess_mean + dark_threshold * gb_excess_std) 280 | 281 | abs_spill[dark_mask] = gr_excess_outliers & gb_excess_outliers 282 | 283 | # Method 3: Normalized green proportion analysis 284 | total_color = r + g + b + 1e-8 285 | green_proportion = g / total_color 286 | 287 | # Statistical analysis of green proportions 288 | gp_mean, gp_std = np.mean(green_proportion), np.std(green_proportion) 289 | prop_outliers = green_proportion > (gp_mean + threshold * gp_std) 290 | 291 | # Also check if green proportion is unnaturally high (>40% in any pixel) 292 | high_green_prop = green_proportion > 0.4 293 | 294 | # Method 4: SH coefficient analysis if available 295 | sh_spill = np.zeros(len(r), dtype=bool) 296 | if rest.shape[1] > 0 and rest.ndim == 3 and rest.shape[2] >= 3: 297 | rest_energy = (rest ** 2).sum(axis=1) 298 | r_energy, g_energy, b_energy = rest_energy[:, 0], rest_energy[:, 1], rest_energy[:, 2] 299 | 300 | # Analyze green energy patterns 301 | total_sh_energy = r_energy + g_energy + b_energy + 1e-8 302 | green_sh_prop = g_energy / total_sh_energy 303 | 304 | gsh_mean, gsh_std = np.mean(green_sh_prop), np.std(green_sh_prop) 305 | sh_spill = green_sh_prop > (gsh_mean + threshold * gsh_std) 306 | 307 | # Combine all methods with weighted voting 308 | combined_mask = ( 309 | ratio_spill | # Good for bright regions 310 | abs_spill | # Good for dark regions 311 | (prop_outliers & (brightness < np.percentile(brightness, 50))) | # Proportional analysis for mid-dark 312 | high_green_prop | # Catch extreme cases 313 | sh_spill # SH pattern analysis 314 | ) 315 | 316 | return combined_mask -------------------------------------------------------------------------------- /interpolate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gaussian Splat Model Interpolator 3 | ================================ 4 | Author: Felix Hirt 5 | License: MIT License (see LICENSE file for details) 6 | 7 | Note: 8 | This file contains original code by Felix Hirt, licensed under MIT. 9 | """ 10 | 11 | from typing import List, Optional 12 | import torch 13 | from tqdm import tqdm 14 | import os 15 | import argparse 16 | 17 | from helpers.PointModel import PointModel 18 | 19 | 20 | class GaussianInterpolator: 21 | def __init__(self, device: Optional[str] = None): 22 | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") 23 | self.models: List = [] 24 | self.max_points = 0 25 | self.correspondences = {} 26 | 27 | def load_pointmodels(self, models: List): 28 | self.models = models 29 | self.max_points = max([m._xyz.shape[0] for m in models]) if models else 0 30 | 31 | def slerp(self, q1, q2, t): 32 | q1 = q1 / torch.linalg.norm(q1, dim=-1, keepdim=True) 33 | q2 = q2 / torch.linalg.norm(q2, dim=-1, keepdim=True) 34 | 35 | dot = (q1 * q2).sum(dim=-1, keepdim=True) 36 | 37 | mask = dot < 0.0 38 | q2 = torch.where(mask, -q2, q2) 39 | dot = torch.where(mask, -dot, dot) 40 | 41 | DOT_THRESHOLD = 0.9995 42 | 43 | linear_interp = q1 + t * (q2 - q1) 44 | linear_interp = linear_interp / torch.linalg.norm(linear_interp, dim=-1, keepdim=True) 45 | 46 | theta_0 = torch.arccos(dot.clamp(-1, 1)) 47 | sin_theta_0 = torch.sin(theta_0) 48 | theta = theta_0 * t 49 | s1 = torch.sin(theta_0 - theta) / sin_theta_0 50 | s2 = torch.sin(theta) / sin_theta_0 51 | slerp_interp = s1 * q1 + s2 * q2 52 | 53 | result = torch.where(dot > DOT_THRESHOLD, linear_interp, slerp_interp) 54 | 55 | return result 56 | 57 | def build_correspondences(self, spatial_weight: float = 0.7, color_weight: float = 0.3, distance_threshold: float = None, batch_size = 2048, force_rebuild: bool = False, show_progress: bool = True): 58 | """Build pairwise correspondences for consecutive model pairs. 59 | """ 60 | if not force_rebuild and self.correspondences: 61 | return 62 | 63 | pairs = list(zip(range(len(self.models) - 1), range(1, len(self.models)))) 64 | iterator = pairs if not show_progress else tqdm(pairs, desc="Building correspondences") 65 | for i, j in iterator: 66 | mi = self.models[i] 67 | mj = self.models[j] 68 | idx_map = self.correspond_one_to_one( 69 | mi._xyz, 70 | mj._xyz, 71 | spatial_weight=spatial_weight, 72 | color_weight=color_weight, 73 | a_features_dc=mi._features_dc, 74 | b_features_dc=mj._features_dc, 75 | distance_threshold=distance_threshold, 76 | batch_size=batch_size) 77 | self.correspondences[(i, j)] = idx_map 78 | 79 | def correspond_one_to_one(self, 80 | a_xyz: torch.Tensor, 81 | b_xyz: torch.Tensor, 82 | spatial_weight: float = 1.0, 83 | color_weight: float = 0.5, 84 | a_features_dc: torch.Tensor = None, 85 | b_features_dc: torch.Tensor = None, 86 | batch_size: int = 1024, 87 | distance_threshold: float = None, 88 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 89 | ): 90 | """ 91 | one-to-one correspondence ensuring all points in the smaller set 92 | are matched exactly once. 93 | """ 94 | 95 | a_xyz = a_xyz.to(device) 96 | b_xyz = b_xyz.to(device) 97 | if a_features_dc is not None: 98 | a_features_dc = a_features_dc.to(device) 99 | if b_features_dc is not None: 100 | b_features_dc = b_features_dc.to(device) 101 | 102 | # helper function to calculate weight 103 | def make_feat(xyz, feat, sw, cw): 104 | if feat is None: 105 | return xyz * sw 106 | return torch.cat([xyz * sw, feat.view(feat.shape[0], -1) * cw], dim=1) 107 | 108 | a_feat = make_feat(a_xyz, a_features_dc, spatial_weight, color_weight) 109 | b_feat = make_feat(b_xyz, b_features_dc, spatial_weight, color_weight) 110 | 111 | N, D = a_feat.shape 112 | M = b_feat.shape[0] 113 | 114 | # Identify smaller and larger sets 115 | if N <= M: 116 | small_feat, large_feat = a_feat, b_feat 117 | small_is_a = True 118 | else: 119 | small_feat, large_feat = b_feat, a_feat 120 | small_is_a = False 121 | 122 | n_small, n_large = small_feat.shape[0], large_feat.shape[0] 123 | 124 | large_used = torch.zeros(n_large, dtype=torch.bool, device=device) 125 | small_used = torch.zeros(n_small, dtype=torch.bool, device=device) 126 | matched_small_idx = [] 127 | matched_large_idx = [] 128 | 129 | # Track leftovers from threshold filtering 130 | skipped_small_idx = [] 131 | skipped_large_idx = [] 132 | 133 | for start in tqdm(range(0, n_small, batch_size), total=int(n_small/batch_size+1), desc="Matching points", leave=False): 134 | end = min(start + batch_size, n_small) 135 | chunk = small_feat[start:end] 136 | chunk_size = end - start 137 | 138 | # Compute distances for this chunk 139 | dists = torch.cdist(chunk, large_feat) 140 | dists[:, large_used] = float("inf") 141 | 142 | for i in range(chunk_size): 143 | best_j = torch.argmin(dists[i]) 144 | best_dist = dists[i, best_j] 145 | if best_dist == float("inf"): 146 | # all large set points used 147 | break 148 | 149 | if distance_threshold is not None and best_dist > distance_threshold: 150 | # Skip this match — mark both as leftovers 151 | skipped_small_idx.append(start + i) 152 | skipped_large_idx.append(best_j.item()) 153 | continue 154 | 155 | matched_small_idx.append(start + i) 156 | matched_large_idx.append(best_j.item()) 157 | large_used[best_j] = True 158 | small_used[start + i] = True 159 | dists[:, best_j] = float("inf") # prevent reuse 160 | 161 | del dists 162 | torch.cuda.empty_cache() 163 | 164 | matched_small_idx = torch.tensor(matched_small_idx, device=device, dtype=torch.long) 165 | matched_large_idx = torch.tensor(matched_large_idx, device=device, dtype=torch.long) 166 | 167 | skipped_small_idx = torch.tensor(skipped_small_idx, device=device, dtype=torch.long) 168 | skipped_large_idx = torch.tensor(skipped_large_idx, device=device, dtype=torch.long) 169 | 170 | # Compute leftovers 171 | leftover_small_idx = torch.nonzero(~small_used, as_tuple=False).squeeze(1) 172 | leftover_large_idx = torch.nonzero(~large_used, as_tuple=False).squeeze(1) 173 | 174 | # Combine threshold-skipped points with leftovers 175 | leftover_small_idx = torch.unique(torch.cat([leftover_small_idx, skipped_small_idx])) 176 | leftover_large_idx = torch.unique(torch.cat([leftover_large_idx, skipped_large_idx])) 177 | 178 | if small_is_a: 179 | matched_a_idx = matched_small_idx 180 | matched_b_idx = matched_large_idx 181 | leftover_a_idx = leftover_small_idx 182 | leftover_b_idx = leftover_large_idx 183 | else: 184 | matched_a_idx = matched_large_idx 185 | matched_b_idx = matched_small_idx 186 | leftover_a_idx = leftover_large_idx 187 | leftover_b_idx = leftover_small_idx 188 | 189 | print(str(min(leftover_a_idx.shape[0], leftover_b_idx.shape[0])) + " points over distance threshold") 190 | 191 | return matched_a_idx, matched_b_idx, leftover_a_idx, leftover_b_idx 192 | 193 | def interpolate_between( 194 | self, 195 | idx_a: int, 196 | idx_b: int, 197 | t: float 198 | ): 199 | """Interpolate between model idx_a and idx_b at fraction t in [0,1]. 200 | Uses one-to-one correspondences (and handles leftover points). 201 | - At t=0: exactly model A (+ fade-in points from B with opacity 0). 202 | - At t=1: exactly model B (+ fade-out points from A with opacity 0). 203 | - In between: linear interpolation in parameter space. 204 | """ 205 | assert 0.0 <= t <= 1.0 206 | a = self.models[idx_a] 207 | b = self.models[idx_b] 208 | device = self.device 209 | 210 | key = (idx_a, idx_b) 211 | 212 | if key not in self.correspondences: 213 | raise RuntimeError(f"Failed to build correspondence for pair {key}") 214 | 215 | matched_a, matched_b, leftover_a, leftover_b = self.correspondences[key] 216 | matched_a = matched_a.to(device) 217 | matched_b = matched_b.to(device) 218 | leftover_a = leftover_a.to(device) 219 | leftover_b = leftover_b.to(device) 220 | 221 | # These correspond to "matched" points that will interpolate between A and B 222 | src_idx = matched_a 223 | tgt_idx = matched_b 224 | 225 | pos_a = a._xyz[src_idx].to(device) 226 | pos_b = b._xyz[tgt_idx].to(device) 227 | 228 | fdc_a = a._features_dc[src_idx].to(device) 229 | fdc_b = b._features_dc[tgt_idx].to(device) 230 | 231 | fret_a = a._features_rest[src_idx].to(device) 232 | fret_b = b._features_rest[tgt_idx].to(device) 233 | 234 | scale_a = a._scaling[src_idx].to(device) 235 | scale_b = b._scaling[tgt_idx].to(device) 236 | 237 | rot_a = a._rotation[src_idx].to(device) 238 | rot_b = b._rotation[tgt_idx].to(device) 239 | 240 | op_a = a._opacity[src_idx].to(device) 241 | op_b = b._opacity[tgt_idx].to(device) 242 | 243 | # Interpolate 244 | out_xyz = pos_a * (1.0 - t) + pos_b * t 245 | out_fdc = fdc_a * (1.0 - t) + fdc_b * t 246 | out_frest = fret_a * (1.0 - t) + fret_b * t 247 | out_scale = scale_a * (1.0 - t) + scale_b * t 248 | out_rot = self.slerp(rot_a, rot_b, t) 249 | out_op = op_a * (1.0 - t) + op_b * t 250 | 251 | # Handle fade-out points (leftovers from A) 252 | fade_out_xyz = a._xyz[leftover_a] 253 | fade_out_fdc = a._features_dc[leftover_a] 254 | fade_out_frest = a._features_rest[leftover_a] 255 | fade_out_rot = a._rotation[leftover_a] 256 | # Interpolate scale and opacity to disappear 257 | fade_out_scale = a._scaling[leftover_a] * (1.0 - t) + -10.0 * t 258 | fade_out_op = a._opacity[leftover_a] * (1.0 - t) + -10.0 * t 259 | 260 | # Handle fade-in points (leftovers from B) 261 | fade_in_xyz = b._xyz[leftover_b] 262 | fade_in_fdc = b._features_dc[leftover_b] 263 | fade_in_frest = b._features_rest[leftover_b] 264 | fade_in_rot = b._rotation[leftover_b] 265 | # Interpolate scale and opacity to appear 266 | fade_in_scale = -10.0 * (1.0 - t) + b._scaling[leftover_b] * t 267 | fade_in_op = -10.0 * (1.0 - t) + b._opacity[leftover_b] * t 268 | 269 | # Concatenate all parts 270 | out_xyz = torch.cat([out_xyz, fade_out_xyz, fade_in_xyz], dim=0) 271 | out_fdc = torch.cat([out_fdc, fade_out_fdc, fade_in_fdc], dim=0) 272 | out_frest = torch.cat([out_frest, fade_out_frest, fade_in_frest], dim=0) 273 | out_scale = torch.cat([out_scale, fade_out_scale, fade_in_scale], dim=0) 274 | out_rot = torch.cat([out_rot, fade_out_rot, fade_in_rot], dim=0) 275 | out_op = torch.cat([out_op, fade_out_op, fade_in_op], dim=0) 276 | 277 | # Construct output model 278 | out_model = type(a)(sh_degree=getattr(a, "max_sh_degree", None)) 279 | out_model._xyz = out_xyz.contiguous() 280 | out_model._features_dc = out_fdc.contiguous() 281 | out_model._features_rest = out_frest.contiguous() 282 | out_model._scaling = out_scale.contiguous() 283 | out_model._rotation = out_rot.contiguous() 284 | out_model._opacity = out_op.contiguous() 285 | 286 | return out_model 287 | 288 | def save_interpolated_ply(self, idx_a: int, idx_b: int, t: float, path: str, keep_count: Optional[int] = None, show_progress: bool = True): 289 | pm = self.interpolate_between(idx_a, idx_b, t) 290 | 291 | pm.save_ply(path) 292 | 293 | if(__name__ == "__main__"): 294 | parser = argparse.ArgumentParser( 295 | description="Interpolates between 3D point cloud models (.ply files).", 296 | formatter_class=argparse.RawTextHelpFormatter # For better help text formatting 297 | ) 298 | 299 | #Input Arguments 300 | input_group = parser.add_mutually_exclusive_group(required=True) 301 | input_group.add_argument( 302 | '-d', '--directory', 303 | type=str, 304 | help="Path to a directory containing .ply models." 305 | ) 306 | input_group.add_argument( 307 | '-m', '--models', 308 | nargs='+', # '+' means one or more arguments 309 | type=str, 310 | help="Paths to two or more individual .ply model files." 311 | ) 312 | 313 | # --- Output Argument --- 314 | parser.add_argument( 315 | '-o', '--output_dir', 316 | type=str, 317 | required=True, 318 | help="Required. Path to the directory where output will be saved." 319 | ) 320 | 321 | # --- Optional Interpolation Parameters --- 322 | parser.add_argument( 323 | '--models_to_create', 324 | type=int, 325 | default=10, 326 | help="Number of intermediate models to generate. Default: 10" 327 | ) 328 | parser.add_argument( 329 | '--direct_interpolation_value', 330 | type=float, 331 | default=None, 332 | help="Directly export a PLY with a certain interpolation value. Note: 0.0 - 1.0 is between Model #1 and Model #2. 1.0 - 2.0 is between Model #2 and Model #3..." 333 | ) 334 | parser.add_argument( 335 | '--spatial_weight', 336 | type=float, 337 | default=0.7, 338 | help="Weight for spatial distance in correspondence. Default: 0.7" 339 | ) 340 | parser.add_argument( 341 | '--color_weight', 342 | type=float, 343 | default=0.3, 344 | help="Weight for color difference in correspondence. Default: 0.3" 345 | ) 346 | parser.add_argument( 347 | '--distance_threshold', 348 | type=float, 349 | default=None, 350 | help="Max distance for point correspondences. If not set, no threshold is used." 351 | ) 352 | parser.add_argument( 353 | '--batch_size', 354 | type=int, 355 | default=512, 356 | help="Size of point batches to process. Lower for less GPU memory usage." 357 | ) 358 | parser.add_argument( 359 | "--recenter_models", 360 | action="store_true", 361 | help="Recenter all the used models before interpolating") 362 | parser.add_argument( 363 | "--normalize_scales", 364 | action="store_true", 365 | help="Normalize the scale of all models before interpolating") 366 | 367 | args = parser.parse_args() 368 | 369 | # File Collection 370 | ply_files = [] 371 | if args.directory: 372 | print(f"Searching for .ply files in directory: {args.directory}") 373 | # Check if directory exists 374 | if not os.path.isdir(args.directory): 375 | parser.error(f"Directory not found: {args.directory}") 376 | ply_files = sorted([ 377 | os.path.join(args.directory, f) for f in os.listdir(args.directory) if f.lower().endswith('.ply') 378 | ]) 379 | elif args.models: 380 | print("Using provided list of models.") 381 | ply_files = args.models 382 | 383 | # Ensure we have at least 2 models to work with 384 | if len(ply_files) < 2: 385 | parser.error("At least two .ply models are required for interpolation, but found " + str(len(ply_files))) 386 | 387 | print(f"\nFound {len(ply_files)} models for processing.") 388 | 389 | # Ensure output directory exists 390 | os.makedirs(args.output_dir, exist_ok=True) 391 | print(f"Output will be saved to: {args.output_dir}\n") 392 | 393 | # Load models 394 | point_models = [] 395 | for ply_file in ply_files: 396 | pm = PointModel() 397 | pm.load_ply(ply_file) 398 | if(args.recenter_models): 399 | pm.recenter_point_cloud() 400 | if(args.normalize_scales): 401 | pm.normalize_scale(10) 402 | point_models.append(pm) 403 | 404 | interp = GaussianInterpolator(device='cuda') 405 | interp.load_pointmodels(point_models) 406 | interp.build_correspondences(spatial_weight=args.spatial_weight, color_weight=args.color_weight, distance_threshold=args.distance_threshold, batch_size=args.batch_size) 407 | 408 | if args.direct_interpolation_value is not None: 409 | value = args.direct_interpolation_value 410 | idx_a = int(value) 411 | t = value - idx_a 412 | idx_b = idx_a + 1 413 | if(idx_b > len(interp.models) - 1): 414 | idx_a -= 1 415 | idx_b -= 1 416 | t=1 417 | interp.save_interpolated_ply(idx_a, idx_b, t, os.path.join(args.output_dir, "interpolated_model.ply")) 418 | else: 419 | models_to_create = args.models_to_create 420 | for model_idx in tqdm(range(0, len(interp.models)-1), desc="Interpolating"): 421 | #interpolate 422 | for i in tqdm(range(0, models_to_create+1), desc="Saving PLYs"): 423 | interp.save_interpolated_ply(model_idx, model_idx+1, i/models_to_create, os.path.join(args.output_dir, "interpolated_sequence_frame"+str(i + models_to_create * model_idx)+".ply")) 424 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.3.2 2 | plyfile==1.1.2 3 | torch 4 | tqdm==4.67.0 5 | -------------------------------------------------------------------------------- /requirements_visualizer.txt: -------------------------------------------------------------------------------- 1 | numpy==2.3.2 2 | plyfile==1.1.2 3 | torch 4 | tqdm==4.67.0 5 | gsplat 6 | viser 7 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | """Gaussian Splatting Viewer using viser + gsplat rasterization 2 | ================================ 3 | Author: Felix Hirt 4 | License: MIT License (see LICENSE file for details) 5 | 6 | Note: 7 | This file contains original code by Felix Hirt, licensed under MIT. 8 | """ 9 | 10 | import argparse 11 | import time 12 | from typing import Tuple 13 | import time 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | import viser 18 | import viser.transforms as vt 19 | import os 20 | 21 | from gsplat.rendering import rasterization 22 | 23 | from helpers.PointModel import PointModel 24 | 25 | from interpolate import GaussianInterpolator 26 | 27 | 28 | class CameraHelpers: 29 | @staticmethod 30 | def c2w_from_camera(camera) -> np.ndarray: 31 | wxyz = camera.wxyz 32 | # build 4x4 c2w matrix 33 | rot = vt.SO3(wxyz).as_matrix() # 3x3 34 | pos = camera.position 35 | c2w = np.concatenate([np.concatenate([rot, pos[:, None]], 1), [[0, 0, 0, 1]]], 0) 36 | return c2w 37 | 38 | @staticmethod 39 | def K_from_camera_fov_aspect(fov: float, img_wh: Tuple[int, int]) -> np.ndarray: 40 | W, H = img_wh 41 | focal_length = H / 2.0 / np.tan(fov / 2.0) 42 | K = np.array( 43 | [ 44 | [focal_length, 0.0, W / 2.0], 45 | [0.0, focal_length, H / 2.0], 46 | [0.0, 0.0, 1.0], 47 | ] 48 | ) 49 | return K 50 | 51 | class GsplatViserViewer: 52 | '''Adapted for interpolation''' 53 | def __init__(self, server: viser.ViserServer, model: PointModel, interpolator: GaussianInterpolator, device: torch.device): 54 | self.server = server 55 | self.model = model 56 | self.device = device 57 | self.interpolator = interpolator 58 | 59 | # default render resolution shown to clients 60 | self.render_w = 1920 61 | self.render_h = 1080 62 | 63 | self._auto_rotate_enabled = False 64 | self._rotation_speed = 10.0 # degrees per frame 65 | self._last_rotation_time = time.time() 66 | 67 | self._auto_interp_enabled = False 68 | self._interp_direction = 1.0 # forward or backward 69 | self._interp_speed = 0.05 # slider change per update 70 | 71 | # GUI 72 | with self.server.gui.add_folder("Controls"): 73 | self.interpolation_slider = self.server.gui.add_slider( 74 | "Interpolation", min=1, max=len(interpolator.models), marks=range(1, len(interpolator.models)), step=0.001, initial_value=1 75 | ) 76 | self.interpolation_slider.on_update(self._on_slider_update) 77 | 78 | self.auto_interp_checkbox = self.server.gui.add_checkbox( 79 | label="Enable Auto-Interpolation", 80 | initial_value=False, 81 | hint="Continuously move the interpolation slider back and forth.", 82 | ) 83 | self.auto_interp_checkbox.on_update(self._on_auto_interp_update) 84 | 85 | self.auto_rotate_checkbox = self.server.gui.add_checkbox( 86 | label="Enable Auto-Rotate", 87 | initial_value=False, 88 | hint="Automatically rotate the camera around the model." 89 | ) 90 | self.auto_rotate_checkbox.on_update(self._on_auto_rotate_update) 91 | 92 | self.fps_label = self.server.gui.add_number(label="Render FPS", initial_value=0, disabled=True) 93 | 94 | # register client connect/disconnect 95 | server.on_client_connect(self._connect_client) 96 | server.on_client_disconnect(self._disconnect_client) 97 | 98 | self._clients = {} 99 | 100 | self._start_background_tasks() 101 | 102 | def _on_slider_update(self, event): 103 | value = self.interpolation_slider.value - 1 104 | idx_a = int(value) 105 | t = value - idx_a 106 | idx_b = idx_a + 1 107 | if(idx_b > len(self.interpolator.models) - 1): 108 | idx_a -= 1 109 | idx_b -= 1 110 | t=1 111 | self.model = self.interpolator.interpolate_between(idx_a, idx_b, t) 112 | 113 | for cid, client in list(self._clients.items()): 114 | self._render_for_client(client) 115 | 116 | def _on_auto_rotate_update(self, event): 117 | self._auto_rotate_enabled = self.auto_rotate_checkbox.value 118 | print(f"Auto-rotate {'enabled' if self._auto_rotate_enabled else 'disabled'}") 119 | 120 | def _on_auto_interp_update(self, event): 121 | self._auto_interp_enabled = self.auto_interp_checkbox.value 122 | print(f"Auto-interpolation {'enabled' if self._auto_interp_enabled else 'disabled'}") 123 | 124 | def _connect_client(self, client: viser.ClientHandle): 125 | self._clients[client.client_id] = client 126 | 127 | # camera movement 128 | @client.camera.on_update 129 | def _camera_moved(_: viser.CameraHandle): 130 | # small debounce could be added 131 | self._render_for_client(client) 132 | 133 | # when a client connects send render 134 | self._render_for_client(client) 135 | 136 | def _disconnect_client(self, client: viser.ClientHandle): 137 | self._clients.pop(client.client_id, None) 138 | 139 | def _prepare_render_inputs(self): 140 | """Convert PointModel tensors into tensors expected by gsplat.rasterization. 141 | Returns means, quats, scales, opacities, colors, and sh_degree (or None). 142 | """ 143 | pm = self.model 144 | device = self.device 145 | 146 | means = pm._xyz.to(device) 147 | 148 | try: 149 | quats = pm.rotation_activation(pm._rotation).to(device) 150 | except Exception: 151 | # fallback: if rotation already stored normalized 152 | quats = pm._rotation.to(device) 153 | 154 | try: 155 | scales = pm.get_scaling.to(device) 156 | except Exception: 157 | scales = pm._scaling.to(device) 158 | 159 | try: 160 | opacities = pm.get_opacity.squeeze(-1).to(device) 161 | except Exception: 162 | opacities = pm._opacity.squeeze(-1).to(device) 163 | 164 | # colors / SH coefficients: construct colors as [N, 3, coeffs] 165 | if hasattr(pm, "_features_dc") and pm._features_dc.numel() > 0: 166 | parts = [pm._features_dc] 167 | if hasattr(pm, "_features_rest") and pm._features_rest.numel() > 0: 168 | parts.append(pm._features_rest) 169 | colors = torch.cat(parts, dim=1).to(device) 170 | sh_degree = int(colors.shape[2] ** 0.5) - 1 if colors.shape[2] > 0 else None 171 | else: 172 | # fallback: try to use 'colors' attribute if present 173 | if hasattr(pm, "_colors"): 174 | colors = pm._colors.to(device) 175 | sh_degree = None 176 | else: 177 | # empty colors -> create a dummy gray 178 | N = means.shape[0] 179 | colors = torch.ones((N, 3, 1), device=device) * 0.5 180 | sh_degree = 0 181 | 182 | return means, quats, scales, opacities, colors, sh_degree 183 | 184 | def _render_for_client(self, client: viser.ClientHandle): 185 | start_time = time.time() 186 | camera = client.camera 187 | img_wh = (self.render_w, self.render_h) 188 | c2w = CameraHelpers.c2w_from_camera(camera) 189 | K = CameraHelpers.K_from_camera_fov_aspect(camera.fov, img_wh) 190 | 191 | means, quats, scales, opacities, colors, sh_degree = self._prepare_render_inputs() 192 | 193 | c2w_t = torch.from_numpy(c2w).float().to(self.device) 194 | 195 | R_fix = torch.tensor([ 196 | [1, 0, 0, 0], 197 | [0, 0, -1, 0], 198 | [0, 1, 0, 0], 199 | [0, 0, 0, 1], 200 | ], dtype=torch.float32, device=self.device) 201 | c2w_t = R_fix @ c2w_t # adjust camera orientation 202 | 203 | K_t = torch.from_numpy(K).float().to(self.device) 204 | viewmat = c2w_t.inverse()[None] 205 | K_in = K_t[None] 206 | 207 | with torch.no_grad(): 208 | render_colors, render_alphas, meta = rasterization( 209 | means, 210 | quats, 211 | scales, 212 | opacities, 213 | colors, 214 | viewmat, 215 | K_in, 216 | img_wh[0], 217 | img_wh[1], 218 | sh_degree=sh_degree, 219 | render_mode="RGB", 220 | radius_clip=3, 221 | ) 222 | 223 | img = render_colors[0, ..., 0:3].cpu().numpy() 224 | 225 | end_time = time.time() 226 | dt = end_time - start_time 227 | fps = 1.0 / dt if dt > 0 else 0.0 228 | 229 | self.fps_label.value = fps 230 | 231 | client.scene.set_background_image( 232 | img, 233 | format="jpeg", 234 | jpeg_quality=70, 235 | ) 236 | 237 | def _start_background_tasks(self): 238 | """Background loop for camera rotation and auto interpolation.""" 239 | import threading 240 | 241 | def loop(): 242 | while True: 243 | # Auto-rotation 244 | if self._auto_rotate_enabled: 245 | for cid, client in list(self._clients.items()): 246 | self._rotate_camera(client) 247 | self._render_for_client(client) 248 | 249 | # Auto-interpolation 250 | if self._auto_interp_enabled: 251 | self._update_auto_interpolation() 252 | 253 | time.sleep(0.05) # ~20 FPS 254 | 255 | thread = threading.Thread(target=loop, daemon=True) 256 | thread.start() 257 | 258 | def _rotate_camera(self, client: viser.ClientHandle): 259 | """Orbit the camera around the model center.""" 260 | camera = client.camera 261 | 262 | target = np.array([0.0, 0.0, 0.0], dtype=np.float32) 263 | 264 | offset = camera.position - target 265 | 266 | theta = np.radians(self._rotation_speed) 267 | rot_z = np.array([ 268 | [np.cos(theta), -np.sin(theta), 0], 269 | [np.sin(theta), np.cos(theta), 0], 270 | [0, 0, 1], 271 | ]) 272 | new_offset = rot_z @ offset 273 | 274 | # Update camera position and look at target 275 | camera.position = target + new_offset 276 | camera.look_at = target 277 | 278 | def _update_auto_interpolation(self): 279 | """Continuously moves the interpolation slider back and forth.""" 280 | value = self.interpolation_slider.value 281 | decimal = value - int(value) 282 | value += self._interp_direction * (0.001 if decimal<0.01 or decimal > 0.99 else self._interp_speed) 283 | # Bounce between slider min and max 284 | if value >= len(self.interpolator.models): 285 | value = len(self.interpolator.models) 286 | self._interp_direction = -1.0 287 | elif value <= 1.0: 288 | value = 1.0 289 | self._interp_direction = 1.0 290 | 291 | self.interpolation_slider.value = value 292 | self._on_slider_update(None) 293 | 294 | 295 | def main(): 296 | parser = argparse.ArgumentParser( 297 | description="Interpolates between 3D point cloud models and visualizes them.", 298 | formatter_class=argparse.RawTextHelpFormatter # For better help text formatting 299 | ) 300 | parser.add_argument("--port", type=int, default=8080) 301 | parser.add_argument("--device", type=str, default="cuda:0") 302 | 303 | #Input Arguments 304 | input_group = parser.add_mutually_exclusive_group(required=True) 305 | input_group.add_argument( 306 | '-d', '--directory', 307 | type=str, 308 | help="Path to a directory containing .ply models." 309 | ) 310 | input_group.add_argument( 311 | '-m', '--models', 312 | nargs='+', # '+' means one or more arguments 313 | type=str, 314 | help="Paths to two or more individual .ply model files." 315 | ) 316 | 317 | # --- Optional Interpolation Parameters --- 318 | parser.add_argument( 319 | '--spatial_weight', 320 | type=float, 321 | default=0.7, 322 | help="Weight for spatial distance in correspondence. Default: 0.7" 323 | ) 324 | parser.add_argument( 325 | '--color_weight', 326 | type=float, 327 | default=0.3, 328 | help="Weight for color difference in correspondence. Default: 0.3" 329 | ) 330 | parser.add_argument( 331 | '--distance_threshold', 332 | type=float, 333 | default=None, 334 | help="Max distance for point correspondences. If not set, no threshold is used." 335 | ) 336 | parser.add_argument( 337 | '--batch_size', 338 | type=int, 339 | default=512, 340 | help="Size of point batches to process. Lower for less GPU memory usage." 341 | ) 342 | parser.add_argument( 343 | "--disable_recenter_models", 344 | action="store_true", 345 | help="Disables recentering of all the used models before interpolating") 346 | parser.add_argument( 347 | "--disable_normalize_scales", 348 | action="store_true", 349 | help="Disables the normalization of the scale of all models before interpolating") 350 | 351 | args = parser.parse_args() 352 | 353 | # File Collection 354 | ply_files = [] 355 | if args.directory: 356 | print(f"Searching for .ply files in directory: {args.directory}") 357 | # Check if directory exists 358 | if not os.path.isdir(args.directory): 359 | parser.error(f"Directory not found: {args.directory}") 360 | ply_files = sorted([ 361 | os.path.join(args.directory, f) for f in os.listdir(args.directory) if f.lower().endswith('.ply') 362 | ]) 363 | elif args.models: 364 | print("Using provided list of models.") 365 | ply_files = args.models 366 | 367 | # Ensure we have at least 2 models to work with 368 | if len(ply_files) < 2: 369 | parser.error("At least two .ply models are required for interpolation, but found " + str(len(ply_files))) 370 | 371 | print(f"\nFound {len(ply_files)} models for processing.") 372 | 373 | # Load models 374 | point_models = [] 375 | for ply_file in ply_files: 376 | pm = PointModel() 377 | pm.load_ply(ply_file) 378 | if(not args.disable_recenter_models): 379 | pm.recenter_point_cloud() 380 | if(not args.disable_normalize_scales): 381 | pm.normalize_scale(10) 382 | point_models.append(pm) 383 | 384 | interp = GaussianInterpolator(device='cuda') 385 | interp.load_pointmodels(point_models) 386 | interp.build_correspondences(spatial_weight=args.spatial_weight, color_weight=args.color_weight, distance_threshold=args.distance_threshold, batch_size=args.batch_size) 387 | 388 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 389 | 390 | # start viser server and viewer 391 | server = viser.ViserServer(port=args.port, verbose=False) 392 | viewer = GsplatViserViewer(server=server, model=point_models[0], interpolator=interp, device=device) 393 | 394 | print(f"Viewer server running on port {args.port}. Connect with a viser client.") 395 | try: 396 | while True: 397 | time.sleep(10) 398 | except KeyboardInterrupt: 399 | print("Shutting down viewer...") 400 | 401 | 402 | if __name__ == "__main__": 403 | main() 404 | --------------------------------------------------------------------------------