├── .clang-format ├── .gitignore ├── LICENSE ├── MATH.md ├── README.md ├── analytic_diff.ipynb ├── colmap_splat.py ├── lint.sh ├── pyproject.toml ├── requirements.txt ├── setup.py ├── splat_py ├── __init__.py ├── config.py ├── cuda_autograd_functions.py ├── dataloader.py ├── depth.py ├── optimizer_manager.py ├── rasterize.py ├── read_colmap.py ├── structs.py ├── tile_culling.py ├── trainer.py └── utils.py ├── src ├── bindings.cpp ├── checks.cuh ├── depth.cu ├── matrix.cuh ├── precompute_sh.cu ├── projection.cu ├── projection_backward.cu ├── render.cu ├── render_backward.cu ├── spherical_harmonics.cuh └── tile_culling.cu └── test ├── __init__.py ├── gaussian_test_data.py ├── test_cuda_autograd_functions.py ├── test_dataloader.py ├── test_depth.py ├── test_projection.py ├── test_rasterize.py ├── test_rasterize_autograd.py ├── test_structs.py ├── test_tile_culling.py └── test_utils.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | 3 | IndentWidth: 4 4 | AlignAfterOpenBracket: BlockIndent 5 | 6 | ColumnLimit: 100 7 | 8 | # use int& a instead of int &a 9 | DerivePointerAlignment: false 10 | PointerAlignment: Left 11 | 12 | # const int* a instead of int const* a 13 | QualifierAlignment: Left 14 | 15 | # arguments on new line if > 1 line 16 | BinPackArguments: false 17 | BinPackParameters: false 18 | ExperimentalAutoDetectBinPacking: false 19 | AllowAllParametersOfDeclarationOnNextLine: false 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.so 3 | *.pt 4 | build/ 5 | .vscode/ 6 | __pycache__/ 7 | *.egg-info 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 joeyan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D Gaussian Splatting 2 | A "from scratch" re-implementation of [3D Gaussian Splatting 3 | for Real-Time Radiance Field Rendering](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) by Kerbl and Kopanas et al. 4 | 5 | This repository implements the forward and backwards passes using a PyTorch CUDA extension based on the algorithms descriped in the paper. Some details of the splatting and adaptive control algorithm are not explicitly described in the paper and there may be differences between this repo and the official implementation. 6 | 7 | ## Motivation 8 | 9 | 1. Provide a detailed explanation of the differential rasterization algorithm. The forward and backward pass are detailed in [MATH.md](/MATH.md) 10 | 2. Permissive license. The original implementation does not allow commercial use and was never referenced during the development of this repository. 11 | 3. Modular projection functions and gradient checks allow for easier experimentation with camera/pose gradients, new camera models etc. 12 | 4. Minimal dependencies. 13 | 14 | If there are any issues/errors please open an Issue or Pull Request! 15 | 16 | ## Performance 17 | 18 | Evaluations done with the Mip-NeRF 360 dataset at ~1 megapixel resoloution. This corresponds to the 2x downsampled indoor scenes and 4x downsampled outdoor scenes. Every 8th image was used for the test split. Here are some comparisons with the with the official Inria implementation (copied from "Per-Scene Error Metrics"). 19 | 20 | 21 | | Method | Dataset | PSNR | SSIM | N Gaussians | Train Duration | 22 | |-----------|-------------|------|------|-------------|------------------| 23 | | Inria-30k | Garden 1/4x | 27.41| 0.87 | | | 24 | | Ours-30k | Garden 1/4x | 27.05| 0.85 | 2.86M | 20:18 (RTX4090) | 25 | | Inria-7k | Garden 1/4x | 26.24| 0.83 | | | 26 | | Ours-7k | Garden 1/4x | 25.83| 0.80 | 1.52M | 3:05 (RTX4090) | 27 | | Inria-30k | Counter 1/2x| 28.70| 0.91 | | | 28 | | Ours-30k | Counter 1/2x| 28.75| 0.90 | 1.84M | 23:37 (RTX4090) | 29 | | Inria-7k | Counter 1/2x| 26.70| 0.87 | | | 30 | | Ours-7k | Counter 1/2x| 27.59| 0.89 | 1.37M | 4:10 (RTX4090) | 31 | | Inria-30k | Bonsai 1/2x| 31.98| 0.94 | | | 32 | | Ours-30k | Bonsai 1/2x| 32.21| 0.95 | 2.85M | 27:22 (RTX4090) | 33 | | Inria-7k | Bonsai 1/2x | 28.85| 0.91 | | | 34 | | Ours-7k | Bonsai 1/2x | 30.42| 0.93 | 1.86M | 4:19 (RTX4090) | 35 | | Inria-30k | Room 1/2x | 30.63| 0.91 | | | 36 | | Ours-30k | Room 1/2x | 31.73| 0.93 | 1.53M | 20:13 (RTX4090) | 37 | | Inria-7k | Room 1/2x | 28.14| 0.88 | | | 38 | | Ours-7k | Room 1/2x | 30.30| 0.91 | 1.01M | 3:17 (RTX4090) | 39 | 40 | 41 | A comparison from one of the test images in the `garden` dataset. The official implementation and ground truth images appear to be more saturated since they are screen captures of the pdf. 42 | 43 | Ours - 30k: 44 | ![image](https://github.com/joeyan/gaussian_splatting/assets/17635504/519a5f04-82f3-4291-b063-c122efd22c19) 45 | 46 | Official Inria implementation - 30k: 47 | ![image](https://github.com/joeyan/gaussian_splatting/assets/17635504/1460b7eb-a28c-43ed-b8e2-a2695f6ab805) 48 | 49 | Ground truth: 50 | ![image](https://github.com/joeyan/gaussian_splatting/assets/17635504/e3c1f0c2-3f36-41dc-8441-df856399e987) 51 | 52 | 53 | ## Installation 54 | This package requires CUDA which can be installed from [here](https://developer.nvidia.com/cuda-downloads). 55 | 56 | 1. Install Python dependencies 57 | ``` 58 | pip install -r requirements.txt 59 | ``` 60 | 61 | 2. Install the PyTorch CUDA extension 62 | ``` 63 | python setup.py build_ext && python setup.py install 64 | ``` 65 | Note: 66 | - Windows systems may need modify compilation flags in `setup.py` 67 | 68 | Optional: 69 | This project uses `clang-format` to lint the C++/CUDA files: 70 | 71 | ``` 72 | sudo apt install clang-format 73 | ``` 74 | Running `lint.sh` will run both `black` and `clang-format`. 75 | 76 | 77 | ## Training on Mip-Nerf 360 Scenes 78 | 79 | 1. Download the [Mip-NeRF 360](https://jonbarron.info/mipnerf360/) dataset and unzip 80 | 81 | ``` 82 | wget http://storage.googleapis.com/gresearch/refraw360/360_v2.zip && unzip 360_v2.zip 83 | ``` 84 | 85 | 86 | 2. Run the training script: 87 | ``` 88 | python colmap_splat.py 7k --dataset_path --downsample_factor 4 89 | ``` 90 | 91 | To run the high-quality version use `30k` instead of `7k` The `dataset_path` argument refers to the top-level folder for each dataset (`garden`, `kitchen` etc). The paper uses `--downsample_factor 4` for the outdoor scenes and `--downsample_factor 2` for the indoor scenes. 92 | 93 | 94 | For more options: 95 | ``` 96 | python colmap_splat.py 7k --help 97 | ``` 98 | 99 | To run all unit tests: 100 | 101 | ``` 102 | python -m unittest discover test 103 | ``` 104 | 105 | ## References 106 | 107 | The original paper: 108 | ``` 109 | @Article{kerbl3Dgaussians, 110 | author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George}, 111 | title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering}, 112 | journal = {ACM Transactions on Graphics}, 113 | number = {4}, 114 | volume = {42}, 115 | month = {July}, 116 | year = {2023}, 117 | url= {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/} 118 | } 119 | ``` 120 | 121 | The EWA Splatting approach that is the basis for 3D Gaussian Splatting: 122 | ``` 123 | @Article{zwicker2002ewa, 124 | author={M. Zwicker and H. Pfister and J. van Baar and M. Gross}, 125 | title={EWA Splatting}, 126 | journal={IEEE Transactions on Visualization and Computer Graphics}, 127 | number={3}, 128 | volume={8}, 129 | month={July}, 130 | year={2002}, 131 | publisher={IEEE}, 132 | url={https://www.cs.umd.edu/~zwicker/publications/EWASplatting-TVCG02.pdf} 133 | } 134 | ``` 135 | 136 | `gsplat` [Mathematical Supplement](https://arxiv.org/abs/2312.02121) 137 | ``` 138 | @misc{ye2023mathematical, 139 | title={Mathematical Supplement for the $\texttt{gsplat}$ Library}, 140 | author={Vickie Ye and Angjoo Kanazawa}, 141 | year={2023}, 142 | eprint={2312.02121}, 143 | archivePrefix={arXiv}, 144 | primaryClass={cs.MS} 145 | } 146 | ``` 147 | 148 | A great reference for matrix derivatives: 149 | ``` 150 | @misc{giles2008extended, 151 | title={An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation}, 152 | author={Mike Giles}, 153 | month={January} 154 | year={2008}, 155 | url={https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf} 156 | } 157 | ``` 158 | -------------------------------------------------------------------------------- /colmap_splat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import plotext as plt 6 | import torch 7 | import tyro 8 | import yaml 9 | 10 | from splat_py.config import SplatConfigs 11 | from splat_py.dataloader import ColmapData 12 | from splat_py.trainer import SplatTrainer 13 | 14 | 15 | def plot_metrics(metrics, config): 16 | x = np.arange(len(metrics.train_psnr)) 17 | train_psnr = np.array(metrics.train_psnr) 18 | num_gaussians = np.array(metrics.num_gaussians) 19 | 20 | # test psnr has different x-axis 21 | test_psnr = np.array(metrics.test_psnr) 22 | x_test = np.arange(len(test_psnr)) * config.test_eval_interval 23 | 24 | # smooth train psnr for better visualization 25 | smoothing_weights = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0.1]) 26 | smoothing_weights /= np.sum(smoothing_weights) 27 | train_psnr = np.convolve(train_psnr, smoothing_weights, mode="valid") 28 | 29 | plt.plot(x, train_psnr, xside="lower", yside="left", label="Train PSNR") 30 | plt.plot(x_test, test_psnr, xside="lower", yside="left", label="Test PSNR") 31 | plt.plot(x, num_gaussians, xside="upper", yside="right", label="Num Gaussians") 32 | 33 | plt.xlabel("Iteration") 34 | plt.ylabel("Train PSNR", yside="left") 35 | plt.ylabel("Num Gaussians", yside="right") 36 | 37 | plt.title("Gaussian Splatting") 38 | plt.show() 39 | 40 | 41 | config = tyro.cli(SplatConfigs) 42 | 43 | if not os.path.exists(config.output_dir): 44 | os.makedirs(config.output_dir) 45 | # save a copy of the config 46 | yaml.dump(config, open(os.path.join(config.output_dir, "config.yaml"), "w")) 47 | 48 | torch.manual_seed(0) 49 | colmap_data = ColmapData( 50 | config.dataset_path, 51 | torch.device("cuda"), 52 | downsample_factor=config.downsample_factor, 53 | config=config, 54 | ) 55 | 56 | if config.load_checkpoint: 57 | gaussians = torch.load(config.checkpoint_path) 58 | else: 59 | gaussians = colmap_data.create_gaussians() 60 | gaussians.xyz = torch.nn.Parameter(gaussians.xyz) 61 | gaussians.quaternion = torch.nn.Parameter(gaussians.quaternion) 62 | gaussians.scale = torch.nn.Parameter(gaussians.scale) 63 | gaussians.opacity = torch.nn.Parameter(gaussians.opacity) 64 | gaussians.rgb = torch.nn.Parameter(gaussians.rgb) 65 | 66 | images = colmap_data.get_images() 67 | cameras = colmap_data.get_cameras() 68 | 69 | 70 | start = time.time() 71 | trainer = SplatTrainer(gaussians, images, cameras, config) 72 | trainer.train() 73 | end = time.time() 74 | 75 | # save gaussians 76 | torch.save(gaussians, os.path.join(config.output_dir, "gaussians_final.pt")) 77 | 78 | # training time 79 | seconds = end - start 80 | minutes, seconds = divmod(seconds, 60) 81 | print("Total training time: {}min {}sec".format(int(minutes), int(seconds))) 82 | print("Max Test PSNR: ", max(trainer.metrics.test_psnr)) 83 | plot_metrics(trainer.metrics, config) 84 | -------------------------------------------------------------------------------- /lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # It seems like "IndentPragma" in clang-format was abandoned. 4 | # Following this block post: 5 | # https://medicineyeh.wordpress.com/2017/07/13/clang-format-with-pragma/ 6 | 7 | set -eo pipefail 8 | 9 | # get git root dir 10 | GITROOT=$(git rev-parse --show-toplevel) 11 | 12 | # run python formatter in the git root directory 13 | (cd $GITROOT && black .) 14 | 15 | # get all c/c++/cuda files in the git src directory 16 | SRCDIR=$GITROOT/src 17 | FILES=$(find $SRCDIR -type f -name "*.c" -o -name "*.cpp" -o -name "*.cu" -o -name "*.h" -o -name "*.hpp" -o -name "*.cuh") 18 | 19 | for FILE in $FILES; do 20 | echo "Linting $FILE" 21 | # Replace "#pragma unroll" by "//#pragma unroll" 22 | sed -i 's/#pragma unroll/\/\/#pragma unroll/g' $FILE 23 | # Do format 24 | clang-format -i $FILE 25 | # Replace "// *#pragma unroll" by "#pragma unroll" 26 | sed -i 's/\/\/ *#pragma unroll/#pragma unroll/g' $FILE 27 | done 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==24.3.0 2 | matplotlib==3.8.2 3 | matplotlib-inline==0.1.6 4 | numpy==1.24.1 5 | opencv-python==4.9.0.80 6 | plotext==5.2.8 7 | pybind11==2.12.0 8 | scipy==1.11.4 9 | sympy==1.12 10 | torch==2.1.2 11 | torchaudio==2.1.2 12 | torchgeometry==0.1.2 13 | torchmetrics==1.2.1 14 | torchvision==0.16.2 15 | triton==2.1.0 16 | tyro==0.6.3 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | c_flags = ["-O3", "-std=c++17"] 5 | nvcc_flags = ["-O3", "-std=c++17"] 6 | setup( 7 | name="splat_cuda", 8 | ext_modules=[ 9 | CUDAExtension( 10 | name="splat_cuda", 11 | sources=[ 12 | "src/bindings.cpp", 13 | "src/depth.cu", 14 | "src/precompute_sh.cu", 15 | "src/projection.cu", 16 | "src/projection_backward.cu", 17 | "src/render.cu", 18 | "src/render_backward.cu", 19 | "src/tile_culling.cu", 20 | ], 21 | ), 22 | ], 23 | extra_compile_args={ 24 | "cxx": c_flags, 25 | "nvcc": nvcc_flags, 26 | }, 27 | cmdclass={"build_ext": BuildExtension}, 28 | ) 29 | -------------------------------------------------------------------------------- /splat_py/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyan/gaussian_splatting/c1f5a71e3549d8bd089be6e9777c16f8e1bc333f/splat_py/__init__.py -------------------------------------------------------------------------------- /splat_py/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import tyro 3 | from typing import Literal 4 | import yaml 5 | 6 | 7 | class yamlEnabled(object): 8 | """ 9 | Decorator to enable yaml serialization for a class. 10 | from: https://stackoverflow.com/questions/74723634/how-do-you-use-a-frozen-dataclass-in-a-dictionary-and-export-it-to-yaml 11 | """ 12 | 13 | def __init__(self, tag): 14 | self.tag = tag 15 | 16 | def __call__(self, cls): 17 | def to_yaml(dumper, data): 18 | return dumper.represent_mapping(self.tag, vars(data)) 19 | 20 | yaml.SafeDumper.add_representer(cls, to_yaml) 21 | 22 | def from_yaml(loader, node): 23 | data = loader.construct_mapping(node) 24 | return cls(**data) 25 | 26 | yaml.SafeLoader.add_constructor(self.tag, from_yaml) 27 | return cls 28 | 29 | 30 | @yamlEnabled("!SplatConfig") 31 | @dataclass 32 | class SplatConfig: 33 | """Path to dataset directory""" 34 | 35 | dataset_path: str = "garden" 36 | """downsample factor for the images - if applicable""" 37 | downsample_factor: int = 4 38 | """output directory for saving the results""" 39 | output_dir: str = "splat_output" 40 | 41 | """interval for saving checkpoints""" 42 | checkpoint_interval: int = 10000 43 | """initialize gaussians from checkpoint""" 44 | load_checkpoint: bool = False 45 | """path to saved gaussian checkpoint""" 46 | checkpoint_path: str = "" 47 | 48 | """interval for saving debug training images""" 49 | save_debug_image_interval: int = 200 50 | """interval to print debug information""" 51 | print_interval: int = 100 52 | 53 | """initial opacity for gaussians initialized from a point cloud""" 54 | initial_opacity: float = 0.2 55 | """number of neighbors used to compute the initial scale""" 56 | initial_scale_num_neighbors: int = 3 57 | """factor to scale the distance to the nearest neighbors""" 58 | initial_scale_factor: float = 0.8 59 | """maximum initial scale""" 60 | max_initial_scale: float = 0.1 61 | 62 | """gaussians closer than this are culled alongside points outside of fov""" 63 | near_thresh: float = 0.3 64 | """gaussians farther than this are culled alongside points outside of fov""" 65 | far_thresh: float = 500.0 66 | """mahalanobis distance for tile culling 3.0 = 99.7%""" 67 | mh_dist: float = 3.0 68 | """keep gaussians that project within this padding of image during frustrum culling""" 69 | cull_mask_padding: int = 100 70 | """max rgb value for splatted image""" 71 | saturated_pixel_value: float = 255.0 72 | 73 | """number of iterations for training""" 74 | num_iters: int = 7000 75 | """fraction of ssim loss to l1 loss""" 76 | ssim_frac: float = 0.2 77 | "base learning rate" 78 | base_lr: float = 0.002 79 | """learning rate multiplier for xyz""" 80 | xyz_lr_multiplier: float = 0.1 81 | """learning rate multiplier for quaternion""" 82 | quat_lr_multiplier: float = 2 83 | """learning rate multiplier for scale""" 84 | scale_lr_multiplier: float = 5 85 | """learning rate multiplier for opacity""" 86 | opacity_lr_multiplier: float = 10 87 | """learning rate multiplier for rgb""" 88 | rgb_lr_multiplier: float = 2 89 | """learning rate multiplier for spherical harmonics""" 90 | sh_lr_multiplier: float = 0.1 91 | 92 | """interval to evaluate test images""" 93 | test_eval_interval: int = 500 94 | """select every nth image for the test split - 8 is same as GS and Mip-Nerf 360 papers""" 95 | test_split_ratio: int = 8 96 | 97 | """use background color""" 98 | use_background: bool = True 99 | """background color end interval""" 100 | use_background_end: int = 6600 101 | 102 | """interval to reset all opacities to a fixed value""" 103 | reset_opacity_interval: int = 3001 104 | """opacity value to reset to""" 105 | reset_opacity_value: float = 0.20 106 | """start iteration for reset opacity""" 107 | reset_opacity_start: int = 1050 108 | """end iteration for reset opacity""" 109 | reset_opacity_end: int = 6500 110 | 111 | """precompute SH to RGB for each gaussian - speeds up computation ~1.4-2x""" 112 | use_sh_precompute: bool = True 113 | """max SH band to use - 0 is no view dependent color""" 114 | max_sh_band: Literal[0, 1, 2, 3] = 3 115 | """add SH band every interval until all are added""" 116 | add_sh_band_interval: int = 1000 117 | 118 | """use split gaussians""" 119 | use_split: bool = True 120 | """use clone gaussians""" 121 | use_clone: bool = True 122 | """use delete gaussians""" 123 | use_delete: bool = True 124 | 125 | """start iteration for adaptive control""" 126 | adaptive_control_start: int = 750 127 | """end iteration for adaptive control""" 128 | adaptive_control_end: int = 6500 129 | """interval for adaptive control""" 130 | adaptive_control_interval: int = 100 131 | 132 | """max number of gaussians""" 133 | max_gaussians: int = 4250000 134 | 135 | """delete gaussians with opacity below this threshold""" 136 | delete_opacity_threshold: float = 0.1 137 | """clone gaussians with scale below this threshold""" 138 | clone_scale_threshold: float = 0.01 139 | """delete gaussians with scale norm above this threshold""" 140 | max_scale_norm: float = 0.5 141 | """densify a fixed fraction of gaussians every iteration""" 142 | use_fractional_densification: bool = True 143 | """front load densification - slower but slightly higher psnr""" 144 | use_adaptive_fractional_densification: bool = True 145 | 146 | """densify gaussians over this percentile - only used if use_fractional_densification is True""" 147 | uv_grad_percentile: float = 0.96 148 | """densify gaussians over this percentile - only used if use_fractional_densification is True""" 149 | scale_norm_percentile: float = 0.99 150 | 151 | """densify gaussians over this threshold - only used if use_fractional_densification is False""" 152 | uv_grad_threshold: float = 0.0002 153 | 154 | """decrease scale of split gaussians by this factor""" 155 | split_scale_factor: float = 1.6 156 | """number of samples to split gaussians into""" 157 | num_split_samples: int = 2 158 | 159 | 160 | # allow user to choose from 7k or 30k config as base configuration 161 | SplatConfigs = tyro.extras.subcommand_type_from_defaults( 162 | { 163 | "7k": SplatConfig(), # default config is 7k 164 | "30k": SplatConfig( 165 | num_iters=30000, 166 | adaptive_control_start=1500, 167 | adaptive_control_end=27500, 168 | adaptive_control_interval=300, 169 | reset_opacity_end=27500, 170 | use_background_end=28000, 171 | ), 172 | } 173 | ) 174 | -------------------------------------------------------------------------------- /splat_py/cuda_autograd_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from splat_cuda import ( 4 | camera_projection_cuda, 5 | camera_projection_backward_cuda, 6 | compute_sigma_world_cuda, 7 | compute_sigma_world_backward_cuda, 8 | compute_projection_jacobian_cuda, 9 | compute_projection_jacobian_backward_cuda, 10 | compute_conic_cuda, 11 | compute_conic_backward_cuda, 12 | render_tiles_cuda, 13 | render_tiles_backward_cuda, 14 | precompute_rgb_from_sh_cuda, 15 | precompute_rgb_from_sh_backward_cuda, 16 | ) 17 | 18 | 19 | class CameraPointProjection(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, xyz_camera, K): 22 | uv = torch.zeros(xyz_camera.shape[0], 2, dtype=xyz_camera.dtype, device=xyz_camera.device) 23 | camera_projection_cuda(xyz_camera, K, uv) 24 | ctx.save_for_backward(xyz_camera, K) 25 | return uv 26 | 27 | @staticmethod 28 | def backward(ctx, grad_uv): 29 | xyz_camera, K = ctx.saved_tensors 30 | grad_xyz_camera = torch.zeros( 31 | xyz_camera.shape, dtype=xyz_camera.dtype, device=xyz_camera.device 32 | ) 33 | camera_projection_backward_cuda(xyz_camera, K, grad_uv, grad_xyz_camera) 34 | return grad_xyz_camera, None 35 | 36 | 37 | class ComputeSigmaWorld(torch.autograd.Function): 38 | @staticmethod 39 | def forward(ctx, quaternion, scale): 40 | sigma_world = torch.zeros( 41 | quaternion.shape[0], 42 | 3, 43 | 3, 44 | dtype=quaternion.dtype, 45 | device=quaternion.device, 46 | ) 47 | compute_sigma_world_cuda(quaternion, scale, sigma_world) 48 | ctx.save_for_backward(quaternion, scale) 49 | return sigma_world 50 | 51 | @staticmethod 52 | def backward(ctx, grad_sigma_world): 53 | quaternion, scale = ctx.saved_tensors 54 | grad_quaternion = torch.zeros( 55 | quaternion.shape, dtype=quaternion.dtype, device=quaternion.device 56 | ) 57 | grad_scale = torch.zeros(scale.shape, dtype=scale.dtype, device=scale.device) 58 | compute_sigma_world_backward_cuda( 59 | quaternion, scale, grad_sigma_world, grad_quaternion, grad_scale 60 | ) 61 | return grad_quaternion, grad_scale 62 | 63 | 64 | class ComputeProjectionJacobian(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, xyz_camera, K): 67 | jacobian = torch.zeros( 68 | xyz_camera.shape[0], 2, 3, dtype=xyz_camera.dtype, device=xyz_camera.device 69 | ) 70 | compute_projection_jacobian_cuda(xyz_camera, K, jacobian) 71 | ctx.save_for_backward(xyz_camera, K) 72 | return jacobian 73 | 74 | @staticmethod 75 | def backward(ctx, grad_jacobian): 76 | xyz_camera, K = ctx.saved_tensors 77 | grad_xyz_camera = torch.zeros( 78 | xyz_camera.shape, dtype=xyz_camera.dtype, device=xyz_camera.device 79 | ) 80 | compute_projection_jacobian_backward_cuda(xyz_camera, K, grad_jacobian, grad_xyz_camera) 81 | return grad_xyz_camera, None 82 | 83 | 84 | class ComputeConic(torch.autograd.Function): 85 | @staticmethod 86 | def forward(ctx, sigma_world, J, camera_T_world): 87 | conic = torch.zeros(J.shape[0], 3, dtype=sigma_world.dtype, device=sigma_world.device) 88 | compute_conic_cuda(sigma_world, J, camera_T_world, conic) 89 | ctx.save_for_backward(sigma_world, camera_T_world, J) 90 | return conic 91 | 92 | @staticmethod 93 | def backward(ctx, grad_conic): 94 | sigma_world, camera_T_world, J = ctx.saved_tensors 95 | grad_sigma_world = torch.zeros( 96 | sigma_world.shape, dtype=sigma_world.dtype, device=sigma_world.device 97 | ) 98 | grad_J = torch.zeros(J.shape, dtype=J.dtype, device=J.device) 99 | compute_conic_backward_cuda( 100 | sigma_world, J, camera_T_world, grad_conic, grad_sigma_world, grad_J 101 | ) 102 | return grad_sigma_world, grad_J, None 103 | 104 | 105 | class PrecomputeRGBFromSH(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx, sh_coeffs, xyz, camera_T_world): 108 | rgb = torch.zeros(xyz.shape[0], 3, dtype=sh_coeffs.dtype, device=sh_coeffs.device) 109 | precompute_rgb_from_sh_cuda(xyz, sh_coeffs, camera_T_world, rgb) 110 | 111 | if sh_coeffs.dim() == 2: 112 | num_sh_coeff = torch.tensor(1, dtype=torch.int, device=sh_coeffs.device) 113 | else: 114 | num_sh_coeff = torch.tensor( 115 | sh_coeffs.shape[2], dtype=torch.int, device=sh_coeffs.device 116 | ) 117 | ctx.save_for_backward(xyz, camera_T_world, num_sh_coeff) 118 | return rgb 119 | 120 | @staticmethod 121 | def backward(ctx, grad_rgb): 122 | xyz, camera_T_world, num_sh_coeff = ctx.saved_tensors 123 | grad_sh_coeffs = torch.zeros( 124 | xyz.shape[0], 3, num_sh_coeff.item(), dtype=xyz.dtype, device=xyz.device 125 | ) 126 | precompute_rgb_from_sh_backward_cuda(xyz, camera_T_world, grad_rgb, grad_sh_coeffs) 127 | return grad_sh_coeffs, None, None 128 | 129 | 130 | class RenderImage(torch.autograd.Function): 131 | @staticmethod 132 | def forward( 133 | ctx, 134 | rgb, 135 | opacity, 136 | uvs, 137 | conic, 138 | rays, 139 | splat_start_end_idx_by_tile_idx, 140 | sorted_gaussian_idx_by_splat_idx, 141 | image_size, 142 | background_rgb, 143 | ): 144 | rendered_image = torch.zeros( 145 | image_size[0], image_size[1], 3, dtype=rgb.dtype, device=rgb.device 146 | ) 147 | num_splats_per_pixel = torch.zeros( 148 | image_size[0], image_size[1], dtype=torch.int, device=rgb.device 149 | ) 150 | final_weight_per_pixel = torch.zeros( 151 | image_size[0], image_size[1], dtype=rgb.dtype, device=rgb.device 152 | ) 153 | 154 | render_tiles_cuda( 155 | uvs, 156 | opacity, 157 | rgb, 158 | conic, 159 | rays, 160 | splat_start_end_idx_by_tile_idx, 161 | sorted_gaussian_idx_by_splat_idx, 162 | background_rgb, 163 | num_splats_per_pixel, 164 | final_weight_per_pixel, 165 | rendered_image, 166 | ) 167 | ctx.save_for_backward( 168 | uvs, 169 | opacity, 170 | rgb, 171 | conic, 172 | rays, 173 | splat_start_end_idx_by_tile_idx, 174 | sorted_gaussian_idx_by_splat_idx, 175 | background_rgb, 176 | num_splats_per_pixel, 177 | final_weight_per_pixel, 178 | ) 179 | return rendered_image 180 | 181 | @staticmethod 182 | def backward(ctx, grad_rendered_image): 183 | ( 184 | uvs, 185 | opacity, 186 | rgb, 187 | conic, 188 | rays, 189 | splat_start_end_idx_by_tile_idx, 190 | sorted_gaussian_idx_by_splat_idx, 191 | background_rgb, 192 | num_splats_per_pixel, 193 | final_weight_per_pixel, 194 | ) = ctx.saved_tensors 195 | grad_rgb = torch.zeros_like(rgb) 196 | grad_opacity = torch.zeros_like(opacity) 197 | grad_uv = torch.zeros_like(uvs) 198 | grad_conic = torch.zeros_like(conic) 199 | 200 | # ensure input is contiguous 201 | grad_rendered_image = grad_rendered_image.contiguous() 202 | render_tiles_backward_cuda( 203 | uvs, 204 | opacity, 205 | rgb, 206 | conic, 207 | rays, 208 | splat_start_end_idx_by_tile_idx, 209 | sorted_gaussian_idx_by_splat_idx, 210 | background_rgb, 211 | num_splats_per_pixel, 212 | final_weight_per_pixel, 213 | grad_rendered_image, 214 | grad_rgb, 215 | grad_opacity, 216 | grad_uv, 217 | grad_conic, 218 | ) 219 | return grad_rgb, grad_opacity, grad_uv, grad_conic, None, None, None, None, None 220 | -------------------------------------------------------------------------------- /splat_py/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | 5 | from splat_py.config import SplatConfig 6 | from splat_py.read_colmap import ( 7 | read_images_binary, 8 | read_points3D_binary, 9 | read_cameras_binary, 10 | qvec2rotmat, 11 | ) 12 | from splat_py.utils import inverse_sigmoid, compute_initial_scale_from_sparse_points 13 | from splat_py.structs import Gaussians, Image, Camera 14 | 15 | 16 | class GaussianSplattingDataset: 17 | """ 18 | Generic Gaussian Splatting Dataset class 19 | 20 | Classes that inherit from this class should have the following variables: 21 | 22 | device: torch device 23 | xyz: Nx3 tensor of points 24 | rgb: Nx3 tensor of rgb values 25 | 26 | images: list of Image objects 27 | cameras: dict of Camera objects 28 | 29 | """ 30 | 31 | def __init__(self, config): 32 | self.config = config 33 | 34 | def verify_loaded_points(self): 35 | """ 36 | Verify that the values loaded from the dataset are consistent 37 | """ 38 | N = self.xyz.shape[0] 39 | assert self.xyz.shape[1] == 3 40 | assert self.rgb.shape[0] == N 41 | assert self.rgb.shape[1] == 3 42 | 43 | def create_gaussians(self): 44 | """ 45 | Create gaussians object from the dataset 46 | """ 47 | self.verify_loaded_points() 48 | 49 | N = self.xyz.shape[0] 50 | initial_opacity = torch.ones(N, 1) * inverse_sigmoid(self.config.initial_opacity) 51 | # compute scale based on the density of the points around each point 52 | initial_scale = compute_initial_scale_from_sparse_points( 53 | self.xyz, 54 | num_neighbors=self.config.initial_scale_num_neighbors, 55 | neighbor_dist_to_scale_factor=self.config.initial_scale_factor, 56 | max_initial_scale=self.config.max_initial_scale, 57 | ) 58 | initial_quaternion = torch.zeros(N, 4) 59 | initial_quaternion[:, 0] = 1.0 60 | 61 | return Gaussians( 62 | xyz=self.xyz.to(self.device), 63 | rgb=self.rgb.to(self.device), 64 | opacity=initial_opacity.to(self.device), 65 | scale=initial_scale.to(self.device), 66 | quaternion=initial_quaternion.to(self.device), 67 | ) 68 | 69 | def get_images(self): 70 | """ 71 | get images from the dataset 72 | """ 73 | 74 | return self.images 75 | 76 | def get_cameras(self): 77 | """ 78 | get cameras from the dataset 79 | """ 80 | 81 | return self.cameras 82 | 83 | 84 | class ColmapData(GaussianSplattingDataset): 85 | """ 86 | This class loads data similar to Mip-Nerf 360 Dataset generated with colmap 87 | 88 | Format: 89 | 90 | dataset_dir: 91 | images: full resoloution images 92 | ... 93 | images_N: downsampled images by a factor of N 94 | ... 95 | poses_bounds.npy: currently unused 96 | sparse: 97 | 0: 98 | cameras.bin 99 | images.bin 100 | points3D.bin 101 | """ 102 | 103 | def __init__( 104 | self, 105 | colmap_directory_path: str, 106 | device: torch.device, 107 | downsample_factor: int, 108 | config: SplatConfig, 109 | ) -> None: 110 | super().__init__(config) 111 | 112 | self.colmap_directory_path = colmap_directory_path 113 | self.device = device 114 | self.downsample_factor = downsample_factor 115 | 116 | # load sparse points 117 | points_path = os.path.join(colmap_directory_path, "sparse", "0", "points3D.bin") 118 | sparse_points = read_points3D_binary(points_path) 119 | num_points = len(sparse_points) 120 | 121 | self.xyz = torch.zeros(num_points, 3) 122 | self.rgb = torch.zeros(num_points, 3) 123 | row = 0 124 | for _, point in sparse_points.items(): 125 | self.xyz[row] = torch.tensor(point.xyz, dtype=torch.float32) 126 | self.rgb[row] = torch.tensor( 127 | point.rgb / 255.0 / 0.28209479177387814, dtype=torch.float32 128 | ) 129 | row += 1 130 | 131 | # load images 132 | image_info_path = os.path.join(colmap_directory_path, "sparse", "0", "images.bin") 133 | self.image_info = read_images_binary(image_info_path) 134 | 135 | self.images = [] 136 | for _, image_info in self.image_info.items(): 137 | # load image 138 | image_path = os.path.join( 139 | colmap_directory_path, 140 | f"images_{self.downsample_factor}", 141 | image_info.name, 142 | ) 143 | image = cv2.imread(image_path) 144 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 145 | 146 | # load transform 147 | camera_T_world = torch.eye(4) 148 | camera_T_world[:3, :3] = torch.tensor(qvec2rotmat(image_info.qvec), dtype=torch.float32) 149 | camera_T_world[:3, 3] = torch.tensor(image_info.tvec, dtype=torch.float32) 150 | 151 | self.images.append( 152 | Image( 153 | image=torch.from_numpy(image).to(torch.uint8).to(self.device), 154 | camera_id=image_info.camera_id, 155 | camera_T_world=camera_T_world.to(self.device), 156 | ) 157 | ) 158 | 159 | # load cameras 160 | cameras_path = os.path.join(colmap_directory_path, "sparse", "0", "cameras.bin") 161 | cameras = read_cameras_binary(cameras_path) 162 | 163 | self.cameras = {} 164 | for camera_id, camera in cameras.items(): 165 | K = torch.zeros((3, 3), dtype=torch.float32, device=self.device) 166 | if camera.model == "SIMPLE_PINHOLE": 167 | # colmap params [f, cx, cy] 168 | K[0, 0] = camera.params[0] / float(self.downsample_factor) 169 | K[1, 1] = camera.params[0] / float(self.downsample_factor) 170 | K[0, 2] = camera.params[1] / float(self.downsample_factor) 171 | K[1, 2] = camera.params[2] / float(self.downsample_factor) 172 | K[2, 2] = 1.0 173 | elif camera.model == "PINHOLE": 174 | # colmap params [fx, fy, cx, cy] 175 | K[0, 0] = camera.params[0] / float(self.downsample_factor) 176 | K[1, 1] = camera.params[1] / float(self.downsample_factor) 177 | K[0, 2] = camera.params[2] / float(self.downsample_factor) 178 | K[1, 2] = camera.params[3] / float(self.downsample_factor) 179 | K[2, 2] = 1.0 180 | else: 181 | raise NotImplementedError("Only Pinhole and Simple Pinhole cameras are supported") 182 | 183 | self.cameras[camera_id] = Camera( 184 | width=self.images[0].image.shape[1], 185 | height=self.images[0].image.shape[0], 186 | K=K, 187 | ) 188 | -------------------------------------------------------------------------------- /splat_py/depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from splat_cuda import render_depth_cuda 4 | from splat_py.utils import transform_points_torch 5 | from splat_py.cuda_autograd_functions import ( 6 | CameraPointProjection, 7 | ComputeSigmaWorld, 8 | ComputeProjectionJacobian, 9 | ComputeConic, 10 | ) 11 | from splat_py.structs import Gaussians, Tiles 12 | from splat_py.tile_culling import ( 13 | get_splats, 14 | ) 15 | 16 | 17 | def render_depth( 18 | gaussians, alpha_threshold, camera_T_world, camera, near_thresh, cull_mask_padding, mh_dist 19 | ): 20 | with torch.no_grad(): 21 | xyz_camera_frame = transform_points_torch(gaussians.xyz, camera_T_world) 22 | uv = CameraPointProjection.apply(xyz_camera_frame, camera.K) 23 | 24 | # perform frustrum culling 25 | culling_mask = torch.zeros( 26 | xyz_camera_frame.shape[0], 27 | dtype=torch.bool, 28 | device=gaussians.xyz.device, 29 | ) 30 | culling_mask = culling_mask | (xyz_camera_frame[:, 2] < near_thresh) 31 | culling_mask = ( 32 | culling_mask 33 | | (uv[:, 0] < -1 * cull_mask_padding) 34 | | (uv[:, 0] > camera.width + cull_mask_padding) 35 | | (uv[:, 1] < -1 * cull_mask_padding) 36 | | (uv[:, 1] > camera.height + cull_mask_padding) 37 | ) 38 | 39 | # cull gaussians outside of camera frustrum 40 | uv = uv[~culling_mask, :] 41 | xyz_camera_frame = xyz_camera_frame[~culling_mask, :] 42 | 43 | if gaussians.sh is not None: 44 | culled_gaussians = Gaussians( 45 | xyz=gaussians.xyz[~culling_mask, :], 46 | quaternion=gaussians.quaternion[~culling_mask, :], 47 | scale=gaussians.scale[~culling_mask, :], 48 | opacity=torch.sigmoid( 49 | gaussians.opacity[~culling_mask] 50 | ), # apply sigmoid activation to opacity 51 | rgb=gaussians.rgb[~culling_mask, :], 52 | sh=gaussians.sh[~culling_mask, :], 53 | ) 54 | else: 55 | culled_gaussians = Gaussians( 56 | xyz=gaussians.xyz[~culling_mask, :], 57 | quaternion=gaussians.quaternion[~culling_mask, :], 58 | scale=gaussians.scale[~culling_mask, :], 59 | opacity=torch.sigmoid( 60 | gaussians.opacity[~culling_mask] 61 | ), # apply sigmoid activation to opacity 62 | rgb=gaussians.rgb[~culling_mask, :], 63 | ) 64 | 65 | sigma_world = ComputeSigmaWorld.apply(culled_gaussians.quaternion, culled_gaussians.scale) 66 | J = ComputeProjectionJacobian.apply(xyz_camera_frame, camera.K) 67 | conic = ComputeConic.apply(sigma_world, J, camera_T_world) 68 | 69 | # perform tile culling 70 | tiles = Tiles(camera.height, camera.width, uv.device) 71 | sorted_gaussian_idx_by_splat_idx, splat_start_end_idx_by_tile_idx = get_splats( 72 | uv, tiles, conic, xyz_camera_frame, mh_dist 73 | ) 74 | 75 | depth_image = ( 76 | torch.ones(camera.height, camera.width, 1, dtype=torch.float32, device=uv.device) * -1.0 77 | ) 78 | render_depth_cuda( 79 | xyz_camera_frame, 80 | uv, 81 | culled_gaussians.opacity, 82 | conic, 83 | splat_start_end_idx_by_tile_idx, 84 | sorted_gaussian_idx_by_splat_idx, 85 | alpha_threshold, 86 | depth_image, 87 | ) 88 | return depth_image 89 | -------------------------------------------------------------------------------- /splat_py/optimizer_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class OptimizerManager: 5 | """ 6 | Manages adding/deleting gaussians and updating SH Bands 7 | """ 8 | 9 | def __init__(self, gaussians, config): 10 | self.config = config 11 | self.setup_optimizer(gaussians) 12 | 13 | def setup_optimizer(self, gaussians): 14 | # add new params to optimizer 15 | self.optimizer = torch.optim.Adam( 16 | [ 17 | { 18 | "params": gaussians.xyz, 19 | "lr": self.config.base_lr * self.config.xyz_lr_multiplier, 20 | }, 21 | { 22 | "params": gaussians.quaternion, 23 | "lr": self.config.base_lr * self.config.quat_lr_multiplier, 24 | }, 25 | { 26 | "params": gaussians.scale, 27 | "lr": self.config.base_lr * self.config.scale_lr_multiplier, 28 | }, 29 | { 30 | "params": gaussians.opacity, 31 | "lr": self.config.base_lr * self.config.opacity_lr_multiplier, 32 | }, 33 | { 34 | "params": gaussians.rgb, 35 | "lr": self.config.base_lr * self.config.rgb_lr_multiplier, 36 | }, 37 | ], 38 | ) 39 | if gaussians.sh is not None: 40 | self.optimizer.add_param_group( 41 | {"params": gaussians.sh, "lr": self.config.base_lr * self.config.sh_lr_multiplier}, 42 | ) 43 | 44 | def reset_opacity_exp_avg(self, gaussians): 45 | # reset exp_avg and exp_avg_sq for opacity 46 | old_optimizer_param = self.optimizer.param_groups[3]["params"][0] 47 | optimizer_param_state = self.optimizer.state[old_optimizer_param] 48 | del self.optimizer.state[old_optimizer_param] 49 | 50 | optimizer_param_state["exp_avg"] = torch.zeros_like(optimizer_param_state["exp_avg"]) 51 | optimizer_param_state["exp_avg_sq"] = torch.zeros_like(optimizer_param_state["exp_avg_sq"]) 52 | 53 | del self.optimizer.param_groups[3]["params"][0] 54 | del self.optimizer.param_groups[3]["params"] 55 | 56 | self.optimizer.param_groups[3]["params"] = [gaussians.opacity] 57 | self.optimizer.state[3] = optimizer_param_state 58 | 59 | def add_sh_to_optimizer(self, gaussians): 60 | self.optimizer.add_param_group( 61 | {"params": gaussians.sh, "lr": self.config.base_lr * self.config.sh_lr_multiplier}, 62 | ) 63 | 64 | def add_sh_band_to_optimizer(self, gaussians): 65 | old_optimizer_param = self.optimizer.param_groups[5]["params"][0] 66 | optimizer_param_state = self.optimizer.state[old_optimizer_param] 67 | del self.optimizer.state[old_optimizer_param] 68 | 69 | optimizer_param_state["exp_avg"] = torch.zeros_like(gaussians.sh) 70 | optimizer_param_state["exp_avg_sq"] = torch.zeros_like(gaussians.sh) 71 | 72 | del self.optimizer.param_groups[5]["params"][0] 73 | del self.optimizer.param_groups[5]["params"] 74 | 75 | self.optimizer.param_groups[5]["params"] = [gaussians.sh] 76 | self.optimizer.state[5] = optimizer_param_state 77 | 78 | def delete_param_from_optimizer(self, new_param, keep_mask, param_index): 79 | old_optimizer_param = self.optimizer.param_groups[param_index]["params"][0] 80 | optimizer_param_state = self.optimizer.state[old_optimizer_param] 81 | del self.optimizer.state[old_optimizer_param] 82 | 83 | optimizer_param_state["exp_avg"] = optimizer_param_state["exp_avg"][keep_mask, :] 84 | optimizer_param_state["exp_avg_sq"] = optimizer_param_state["exp_avg_sq"][keep_mask, :] 85 | 86 | del self.optimizer.param_groups[param_index]["params"][0] 87 | del self.optimizer.param_groups[param_index]["params"] 88 | 89 | self.optimizer.param_groups[param_index]["params"] = [new_param] 90 | self.optimizer.state[new_param] = optimizer_param_state 91 | 92 | def delete_gaussians_from_optimizer(self, updated_gaussians, keep_mask): 93 | self.delete_param_from_optimizer(updated_gaussians.xyz, keep_mask, 0) 94 | self.delete_param_from_optimizer(updated_gaussians.quaternion, keep_mask, 1) 95 | self.delete_param_from_optimizer(updated_gaussians.scale, keep_mask, 2) 96 | self.delete_param_from_optimizer(updated_gaussians.opacity, keep_mask, 3) 97 | self.delete_param_from_optimizer(updated_gaussians.rgb, keep_mask, 4) 98 | if updated_gaussians.sh is not None: 99 | self.delete_param_from_optimizer(updated_gaussians.sh, keep_mask, 5) 100 | 101 | def add_params_to_optimizer(self, new_param, num_added, param_index): 102 | old_optimizer_param = self.optimizer.param_groups[param_index]["params"][0] 103 | optimizer_param_state = self.optimizer.state[old_optimizer_param] 104 | 105 | if new_param.dim() == 2: 106 | # set exp_avg and exp_avg_sq for cloned gaussians to zero 107 | optimizer_param_state["exp_avg"] = torch.cat( 108 | [ 109 | optimizer_param_state["exp_avg"], 110 | torch.zeros( 111 | num_added, 112 | new_param.shape[1], 113 | device=optimizer_param_state["exp_avg"].device, 114 | dtype=optimizer_param_state["exp_avg"].dtype, 115 | ), 116 | ], 117 | dim=0, 118 | ) 119 | optimizer_param_state["exp_avg_sq"] = torch.cat( 120 | [ 121 | optimizer_param_state["exp_avg_sq"], 122 | torch.zeros( 123 | num_added, 124 | new_param.shape[1], 125 | device=optimizer_param_state["exp_avg_sq"].device, 126 | dtype=optimizer_param_state["exp_avg_sq"].dtype, 127 | ), 128 | ], 129 | dim=0, 130 | ) 131 | if new_param.dim() == 3: 132 | # set exp_avg and exp_avg_sq for cloned gaussians to zero 133 | optimizer_param_state["exp_avg"] = torch.cat( 134 | [ 135 | optimizer_param_state["exp_avg"], 136 | torch.zeros( 137 | num_added, 138 | new_param.shape[1], 139 | new_param.shape[2], 140 | device=optimizer_param_state["exp_avg"].device, 141 | dtype=optimizer_param_state["exp_avg"].dtype, 142 | ), 143 | ], 144 | dim=0, 145 | ) 146 | optimizer_param_state["exp_avg_sq"] = torch.cat( 147 | [ 148 | optimizer_param_state["exp_avg_sq"], 149 | torch.zeros( 150 | num_added, 151 | new_param.shape[1], 152 | new_param.shape[2], 153 | device=optimizer_param_state["exp_avg_sq"].device, 154 | dtype=optimizer_param_state["exp_avg_sq"].dtype, 155 | ), 156 | ], 157 | dim=0, 158 | ) 159 | 160 | del self.optimizer.state[old_optimizer_param] 161 | del old_optimizer_param 162 | self.optimizer.param_groups[param_index]["params"] = [new_param] 163 | self.optimizer.state[new_param] = optimizer_param_state 164 | 165 | def add_gaussians_to_optimizer(self, updated_gaussians, num_added): 166 | self.add_params_to_optimizer(updated_gaussians.xyz, num_added, 0) 167 | self.add_params_to_optimizer(updated_gaussians.quaternion, num_added, 1) 168 | self.add_params_to_optimizer(updated_gaussians.scale, num_added, 2) 169 | self.add_params_to_optimizer(updated_gaussians.opacity, num_added, 3) 170 | self.add_params_to_optimizer(updated_gaussians.rgb, num_added, 4) 171 | if updated_gaussians.sh is not None: 172 | self.add_params_to_optimizer(updated_gaussians.sh, num_added, 5) 173 | -------------------------------------------------------------------------------- /splat_py/rasterize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from splat_py.utils import transform_points_torch, compute_rays_in_world_frame 4 | from splat_py.cuda_autograd_functions import ( 5 | CameraPointProjection, 6 | ComputeSigmaWorld, 7 | ComputeProjectionJacobian, 8 | ComputeConic, 9 | RenderImage, 10 | PrecomputeRGBFromSH, 11 | ) 12 | from splat_py.structs import Gaussians, Tiles 13 | from splat_py.tile_culling import ( 14 | get_splats, 15 | ) 16 | 17 | 18 | def rasterize( 19 | gaussians, 20 | camera_T_world, 21 | camera, 22 | near_thresh, 23 | far_thresh, 24 | cull_mask_padding, 25 | mh_dist, 26 | use_sh_precompute, 27 | background_rgb, 28 | ): 29 | xyz_camera_frame = transform_points_torch(gaussians.xyz, camera_T_world) 30 | uv = CameraPointProjection.apply(xyz_camera_frame, camera.K) 31 | 32 | # perform frustrum culling 33 | culling_mask = torch.zeros( 34 | xyz_camera_frame.shape[0], 35 | dtype=torch.bool, 36 | device=gaussians.xyz.device, 37 | ) 38 | culling_mask = ( 39 | culling_mask 40 | | (xyz_camera_frame[:, 2] < near_thresh) 41 | | (xyz_camera_frame[:, 2] > far_thresh) 42 | ) 43 | culling_mask = ( 44 | culling_mask 45 | | (uv[:, 0] < -1 * cull_mask_padding) 46 | | (uv[:, 0] > camera.width + cull_mask_padding) 47 | | (uv[:, 1] < -1 * cull_mask_padding) 48 | | (uv[:, 1] > camera.height + cull_mask_padding) 49 | ) 50 | 51 | # cull gaussians outside of camera frustrum 52 | uv = uv[~culling_mask, :] 53 | xyz_camera_frame = xyz_camera_frame[~culling_mask, :] 54 | 55 | if gaussians.sh is not None: 56 | culled_gaussians = Gaussians( 57 | xyz=gaussians.xyz[~culling_mask, :], 58 | quaternion=gaussians.quaternion[~culling_mask, :], 59 | scale=gaussians.scale[~culling_mask, :], 60 | opacity=torch.sigmoid( 61 | gaussians.opacity[~culling_mask] 62 | ), # apply sigmoid activation to opacity 63 | rgb=gaussians.rgb[~culling_mask, :], 64 | sh=gaussians.sh[~culling_mask, :], 65 | ) 66 | else: 67 | culled_gaussians = Gaussians( 68 | xyz=gaussians.xyz[~culling_mask, :], 69 | quaternion=gaussians.quaternion[~culling_mask, :], 70 | scale=gaussians.scale[~culling_mask, :], 71 | opacity=torch.sigmoid( 72 | gaussians.opacity[~culling_mask] 73 | ), # apply sigmoid activation to opacity 74 | rgb=gaussians.rgb[~culling_mask, :], 75 | ) 76 | 77 | sigma_world = ComputeSigmaWorld.apply(culled_gaussians.quaternion, culled_gaussians.scale) 78 | J = ComputeProjectionJacobian.apply(xyz_camera_frame, camera.K) 79 | conic = ComputeConic.apply(sigma_world, J, camera_T_world) 80 | 81 | # perform tile culling 82 | tiles = Tiles(camera.height, camera.width, uv.device) 83 | 84 | sorted_gaussian_idx_by_splat_idx, splat_start_end_idx_by_tile_idx = get_splats( 85 | uv, tiles, conic, xyz_camera_frame, mh_dist 86 | ) 87 | rays = torch.zeros(1, 1, 1, dtype=gaussians.xyz.dtype, device=gaussians.xyz.device) 88 | if culled_gaussians.sh is not None: 89 | sh_coeffs = torch.cat((culled_gaussians.rgb.unsqueeze(dim=2), culled_gaussians.sh), dim=2) 90 | if use_sh_precompute: 91 | render_rgb = PrecomputeRGBFromSH.apply( 92 | sh_coeffs, culled_gaussians.xyz, torch.inverse(camera_T_world).contiguous() 93 | ) 94 | else: 95 | render_rgb = sh_coeffs 96 | # actually need to compute rays here 97 | rays = compute_rays_in_world_frame(camera, camera_T_world) 98 | else: 99 | render_rgb = culled_gaussians.rgb 100 | 101 | image = RenderImage.apply( 102 | render_rgb, 103 | culled_gaussians.opacity, 104 | uv, 105 | conic, 106 | rays, 107 | splat_start_end_idx_by_tile_idx, 108 | sorted_gaussian_idx_by_splat_idx, 109 | torch.tensor([camera.height, camera.width], device=uv.device), 110 | background_rgb, 111 | ) 112 | return image, culling_mask, uv 113 | -------------------------------------------------------------------------------- /splat_py/read_colmap.py: -------------------------------------------------------------------------------- 1 | # These functions are directly copied from the COLMAP source code from 2 | # colmap/scripts/python/read_write_model.py 3 | 4 | 5 | # Copyright (c) 2023, ETH Zurich and UNC Chapel Hill. 6 | # All rights reserved. 7 | # 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, are permitted provided that the following conditions are met: 10 | # 11 | # * Redistributions of source code must retain the above copyright 12 | # notice, this list of conditions and the following disclaimer. 13 | # 14 | # * Redistributions in binary form must reproduce the above copyright 15 | # notice, this list of conditions and the following disclaimer in the 16 | # documentation and/or other materials provided with the distribution. 17 | # 18 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 19 | # its contributors may be used to endorse or promote products derived 20 | # from this software without specific prior written permission. 21 | # 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 25 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 26 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 27 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 28 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 29 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 30 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 31 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | # POSSIBILITY OF SUCH DAMAGE. 33 | 34 | 35 | import collections 36 | import struct 37 | import numpy as np 38 | 39 | 40 | CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) 42 | BaseImage = collections.namedtuple( 43 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] 44 | ) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] 47 | ) 48 | 49 | 50 | class Image(BaseImage): 51 | def qvec2rotmat(self): 52 | return qvec2rotmat(self.qvec) 53 | 54 | 55 | CAMERA_MODELS = { 56 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 57 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 58 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 59 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 60 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 61 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 62 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 63 | CameraModel(model_id=7, model_name="FOV", num_params=5), 64 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 65 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 66 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 67 | } 68 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]) 69 | CAMERA_MODEL_NAMES = dict( 70 | [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] 71 | ) 72 | 73 | 74 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 75 | """Read and unpack the next bytes from a binary file. 76 | :param fid: 77 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 78 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 79 | :param endian_character: Any of {@, =, <, >, !} 80 | :return: Tuple of read and unpacked values. 81 | """ 82 | data = fid.read(num_bytes) 83 | return struct.unpack(endian_character + format_char_sequence, data) 84 | 85 | 86 | def read_cameras_binary(path_to_model_file): 87 | """ 88 | see: src/colmap/scene/reconstruction.cc 89 | void Reconstruction::WriteCamerasBinary(const std::string& path) 90 | void Reconstruction::ReadCamerasBinary(const std::string& path) 91 | """ 92 | cameras = {} 93 | with open(path_to_model_file, "rb") as fid: 94 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 95 | for _ in range(num_cameras): 96 | camera_properties = read_next_bytes(fid, num_bytes=24, format_char_sequence="iiQQ") 97 | camera_id = camera_properties[0] 98 | model_id = camera_properties[1] 99 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 100 | width = camera_properties[2] 101 | height = camera_properties[3] 102 | num_params = CAMERA_MODEL_IDS[model_id].num_params 103 | params = read_next_bytes( 104 | fid, 105 | num_bytes=8 * num_params, 106 | format_char_sequence="d" * num_params, 107 | ) 108 | cameras[camera_id] = Camera( 109 | id=camera_id, 110 | model=model_name, 111 | width=width, 112 | height=height, 113 | params=np.array(params), 114 | ) 115 | assert len(cameras) == num_cameras 116 | return cameras 117 | 118 | 119 | def read_images_binary(path_to_model_file): 120 | """ 121 | see: src/colmap/scene/reconstruction.cc 122 | void Reconstruction::ReadImagesBinary(const std::string& path) 123 | void Reconstruction::WriteImagesBinary(const std::string& path) 124 | """ 125 | images = {} 126 | with open(path_to_model_file, "rb") as fid: 127 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 128 | for _ in range(num_reg_images): 129 | binary_image_properties = read_next_bytes( 130 | fid, num_bytes=64, format_char_sequence="idddddddi" 131 | ) 132 | image_id = binary_image_properties[0] 133 | qvec = np.array(binary_image_properties[1:5]) 134 | tvec = np.array(binary_image_properties[5:8]) 135 | camera_id = binary_image_properties[8] 136 | binary_image_name = b"" 137 | current_char = read_next_bytes(fid, 1, "c")[0] 138 | while current_char != b"\x00": # look for the ASCII 0 entry 139 | binary_image_name += current_char 140 | current_char = read_next_bytes(fid, 1, "c")[0] 141 | image_name = binary_image_name.decode("utf-8") 142 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] 143 | x_y_id_s = read_next_bytes( 144 | fid, 145 | num_bytes=24 * num_points2D, 146 | format_char_sequence="ddq" * num_points2D, 147 | ) 148 | xys = np.column_stack( 149 | [ 150 | tuple(map(float, x_y_id_s[0::3])), 151 | tuple(map(float, x_y_id_s[1::3])), 152 | ] 153 | ) 154 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 155 | images[image_id] = Image( 156 | id=image_id, 157 | qvec=qvec, 158 | tvec=tvec, 159 | camera_id=camera_id, 160 | name=image_name, 161 | xys=xys, 162 | point3D_ids=point3D_ids, 163 | ) 164 | return images 165 | 166 | 167 | def read_points3D_binary(path_to_model_file): 168 | """ 169 | see: src/colmap/scene/reconstruction.cc 170 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 171 | void Reconstruction::WritePoints3DBinary(const std::string& path) 172 | """ 173 | points3D = {} 174 | with open(path_to_model_file, "rb") as fid: 175 | num_points = read_next_bytes(fid, 8, "Q")[0] 176 | for _ in range(num_points): 177 | binary_point_line_properties = read_next_bytes( 178 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 179 | ) 180 | point3D_id = binary_point_line_properties[0] 181 | xyz = np.array(binary_point_line_properties[1:4]) 182 | rgb = np.array(binary_point_line_properties[4:7]) 183 | error = np.array(binary_point_line_properties[7]) 184 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] 185 | track_elems = read_next_bytes( 186 | fid, 187 | num_bytes=8 * track_length, 188 | format_char_sequence="ii" * track_length, 189 | ) 190 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 191 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 192 | points3D[point3D_id] = Point3D( 193 | id=point3D_id, 194 | xyz=xyz, 195 | rgb=rgb, 196 | error=error, 197 | image_ids=image_ids, 198 | point2D_idxs=point2D_idxs, 199 | ) 200 | return points3D 201 | 202 | 203 | def qvec2rotmat(qvec): 204 | return np.array( 205 | [ 206 | [ 207 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 208 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 209 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 210 | ], 211 | [ 212 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 213 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 214 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 215 | ], 216 | [ 217 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 218 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 219 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 220 | ], 221 | ] 222 | ) 223 | -------------------------------------------------------------------------------- /splat_py/structs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | TILE_EDGE_LENGTH_PX = 16 5 | 6 | 7 | class GSMetrics: 8 | def __init__(self): 9 | self.train_psnr = [] 10 | self.test_psnr = [] 11 | self.num_gaussians = [] 12 | 13 | 14 | class Image: 15 | """ 16 | Image and Pose information 17 | """ 18 | 19 | def __init__( 20 | self, 21 | image, # loaded image [HxWx3], 8bit, RGB 22 | camera_id, # camera id associated with the image 23 | camera_T_world, # world to camera transform matrix [4x4] 24 | ): 25 | self.image = image 26 | self.camera_id = camera_id 27 | self.camera_T_world = camera_T_world 28 | 29 | 30 | class Camera: 31 | """ 32 | Basic Pinhole Camera class 33 | """ 34 | 35 | def __init__( 36 | self, 37 | width, # image width 38 | height, # image height 39 | K, # camera matrix [3x3] 40 | ): 41 | self.width = width 42 | self.height = height 43 | self.K = K 44 | 45 | 46 | class Gaussians(torch.nn.Module): 47 | """ 48 | Contains all mutable gaussian parameters 49 | """ 50 | 51 | def __init__( 52 | self, 53 | xyz, # Nx3 [x, y, z] 54 | rgb, # Nx3 [r, g, b] normalized to [0, 1] 55 | opacity, # Nx1 [opacity] from [0, 1] 56 | scale, # Nx3 [sx, sy, sz] 57 | quaternion, # Nx4 [qw, qx, qy, qz] 58 | sh=None, 59 | ): 60 | super().__init__() 61 | self.xyz = xyz 62 | self.rgb = rgb 63 | self.opacity = opacity 64 | self.scale = scale 65 | self.quaternion = quaternion 66 | self.sh = sh 67 | 68 | self.verify_sizes() 69 | 70 | def __len__(self): 71 | return self.xyz.shape[0] 72 | 73 | def verify_sizes(self): 74 | num_gaussians = self.xyz.shape[0] 75 | assert self.rgb.shape[0] == num_gaussians 76 | if self.sh is not None: 77 | assert self.sh.shape[0] == num_gaussians 78 | 79 | assert self.opacity.shape[0] == num_gaussians 80 | assert self.scale.shape[0] == num_gaussians 81 | assert self.quaternion.shape[0] == num_gaussians 82 | 83 | assert self.xyz.shape[1] == 3 84 | assert self.rgb.shape[1] == 3 85 | if self.sh is not None: 86 | assert self.sh.shape[1] == 3 87 | 88 | assert self.opacity.shape[1] == 1 89 | assert self.scale.shape[1] == 3 90 | assert self.quaternion.shape[1] == 4 91 | 92 | def filter_in_place(self, keep_mask): 93 | self.xyz = torch.nn.Parameter(self.xyz.detach()[keep_mask, :]) 94 | self.rgb = torch.nn.Parameter(self.rgb.detach()[keep_mask, :]) 95 | if self.sh is not None: 96 | self.sh = torch.nn.Parameter(self.sh.detach()[keep_mask, :]) 97 | self.opacity = torch.nn.Parameter(self.opacity.detach()[keep_mask]) 98 | self.scale = torch.nn.Parameter(self.scale.detach()[keep_mask, :]) 99 | self.quaternion = torch.nn.Parameter(self.quaternion.detach()[keep_mask, :]) 100 | self.verify_sizes() 101 | 102 | def append(self, xyz, rgb, opacity, scale, quaternion, sh=None): 103 | self.xyz = torch.nn.Parameter(torch.cat((self.xyz.detach(), xyz.detach()), dim=0)) 104 | self.rgb = torch.nn.Parameter(torch.cat((self.rgb.detach(), rgb.detach()), dim=0)) 105 | if sh is not None: 106 | self.sh = torch.nn.Parameter(torch.cat((self.sh.detach(), sh.detach()), dim=0)) 107 | self.opacity = torch.nn.Parameter( 108 | torch.cat((self.opacity.detach(), opacity.detach()), dim=0) 109 | ) 110 | self.scale = torch.nn.Parameter(torch.cat((self.scale.detach(), scale.detach()), dim=0)) 111 | self.quaternion = torch.nn.Parameter( 112 | torch.cat((self.quaternion.detach(), quaternion.detach()), dim=0) 113 | ) 114 | self.verify_sizes() 115 | 116 | 117 | class Tiles: 118 | """ 119 | Tiles for rasterization 120 | """ 121 | 122 | def __init__(self, image_height, image_width, device): 123 | self.image_height = image_height 124 | self.image_width = image_width 125 | self.device = device 126 | self.tile_edge_size = TILE_EDGE_LENGTH_PX 127 | 128 | # Need to round up to the nearest multiple of TILE_EDGE_LENGTH_PX to ensure all pixels are covered 129 | self.image_height_padded = int( 130 | np.ceil(image_height / self.tile_edge_size) * self.tile_edge_size 131 | ) 132 | self.image_width_padded = int( 133 | np.ceil(image_width / self.tile_edge_size) * self.tile_edge_size 134 | ) 135 | 136 | self.y_tiles_count = int(self.image_height_padded / self.tile_edge_size) 137 | self.x_tiles_count = int(self.image_width_padded / self.tile_edge_size) 138 | self.tile_count = self.y_tiles_count * self.x_tiles_count 139 | -------------------------------------------------------------------------------- /splat_py/tile_culling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from splat_cuda import ( 4 | get_sorted_gaussian_list, 5 | ) 6 | 7 | 8 | def get_splats( 9 | uvs, 10 | tiles, 11 | conic, 12 | xyz_camera_frame, 13 | mh_dist, 14 | ): 15 | # nan in xyz will cause an unrecoverable sorting failure 16 | if torch.any(~torch.isfinite(xyz_camera_frame)): 17 | print("xyz_camera_frame has NaN") 18 | exit() 19 | return get_sorted_gaussian_list( 20 | 1024, 21 | uvs, 22 | xyz_camera_frame, 23 | conic, 24 | tiles.x_tiles_count, 25 | tiles.y_tiles_count, 26 | mh_dist, 27 | ) 28 | -------------------------------------------------------------------------------- /splat_py/trainer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchmetrics.image import StructuralSimilarityIndexMeasure 5 | 6 | from splat_py.optimizer_manager import OptimizerManager 7 | from splat_py.structs import GSMetrics 8 | from splat_py.rasterize import rasterize 9 | from splat_py.utils import ( 10 | inverse_sigmoid, 11 | quaternion_to_rotation_torch, 12 | ) 13 | 14 | 15 | class SplatTrainer: 16 | def __init__(self, gaussians, images, cameras, config): 17 | self.gaussians = gaussians 18 | self.images = images 19 | self.cameras = cameras 20 | self.config = config 21 | 22 | self.optimizer_manager = OptimizerManager(gaussians, self.config) 23 | self.metrics = GSMetrics() 24 | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.gaussians.xyz.device) 25 | 26 | self.reset_grad_accum() 27 | self.setup_test_train_split() 28 | self.setup_test_images() 29 | 30 | def setup_test_train_split(self): 31 | num_images = len(self.images) 32 | all_images = np.arange(num_images) 33 | self.test_split = np.arange(0, num_images, self.config.test_split_ratio) 34 | self.train_split = np.array(list(set(all_images) - set(self.test_split))) 35 | 36 | # setup for sampling the train split using torch.multinomial 37 | self.train_split = torch.tensor( 38 | self.train_split, dtype=torch.int, device=self.gaussians.xyz.device 39 | ) 40 | self.train_prob = torch.ones( 41 | len(self.train_split), dtype=torch.float32, device=self.gaussians.xyz.device 42 | ) / len(self.train_split) 43 | 44 | def setup_test_images(self): 45 | for image_idx in range(len(self.images)): 46 | self.images[image_idx].image = ( 47 | self.images[image_idx].image.to(torch.float32) / self.config.saturated_pixel_value 48 | ) 49 | 50 | def reset_grad_accum(self): 51 | # reset grad accumulators 52 | self.uv_grad_accum = torch.zeros( 53 | (self.gaussians.xyz.shape[0], 2), 54 | dtype=self.gaussians.xyz.dtype, 55 | device=self.gaussians.xyz.device, 56 | ) 57 | self.xyz_grad_accum = torch.zeros( 58 | self.gaussians.xyz.shape, 59 | dtype=self.gaussians.xyz.dtype, 60 | device=self.gaussians.xyz.device, 61 | ) 62 | self.grad_accum_count = torch.zeros( 63 | self.gaussians.xyz.shape[0], 64 | dtype=torch.int, 65 | device=self.gaussians.xyz.device, 66 | ) 67 | 68 | def reset_opacity(self): 69 | print("\t\tResetting opacity") 70 | self.gaussians.opacity = torch.nn.Parameter( 71 | torch.ones_like(self.gaussians.opacity) 72 | * inverse_sigmoid(self.config.reset_opacity_value) 73 | ) 74 | self.optimizer_manager.reset_opacity_exp_avg(self.gaussians) 75 | self.reset_grad_accum() 76 | 77 | def add_sh_band(self): 78 | num_gaussians = self.gaussians.xyz.shape[0] 79 | if self.config.max_sh_band == 0: 80 | return 81 | elif self.gaussians.sh is None: 82 | new_sh = torch.zeros( 83 | num_gaussians, 84 | 3, 85 | 3, 86 | dtype=self.gaussians.rgb.dtype, 87 | device=self.gaussians.rgb.device, 88 | ) 89 | self.gaussians.sh = torch.nn.Parameter(new_sh) 90 | self.optimizer_manager.add_sh_to_optimizer(self.gaussians) 91 | elif self.gaussians.sh.shape[2] == 3 and self.config.max_sh_band > 1: 92 | new_sh = torch.zeros( 93 | num_gaussians, 94 | 3, 95 | 8, 96 | dtype=self.gaussians.rgb.dtype, 97 | device=self.gaussians.rgb.device, 98 | ) 99 | new_sh[:, :, :3] = self.gaussians.sh 100 | self.gaussians.sh = torch.nn.Parameter(new_sh) 101 | self.optimizer_manager.add_sh_band_to_optimizer(self.gaussians) 102 | elif self.gaussians.sh.shape[2] == 8 and self.config.max_sh_band > 2: 103 | new_sh = torch.zeros( 104 | num_gaussians, 105 | 3, 106 | 15, 107 | dtype=self.gaussians.rgb.dtype, 108 | device=self.gaussians.rgb.device, 109 | ) 110 | new_sh[:, :, :8] = self.gaussians.sh 111 | self.gaussians.sh = torch.nn.Parameter(new_sh) 112 | self.optimizer_manager.add_sh_band_to_optimizer(self.gaussians) 113 | 114 | def delete_gaussians(self, keep_mask): 115 | self.gaussians.filter_in_place(keep_mask) 116 | self.uv_grad_accum = self.uv_grad_accum[keep_mask, :] 117 | self.xyz_grad_accum = self.xyz_grad_accum[keep_mask, :] 118 | self.grad_accum_count = self.grad_accum_count[keep_mask] 119 | 120 | # remove deleted gaussians from optimizer 121 | self.optimizer_manager.delete_gaussians_from_optimizer(self.gaussians, keep_mask) 122 | 123 | def clone_gaussians(self, clone_mask, xyz_grad_avg): 124 | # create cloned gaussians 125 | cloned_xyz = self.gaussians.xyz[clone_mask, :].clone().detach() 126 | cloned_xyz -= xyz_grad_avg[clone_mask, :] * 0.01 127 | cloned_quaternion = self.gaussians.quaternion[clone_mask, :].clone().detach() 128 | cloned_scale = self.gaussians.scale[clone_mask, :].clone().detach() 129 | cloned_opacity = self.gaussians.opacity[clone_mask].clone().detach() 130 | cloned_rgb = self.gaussians.rgb[clone_mask, :].clone().detach() 131 | if self.gaussians.sh is not None: 132 | cloned_sh = self.gaussians.sh[clone_mask, :].clone().detach() 133 | 134 | # keep grads up to date 135 | self.uv_grad_accum = torch.cat( 136 | [self.uv_grad_accum, self.uv_grad_accum[clone_mask, :]], dim=0 137 | ) 138 | self.xyz_grad_accum = torch.cat( 139 | [self.xyz_grad_accum, self.xyz_grad_accum[clone_mask, :]], dim=0 140 | ) 141 | self.grad_accum_count = torch.cat( 142 | [self.grad_accum_count, self.grad_accum_count[clone_mask]], dim=0 143 | ) 144 | 145 | # clone gaussians 146 | if self.gaussians.sh is not None: 147 | self.gaussians.append( 148 | cloned_xyz, 149 | cloned_rgb, 150 | cloned_opacity, 151 | cloned_scale, 152 | cloned_quaternion, 153 | cloned_sh, 154 | ) 155 | else: 156 | self.gaussians.append( 157 | cloned_xyz, cloned_rgb, cloned_opacity, cloned_scale, cloned_quaternion 158 | ) 159 | self.optimizer_manager.add_gaussians_to_optimizer( 160 | self.gaussians, torch.sum(clone_mask).detach().cpu().numpy() 161 | ) 162 | 163 | def split_gaussians(self, split_mask): 164 | samples = self.config.num_split_samples 165 | # create split gaussians 166 | split_quaternion = ( 167 | self.gaussians.quaternion[split_mask, :].clone().detach().repeat(samples, 1) 168 | ) 169 | split_scale = self.gaussians.scale[split_mask, :].clone().detach().repeat(samples, 1) 170 | split_opacity = self.gaussians.opacity[split_mask].clone().detach().repeat(samples, 1) 171 | split_rgb = self.gaussians.rgb[split_mask, :].clone().detach().repeat(samples, 1) 172 | if self.gaussians.sh is not None: 173 | split_sh = self.gaussians.sh[split_mask, :].clone().detach().repeat(samples, 1, 1) 174 | split_xyz = self.gaussians.xyz[split_mask, :].clone().detach().repeat(samples, 1) 175 | 176 | # centered random samples 177 | random_samples = torch.rand(split_mask.sum() * samples, 3, device=self.gaussians.xyz.device) 178 | # scale by scale factors 179 | scale_factors = torch.exp(split_scale) 180 | random_samples = random_samples * scale_factors 181 | # rotate by quaternion 182 | split_quaternion = split_quaternion / torch.norm(split_quaternion, dim=1, keepdim=True) 183 | split_rotations = quaternion_to_rotation_torch(split_quaternion) 184 | 185 | random_samples = torch.bmm(split_rotations, random_samples.unsqueeze(-1)).squeeze(-1) 186 | # translate by original mean locations 187 | split_xyz += random_samples 188 | 189 | # update scale 190 | split_scale = torch.log(torch.exp(split_scale) / self.config.split_scale_factor) 191 | 192 | # delete original split gaussians 193 | self.delete_gaussians(~split_mask) 194 | 195 | # add split gaussians 196 | if self.gaussians.sh is not None: 197 | self.gaussians.append( 198 | split_xyz, split_rgb, split_opacity, split_scale, split_quaternion, split_sh 199 | ) 200 | else: 201 | self.gaussians.append( 202 | split_xyz, split_rgb, split_opacity, split_scale, split_quaternion 203 | ) 204 | self.optimizer_manager.add_gaussians_to_optimizer( 205 | self.gaussians, torch.sum(split_mask).detach().cpu().numpy() * samples 206 | ) 207 | 208 | def adaptive_density_control(self, iter): 209 | if not (self.config.use_delete or self.config.use_clone or self.config.use_split): 210 | return 211 | print("Adaptive_density control update") 212 | 213 | # Step 1. Delete gaussians 214 | # low opacity 215 | keep_mask = self.gaussians.opacity > inverse_sigmoid(self.config.delete_opacity_threshold) 216 | keep_mask = keep_mask.squeeze(1) 217 | print("\tlow opacity mask: ", torch.sum(~keep_mask).detach().cpu().numpy()) 218 | # no views or grad 219 | zero_view_mask = self.grad_accum_count == 0 220 | zero_grad_mask = torch.norm(self.uv_grad_accum, dim=1) == 0.0 221 | print("\tzero view mask: ", torch.sum(zero_view_mask).detach().cpu().numpy()) 222 | print("\tzero grad mask: ", torch.sum(zero_grad_mask).detach().cpu().numpy()) 223 | keep_mask &= ~zero_view_mask 224 | keep_mask &= ~zero_grad_mask 225 | 226 | delete_count = torch.sum(~keep_mask).detach().cpu().numpy() 227 | print("\tDeleting: ", delete_count) 228 | if (delete_count > 0) and self.config.use_delete: 229 | self.delete_gaussians(keep_mask) 230 | 231 | if len(self.gaussians) > self.config.max_gaussians: 232 | print("Max gaussians exceeded, skipping densification") 233 | self.reset_grad_accum() 234 | return 235 | 236 | # Step 2. Densify gaussians 237 | uv_grad_avg = self.uv_grad_accum / self.grad_accum_count.unsqueeze(1).float() 238 | xyz_grad_avg = self.xyz_grad_accum / self.grad_accum_count.unsqueeze(1).float() 239 | 240 | uv_grad_avg_norm = torch.norm(uv_grad_avg, dim=1) 241 | 242 | if self.config.use_fractional_densification: 243 | if self.config.use_adaptive_fractional_densification: 244 | scale_factor = ( 245 | float(self.config.adaptive_control_end - iter) 246 | / float(self.config.adaptive_control_end - self.config.adaptive_control_start) 247 | * 2.0 248 | ) 249 | else: 250 | scale_factor = 1.0 251 | uv_percentile = 1.0 - (1.0 - self.config.uv_grad_percentile) * scale_factor 252 | uv_split_val = torch.quantile(uv_grad_avg_norm, uv_percentile).item() 253 | else: 254 | uv_split_val = self.config.uv_grad_threshold 255 | densify_mask = uv_grad_avg_norm > uv_split_val 256 | print( 257 | "\tDensify mask: ", 258 | torch.sum(densify_mask).detach().cpu().numpy(), 259 | "split_val", 260 | uv_split_val, 261 | ) 262 | 263 | scale_max = self.gaussians.scale.exp().max(dim=-1).values 264 | clone_mask = densify_mask & (scale_max <= self.config.clone_scale_threshold) 265 | print("\tClone Mask: ", torch.sum(clone_mask).detach().cpu().numpy()) 266 | 267 | # Step 2.1 clone gaussians 268 | if clone_mask.any() and self.config.use_clone: 269 | self.clone_gaussians(clone_mask, xyz_grad_avg) 270 | # keep masks up to date 271 | densify_mask = torch.cat([densify_mask, densify_mask[clone_mask]], dim=0) 272 | scale_max = torch.cat([scale_max, scale_max[clone_mask]], dim=0) 273 | 274 | split_mask = densify_mask & (scale_max > self.config.clone_scale_threshold) 275 | 276 | if self.config.use_adaptive_fractional_densification: 277 | scale_factor = ( 278 | float(self.config.adaptive_control_end - iter) 279 | / float(self.config.adaptive_control_end - self.config.adaptive_control_start) 280 | * 2.0 281 | ) 282 | else: 283 | scale_factor = 1.0 284 | scale_percentile = 1.0 - (1.0 - self.config.scale_norm_percentile) * scale_factor 285 | 286 | scale_split = torch.quantile(scale_max, scale_percentile).item() 287 | too_big_mask = scale_max > scale_split 288 | split_mask = split_mask | too_big_mask 289 | 290 | print("\tSplit Mask: ", torch.sum(split_mask).detach().cpu().numpy()) 291 | # Step 2.2 split gaussians 292 | if split_mask.any() and self.config.use_split: 293 | self.split_gaussians(split_mask) 294 | 295 | self.reset_grad_accum() 296 | 297 | def compute_test_psnr(self, save_test_images=False, iter=0): 298 | with torch.no_grad(): 299 | test_psnrs = [] 300 | test_ssim = [] 301 | for test_img_idx in self.test_split: 302 | test_camera_T_world = self.images[test_img_idx].camera_T_world 303 | test_camera = self.cameras[self.images[test_img_idx].camera_id] 304 | 305 | ( 306 | test_image, 307 | _, 308 | _, 309 | ) = rasterize( 310 | self.gaussians, 311 | test_camera_T_world, 312 | test_camera, 313 | near_thresh=self.config.near_thresh, 314 | far_thresh=self.config.far_thresh, 315 | cull_mask_padding=self.config.cull_mask_padding, 316 | mh_dist=self.config.mh_dist, 317 | use_sh_precompute=self.config.use_sh_precompute, 318 | background_rgb=torch.zeros( 319 | 3, device=self.gaussians.xyz.device, dtype=self.gaussians.xyz.dtype 320 | ), 321 | ) 322 | gt_image = self.images[test_img_idx].image.to(torch.device("cuda")) 323 | 324 | l2_loss = torch.nn.functional.mse_loss(test_image.clip(0, 1), gt_image) 325 | psnr = -10 * torch.log10(l2_loss).item() 326 | 327 | ssim = self.ssim( 328 | test_image.unsqueeze(0).permute(0, 3, 1, 2).clip(0, 1), 329 | gt_image.unsqueeze(0).permute(0, 3, 1, 2), 330 | ) 331 | 332 | test_psnrs.append(psnr) 333 | test_ssim.append(ssim.item()) 334 | 335 | if save_test_images: 336 | debug_image = test_image.clip(0, 1).detach().cpu().numpy() 337 | cv2.imwrite( 338 | "{}/iter{}_test_image_{}.png".format( 339 | self.config.output_dir, iter, test_img_idx 340 | ), 341 | (debug_image * self.config.saturated_pixel_value).astype(np.uint8)[ 342 | ..., ::-1 343 | ], 344 | ) 345 | 346 | return torch.tensor(test_psnrs), torch.tensor(test_ssim) 347 | 348 | def splat_and_compute_loss(self, image_idx, camera_T_world, camera, background_rgb): 349 | image, culling_mask, uv = rasterize( 350 | self.gaussians, 351 | camera_T_world, 352 | camera, 353 | near_thresh=self.config.near_thresh, 354 | far_thresh=self.config.far_thresh, 355 | cull_mask_padding=self.config.cull_mask_padding, 356 | mh_dist=self.config.mh_dist, 357 | use_sh_precompute=self.config.use_sh_precompute, 358 | background_rgb=background_rgb, 359 | ) 360 | uv.retain_grad() 361 | 362 | gt_image = self.images[image_idx].image.to(torch.device("cuda")) 363 | l1_loss = torch.nn.functional.l1_loss(image, gt_image) 364 | 365 | # for debug only 366 | l2_loss = torch.nn.functional.mse_loss(image, gt_image) 367 | psnr = -10 * torch.log10(l2_loss) 368 | 369 | # channel first tensor for SSIM 370 | ssim_loss = 1.0 - self.ssim( 371 | image.unsqueeze(0).permute(0, 3, 1, 2), 372 | gt_image.unsqueeze(0).permute(0, 3, 1, 2), 373 | ) 374 | loss = (1.0 - self.config.ssim_frac) * l1_loss + self.config.ssim_frac * ssim_loss 375 | loss.backward() 376 | self.optimizer_manager.optimizer.step() 377 | 378 | # scale uv grad back to world coordinates - this way, uv grad is consistent across multiple cameras 379 | uv_grad = uv.grad.detach() 380 | uv_grad[:, 0] = uv_grad[:, 0] * camera.K[0, 0] 381 | uv_grad[:, 1] = uv_grad[:, 1] * camera.K[1, 1] 382 | 383 | self.uv_grad_accum[~culling_mask] += torch.abs(uv_grad) 384 | self.xyz_grad_accum += torch.abs(self.gaussians.xyz.grad.detach()) 385 | self.grad_accum_count += (~culling_mask).int() 386 | 387 | return image, psnr 388 | 389 | def train(self): 390 | for i in range(self.config.num_iters): 391 | 392 | self.optimizer_manager.optimizer.zero_grad() 393 | # compute test PSNR right after zero grad to minimize memory usage 394 | if i % self.config.test_eval_interval == 0: 395 | test_psnrs, test_ssims = self.compute_test_psnr() 396 | self.metrics.test_psnr.append(test_psnrs.mean().item()) 397 | print( 398 | "\t\t\t\t\t\tTEST SPLIT PSNR: {}, SSIM: {}".format( 399 | test_psnrs.mean().item(), test_ssims.mean().item() 400 | ) 401 | ) 402 | 403 | image_idx = self.train_split[ 404 | torch.multinomial(self.train_prob, num_samples=1, replacement=False).item() 405 | ] 406 | camera_T_world = self.images[image_idx].camera_T_world 407 | camera = self.cameras[self.images[image_idx].camera_id] 408 | 409 | background_rgb = torch.zeros( 410 | 3, device=self.gaussians.xyz.device, dtype=self.gaussians.xyz.dtype 411 | ) 412 | if self.config.use_background and i < self.config.use_background_end: 413 | background_rgb = ( 414 | torch.ones(3, device=self.gaussians.xyz.device, dtype=self.gaussians.xyz.dtype) 415 | * float(i % 255) 416 | / 255.0 417 | ) 418 | image, psnr = self.splat_and_compute_loss( 419 | image_idx, camera_T_world, camera, background_rgb=background_rgb 420 | ) 421 | self.metrics.train_psnr.append(psnr.item()) 422 | self.metrics.num_gaussians.append(self.gaussians.xyz.shape[0]) 423 | 424 | if i % self.config.print_interval == 0: 425 | print( 426 | "Iter: {}, PSNR: {}, N: {}".format( 427 | i, psnr.detach().cpu().numpy(), self.gaussians.xyz.shape[0] 428 | ) 429 | ) 430 | 431 | if ( 432 | i > self.config.adaptive_control_start 433 | and i % self.config.adaptive_control_interval == 0 434 | and i < self.config.adaptive_control_end 435 | ): 436 | self.adaptive_density_control(i) 437 | 438 | if ( 439 | i > self.config.reset_opacity_start 440 | and i < self.config.reset_opacity_end 441 | and i % self.config.reset_opacity_interval == 0 442 | ): 443 | self.reset_opacity() 444 | 445 | if i > 0 and i % self.config.add_sh_band_interval == 0: 446 | self.add_sh_band() 447 | 448 | if i % self.config.save_debug_image_interval == 0: 449 | debug_image = image.clip(0, 1).detach().cpu().numpy() 450 | cv2.imwrite( 451 | "{}/iter{}_image_{}.png".format(self.config.output_dir, i, image_idx), 452 | (debug_image * self.config.saturated_pixel_value).astype(np.uint8)[..., ::-1], 453 | ) 454 | if i > 0 and i % self.config.checkpoint_interval == 0: 455 | with torch.no_grad(): 456 | torch.save( 457 | self.gaussians, 458 | "{}/gaussians_iter_{}.pt".format(self.config.output_dir, i), 459 | ) 460 | final_psnrs, final_ssim = self.compute_test_psnr(save_test_images=True, iter=i) 461 | print( 462 | "Final PSNR: {}, SSIM: {}".format(final_psnrs.mean().item(), final_ssim.mean().item()) 463 | ) 464 | -------------------------------------------------------------------------------- /splat_py/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.spatial import KDTree 4 | 5 | 6 | def inverse_sigmoid(x): 7 | """ 8 | Inverse of sigmoid activation 9 | """ 10 | clipped = np.clip(x, 1e-4, 1 - 1e-4) 11 | return np.log(clipped / (1.0 - (clipped))) 12 | 13 | 14 | def inverse_sigmoid_torch(x): 15 | clipped = torch.clip(x, 1e-4, 1 - 1e-4) 16 | return torch.log(clipped / (1.0 - (clipped))) 17 | 18 | 19 | def compute_initial_scale_from_sparse_points( 20 | points, num_neighbors, neighbor_dist_to_scale_factor, max_initial_scale 21 | ): 22 | """ 23 | Computes the initial gaussian scale from the distance to the nearest points 24 | """ 25 | points_np = points.cpu().numpy() 26 | tree = KDTree(points_np) 27 | 28 | n_pts = points_np.shape[0] 29 | scale = torch.zeros(n_pts, 3, dtype=torch.float32) 30 | for pt_idx in range(n_pts): 31 | neighbor_dist_vect, _ = tree.query(points_np[pt_idx, :], k=num_neighbors, workers=-1) 32 | initial_scale = min(np.mean(neighbor_dist_vect), max_initial_scale) 33 | # use log since scale has exp activation 34 | scale[pt_idx, :] = torch.ones(3, dtype=torch.float32) * np.log( 35 | initial_scale * neighbor_dist_to_scale_factor 36 | ) 37 | return scale 38 | 39 | 40 | def quaternion_to_rotation_torch(q): 41 | """' 42 | Convert tensor of normalized quaternion [N, 4] in [w, x, y, z] format to rotation matrices 43 | [N, 3, 3] 44 | """ 45 | rot = [ 46 | 1 - 2 * q[:, 2] ** 2 - 2 * q[:, 3] ** 2, 47 | 2 * q[:, 1] * q[:, 2] - 2 * q[:, 0] * q[:, 3], 48 | 2 * q[:, 3] * q[:, 1] + 2 * q[:, 0] * q[:, 2], 49 | 2 * q[:, 1] * q[:, 2] + 2 * q[:, 0] * q[:, 3], 50 | 1 - 2 * q[:, 1] ** 2 - 2 * q[:, 3] ** 2, 51 | 2 * q[:, 2] * q[:, 3] - 2 * q[:, 0] * q[:, 1], 52 | 2 * q[:, 3] * q[:, 1] - 2 * q[:, 0] * q[:, 2], 53 | 2 * q[:, 2] * q[:, 3] + 2 * q[:, 0] * q[:, 1], 54 | 1 - 2 * q[:, 1] ** 2 - 2 * q[:, 2] ** 2, 55 | ] 56 | rot = torch.stack(rot, dim=1).reshape(-1, 3, 3) 57 | return rot 58 | 59 | 60 | def transform_points_torch(pts, transform): # N x 3 # N x 4 x 4 61 | """ 62 | Transform points by a 4x4 matrix 63 | """ 64 | pts = torch.cat([pts, torch.ones(pts.shape[0], 1, dtype=pts.dtype, device=pts.device)], dim=1) 65 | transformed_pts = torch.matmul(transform, pts.unsqueeze(-1)).squeeze(-1)[:, :3] 66 | 67 | if torch.isnan(transformed_pts).any(): 68 | print("NaN in transform_points_torch") 69 | filtered_tensor = pts[torch.any(transformed_pts.isnan(), dim=1)] 70 | print(filtered_tensor.detach().cpu().numpy()) 71 | 72 | return transformed_pts.contiguous() 73 | 74 | 75 | def compute_rays(camera): 76 | """ 77 | Compute rays in camera space 78 | """ 79 | # grid of uv coordinates 80 | u = torch.linspace( 81 | 0, camera.width - 1, camera.width, dtype=camera.K.dtype, device=camera.K.device 82 | ) 83 | v = torch.linspace( 84 | 0, 85 | camera.height - 1, 86 | camera.height, 87 | dtype=camera.K.dtype, 88 | device=camera.K.device, 89 | ) 90 | 91 | # use (v, u) order to preserve row-major order 92 | v, u = torch.meshgrid(v, u, indexing="ij") 93 | v = v.flatten() 94 | u = u.flatten() 95 | 96 | K = camera.K 97 | # Inverse pinhole projection 98 | # fx * x/z + cx = u => x/z = (u - cx) / fx, z = 1 99 | # fy * y/z + cy = v => y/z = (v - cy) / fy, z = 1 100 | ray_dir = torch.stack( 101 | [ 102 | (u - K[0, 2]) / K[0, 0], 103 | (v - K[1, 2]) / K[1, 1], 104 | torch.ones_like(u), 105 | ], 106 | dim=-1, 107 | ) 108 | ray_dir = ray_dir / torch.norm(ray_dir, dim=1, keepdim=True) 109 | return ray_dir 110 | 111 | 112 | def compute_rays_in_world_frame(camera, camera_T_world): 113 | """ 114 | Compute rays in world space 115 | """ 116 | rays = compute_rays(camera) 117 | # transform rays to world space 118 | world_T_camera = torch.inverse(camera_T_world) 119 | rays = (world_T_camera[:3, :3] @ rays.T).T 120 | rays = rays / torch.norm(rays, dim=1, keepdim=True) 121 | rays = rays.reshape(camera.height, camera.width, 3) 122 | rays = rays.contiguous() 123 | return rays 124 | -------------------------------------------------------------------------------- /src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void render_tiles_cuda( 4 | torch::Tensor uvs, 5 | torch::Tensor opacity, 6 | torch::Tensor rgb, 7 | torch::Tensor conic, 8 | torch::Tensor view_dir_by_pixel, 9 | torch::Tensor splat_start_end_idx_by_tile_idx, 10 | torch::Tensor gaussian_idx_by_splat_idx, 11 | torch::Tensor background_rgb, 12 | torch::Tensor num_splats_per_pixel, 13 | torch::Tensor final_weight_per_pixel, 14 | torch::Tensor rendered_image 15 | ); 16 | 17 | void render_tiles_backward_cuda( 18 | torch::Tensor uvs, 19 | torch::Tensor opacity, 20 | torch::Tensor rgb, 21 | torch::Tensor conic, 22 | torch::Tensor view_dir_by_pixel, 23 | torch::Tensor splat_start_end_idx_by_tile_idx, 24 | torch::Tensor gaussian_idx_by_splat_idx, 25 | torch::Tensor background_rgb, 26 | torch::Tensor num_splats_per_pixel, 27 | torch::Tensor final_weight_per_pixel, 28 | torch::Tensor grad_image, 29 | torch::Tensor grad_rgb, 30 | torch::Tensor grad_opacity, 31 | torch::Tensor grad_uvs, 32 | torch::Tensor grad_conic 33 | ); 34 | 35 | void camera_projection_cuda(torch::Tensor xyz, torch::Tensor K, torch::Tensor uv); 36 | 37 | void camera_projection_backward_cuda( 38 | torch::Tensor xyz, 39 | torch::Tensor K, 40 | torch::Tensor uv_grad_out, 41 | torch::Tensor xyz_grad_in 42 | ); 43 | 44 | void compute_sigma_world_cuda( 45 | torch::Tensor quaternion, 46 | torch::Tensor scale, 47 | torch::Tensor sigma_world 48 | ); 49 | 50 | void compute_sigma_world_backward_cuda( 51 | torch::Tensor quaternion, 52 | torch::Tensor scale, 53 | torch::Tensor sigma_world_grad_out, 54 | torch::Tensor quaternion_grad_in, 55 | torch::Tensor scale_grad_in 56 | ); 57 | 58 | void compute_projection_jacobian_cuda(torch::Tensor xyz, torch::Tensor K, torch::Tensor J); 59 | 60 | void compute_projection_jacobian_backward_cuda( 61 | torch::Tensor xyz, 62 | torch::Tensor K, 63 | torch::Tensor jac_grad_out, 64 | torch::Tensor xyz_grad_in 65 | ); 66 | 67 | void compute_conic_cuda( 68 | torch::Tensor sigma_world, 69 | torch::Tensor J, 70 | torch::Tensor camera_T_world, 71 | torch::Tensor conic 72 | ); 73 | 74 | void compute_conic_backward_cuda( 75 | torch::Tensor sigma_world, 76 | torch::Tensor J, 77 | torch::Tensor camera_T_world, 78 | torch::Tensor conic_grad_out, 79 | torch::Tensor sigma_world_grad_in, 80 | torch::Tensor J_grad_in 81 | ); 82 | 83 | std::tuple get_sorted_gaussian_list( 84 | const int max_tiles_per_gaussian, 85 | torch::Tensor uvs, 86 | torch::Tensor xyz_camera_frame, 87 | torch::Tensor conic, 88 | const int n_tiles_x, 89 | const int n_tiles_y, 90 | const float mh_dist 91 | ); 92 | 93 | void precompute_rgb_from_sh_cuda( 94 | const torch::Tensor xyz, 95 | const torch::Tensor sh_coeff, 96 | const torch::Tensor camera_T_world, 97 | torch::Tensor rgb 98 | ); 99 | 100 | void precompute_rgb_from_sh_backward_cuda( 101 | const torch::Tensor xyz, 102 | const torch::Tensor camera_T_world, 103 | const torch::Tensor grad_rgb, 104 | torch::Tensor grad_sh 105 | ); 106 | 107 | void render_depth_cuda( 108 | torch::Tensor xyz_camera_frame, 109 | torch::Tensor uvs, 110 | torch::Tensor opacity, 111 | torch::Tensor conic, 112 | torch::Tensor splat_start_end_idx_by_tile_idx, 113 | torch::Tensor gaussian_idx_by_splat_idx, 114 | const float alpha_threshold, 115 | torch::Tensor depth_image 116 | ); 117 | 118 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 119 | m.def("render_tiles_cuda", &render_tiles_cuda, "Render tiles CUDA"); 120 | m.def("render_tiles_backward_cuda", &render_tiles_backward_cuda, "Render tiles backward"); 121 | m.def("camera_projection_cuda", &camera_projection_cuda, "project point into image CUDA"); 122 | m.def( 123 | "camera_projection_backward_cuda", 124 | &camera_projection_backward_cuda, 125 | "project point into image backward CUDA" 126 | ); 127 | m.def("compute_sigma_world_cuda", &compute_sigma_world_cuda, "compute sigma world CUDA"); 128 | m.def( 129 | "compute_sigma_world_backward_cuda", 130 | &compute_sigma_world_backward_cuda, 131 | "compute sigma world backward CUDA" 132 | ); 133 | m.def( 134 | "compute_projection_jacobian_cuda", 135 | &compute_projection_jacobian_cuda, 136 | "compute projection jacobian CUDA" 137 | ); 138 | m.def( 139 | "compute_projection_jacobian_backward_cuda", 140 | &compute_projection_jacobian_backward_cuda, 141 | "compute projection jacobian backward CUDA" 142 | ); 143 | m.def("compute_conic_cuda", &compute_conic_cuda, "compute conic CUDA"); 144 | m.def( 145 | "compute_conic_backward_cuda", &compute_conic_backward_cuda, "compute conic backward CUDA" 146 | ); 147 | m.def("get_sorted_gaussian_list", &get_sorted_gaussian_list, "get sorted gaussian list"); 148 | m.def( 149 | "precompute_rgb_from_sh_cuda", 150 | &precompute_rgb_from_sh_cuda, 151 | "precompute rgb from sh per gaussian" 152 | ); 153 | m.def( 154 | "precompute_rgb_from_sh_backward_cuda", 155 | &precompute_rgb_from_sh_backward_cuda, 156 | "precompute rgb from sh per gaussian backward" 157 | ); 158 | m.def("render_depth_cuda", &render_depth_cuda, "Render depth CUDA"); 159 | } 160 | -------------------------------------------------------------------------------- /src/checks.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef CHECKS_H 3 | #define CHECKS_H 4 | 5 | #define CHECK_IS_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " is not a CUDA tensor") 6 | #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " is not a contiguous tensor") 7 | #define CHECK_VALID_INPUT(x) \ 8 | CHECK_IS_CUDA(x); \ 9 | CHECK_IS_CONTIGUOUS(x) 10 | 11 | #define CHECK_FLOAT_TENSOR(x) TORCH_CHECK(x.dtype() == torch::kFloat32, #x " is not a float tensor") 12 | #define CHECK_DOUBLE_TENSOR(x) \ 13 | TORCH_CHECK(x.dtype() == torch::kFloat64, #x " is not a double tensor") 14 | #define CHECK_INT_TENSOR(x) TORCH_CHECK(x.dtype() == torch::kInt32, #x " is not an int tensor") 15 | 16 | #endif 17 | -------------------------------------------------------------------------------- /src/depth.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "checks.cuh" 6 | 7 | template 8 | __global__ void render_depth_kernel( 9 | const float* __restrict__ xyz_camera_frame, 10 | const float* __restrict__ uvs, 11 | const float* __restrict__ opacity, 12 | const float* __restrict__ conic, 13 | const int* __restrict__ splat_start_end_idx_by_tile_idx, 14 | const int* __restrict__ gaussian_idx_by_splat_idx, 15 | const int image_width, 16 | const int image_height, 17 | const float alpha_threshold, 18 | float* __restrict__ depth_image 19 | ) { 20 | // grid = tiles, blocks = pixels within each tile 21 | const int u_splat = blockIdx.x * blockDim.x + threadIdx.x; 22 | const int v_splat = blockIdx.y * blockDim.y + threadIdx.y; 23 | const int tile_idx = blockIdx.x + blockIdx.y * gridDim.x; 24 | 25 | // keep threads around even if pixel is not valid for copying data 26 | bool valid_pixel = u_splat < image_width && v_splat < image_height; 27 | 28 | const int splat_idx_start = splat_start_end_idx_by_tile_idx[tile_idx]; 29 | const int splat_idx_end = splat_start_end_idx_by_tile_idx[tile_idx + 1]; 30 | int num_splats_this_tile = splat_idx_end - splat_idx_start; 31 | 32 | const int thread_id = threadIdx.x + threadIdx.y * blockDim.x; 33 | const int block_size = blockDim.x * blockDim.y; 34 | 35 | float alpha_accum = 0.0; 36 | 37 | // shared memory copies of inputs 38 | __shared__ int _gaussian_idx_by_splat_idx[CHUNK_SIZE]; 39 | __shared__ float _uvs[CHUNK_SIZE * 2]; 40 | __shared__ float _opacity[CHUNK_SIZE]; 41 | __shared__ float _conic[CHUNK_SIZE * 3]; 42 | 43 | const int num_chunks = (num_splats_this_tile + CHUNK_SIZE - 1) / CHUNK_SIZE; 44 | bool found_depth = false; 45 | // copy chunks 46 | for (int chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { 47 | __syncthreads(); // make sure previous iteration is complete before 48 | // modifying inputs 49 | for (int i = thread_id; i < CHUNK_SIZE; i += block_size) { 50 | const int tile_splat_idx = chunk_idx * CHUNK_SIZE + i; 51 | if (tile_splat_idx >= num_splats_this_tile) { 52 | break; 53 | } 54 | const int global_splat_idx = splat_idx_start + tile_splat_idx; 55 | 56 | const int gaussian_idx = gaussian_idx_by_splat_idx[global_splat_idx]; 57 | _gaussian_idx_by_splat_idx[i] = gaussian_idx; 58 | _uvs[i * 2 + 0] = uvs[gaussian_idx * 2 + 0]; 59 | _uvs[i * 2 + 1] = uvs[gaussian_idx * 2 + 1]; 60 | _opacity[i] = opacity[gaussian_idx]; 61 | 62 | #pragma unroll 63 | for (int j = 0; j < 3; j++) { 64 | _conic[i * 3 + j] = conic[gaussian_idx * 3 + j]; 65 | } 66 | } 67 | __syncthreads(); // wait for copying to complete before attempting to 68 | // use data 69 | if (valid_pixel && !found_depth) { 70 | int chunk_start = chunk_idx * CHUNK_SIZE; 71 | int chunk_end = min((chunk_idx + 1) * CHUNK_SIZE, num_splats_this_tile); 72 | int num_splats_this_chunk = chunk_end - chunk_start; 73 | for (int i = 0; i < num_splats_this_chunk; i++) { 74 | const float u_mean = _uvs[i * 2 + 0]; 75 | const float v_mean = _uvs[i * 2 + 1]; 76 | 77 | const float u_diff = __int2float_rn(u_splat) - u_mean; 78 | const float v_diff = __int2float_rn(v_splat) - v_mean; 79 | 80 | // 2d covariance matrix - add 0.25 to diagonal to make it positive definite rather 81 | // than semi-definite 82 | const float a = _conic[i * 3 + 0] + 0.25; 83 | const float b = _conic[i * 3 + 1] * 0.5; 84 | const float c = _conic[i * 3 + 2] + 0.25; 85 | const float det = a * c - b * b; 86 | 87 | float alpha = 0.0; 88 | // compute mahalanobis distance 89 | const float mh_sq = 90 | (c * u_diff * u_diff - (b + b) * u_diff * v_diff + a * v_diff * v_diff) / det; 91 | if (mh_sq > 0.0) { 92 | // probablity at this pixel normalized to have 93 | // probability at the center of the gaussian to be 1.0 94 | const float norm_prob = __expf(-0.5 * mh_sq); 95 | alpha = _opacity[i] * norm_prob; 96 | } 97 | const float weight = alpha * (1.0 - alpha_accum); 98 | alpha_accum += weight; 99 | 100 | if (alpha_accum > alpha_threshold) { 101 | // get depth from this gaussians 102 | const int gaussian_idx = _gaussian_idx_by_splat_idx[i]; 103 | const float x = xyz_camera_frame[gaussian_idx * 3 + 0]; 104 | const float y = xyz_camera_frame[gaussian_idx * 3 + 1]; 105 | const float z = xyz_camera_frame[gaussian_idx * 3 + 2]; 106 | const float depth = sqrt(x * x + y * y + z * z); 107 | 108 | depth_image[v_splat * image_width + u_splat] = depth; 109 | found_depth = true; 110 | break; 111 | } 112 | } // end splat loop 113 | } // valid pixel check 114 | } // end chunk loop 115 | } 116 | 117 | void render_depth_cuda( 118 | torch::Tensor xyz_camera_frame, 119 | torch::Tensor uvs, 120 | torch::Tensor opacity, 121 | torch::Tensor conic, 122 | torch::Tensor splat_start_end_idx_by_tile_idx, 123 | torch::Tensor gaussian_idx_by_splat_idx, 124 | const float alpha_threshold, 125 | torch::Tensor depth_image 126 | ) { 127 | CHECK_VALID_INPUT(xyz_camera_frame); 128 | CHECK_VALID_INPUT(uvs); 129 | CHECK_VALID_INPUT(opacity); 130 | CHECK_VALID_INPUT(conic); 131 | CHECK_VALID_INPUT(splat_start_end_idx_by_tile_idx); 132 | CHECK_VALID_INPUT(gaussian_idx_by_splat_idx); 133 | CHECK_VALID_INPUT(depth_image); 134 | 135 | int N = uvs.size(0); 136 | TORCH_CHECK(uvs.size(1) == 2, "uvs must be Nx2 (u, v)"); 137 | TORCH_CHECK( 138 | xyz_camera_frame.size(0) == N, 139 | "xyz_camera_frame must have the same number of elements as uvs" 140 | ); 141 | TORCH_CHECK(xyz_camera_frame.size(1) == 3, "xyz_camera_frame must be Nx3"); 142 | TORCH_CHECK(opacity.size(0) == N, "Opacity must have the same number of elements as uvs"); 143 | TORCH_CHECK(opacity.size(1) == 1, "Opacity must be Nx1"); 144 | TORCH_CHECK(conic.size(0) == N, "Conic must have the same number of elements as uvs"); 145 | TORCH_CHECK(conic.size(1) == 3, "Conic must be Nx3"); 146 | int image_height = depth_image.size(0); 147 | int image_width = depth_image.size(1); 148 | TORCH_CHECK(depth_image.size(2) == 1, "Depth Image must be HxWx1"); 149 | 150 | int num_tiles_x = (image_width + 16 - 1) / 16; 151 | int num_tiles_y = (image_height + 16 - 1) / 16; 152 | 153 | dim3 block_size(16, 16, 1); 154 | dim3 grid_size(num_tiles_x, num_tiles_y, 1); 155 | 156 | CHECK_FLOAT_TENSOR(xyz_camera_frame); 157 | CHECK_FLOAT_TENSOR(uvs); 158 | CHECK_FLOAT_TENSOR(opacity); 159 | CHECK_FLOAT_TENSOR(conic); 160 | CHECK_INT_TENSOR(splat_start_end_idx_by_tile_idx); 161 | CHECK_INT_TENSOR(gaussian_idx_by_splat_idx); 162 | CHECK_FLOAT_TENSOR(depth_image); 163 | 164 | render_depth_kernel<960><<>>( 165 | xyz_camera_frame.data_ptr(), 166 | uvs.data_ptr(), 167 | opacity.data_ptr(), 168 | conic.data_ptr(), 169 | splat_start_end_idx_by_tile_idx.data_ptr(), 170 | gaussian_idx_by_splat_idx.data_ptr(), 171 | image_width, 172 | image_height, 173 | alpha_threshold, 174 | depth_image.data_ptr() 175 | ); 176 | cudaDeviceSynchronize(); 177 | } 178 | -------------------------------------------------------------------------------- /src/matrix.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | template 5 | __device__ void transpose(const T* A, T* A_T, int num_rows_input, int num_cols_input) { 6 | #pragma unroll 7 | for (int row = 0; row < num_rows_input; row++) { 8 | #pragma unroll 9 | for (int col = 0; col < num_cols_input; col++) { 10 | A_T[col * num_rows_input + row] = A[row * num_cols_input + col]; 11 | } 12 | } 13 | } 14 | 15 | template 16 | __device__ void 17 | matrix_multiply(const T* A, const T* B, T* C, int num_rows_A, int num_cols_A, int num_cols_B) { 18 | #pragma unroll 19 | for (int row_a = 0; row_a < num_rows_A; row_a++) { 20 | #pragma unroll 21 | for (int col_b = 0; col_b < num_cols_B; col_b++) { 22 | T sum = 0; 23 | #pragma unroll 24 | for (int cols_A = 0; cols_A < num_cols_A; cols_A++) { 25 | sum += A[row_a * num_cols_A + cols_A] * B[cols_A * num_cols_B + col_b]; 26 | } 27 | C[row_a * num_cols_B + col_b] = sum; 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/precompute_sh.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "checks.cuh" 4 | #include "spherical_harmonics.cuh" 5 | 6 | //(TODO) remove N_SH templating 7 | template 8 | __global__ void precompute_rgb_from_sh_kernel( 9 | const T* __restrict__ xyz, 10 | const T* __restrict__ sh_coeff, 11 | const T camera_x, 12 | const T camera_y, 13 | const T camera_z, 14 | const unsigned int N, 15 | T* __restrict__ rgb 16 | ) { 17 | const int gaussian_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (gaussian_idx >= N) { 19 | return; 20 | } 21 | 22 | if (N_SH == 1) { 23 | #pragma unroll 24 | for (int channel = 0; channel < 3; channel++) { 25 | rgb[gaussian_idx * 3 + channel] = sh_coeff[gaussian_idx * 3 + channel]; 26 | } 27 | } else { 28 | // compute normalized view direction 29 | T view_dir[3] = { 30 | xyz[gaussian_idx * 3 + 0] - camera_x, 31 | xyz[gaussian_idx * 3 + 1] - camera_y, 32 | xyz[gaussian_idx * 3 + 2] - camera_z}; 33 | const T r_view_dir_norm = rsqrt( 34 | view_dir[0] * view_dir[0] + view_dir[1] * view_dir[1] + view_dir[2] * view_dir[2] 35 | ); 36 | #pragma unroll 37 | for (int i = 0; i < 3; ++i) { 38 | view_dir[i] *= r_view_dir_norm; 39 | } 40 | 41 | T sh_at_view_dir[N_SH]; 42 | compute_sh_coeffs_for_view_dir(view_dir, sh_at_view_dir); 43 | 44 | #pragma unroll 45 | for (int channel = 0; channel < 3; channel++) { 46 | T temp_rgb = 0.0; 47 | #pragma unroll 48 | for (int sh_idx = 0; sh_idx < N_SH; sh_idx++) { 49 | temp_rgb += sh_at_view_dir[sh_idx] * 50 | sh_coeff[gaussian_idx * N_SH * 3 + N_SH * channel + sh_idx]; 51 | } 52 | // divide by SH_0 to maintain compatibility with downstream rasterizer 53 | temp_rgb *= r_SH_0; 54 | // set value on output 55 | rgb[gaussian_idx * 3 + channel] = temp_rgb; 56 | } 57 | } 58 | } 59 | 60 | template 61 | __global__ void precompute_rgb_from_sh_backward_kernel( 62 | const T* __restrict__ xyz, 63 | const T camera_x, 64 | const T camera_y, 65 | const T camera_z, 66 | const T* __restrict__ grad_rgb, 67 | const unsigned int N, 68 | T* __restrict__ grad_sh 69 | ) { 70 | const int gaussian_idx = blockIdx.x * blockDim.x + threadIdx.x; 71 | if (gaussian_idx >= N) { 72 | return; 73 | } 74 | if (N_SH == 1) { 75 | #pragma unroll 76 | for (int channel = 0; channel < 3; channel++) { 77 | grad_sh[gaussian_idx * 3 + channel] = grad_rgb[gaussian_idx * 3 + channel]; 78 | } 79 | } else { 80 | // compute normalized view direction 81 | T view_dir[3] = { 82 | xyz[gaussian_idx * 3 + 0] - camera_x, 83 | xyz[gaussian_idx * 3 + 1] - camera_y, 84 | xyz[gaussian_idx * 3 + 2] - camera_z}; 85 | const T r_view_dir_norm = rsqrt( 86 | view_dir[0] * view_dir[0] + view_dir[1] * view_dir[1] + view_dir[2] * view_dir[2] 87 | ); 88 | #pragma unroll 89 | for (int i = 0; i < 3; ++i) { 90 | view_dir[i] *= r_view_dir_norm; 91 | } 92 | 93 | T sh_at_view_dir[N_SH]; 94 | compute_sh_coeffs_for_view_dir(view_dir, sh_at_view_dir); 95 | 96 | // make local copy and undo scaling by SH_0 97 | T grad_rgb_local[3] = { 98 | grad_rgb[gaussian_idx * 3 + 0] * r_SH_0, 99 | grad_rgb[gaussian_idx * 3 + 1] * r_SH_0, 100 | grad_rgb[gaussian_idx * 3 + 2] * r_SH_0}; 101 | 102 | #pragma unroll 103 | for (int channel = 0; channel < 3; channel++) { 104 | #pragma unroll 105 | for (int sh_idx = 0; sh_idx < N_SH; sh_idx++) { 106 | grad_sh[gaussian_idx * N_SH * 3 + N_SH * channel + sh_idx] = 107 | grad_rgb_local[channel] * sh_at_view_dir[sh_idx]; 108 | } 109 | } 110 | } 111 | } 112 | 113 | void precompute_rgb_from_sh_cuda( 114 | const torch::Tensor xyz, 115 | const torch::Tensor sh_coeff, 116 | const torch::Tensor camera_T_world, 117 | torch::Tensor rgb 118 | ) { 119 | CHECK_VALID_INPUT(xyz); 120 | CHECK_VALID_INPUT(sh_coeff); 121 | CHECK_VALID_INPUT(camera_T_world); 122 | CHECK_VALID_INPUT(rgb); 123 | 124 | const int N = xyz.size(0); 125 | TORCH_CHECK(xyz.size(1) == 3, "Input xyz should have 3 channels"); 126 | TORCH_CHECK(sh_coeff.size(0) == N, "N xyz and sh_coeff should match"); 127 | TORCH_CHECK(sh_coeff.size(1) == 3, "SH coefficients should have 3 channels"); 128 | int num_sh_coeff; 129 | if (sh_coeff.dim() == 3) { 130 | num_sh_coeff = sh_coeff.size(2); 131 | } else { 132 | num_sh_coeff = 1; 133 | } 134 | TORCH_CHECK(camera_T_world.size(0) == 4, "camera_T_world should be 4x4 transformation matrix"); 135 | TORCH_CHECK(camera_T_world.size(1) == 4, "camera_T_world should be 4x4 transformation matrix"); 136 | TORCH_CHECK(rgb.size(0) == N, "N xyz and rgb should match"); 137 | TORCH_CHECK(rgb.size(1) == 3, "Output rgb should have 3 channels"); 138 | 139 | const int max_threads_per_block = 1024; 140 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 141 | dim3 gridsize(num_blocks, 1, 1); 142 | dim3 blocksize(max_threads_per_block, 1, 1); 143 | 144 | if (xyz.dtype() == torch::kFloat32) { 145 | CHECK_FLOAT_TENSOR(sh_coeff); 146 | CHECK_FLOAT_TENSOR(camera_T_world); 147 | CHECK_FLOAT_TENSOR(rgb); 148 | 149 | const float camera_x = camera_T_world[0][3].item(); 150 | const float camera_y = camera_T_world[1][3].item(); 151 | const float camera_z = camera_T_world[2][3].item(); 152 | if (num_sh_coeff == 1) { 153 | precompute_rgb_from_sh_kernel<<>>( 154 | xyz.data_ptr(), 155 | sh_coeff.data_ptr(), 156 | camera_x, 157 | camera_y, 158 | camera_z, 159 | N, 160 | rgb.data_ptr() 161 | ); 162 | } else if (num_sh_coeff == 4) { 163 | precompute_rgb_from_sh_kernel<<>>( 164 | xyz.data_ptr(), 165 | sh_coeff.data_ptr(), 166 | camera_x, 167 | camera_y, 168 | camera_z, 169 | N, 170 | rgb.data_ptr() 171 | ); 172 | } else if (num_sh_coeff == 9) { 173 | precompute_rgb_from_sh_kernel<<>>( 174 | xyz.data_ptr(), 175 | sh_coeff.data_ptr(), 176 | camera_x, 177 | camera_y, 178 | camera_z, 179 | N, 180 | rgb.data_ptr() 181 | ); 182 | } else if (num_sh_coeff == 16) { 183 | precompute_rgb_from_sh_kernel<<>>( 184 | xyz.data_ptr(), 185 | sh_coeff.data_ptr(), 186 | camera_x, 187 | camera_y, 188 | camera_z, 189 | N, 190 | rgb.data_ptr() 191 | ); 192 | } else { 193 | AT_ERROR("Unsupported number of SH coefficients: ", num_sh_coeff); 194 | } 195 | } else if (xyz.dtype() == torch::kFloat64) { 196 | CHECK_DOUBLE_TENSOR(sh_coeff); 197 | CHECK_DOUBLE_TENSOR(camera_T_world); 198 | CHECK_DOUBLE_TENSOR(rgb); 199 | 200 | const double camera_x = camera_T_world[0][3].item(); 201 | const double camera_y = camera_T_world[1][3].item(); 202 | const double camera_z = camera_T_world[2][3].item(); 203 | if (num_sh_coeff == 1) { 204 | precompute_rgb_from_sh_kernel<<>>( 205 | xyz.data_ptr(), 206 | sh_coeff.data_ptr(), 207 | camera_x, 208 | camera_y, 209 | camera_z, 210 | N, 211 | rgb.data_ptr() 212 | ); 213 | } else if (num_sh_coeff == 4) { 214 | precompute_rgb_from_sh_kernel<<>>( 215 | xyz.data_ptr(), 216 | sh_coeff.data_ptr(), 217 | camera_x, 218 | camera_y, 219 | camera_z, 220 | N, 221 | rgb.data_ptr() 222 | ); 223 | } else if (num_sh_coeff == 9) { 224 | precompute_rgb_from_sh_kernel<<>>( 225 | xyz.data_ptr(), 226 | sh_coeff.data_ptr(), 227 | camera_x, 228 | camera_y, 229 | camera_z, 230 | N, 231 | rgb.data_ptr() 232 | ); 233 | } else if (num_sh_coeff == 16) { 234 | precompute_rgb_from_sh_kernel<<>>( 235 | xyz.data_ptr(), 236 | sh_coeff.data_ptr(), 237 | camera_x, 238 | camera_y, 239 | camera_z, 240 | N, 241 | rgb.data_ptr() 242 | ); 243 | } else { 244 | AT_ERROR("Unsupported number of SH coefficients: ", num_sh_coeff); 245 | } 246 | } else { 247 | AT_ERROR("Unsupported data type: ", xyz.dtype()); 248 | } 249 | cudaDeviceSynchronize(); 250 | } 251 | 252 | void precompute_rgb_from_sh_backward_cuda( 253 | const torch::Tensor xyz, 254 | const torch::Tensor camera_T_world, 255 | const torch::Tensor grad_rgb, 256 | torch::Tensor grad_sh 257 | ) { 258 | CHECK_VALID_INPUT(xyz); 259 | CHECK_VALID_INPUT(camera_T_world); 260 | CHECK_VALID_INPUT(grad_rgb); 261 | CHECK_VALID_INPUT(grad_sh); 262 | 263 | const int N = xyz.size(0); 264 | TORCH_CHECK(xyz.size(1) == 3, "Input xyz should have 3 channels"); 265 | TORCH_CHECK(camera_T_world.size(0) == 4, "camera_T_world should be 4x4 transformation matrix"); 266 | TORCH_CHECK(camera_T_world.size(1) == 4, "camera_T_world should be 4x4 transformation matrix"); 267 | TORCH_CHECK(grad_rgb.size(0) == N, "N xyz and grad_rgb should match"); 268 | TORCH_CHECK(grad_rgb.size(1) == 3, "Input grad_rgb should have 3 channels"); 269 | TORCH_CHECK(grad_sh.size(0) == N, "N xyz and grad_sh should match"); 270 | TORCH_CHECK(grad_sh.size(1) == 3, "Output grad_sh should have 3 channels"); 271 | int num_sh_coeff; 272 | if (grad_sh.dim() == 3) { 273 | num_sh_coeff = grad_sh.size(2); 274 | } else { 275 | num_sh_coeff = 1; 276 | } 277 | 278 | const int max_threads_per_block = 1024; 279 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 280 | dim3 gridsize(num_blocks, 1, 1); 281 | dim3 blocksize(max_threads_per_block, 1, 1); 282 | 283 | if (xyz.dtype() == torch::kFloat32) { 284 | CHECK_FLOAT_TENSOR(camera_T_world); 285 | CHECK_FLOAT_TENSOR(grad_rgb); 286 | CHECK_FLOAT_TENSOR(grad_sh); 287 | 288 | const float camera_x = camera_T_world[0][3].item(); 289 | const float camera_y = camera_T_world[1][3].item(); 290 | const float camera_z = camera_T_world[2][3].item(); 291 | if (num_sh_coeff == 1) { 292 | precompute_rgb_from_sh_backward_kernel<<>>( 293 | xyz.data_ptr(), 294 | camera_x, 295 | camera_y, 296 | camera_z, 297 | grad_rgb.data_ptr(), 298 | N, 299 | grad_sh.data_ptr() 300 | ); 301 | } else if (num_sh_coeff == 4) { 302 | precompute_rgb_from_sh_backward_kernel<<>>( 303 | xyz.data_ptr(), 304 | camera_x, 305 | camera_y, 306 | camera_z, 307 | grad_rgb.data_ptr(), 308 | N, 309 | grad_sh.data_ptr() 310 | ); 311 | } else if (num_sh_coeff == 9) { 312 | precompute_rgb_from_sh_backward_kernel<<>>( 313 | xyz.data_ptr(), 314 | camera_x, 315 | camera_y, 316 | camera_z, 317 | grad_rgb.data_ptr(), 318 | N, 319 | grad_sh.data_ptr() 320 | ); 321 | } else if (num_sh_coeff == 16) { 322 | precompute_rgb_from_sh_backward_kernel<<>>( 323 | xyz.data_ptr(), 324 | camera_x, 325 | camera_y, 326 | camera_z, 327 | grad_rgb.data_ptr(), 328 | N, 329 | grad_sh.data_ptr() 330 | ); 331 | } else { 332 | AT_ERROR("Unsupported number of SH coefficients: ", num_sh_coeff); 333 | } 334 | } else if (xyz.dtype() == torch::kFloat64) { 335 | CHECK_DOUBLE_TENSOR(camera_T_world); 336 | CHECK_DOUBLE_TENSOR(grad_rgb); 337 | CHECK_DOUBLE_TENSOR(grad_sh); 338 | 339 | const double camera_x = camera_T_world[0][3].item(); 340 | const double camera_y = camera_T_world[1][3].item(); 341 | const double camera_z = camera_T_world[2][3].item(); 342 | if (num_sh_coeff == 1) { 343 | precompute_rgb_from_sh_backward_kernel<<>>( 344 | xyz.data_ptr(), 345 | camera_x, 346 | camera_y, 347 | camera_z, 348 | grad_rgb.data_ptr(), 349 | N, 350 | grad_sh.data_ptr() 351 | ); 352 | } else if (num_sh_coeff == 4) { 353 | precompute_rgb_from_sh_backward_kernel<<>>( 354 | xyz.data_ptr(), 355 | camera_x, 356 | camera_y, 357 | camera_z, 358 | grad_rgb.data_ptr(), 359 | N, 360 | grad_sh.data_ptr() 361 | ); 362 | } else if (num_sh_coeff == 9) { 363 | precompute_rgb_from_sh_backward_kernel<<>>( 364 | xyz.data_ptr(), 365 | camera_x, 366 | camera_y, 367 | camera_z, 368 | grad_rgb.data_ptr(), 369 | N, 370 | grad_sh.data_ptr() 371 | ); 372 | } else if (num_sh_coeff == 16) { 373 | precompute_rgb_from_sh_backward_kernel<<>>( 374 | xyz.data_ptr(), 375 | camera_x, 376 | camera_y, 377 | camera_z, 378 | grad_rgb.data_ptr(), 379 | N, 380 | grad_sh.data_ptr() 381 | ); 382 | } else { 383 | AT_ERROR("Unsupported number of SH coefficients: ", num_sh_coeff); 384 | } 385 | } else { 386 | AT_ERROR("Unsupported data type: ", xyz.dtype()); 387 | } 388 | cudaDeviceSynchronize(); 389 | } 390 | -------------------------------------------------------------------------------- /src/projection.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "checks.cuh" 6 | #include "matrix.cuh" 7 | 8 | template 9 | __global__ void 10 | camera_projection_kernel(const T* __restrict__ xyz, const T* __restrict__ K, const int N, T* uv) { 11 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 12 | if (i >= N) { 13 | return; 14 | } 15 | // u = fx * X / Z + cx 16 | uv[i * 2 + 0] = K[0] * xyz[i * 3 + 0] / xyz[i * 3 + 2] + K[2]; 17 | // v = fy * Y / Z + cy 18 | uv[i * 2 + 1] = K[4] * xyz[i * 3 + 1] / xyz[i * 3 + 2] + K[5]; 19 | } 20 | 21 | void camera_projection_cuda(torch::Tensor xyz, torch::Tensor K, torch::Tensor uv) { 22 | CHECK_VALID_INPUT(xyz); 23 | CHECK_VALID_INPUT(K); 24 | CHECK_VALID_INPUT(uv); 25 | 26 | const int N = xyz.size(0); 27 | TORCH_CHECK(xyz.size(1) == 3, "xyz must have shape Nx3"); 28 | TORCH_CHECK(K.size(0) == 3, "K must have shape 3x3"); 29 | TORCH_CHECK(K.size(1) == 3, "K must have shape 3x3"); 30 | TORCH_CHECK(uv.size(0) == N, "uv must have shape Nx2"); 31 | TORCH_CHECK(uv.size(1) == 2, "uv must have shape Nx2"); 32 | 33 | const int max_threads_per_block = 1024; 34 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 35 | dim3 gridsize(num_blocks, 1, 1); 36 | dim3 blocksize(max_threads_per_block, 1, 1); 37 | 38 | if (xyz.dtype() == torch::kFloat32) { 39 | CHECK_FLOAT_TENSOR(K); 40 | CHECK_FLOAT_TENSOR(uv); 41 | camera_projection_kernel<<>>( 42 | xyz.data_ptr(), K.data_ptr(), N, uv.data_ptr() 43 | ); 44 | } else if (xyz.dtype() == torch::kFloat64) { 45 | CHECK_DOUBLE_TENSOR(K); 46 | CHECK_DOUBLE_TENSOR(uv); 47 | camera_projection_kernel<<>>( 48 | xyz.data_ptr(), K.data_ptr(), N, uv.data_ptr() 49 | ); 50 | } else { 51 | AT_ERROR("Inputs must be float32 or float64"); 52 | } 53 | cudaDeviceSynchronize(); 54 | } 55 | 56 | template 57 | __global__ void compute_sigma_world_kernel( 58 | const T* __restrict__ quaternion, 59 | const T* __restrict__ scale, 60 | const int N, 61 | T* sigma_world 62 | ) { 63 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 64 | if (i >= N) { 65 | return; 66 | } 67 | T qw = quaternion[i * 4 + 0]; 68 | T qx = quaternion[i * 4 + 1]; 69 | T qy = quaternion[i * 4 + 2]; 70 | T qz = quaternion[i * 4 + 3]; 71 | 72 | T norm = sqrt(qx * qx + qy * qy + qz * qz + qw * qw); 73 | 74 | // // zero magnitude quaternion is not valid 75 | qx /= norm; 76 | qy /= norm; 77 | qz /= norm; 78 | qw /= norm; 79 | 80 | T r00 = 1 - 2 * qy * qy - 2 * qz * qz; 81 | T r01 = 2 * qx * qy - 2 * qz * qw; 82 | T r02 = 2 * qx * qz + 2 * qy * qw; 83 | T r10 = 2 * qx * qy + 2 * qz * qw; 84 | T r11 = 1 - 2 * qx * qx - 2 * qz * qz; 85 | T r12 = 2 * qy * qz - 2 * qx * qw; 86 | T r20 = 2 * qx * qz - 2 * qy * qw; 87 | T r21 = 2 * qy * qz + 2 * qx * qw; 88 | T r22 = 1 - 2 * qx * qx - 2 * qy * qy; 89 | 90 | T sx = exp(scale[i * 3 + 0]); 91 | T sy = exp(scale[i * 3 + 1]); 92 | T sz = exp(scale[i * 3 + 2]); 93 | 94 | T sx_sq = sx * sx; 95 | T sy_sq = sy * sy; 96 | T sz_sq = sz * sz; 97 | 98 | sigma_world[i * 9 + 0] = r00 * r00 * sx_sq + r01 * r01 * sy_sq + r02 * r02 * sz_sq; 99 | sigma_world[i * 9 + 1] = r00 * r10 * sx_sq + r01 * r11 * sy_sq + r02 * r12 * sz_sq; 100 | sigma_world[i * 9 + 2] = r00 * r20 * sx_sq + r01 * r21 * sy_sq + r02 * r22 * sz_sq; 101 | 102 | sigma_world[i * 9 + 3] = r00 * r10 * sx_sq + r01 * r11 * sy_sq + r02 * r12 * sz_sq; 103 | sigma_world[i * 9 + 4] = r10 * r10 * sx_sq + r11 * r11 * sy_sq + r12 * r12 * sz_sq; 104 | sigma_world[i * 9 + 5] = r10 * r20 * sx_sq + r11 * r21 * sy_sq + r12 * r22 * sz_sq; 105 | 106 | sigma_world[i * 9 + 6] = r00 * r20 * sx_sq + r01 * r21 * sy_sq + r02 * r22 * sz_sq; 107 | sigma_world[i * 9 + 7] = r10 * r20 * sx_sq + r11 * r21 * sy_sq + r12 * r22 * sz_sq; 108 | sigma_world[i * 9 + 8] = r20 * r20 * sx_sq + r21 * r21 * sy_sq + r22 * r22 * sz_sq; 109 | } 110 | 111 | void compute_sigma_world_cuda( 112 | torch::Tensor quaternion, 113 | torch::Tensor scale, 114 | torch::Tensor sigma_world 115 | ) { 116 | CHECK_VALID_INPUT(quaternion); 117 | CHECK_VALID_INPUT(scale); 118 | CHECK_VALID_INPUT(sigma_world); 119 | 120 | const int N = quaternion.size(0); 121 | TORCH_CHECK(quaternion.size(1) == 4, "quaternion must have shape Nx4"); 122 | TORCH_CHECK(scale.size(0) == N, "scale must have shape Nx1"); 123 | TORCH_CHECK(sigma_world.size(0) == N, "sigma_world must have shape Nx3x3"); 124 | TORCH_CHECK(sigma_world.size(1) == 3, "sigma_world must have shape Nx3x3"); 125 | TORCH_CHECK(sigma_world.size(2) == 3, "sigma_world must have shape Nx3x3"); 126 | 127 | // can probably update this to improve perf 128 | const int max_threads_per_block = 1024; 129 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 130 | dim3 gridsize(num_blocks, 1, 1); 131 | dim3 blocksize(max_threads_per_block, 1, 1); 132 | 133 | if (quaternion.dtype() == torch::kFloat32) { 134 | CHECK_FLOAT_TENSOR(scale); 135 | CHECK_FLOAT_TENSOR(sigma_world); 136 | compute_sigma_world_kernel<<>>( 137 | quaternion.data_ptr(), scale.data_ptr(), N, sigma_world.data_ptr() 138 | ); 139 | } else if (quaternion.dtype() == torch::kFloat64) { 140 | CHECK_DOUBLE_TENSOR(scale); 141 | CHECK_DOUBLE_TENSOR(sigma_world); 142 | compute_sigma_world_kernel<<>>( 143 | quaternion.data_ptr(), 144 | scale.data_ptr(), 145 | N, 146 | sigma_world.data_ptr() 147 | ); 148 | } else { 149 | AT_ERROR("Inputs must be float32 or float64"); 150 | } 151 | cudaDeviceSynchronize(); 152 | } 153 | 154 | template 155 | __global__ void compute_projection_jacobian_kernel( 156 | const T* __restrict__ xyz, 157 | const T* __restrict__ K, 158 | const int N, 159 | T* J 160 | ) { 161 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 162 | if (i >= N) { 163 | return; 164 | } 165 | T x = xyz[i * 3 + 0]; 166 | T y = xyz[i * 3 + 1]; 167 | T z = xyz[i * 3 + 2]; 168 | 169 | J[i * 6 + 0] = K[0] / z; 170 | J[i * 6 + 1] = 0; 171 | J[i * 6 + 2] = -K[0] * x / (z * z); 172 | J[i * 6 + 3] = 0; 173 | J[i * 6 + 4] = K[4] / z; 174 | J[i * 6 + 5] = -K[4] * y / (z * z); 175 | } 176 | 177 | void compute_projection_jacobian_cuda(torch::Tensor xyz, torch::Tensor K, torch::Tensor J) { 178 | CHECK_VALID_INPUT(xyz); 179 | CHECK_VALID_INPUT(K); 180 | CHECK_VALID_INPUT(J); 181 | 182 | const int N = xyz.size(0); 183 | TORCH_CHECK(xyz.size(1) == 3, "xyz must have shape Nx3"); 184 | TORCH_CHECK(K.size(0) == 3, "K must have shape 3x3"); 185 | TORCH_CHECK(K.size(1) == 3, "K must have shape 3x3"); 186 | TORCH_CHECK(J.size(0) == N, "J must have shape Nx2x3"); 187 | TORCH_CHECK(J.size(1) == 2, "J must have shape Nx2x3"); 188 | TORCH_CHECK(J.size(2) == 3, "J must have shape Nx2x3"); 189 | 190 | const int max_threads_per_block = 1024; 191 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 192 | dim3 gridsize(num_blocks, 1, 1); 193 | dim3 blocksize(max_threads_per_block, 1, 1); 194 | 195 | if (xyz.dtype() == torch::kFloat32) { 196 | CHECK_FLOAT_TENSOR(K); 197 | CHECK_FLOAT_TENSOR(J); 198 | 199 | compute_projection_jacobian_kernel<<>>( 200 | xyz.data_ptr(), K.data_ptr(), N, J.data_ptr() 201 | ); 202 | } else if (xyz.dtype() == torch::kFloat64) { 203 | CHECK_DOUBLE_TENSOR(K); 204 | CHECK_DOUBLE_TENSOR(J); 205 | compute_projection_jacobian_kernel<<>>( 206 | xyz.data_ptr(), K.data_ptr(), N, J.data_ptr() 207 | ); 208 | } else { 209 | AT_ERROR("Inputs must be float32 or float64"); 210 | } 211 | } 212 | 213 | template 214 | __global__ void compute_conic_kernel( 215 | const T* __restrict__ sigma_world, 216 | const T* __restrict__ J, 217 | const T* __restrict__ camera_T_world, 218 | const int N, 219 | T* conic 220 | ) { 221 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 222 | if (i >= N) { 223 | return; 224 | } 225 | // get rotation matrix 226 | T W[9]; 227 | W[0] = camera_T_world[0]; 228 | W[1] = camera_T_world[1]; 229 | W[2] = camera_T_world[2]; 230 | W[3] = camera_T_world[4]; 231 | W[4] = camera_T_world[5]; 232 | W[5] = camera_T_world[6]; 233 | W[6] = camera_T_world[8]; 234 | W[7] = camera_T_world[9]; 235 | W[8] = camera_T_world[10]; 236 | 237 | // compute JW = J * W) 238 | T JW[6]; 239 | matrix_multiply(J + i * 6, W, JW, 2, 3, 3); 240 | 241 | // compute JWSigma = JW * sigma_world 242 | T JWSigma[6]; 243 | matrix_multiply(JW, sigma_world + i * 9, JWSigma, 2, 3, 3); 244 | 245 | T JW_t[6]; 246 | transpose(JW, JW_t, 2, 3); 247 | 248 | // compute sigma_image = JWSigma @ JW_t 249 | T sigma_image[4]; 250 | matrix_multiply(JWSigma, JW_t, sigma_image, 2, 3, 2); 251 | 252 | // write to conic 253 | conic[i * 3 + 0] = sigma_image[0]; 254 | // they are also equal but this keeps the pytorch autograd check happy 255 | conic[i * 3 + 1] = sigma_image[1] + sigma_image[2]; 256 | conic[i * 3 + 2] = sigma_image[3]; 257 | } 258 | 259 | void compute_conic_cuda( 260 | torch::Tensor sigma_world, 261 | torch::Tensor J, 262 | torch::Tensor camera_T_world, 263 | torch::Tensor conic 264 | ) { 265 | CHECK_VALID_INPUT(sigma_world); 266 | CHECK_VALID_INPUT(J); 267 | CHECK_VALID_INPUT(camera_T_world); 268 | CHECK_VALID_INPUT(conic); 269 | 270 | const int N = sigma_world.size(0); 271 | TORCH_CHECK(sigma_world.size(1) == 3, "sigma_world must have shape Nx3x3"); 272 | TORCH_CHECK(sigma_world.size(2) == 3, "sigma_world must have shape Nx3x3"); 273 | TORCH_CHECK(J.size(0) == N, "J must have shape Nx2x3"); 274 | TORCH_CHECK(J.size(1) == 2, "J must have shape Nx2x3"); 275 | TORCH_CHECK(J.size(2) == 3, "J must have shape Nx2x3"); 276 | TORCH_CHECK(camera_T_world.size(0) == 4, "camera_T_world must have shape 4x4"); 277 | TORCH_CHECK(camera_T_world.size(1) == 4, "camera_T_world must have shape 4x4"); 278 | TORCH_CHECK(conic.size(0) == N, "conic must have shape Nx3"); 279 | TORCH_CHECK(conic.size(1) == 3, "conic must have shape Nx3"); 280 | 281 | const int max_threads_per_block = 1024; 282 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 283 | dim3 gridsize(num_blocks, 1, 1); 284 | dim3 blocksize(max_threads_per_block, 1, 1); 285 | 286 | if (sigma_world.dtype() == torch::kFloat32) { 287 | CHECK_FLOAT_TENSOR(J); 288 | CHECK_FLOAT_TENSOR(camera_T_world); 289 | CHECK_FLOAT_TENSOR(conic); 290 | compute_conic_kernel<<>>( 291 | sigma_world.data_ptr(), 292 | J.data_ptr(), 293 | camera_T_world.data_ptr(), 294 | N, 295 | conic.data_ptr() 296 | ); 297 | } else if (sigma_world.dtype() == torch::kFloat64) { 298 | CHECK_DOUBLE_TENSOR(J); 299 | CHECK_DOUBLE_TENSOR(camera_T_world); 300 | CHECK_DOUBLE_TENSOR(conic); 301 | compute_conic_kernel<<>>( 302 | sigma_world.data_ptr(), 303 | J.data_ptr(), 304 | camera_T_world.data_ptr(), 305 | N, 306 | conic.data_ptr() 307 | ); 308 | } else { 309 | AT_ERROR("Inputs must be float32 or float64"); 310 | } 311 | } 312 | -------------------------------------------------------------------------------- /src/render.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "checks.cuh" 6 | #include "spherical_harmonics.cuh" 7 | 8 | template 9 | __global__ void render_tiles_kernel( 10 | const T* __restrict__ uvs, 11 | const T* __restrict__ opacity, 12 | const T* __restrict__ rgb, 13 | const T* __restrict__ conic, 14 | const T* __restrict__ view_dir_by_pixel, 15 | const int* __restrict__ splat_start_end_idx_by_tile_idx, 16 | const int* __restrict__ gaussian_idx_by_splat_idx, 17 | const T* __restrict__ background_rgb, 18 | const int image_width, 19 | const int image_height, 20 | const bool use_fast_exp, 21 | int* num_splats_per_pixel, 22 | T* __restrict__ final_weight_per_pixel, 23 | T* __restrict__ image 24 | ) { 25 | // grid = tiles, blocks = pixels within each tile 26 | const int u_splat = blockIdx.x * blockDim.x + threadIdx.x; 27 | const int v_splat = blockIdx.y * blockDim.y + threadIdx.y; 28 | const int tile_idx = blockIdx.x + blockIdx.y * gridDim.x; 29 | 30 | // keep threads around even if pixel is not valid for copying data 31 | bool valid_pixel = u_splat < image_width && v_splat < image_height; 32 | 33 | const int splat_idx_start = splat_start_end_idx_by_tile_idx[tile_idx]; 34 | const int splat_idx_end = splat_start_end_idx_by_tile_idx[tile_idx + 1]; 35 | int num_splats_this_tile = splat_idx_end - splat_idx_start; 36 | 37 | const int thread_id = threadIdx.x + threadIdx.y * blockDim.x; 38 | const int block_size = blockDim.x * blockDim.y; 39 | 40 | T alpha_accum = 0.0; 41 | T alpha_weight = 0.0; 42 | int num_splats = 0; 43 | 44 | T view_dir[3]; 45 | T sh_at_view_dir[N_SH]; 46 | if (valid_pixel) { 47 | #pragma unroll 48 | for (int axis = 0; axis < 3; axis++) { 49 | view_dir[axis] = view_dir_by_pixel[(v_splat * image_width + u_splat) * 3 + axis]; 50 | } 51 | compute_sh_coeffs_for_view_dir(view_dir, sh_at_view_dir); 52 | } 53 | 54 | // shared memory copies of inputs 55 | __shared__ T _uvs[CHUNK_SIZE * 2]; 56 | __shared__ T _opacity[CHUNK_SIZE]; 57 | __shared__ T _rgb[CHUNK_SIZE * 3 * N_SH]; 58 | __shared__ T _conic[CHUNK_SIZE * 3]; 59 | 60 | const int shared_image_size = 16 * 16 * 3; 61 | __shared__ T _image[shared_image_size]; 62 | 63 | #pragma unroll 64 | for (int i = thread_id; i < shared_image_size; i += block_size) { 65 | _image[i] = 0.0; 66 | } 67 | 68 | const int num_chunks = (num_splats_this_tile + CHUNK_SIZE - 1) / CHUNK_SIZE; 69 | // copy chunks 70 | for (int chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { 71 | __syncthreads(); // make sure previous iteration is complete before 72 | // modifying inputs 73 | for (int i = thread_id; i < CHUNK_SIZE; i += block_size) { 74 | const int tile_splat_idx = chunk_idx * CHUNK_SIZE + i; 75 | if (tile_splat_idx >= num_splats_this_tile) { 76 | break; 77 | } 78 | const int global_splat_idx = splat_idx_start + tile_splat_idx; 79 | 80 | const int gaussian_idx = gaussian_idx_by_splat_idx[global_splat_idx]; 81 | _uvs[i * 2 + 0] = uvs[gaussian_idx * 2 + 0]; 82 | _uvs[i * 2 + 1] = uvs[gaussian_idx * 2 + 1]; 83 | _opacity[i] = opacity[gaussian_idx]; 84 | 85 | #pragma unroll 86 | for (int sh = 0; sh < N_SH; sh++) { 87 | #pragma unroll 88 | for (int channel = 0; channel < 3; channel++) { 89 | // rgb dimensions = (splat_idx, channel_idx, sh_coeff_idx) 90 | _rgb[(i * 3 + channel) * N_SH + sh] = 91 | rgb[(gaussian_idx * 3 + channel) * N_SH + sh]; 92 | } 93 | } 94 | 95 | #pragma unroll 96 | for (int j = 0; j < 3; j++) { 97 | _conic[i * 3 + j] = conic[gaussian_idx * 3 + j]; 98 | } 99 | } 100 | __syncthreads(); // wait for copying to complete before attempting to 101 | // use data 102 | if (valid_pixel) { 103 | int chunk_start = chunk_idx * CHUNK_SIZE; 104 | int chunk_end = min((chunk_idx + 1) * CHUNK_SIZE, num_splats_this_tile); 105 | int num_splats_this_chunk = chunk_end - chunk_start; 106 | for (int i = 0; i < num_splats_this_chunk; i++) { 107 | if (alpha_accum > 0.9999) { 108 | break; 109 | } 110 | const T u_mean = _uvs[i * 2 + 0]; 111 | const T v_mean = _uvs[i * 2 + 1]; 112 | 113 | const T u_diff = T(u_splat) - u_mean; 114 | const T v_diff = T(v_splat) - v_mean; 115 | 116 | // 2d covariance matrix - add 0.25 to diagonal to make it positive definite rather 117 | // than semi-definite 118 | T a; 119 | T c; 120 | const T b = _conic[i * 3 + 1] * 0.5; 121 | if (use_fast_exp) { 122 | a = _conic[i * 3 + 0] + 0.25; 123 | c = _conic[i * 3 + 2] + 0.25; 124 | } else { 125 | a = _conic[i * 3 + 0]; 126 | c = _conic[i * 3 + 2]; 127 | } 128 | const T det = a * c - b * b; 129 | 130 | T alpha = 0.0; 131 | // compute mahalanobis distance 132 | const T mh_sq = 133 | (c * u_diff * u_diff - (b + b) * u_diff * v_diff + a * v_diff * v_diff) / det; 134 | if (mh_sq > 0.0) { 135 | // probablity at this pixel normalized to have 136 | // probability at the center of the gaussian to be 1.0 137 | T norm_prob = 0.0; 138 | if (use_fast_exp) { 139 | norm_prob = __expf(-0.5 * mh_sq); 140 | } else { 141 | norm_prob = exp(-0.5 * mh_sq); 142 | } 143 | alpha = _opacity[i] * norm_prob; 144 | } 145 | if (alpha < 0.00392156862 || !use_fast_exp) { 146 | num_splats++; 147 | continue; 148 | } 149 | alpha_weight = 1.0 - alpha_accum; 150 | const T weight = alpha * (1.0 - alpha_accum); 151 | // compute rgb 152 | T computed_rgb[3]; 153 | sh_to_rgb(_rgb + i * 3 * N_SH, sh_at_view_dir, computed_rgb); 154 | 155 | // update image 156 | #pragma unroll 157 | for (int channel = 0; channel < 3; channel++) { 158 | _image[(threadIdx.y * 16 + threadIdx.x) * 3 + channel] += 159 | computed_rgb[channel] * weight; 160 | } 161 | alpha_accum += weight; 162 | num_splats++; 163 | } // end splat loop 164 | } // valid pixel check 165 | } // end chunk loop 166 | 167 | // add background if the pixel is not saturated 168 | if (valid_pixel && alpha_accum < 0.999) { 169 | #pragma unroll 170 | for (int channel = 0; channel < 3; channel++) { 171 | _image[(threadIdx.y * 16 + threadIdx.x) * 3 + channel] += 172 | background_rgb[channel] * (1.0 - alpha_accum); 173 | } 174 | } 175 | 176 | // copy back to global memory 177 | __syncthreads(); // wait for splatting to complete 178 | if (valid_pixel) { 179 | num_splats_per_pixel[v_splat * image_width + u_splat] = num_splats; 180 | final_weight_per_pixel[v_splat * image_width + u_splat] = alpha_weight; 181 | 182 | #pragma unroll 183 | for (int channel = 0; channel < 3; channel++) { 184 | image[(v_splat * image_width + u_splat) * 3 + channel] = 185 | _image[(threadIdx.y * 16 + threadIdx.x) * 3 + channel]; 186 | } 187 | } 188 | } 189 | 190 | void render_tiles_cuda( 191 | torch::Tensor uvs, 192 | torch::Tensor opacity, 193 | torch::Tensor rgb, 194 | torch::Tensor conic, 195 | torch::Tensor view_dir_by_pixel, 196 | torch::Tensor splat_start_end_idx_by_tile_idx, 197 | torch::Tensor gaussian_idx_by_splat_idx, 198 | torch::Tensor background_rgb, 199 | torch::Tensor num_splats_per_pixel, 200 | torch::Tensor final_weight_per_pixel, 201 | torch::Tensor rendered_image 202 | ) { 203 | CHECK_VALID_INPUT(uvs); 204 | CHECK_VALID_INPUT(opacity); 205 | CHECK_VALID_INPUT(rgb); 206 | CHECK_VALID_INPUT(conic); 207 | CHECK_VALID_INPUT(view_dir_by_pixel); 208 | CHECK_VALID_INPUT(splat_start_end_idx_by_tile_idx); 209 | CHECK_VALID_INPUT(gaussian_idx_by_splat_idx); 210 | CHECK_VALID_INPUT(background_rgb); 211 | CHECK_VALID_INPUT(num_splats_per_pixel); 212 | CHECK_VALID_INPUT(final_weight_per_pixel); 213 | CHECK_VALID_INPUT(rendered_image); 214 | 215 | int N = uvs.size(0); 216 | TORCH_CHECK(uvs.size(1) == 2, "uvs must be Nx2 (u, v)"); 217 | TORCH_CHECK(opacity.size(0) == N, "Opacity must have the same number of elements as uvs"); 218 | TORCH_CHECK(opacity.size(1) == 1, "Opacity must be Nx1"); 219 | TORCH_CHECK(rgb.size(0) == N, "RGB must have the same number of elements as uvs"); 220 | TORCH_CHECK(rgb.size(1) == 3, "RGB must be Nx3"); 221 | TORCH_CHECK(conic.size(0) == N, "Conic must have the same number of elements as uvs"); 222 | TORCH_CHECK(conic.size(1) == 3, "Conic must be Nx3"); 223 | TORCH_CHECK(rendered_image.size(2) == 3, "Image must be HxWx3"); 224 | TORCH_CHECK(background_rgb.dim() == 1, "Background RGB must be 1D"); 225 | TORCH_CHECK(background_rgb.size(0) == 3, "Background RGB must have 3 elements"); 226 | 227 | int image_height = rendered_image.size(0); 228 | int image_width = rendered_image.size(1); 229 | int num_sh_coeff; 230 | if (rgb.dim() == 3) { 231 | num_sh_coeff = rgb.size(2); 232 | } else { 233 | num_sh_coeff = 1; 234 | } 235 | if (num_sh_coeff > 1) { 236 | TORCH_CHECK( 237 | view_dir_by_pixel.size(0) == image_height, 238 | "view_dir_by_pixel must have the same size as the image" 239 | ); 240 | TORCH_CHECK( 241 | view_dir_by_pixel.size(1) == image_width, 242 | "view_dir_by_pixel must have the same size as the image" 243 | ); 244 | TORCH_CHECK(view_dir_by_pixel.size(2) == 3, "view_dir_by_pixel must have 3 channels"); 245 | } 246 | 247 | int num_tiles_x = (image_width + 16 - 1) / 16; 248 | int num_tiles_y = (image_height + 16 - 1) / 16; 249 | 250 | dim3 block_size(16, 16, 1); 251 | dim3 grid_size(num_tiles_x, num_tiles_y, 1); 252 | 253 | if (uvs.dtype() == torch::kFloat32) { 254 | CHECK_FLOAT_TENSOR(opacity); 255 | CHECK_FLOAT_TENSOR(rgb); 256 | CHECK_FLOAT_TENSOR(conic); 257 | CHECK_FLOAT_TENSOR(view_dir_by_pixel); 258 | CHECK_INT_TENSOR(splat_start_end_idx_by_tile_idx); 259 | CHECK_INT_TENSOR(gaussian_idx_by_splat_idx); 260 | CHECK_FLOAT_TENSOR(background_rgb); 261 | CHECK_INT_TENSOR(num_splats_per_pixel); 262 | CHECK_FLOAT_TENSOR(final_weight_per_pixel); 263 | CHECK_FLOAT_TENSOR(rendered_image); 264 | 265 | if (num_sh_coeff == 1) { 266 | render_tiles_kernel<<>>( 267 | uvs.data_ptr(), 268 | opacity.data_ptr(), 269 | rgb.data_ptr(), 270 | conic.data_ptr(), 271 | view_dir_by_pixel.data_ptr(), 272 | splat_start_end_idx_by_tile_idx.data_ptr(), 273 | gaussian_idx_by_splat_idx.data_ptr(), 274 | background_rgb.data_ptr(), 275 | image_width, 276 | image_height, 277 | true, 278 | num_splats_per_pixel.data_ptr(), 279 | final_weight_per_pixel.data_ptr(), 280 | rendered_image.data_ptr() 281 | ); 282 | } else if (num_sh_coeff == 4) { 283 | render_tiles_kernel<<>>( 284 | uvs.data_ptr(), 285 | opacity.data_ptr(), 286 | rgb.data_ptr(), 287 | conic.data_ptr(), 288 | view_dir_by_pixel.data_ptr(), 289 | splat_start_end_idx_by_tile_idx.data_ptr(), 290 | gaussian_idx_by_splat_idx.data_ptr(), 291 | background_rgb.data_ptr(), 292 | image_width, 293 | image_height, 294 | true, 295 | num_splats_per_pixel.data_ptr(), 296 | final_weight_per_pixel.data_ptr(), 297 | rendered_image.data_ptr() 298 | ); 299 | } else if (num_sh_coeff == 9) { 300 | render_tiles_kernel<<>>( 301 | uvs.data_ptr(), 302 | opacity.data_ptr(), 303 | rgb.data_ptr(), 304 | conic.data_ptr(), 305 | view_dir_by_pixel.data_ptr(), 306 | splat_start_end_idx_by_tile_idx.data_ptr(), 307 | gaussian_idx_by_splat_idx.data_ptr(), 308 | background_rgb.data_ptr(), 309 | image_width, 310 | image_height, 311 | true, 312 | num_splats_per_pixel.data_ptr(), 313 | final_weight_per_pixel.data_ptr(), 314 | rendered_image.data_ptr() 315 | ); 316 | } else if (num_sh_coeff == 16) { 317 | render_tiles_kernel<<>>( 318 | uvs.data_ptr(), 319 | opacity.data_ptr(), 320 | rgb.data_ptr(), 321 | conic.data_ptr(), 322 | view_dir_by_pixel.data_ptr(), 323 | splat_start_end_idx_by_tile_idx.data_ptr(), 324 | gaussian_idx_by_splat_idx.data_ptr(), 325 | background_rgb.data_ptr(), 326 | image_width, 327 | image_height, 328 | true, 329 | num_splats_per_pixel.data_ptr(), 330 | final_weight_per_pixel.data_ptr(), 331 | rendered_image.data_ptr() 332 | ); 333 | } else { 334 | AT_ERROR("Unsupported number of SH coefficients: ", num_sh_coeff); 335 | } 336 | } else if (uvs.dtype() == torch::kFloat64) { 337 | CHECK_DOUBLE_TENSOR(opacity); 338 | CHECK_DOUBLE_TENSOR(rgb); 339 | CHECK_DOUBLE_TENSOR(conic); 340 | CHECK_INT_TENSOR(splat_start_end_idx_by_tile_idx); 341 | CHECK_INT_TENSOR(gaussian_idx_by_splat_idx); 342 | CHECK_DOUBLE_TENSOR(background_rgb); 343 | CHECK_INT_TENSOR(num_splats_per_pixel); 344 | CHECK_DOUBLE_TENSOR(final_weight_per_pixel); 345 | CHECK_DOUBLE_TENSOR(rendered_image); 346 | if (num_sh_coeff == 1) { 347 | render_tiles_kernel<<>>( 348 | uvs.data_ptr(), 349 | opacity.data_ptr(), 350 | rgb.data_ptr(), 351 | conic.data_ptr(), 352 | view_dir_by_pixel.data_ptr(), 353 | splat_start_end_idx_by_tile_idx.data_ptr(), 354 | gaussian_idx_by_splat_idx.data_ptr(), 355 | background_rgb.data_ptr(), 356 | image_width, 357 | image_height, 358 | false, 359 | num_splats_per_pixel.data_ptr(), 360 | final_weight_per_pixel.data_ptr(), 361 | rendered_image.data_ptr() 362 | ); 363 | } else if (num_sh_coeff == 4) { 364 | render_tiles_kernel<<>>( 365 | uvs.data_ptr(), 366 | opacity.data_ptr(), 367 | rgb.data_ptr(), 368 | conic.data_ptr(), 369 | view_dir_by_pixel.data_ptr(), 370 | splat_start_end_idx_by_tile_idx.data_ptr(), 371 | gaussian_idx_by_splat_idx.data_ptr(), 372 | background_rgb.data_ptr(), 373 | image_width, 374 | image_height, 375 | false, 376 | num_splats_per_pixel.data_ptr(), 377 | final_weight_per_pixel.data_ptr(), 378 | rendered_image.data_ptr() 379 | ); 380 | } else if (num_sh_coeff == 9) { 381 | render_tiles_kernel<<>>( 382 | uvs.data_ptr(), 383 | opacity.data_ptr(), 384 | rgb.data_ptr(), 385 | conic.data_ptr(), 386 | view_dir_by_pixel.data_ptr(), 387 | splat_start_end_idx_by_tile_idx.data_ptr(), 388 | gaussian_idx_by_splat_idx.data_ptr(), 389 | background_rgb.data_ptr(), 390 | image_width, 391 | image_height, 392 | false, 393 | num_splats_per_pixel.data_ptr(), 394 | final_weight_per_pixel.data_ptr(), 395 | rendered_image.data_ptr() 396 | ); 397 | } else if (num_sh_coeff == 16) { 398 | render_tiles_kernel<<>>( 399 | uvs.data_ptr(), 400 | opacity.data_ptr(), 401 | rgb.data_ptr(), 402 | conic.data_ptr(), 403 | view_dir_by_pixel.data_ptr(), 404 | splat_start_end_idx_by_tile_idx.data_ptr(), 405 | gaussian_idx_by_splat_idx.data_ptr(), 406 | background_rgb.data_ptr(), 407 | image_width, 408 | image_height, 409 | false, 410 | num_splats_per_pixel.data_ptr(), 411 | final_weight_per_pixel.data_ptr(), 412 | rendered_image.data_ptr() 413 | ); 414 | } else { 415 | AT_ERROR("Unsupported number of SH coefficients: ", num_sh_coeff); 416 | } 417 | } else { 418 | AT_ERROR("Inputs must be float32 or float64"); 419 | } 420 | cudaDeviceSynchronize(); 421 | } 422 | -------------------------------------------------------------------------------- /src/spherical_harmonics.cuh: -------------------------------------------------------------------------------- 1 | #ifndef SPHERICAL_HARMONICS_CUH 2 | #define SPHERICAL_HARMONICS_CUH 3 | 4 | __device__ __constant__ const float SH_0 = 0.28209479177387814; 5 | __device__ __constant__ const float r_SH_0 = 3.544907701811032; 6 | // repeat same value to make sign management easier during SH calculation 7 | __device__ __constant__ const float SH_1[3] = { 8 | -0.4886025119029199, 9 | 0.4886025119029199, 10 | -0.4886025119029199}; 11 | __device__ __constant__ const float SH_2[5] = { 12 | 1.0925484305920792, 13 | -1.0925484305920792, 14 | 0.31539156525252005, 15 | -1.0925484305920792, 16 | 0.5462742152960396}; 17 | __device__ __constant__ const float SH_3[7] = { 18 | -0.5900435899266435, 19 | 2.890611442640554, 20 | -0.4570457994644658, 21 | 0.263875515352797, 22 | -0.4570457994644658, 23 | 1.445305721320277, 24 | -0.5900435899266435}; 25 | 26 | template 27 | __device__ __inline__ void 28 | compute_sh_coeffs_for_view_dir(const T* __restrict__ view_dir, T* __restrict__ sh_at_view_dir) { 29 | // Band 0 30 | sh_at_view_dir[0] = T(SH_0); 31 | 32 | if (N_SH < 4) 33 | return; 34 | 35 | const T x = view_dir[0]; 36 | const T y = view_dir[1]; 37 | const T z = view_dir[2]; 38 | 39 | // Band 1 40 | sh_at_view_dir[1] = T(SH_1[0]) * x; 41 | sh_at_view_dir[2] = T(SH_1[1]) * y; 42 | sh_at_view_dir[3] = T(SH_1[2]) * z; 43 | 44 | if (N_SH < 9) 45 | return; 46 | 47 | const T xy = x * y; 48 | const T yz = y * z; 49 | const T xz = x * z; 50 | const T xx = x * x; 51 | const T yy = y * y; 52 | const T zz = z * z; 53 | 54 | // Band 2 55 | sh_at_view_dir[4] = T(SH_2[0]) * xy; // xy 56 | sh_at_view_dir[5] = T(SH_2[1]) * yz; // yz 57 | sh_at_view_dir[6] = T(SH_2[2]) * (3 * zz - 1.0); // 3z^2 - 1 58 | sh_at_view_dir[7] = T(SH_2[3]) * xz; // xz 59 | sh_at_view_dir[8] = T(SH_2[4]) * (xx - yy); // x^2 - y^2 60 | 61 | if (N_SH < 16) 62 | return; 63 | 64 | // Band 3 65 | sh_at_view_dir[9] = T(SH_3[0]) * y * (3 * xx - yy); // y * (3x^2 - y^2) 66 | sh_at_view_dir[10] = T(SH_3[1]) * xy * z; // xyz 67 | sh_at_view_dir[11] = T(SH_3[2]) * y * (5 * zz - 1.0); // y(5z^2 - 1) 68 | sh_at_view_dir[12] = T(SH_3[3]) * z * (5 * zz - 3.0); // z(5z^2 - 3) 69 | sh_at_view_dir[13] = T(SH_3[4]) * x * (5 * zz - 1.0); // x(5z^2 - 1) 70 | sh_at_view_dir[14] = T(SH_3[5]) * z * (xx - yy); // z(x^2 - y^2) 71 | sh_at_view_dir[15] = T(SH_3[6]) * x * (xx - 3 * yy); // x(x^2 - 3y^2) 72 | } 73 | 74 | template 75 | __device__ __inline__ void sh_to_rgb( 76 | const T* __restrict__ sh_coeff, 77 | const T* __restrict__ sh_at_view_dir, 78 | T* __restrict__ rgb 79 | ) { 80 | // set rgb to zero order value 81 | #pragma unroll 82 | for (int channel = 0; channel < 3; channel++) { 83 | rgb[channel] = sh_at_view_dir[0] * sh_coeff[N_SH * channel]; 84 | } 85 | 86 | // add higher order values if needed 87 | if (N_SH < 4) 88 | return; 89 | #pragma unroll 90 | for (int sh = 1; sh < N_SH; sh++) { 91 | #pragma unroll 92 | for (int channel = 0; channel < 3; channel++) { 93 | rgb[channel] += sh_at_view_dir[sh] * sh_coeff[N_SH * channel + sh]; 94 | } 95 | } 96 | } 97 | 98 | template 99 | __device__ __inline__ void compute_sh_grad( 100 | const T* __restrict__ grad_rgb, 101 | const T* __restrict__ sh_at_view_dir, 102 | T* __restrict__ grad_sh 103 | ) { 104 | #pragma unroll 105 | for (int sh = 0; sh < N_SH; sh++) { 106 | #pragma unroll 107 | for (int channel = 0; channel < 3; channel++) { 108 | grad_sh[N_SH * channel + sh] = sh_at_view_dir[sh] * grad_rgb[channel]; 109 | } 110 | } 111 | } 112 | #endif // SPHERICAL_HARMONICS_CUH 113 | -------------------------------------------------------------------------------- /src/tile_culling.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "checks.cuh" 6 | 7 | // returns true if there is overlap between the two bboxes, false otherwise 8 | __device__ __forceinline__ bool split_axis_test( 9 | const float* __restrict__ obb, // [tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y] 10 | const float* __restrict__ tile_bounds // [left, right, top, bottom] 11 | ) { 12 | // from split axis theorem, need overlap on all axes 13 | // axis0 - X axis 14 | const float obb_min_x = fminf(fminf(obb[0], obb[2]), fminf(obb[4], obb[6])); 15 | const float obb_max_x = fmaxf(fmaxf(obb[0], obb[2]), fmaxf(obb[4], obb[6])); 16 | if (obb_min_x > tile_bounds[1] || obb_max_x < tile_bounds[0]) { 17 | return false; 18 | } 19 | // axis1 - Y axis 20 | const float obb_min_y = fminf(fminf(obb[1], obb[3]), fminf(obb[5], obb[7])); 21 | const float obb_max_y = fmaxf(fmaxf(obb[1], obb[3]), fmaxf(obb[5], obb[7])); 22 | if (obb_min_y > tile_bounds[3] || obb_max_y < tile_bounds[2]) { 23 | return false; 24 | } 25 | // axis 2 - obb major axis 26 | const float obb_major_axis_x = obb[2] - obb[0]; 27 | const float obb_major_axis_y = obb[3] - obb[1]; 28 | float tl_ax2 = obb_major_axis_x * tile_bounds[0] + obb_major_axis_y * tile_bounds[2]; // tl 29 | float tr_ax2 = obb_major_axis_x * tile_bounds[1] + obb_major_axis_y * tile_bounds[2]; // tr 30 | float bl_ax2 = obb_major_axis_x * tile_bounds[0] + obb_major_axis_y * tile_bounds[3]; // bl 31 | float br_ax2 = obb_major_axis_x * tile_bounds[1] + obb_major_axis_y * tile_bounds[3]; // br 32 | 33 | float min_tile = fminf(fminf(tl_ax2, tr_ax2), fminf(bl_ax2, br_ax2)); 34 | float max_tile = fmaxf(fmaxf(tl_ax2, tr_ax2), fmaxf(bl_ax2, br_ax2)); 35 | 36 | // top and bottom corners of obb project to same points on ax2 37 | const float obb_r_ax2 = obb_major_axis_x * obb[2] + obb_major_axis_y * obb[3]; // obb top right 38 | const float obb_l_ax2 = obb_major_axis_x * obb[0] + obb_major_axis_y * obb[1]; // obb top left 39 | float min_obb = fminf(obb_r_ax2, obb_l_ax2); 40 | float max_obb = fmaxf(obb_r_ax2, obb_l_ax2); 41 | 42 | if (min_tile > max_obb || max_tile < min_obb) { 43 | return false; 44 | } 45 | // axis 3 - obb minor axis 46 | const float obb_minor_axis_x = obb[2] - obb[6]; 47 | const float obb_minor_axis_y = obb[3] - obb[7]; 48 | tl_ax2 = obb_minor_axis_x * tile_bounds[0] + obb_minor_axis_y * tile_bounds[2]; // tl 49 | tr_ax2 = obb_minor_axis_x * tile_bounds[1] + obb_minor_axis_y * tile_bounds[2]; // tr 50 | bl_ax2 = obb_minor_axis_x * tile_bounds[0] + obb_minor_axis_y * tile_bounds[3]; // bl 51 | br_ax2 = obb_minor_axis_x * tile_bounds[1] + obb_minor_axis_y * tile_bounds[3]; // br 52 | 53 | min_tile = fminf(fminf(tl_ax2, tr_ax2), fminf(bl_ax2, br_ax2)); 54 | max_tile = fmaxf(fmaxf(tl_ax2, tr_ax2), fmaxf(bl_ax2, br_ax2)); 55 | 56 | // top and bottom corners of obb project to same points on ax2 57 | const float obb_t_ax2 = obb_minor_axis_x * obb[2] + obb_minor_axis_y * obb[3]; // obb top right 58 | const float obb_b_ax2 = 59 | obb_minor_axis_x * obb[6] + obb_minor_axis_y * obb[7]; // obb bottom right 60 | min_obb = fminf(obb_t_ax2, obb_b_ax2); 61 | max_obb = fmaxf(obb_t_ax2, obb_b_ax2); 62 | if (min_tile > max_obb || max_tile < min_obb) { 63 | return false; 64 | } 65 | return true; 66 | } 67 | 68 | // returns tile search radius and computes oriented bounding box 69 | __device__ __forceinline__ int compute_obb( 70 | const float u, 71 | const float v, 72 | const float a, 73 | const float b, 74 | const float c, 75 | const float mh_dist, 76 | float* obb // [tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y] 77 | ) { 78 | // compute major axis radius of ellipse 79 | const float left = (a + c) / 2; 80 | const float right = sqrtf((a - c) * (a - c) / 4.0f + b * b); 81 | const float lambda1 = left + right; 82 | const float lambda2 = left - right; 83 | 84 | const float r_major = mh_dist * sqrtf(lambda1); 85 | const float r_minor = mh_dist * sqrtf(lambda2); 86 | 87 | // compute theta 88 | float theta; 89 | if (fabsf(b) < 1e-16) { 90 | if (a >= c) { 91 | theta = 0.0f; 92 | } else { 93 | theta = M_PI / 2; 94 | } 95 | } else { 96 | theta = atan2f(lambda1 - a, b); 97 | } 98 | const float cos_theta = cosf(theta); 99 | const float sin_theta = sinf(theta); 100 | 101 | // compute obb 102 | // top_left aabb [-r_major, -r_minor] 103 | obb[0] = -1 * r_major * cos_theta + r_minor * sin_theta + u; 104 | obb[1] = -1 * r_major * sin_theta - r_minor * cos_theta + v; 105 | 106 | // top_right aabb [r_major, -r_minor] 107 | obb[2] = r_major * cos_theta + r_minor * sin_theta + u; 108 | obb[3] = r_major * sin_theta - r_minor * cos_theta + v; 109 | 110 | // bottom_left aabb [-r_major, r_minor] 111 | obb[4] = -1 * r_major * cos_theta - r_minor * sin_theta + u; 112 | obb[5] = -1 * r_major * sin_theta + r_minor * cos_theta + v; 113 | 114 | // bottom_right aabb [r_major, r_minor] 115 | obb[6] = r_major * cos_theta - r_minor * sin_theta + u; 116 | obb[7] = r_major * sin_theta + r_minor * cos_theta + v; 117 | 118 | // don't need to search the entire image, only need to look at all tiles 119 | // within max radius of the projected center of the gaussian 120 | const int radius_tiles = ceilf(r_major / 16.0f) + 1; 121 | return radius_tiles; 122 | } 123 | 124 | __global__ void compute_num_splats_kernel( 125 | const float* __restrict__ uvs, 126 | const float* __restrict__ conic, 127 | const int n_tiles_x, 128 | const int n_tiles_y, 129 | const float mh_dist, 130 | const int N, 131 | int* __restrict__ num_tiles_per_gaussian, 132 | int* __restrict__ num_gaussians_per_tile 133 | ) { 134 | int gaussian_idx = blockIdx.x * blockDim.x + threadIdx.x; 135 | if (gaussian_idx >= N) { 136 | return; 137 | } 138 | 139 | const float u = uvs[gaussian_idx * 2]; 140 | const float v = uvs[gaussian_idx * 2 + 1]; 141 | 142 | const float a = conic[gaussian_idx * 3] + 0.25f; 143 | const float b = conic[gaussian_idx * 3 + 1] / 2.0f; 144 | const float c = conic[gaussian_idx * 3 + 2] + 0.25f; 145 | 146 | float obb[8]; 147 | const int radius_tiles = compute_obb(u, v, a, b, c, mh_dist, obb); 148 | 149 | const int projected_tile_x = floorf(u / 16.0f); 150 | const int start_tile_x = fmaxf(0, projected_tile_x - radius_tiles); 151 | const int end_tile_x = fminf(n_tiles_x, projected_tile_x + radius_tiles); 152 | 153 | const int projected_tile_y = floorf(v / 16.0f); 154 | const int start_tile_y = fmaxf(0, projected_tile_y - radius_tiles); 155 | const int end_tile_y = fminf(n_tiles_y, projected_tile_y + radius_tiles); 156 | 157 | int n_tiles = 0; 158 | // iterate through tiles 159 | for (int tile_x = start_tile_x; tile_x < end_tile_x; tile_x++) { 160 | for (int tile_y = start_tile_y; tile_y < end_tile_y; tile_y++) { 161 | const int tile_idx = tile_y * n_tiles_x + tile_x; 162 | 163 | float tile_bounds[4]; // [left, right, top, bottom] 164 | tile_bounds[0] = __int2float_rn(tile_x) * 16.0f; 165 | tile_bounds[1] = __int2float_rn(tile_x + 1) * 16.0f; 166 | tile_bounds[2] = __int2float_rn(tile_y) * 16.0f; 167 | tile_bounds[3] = __int2float_rn(tile_y + 1) * 16.0f; 168 | 169 | if (split_axis_test(obb, tile_bounds)) { 170 | // update tile counts 171 | atomicAdd(num_gaussians_per_tile + tile_idx, 1); 172 | n_tiles++; 173 | } 174 | } 175 | } 176 | num_tiles_per_gaussian[gaussian_idx] = n_tiles; 177 | } 178 | 179 | __global__ void compute_tiles_kernel( 180 | const float* __restrict__ uvs, 181 | const float* __restrict__ xyz_camera_frame, 182 | const float* __restrict__ conic, 183 | const int* __restrict__ splat_start_end_idx_by_gaussian_idx, 184 | const int n_tiles_x, 185 | const int n_tiles_y, 186 | const float mh_dist, 187 | const int N, 188 | const double tile_idx_key_multiplier, 189 | int* __restrict__ gaussian_idx_by_splat_idx, 190 | double* __restrict__ sort_keys 191 | ) { 192 | int gaussian_idx = blockIdx.x * blockDim.x + threadIdx.x; 193 | if (gaussian_idx >= N) { 194 | return; 195 | } 196 | 197 | // get per gaussian values 198 | const float u = uvs[gaussian_idx * 2]; 199 | const float v = uvs[gaussian_idx * 2 + 1]; 200 | const double z = (double)(xyz_camera_frame[gaussian_idx * 3 + 2]); 201 | 202 | const float a = conic[gaussian_idx * 3] + 0.25f; 203 | const float b = conic[gaussian_idx * 3 + 1] / 2.0f; 204 | const float c = conic[gaussian_idx * 3 + 2] + 0.25f; 205 | 206 | const int output_start_idx = splat_start_end_idx_by_gaussian_idx[gaussian_idx]; 207 | const int output_end_idx = splat_start_end_idx_by_gaussian_idx[gaussian_idx + 1]; 208 | 209 | float obb[8]; 210 | const int radius_tiles = compute_obb(u, v, a, b, c, mh_dist, obb); 211 | 212 | const int projected_tile_x = floorf(u / 16.0f); 213 | const int start_tile_x = fmaxf(0, projected_tile_x - radius_tiles); 214 | const int end_tile_x = fminf(n_tiles_x, projected_tile_x + radius_tiles); 215 | 216 | const int projected_tile_y = floorf(v / 16.0f); 217 | const int start_tile_y = fmaxf(0, projected_tile_y - radius_tiles); 218 | const int end_tile_y = fminf(n_tiles_y, projected_tile_y + radius_tiles); 219 | 220 | int n_tiles = 0; 221 | // iterate through tiles 222 | for (int tile_x = start_tile_x; tile_x < end_tile_x; tile_x++) { 223 | for (int tile_y = start_tile_y; tile_y < end_tile_y; tile_y++) { 224 | const int tile_idx = tile_y * n_tiles_x + tile_x; 225 | 226 | float tile_bounds[4]; // [left, right, top, bottom] 227 | tile_bounds[0] = __int2float_rn(tile_x) * 16.0f; 228 | tile_bounds[1] = __int2float_rn(tile_x + 1) * 16.0f; 229 | tile_bounds[2] = __int2float_rn(tile_y) * 16.0f; 230 | tile_bounds[3] = __int2float_rn(tile_y + 1) * 16.0f; 231 | 232 | if (split_axis_test(obb, tile_bounds) && 233 | (output_start_idx + n_tiles) < output_end_idx) { 234 | // update gaussian index by splat index 235 | gaussian_idx_by_splat_idx[output_start_idx + n_tiles] = gaussian_idx; 236 | sort_keys[output_start_idx + n_tiles] = 237 | z + tile_idx_key_multiplier * __int2double_rn(tile_idx); 238 | n_tiles++; 239 | } 240 | } 241 | } 242 | } 243 | 244 | std::tuple get_sorted_gaussian_list( 245 | const int max_tiles_per_gaussian, 246 | torch::Tensor uvs, 247 | torch::Tensor xyz_camera_frame, 248 | torch::Tensor conic, 249 | const int n_tiles_x, 250 | const int n_tiles_y, 251 | const float mh_dist 252 | ) { 253 | CHECK_VALID_INPUT(uvs); 254 | CHECK_VALID_INPUT(xyz_camera_frame); 255 | CHECK_VALID_INPUT(conic); 256 | 257 | CHECK_FLOAT_TENSOR(uvs); 258 | CHECK_FLOAT_TENSOR(xyz_camera_frame); 259 | CHECK_FLOAT_TENSOR(conic); 260 | 261 | const int N = uvs.size(0); 262 | 263 | const int max_threads_per_block = 1024; 264 | const int num_blocks = (N + max_threads_per_block - 1) / max_threads_per_block; 265 | dim3 gridsize(num_blocks, 1, 1); 266 | dim3 blocksize(max_threads_per_block, 1, 1); 267 | 268 | // compute number of splats per gaussian/tile 269 | torch::Tensor num_tiles_per_gaussian = 270 | torch::zeros({N}, torch::dtype(torch::kInt32).device(uvs.device())); 271 | 272 | torch::Tensor num_gaussians_per_tile = 273 | torch::zeros({n_tiles_x * n_tiles_y}, torch::dtype(torch::kInt32).device(uvs.device())); 274 | 275 | compute_num_splats_kernel<<>>( 276 | uvs.data_ptr(), 277 | conic.data_ptr(), 278 | n_tiles_x, 279 | n_tiles_y, 280 | mh_dist, 281 | N, 282 | num_tiles_per_gaussian.data_ptr(), 283 | num_gaussians_per_tile.data_ptr() 284 | ); 285 | cudaDeviceSynchronize(); 286 | 287 | // create vector of gaussian indices for each splat and sort keys 288 | torch::Tensor cumsum = num_tiles_per_gaussian.cumsum(0); 289 | torch::Tensor splat_start_end_idx_by_gaussian_idx = torch::cat( 290 | {torch::zeros({1}, torch::dtype(torch::kInt32).device(uvs.device())), 291 | cumsum.to(torch::kInt32)}, 292 | 0 293 | ); 294 | 295 | // create output gaussian idx vector and sort key vector 296 | const int num_splats = splat_start_end_idx_by_gaussian_idx[N].item(); 297 | torch::Tensor gaussian_idx_by_splat_idx = 298 | torch::zeros({num_splats}, torch::dtype(torch::kInt32).device(uvs.device())); 299 | torch::Tensor sort_keys = 300 | torch::zeros({num_splats}, torch::dtype(torch::kFloat64).device(uvs.device())); 301 | 302 | CHECK_FLOAT_TENSOR(xyz_camera_frame); 303 | CHECK_INT_TENSOR(splat_start_end_idx_by_gaussian_idx); 304 | CHECK_INT_TENSOR(gaussian_idx_by_splat_idx); 305 | CHECK_DOUBLE_TENSOR(sort_keys); 306 | 307 | // max_depth + 1.0 308 | const double tile_idx_key_multiplier = 309 | (double)(xyz_camera_frame.select(1, 2).max().item() + 1.0f); 310 | 311 | // compute gaussian index and key for each gaussian-tile intersection 312 | compute_tiles_kernel<<>>( 313 | uvs.data_ptr(), 314 | xyz_camera_frame.data_ptr(), 315 | conic.data_ptr(), 316 | splat_start_end_idx_by_gaussian_idx.data_ptr(), 317 | n_tiles_x, 318 | n_tiles_y, 319 | mh_dist, 320 | N, 321 | tile_idx_key_multiplier, 322 | gaussian_idx_by_splat_idx.data_ptr(), 323 | sort_keys.data_ptr() 324 | ); 325 | cudaDeviceSynchronize(); 326 | 327 | auto result = torch::sort(sort_keys); 328 | torch::Tensor sorted_indices = std::get<1>(result); 329 | torch::Tensor sorted_gaussians = gaussian_idx_by_splat_idx.index_select(0, sorted_indices); 330 | 331 | // compute indices for each tile 332 | torch::Tensor num_gaussians_per_tile_cumsum = num_gaussians_per_tile.cumsum(0); 333 | torch::Tensor splat_start_end_idx_by_tile_idx = torch::cat( 334 | {torch::zeros({1}, torch::dtype(torch::kInt32).device(uvs.device())), 335 | num_gaussians_per_tile_cumsum.to(torch::kInt32)}, 336 | 0 337 | ); 338 | 339 | return std::make_tuple(sorted_gaussians, splat_start_end_idx_by_tile_idx); 340 | } 341 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyan/gaussian_splatting/c1f5a71e3549d8bd089be6e9777c16f8e1bc333f/test/__init__.py -------------------------------------------------------------------------------- /test/gaussian_test_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from splat_py.structs import Gaussians, Camera 4 | 5 | 6 | def get_test_gaussians(device): 7 | xyz = torch.tensor( 8 | [ 9 | [1.0, 2.0, -4.0], 10 | [4.0, 5.0, 6.0], 11 | [7.0, 8.0, -9.0], 12 | [1.0, 2.0, 15.0], 13 | [2.5, -1.0, 4.0], 14 | [-1.0, -2.0, 10.0], 15 | ], 16 | dtype=torch.float32, 17 | device=device, 18 | ) 19 | rgb = torch.ones(xyz.shape, dtype=torch.float32, device=device) * 0.5 20 | rgb[3, :] = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32, device=device) 21 | rgb[4, :] = torch.tensor([0.0, 0.5, 0.0], dtype=torch.float32, device=device) 22 | rgb[5, :] = torch.tensor([0.0, 0.0, 0.5], dtype=torch.float32, device=device) 23 | rgb = rgb / 0.28209479177387814 24 | 25 | opacity = torch.ones(xyz.shape[0], 1, dtype=torch.float32, device=device) 26 | scale = torch.tensor( 27 | [ 28 | [0.02, 0.03, 0.04], 29 | [0.01, 0.05, 0.02], 30 | [0.09, 0.03, 0.01], 31 | [1.0, 3.0, 0.1], 32 | [2.0, 0.2, 0.1], 33 | [2.0, 1.0, 0.1], 34 | ], 35 | dtype=torch.float32, 36 | device=device, 37 | ) 38 | # using exp activation 39 | scale = torch.log(scale) 40 | quaternion = torch.tensor( 41 | [ 42 | [1.0, 0.0, 0.0, 0.0], 43 | [0.0, 1.0, 0.0, 0.0], 44 | [0.0, 0.0, 1.0, 0.0], 45 | [1.0, 0.0, 0.0, 0.0], 46 | [0.714, -0.002, -0.664, 0.221], 47 | [1.0, 0.0, 0.0, 0.0], 48 | ], 49 | dtype=torch.float32, 50 | device=device, 51 | ) 52 | return Gaussians(xyz, rgb, opacity, scale, quaternion) 53 | 54 | 55 | def get_test_camera(device): 56 | # different fx and fy to test computation of gaussian projection 57 | K = torch.tensor( 58 | [ 59 | [430.0, 0.0, 320.0], 60 | [0.0, 410.0, 240.0], 61 | [0.0, 0.0, 1.0], 62 | ], 63 | dtype=torch.float32, 64 | device=device, 65 | ) 66 | return Camera(640, 480, K) 67 | 68 | 69 | def get_test_camera_T_world(device): 70 | return torch.tensor( 71 | [ 72 | [0.9999, 0.0089, 0.0073, -0.3283], 73 | [-0.0106, 0.9568, 0.2905, -1.9260], 74 | [-0.0044, -0.2906, 0.9568, 2.9581], 75 | [0.0000, 0.0000, 0.0000, 1.0000], 76 | ], 77 | dtype=torch.float32, 78 | device=device, 79 | ) 80 | 81 | 82 | def get_test_data(device): 83 | gaussians = get_test_gaussians(device) 84 | camera = get_test_camera(device) 85 | camera_T_world = get_test_camera_T_world(device) 86 | return gaussians, camera, camera_T_world 87 | -------------------------------------------------------------------------------- /test/test_cuda_autograd_functions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from splat_py.cuda_autograd_functions import ( 4 | CameraPointProjection, 5 | ComputeProjectionJacobian, 6 | ComputeSigmaWorld, 7 | ComputeConic, 8 | PrecomputeRGBFromSH, 9 | ) 10 | 11 | 12 | class TestAutogradFunctions(unittest.TestCase): 13 | def setUp(self): 14 | self.assertTrue(torch.cuda.is_available()) 15 | self.device = torch.device("cuda") 16 | self.xyz = torch.tensor( 17 | [ 18 | [1.0, 2.0, 15.0], 19 | [2.5, -1.0, 4.0], 20 | [-1.0, -2.0, 10.0], 21 | ], 22 | dtype=torch.float64, 23 | device=self.device, 24 | requires_grad=True, 25 | ) 26 | self.K = torch.tensor( 27 | [ 28 | [430.0, 0.0, 320.0], 29 | [0.0, 410.0, 240.0], 30 | [0.0, 0.0, 1.0], 31 | ], 32 | dtype=torch.float64, 33 | device=self.device, 34 | requires_grad=False, 35 | ) 36 | self.quaternion = torch.tensor( 37 | [ 38 | [0.8, 0.2, 0.2, 0.2], 39 | [0.714, -0.002, -0.664, 0.221], 40 | [0.0, 0.0, 1.0, 0.0], 41 | ], 42 | dtype=torch.float64, 43 | device=self.device, 44 | requires_grad=True, 45 | ) 46 | self.scale = torch.tensor( 47 | [ 48 | [0.02, 0.03, 0.04], 49 | [0.09, 0.03, 0.01], 50 | [2.0, 1.0, 0.1], 51 | ], 52 | dtype=torch.float64, 53 | device=self.device, 54 | requires_grad=True, 55 | ) 56 | self.camera_T_world = torch.tensor( 57 | [ 58 | [0.9999, 0.0089, 0.0073, -0.3283], 59 | [-0.0106, 0.9568, 0.2905, -1.9260], 60 | [-0.0044, -0.2906, 0.9568, 2.9581], 61 | [0.0000, 0.0000, 0.0000, 1.0000], 62 | ], 63 | dtype=torch.float64, 64 | device=self.device, 65 | requires_grad=False, 66 | ) 67 | 68 | def test_camera_point_projection(self): 69 | test = torch.autograd.gradcheck( 70 | CameraPointProjection.apply, (self.xyz, self.K), raise_exception=True 71 | ) 72 | self.assertTrue(test) 73 | 74 | def test_compute_gaussian_projection_jacobian(self): 75 | test = torch.autograd.gradcheck( 76 | ComputeProjectionJacobian.apply, 77 | (self.xyz, self.K), 78 | raise_exception=True, 79 | ) 80 | self.assertTrue(test) 81 | 82 | def test_compute_sigma_world(self): 83 | test = torch.autograd.gradcheck( 84 | ComputeSigmaWorld.apply, 85 | (self.quaternion, self.scale), 86 | raise_exception=True, 87 | ) 88 | self.assertTrue(test) 89 | 90 | def test_compute_conic(self): 91 | sigma_world = torch.rand( 92 | 1, 93 | 3, 94 | 3, 95 | dtype=self.quaternion.dtype, 96 | device=self.quaternion.device, 97 | requires_grad=True, 98 | ) 99 | projection_jacobian = torch.rand( 100 | 1, 101 | 2, 102 | 3, 103 | dtype=self.quaternion.dtype, 104 | device=self.quaternion.device, 105 | requires_grad=True, 106 | ) 107 | test = torch.autograd.gradcheck( 108 | ComputeConic.apply, 109 | (sigma_world, projection_jacobian, self.camera_T_world), 110 | raise_exception=True, 111 | ) 112 | self.assertTrue(test) 113 | 114 | def test_compute_rgb_from_sh_1(self): 115 | N_gaussians = 100 116 | sh_coeffs = torch.ones( 117 | N_gaussians, 118 | 3, 119 | 1, 120 | dtype=torch.float64, 121 | device=self.device, 122 | requires_grad=True, 123 | ) 124 | xyz = torch.rand( 125 | N_gaussians, 126 | 3, 127 | dtype=torch.float64, 128 | device=self.device, 129 | requires_grad=False, 130 | ) 131 | camera_T_world = torch.zeros( 132 | 4, 133 | 4, 134 | dtype=torch.float64, 135 | device=self.device, 136 | requires_grad=False, 137 | ) 138 | test = torch.autograd.gradcheck( 139 | PrecomputeRGBFromSH.apply, 140 | (sh_coeffs, xyz, camera_T_world), 141 | raise_exception=True, 142 | ) 143 | self.assertTrue(test) 144 | 145 | def test_compute_rgb_from_sh_4(self): 146 | N_gaussians = 100 147 | sh_coeffs = torch.ones( 148 | N_gaussians, 149 | 3, 150 | 4, 151 | dtype=torch.float64, 152 | device=self.device, 153 | requires_grad=True, 154 | ) 155 | xyz = torch.rand( 156 | N_gaussians, 157 | 3, 158 | dtype=torch.float64, 159 | device=self.device, 160 | requires_grad=False, 161 | ) 162 | camera_T_world = torch.zeros( 163 | 4, 164 | 4, 165 | dtype=torch.float64, 166 | device=self.device, 167 | requires_grad=False, 168 | ) 169 | test = torch.autograd.gradcheck( 170 | PrecomputeRGBFromSH.apply, 171 | (sh_coeffs, xyz, camera_T_world), 172 | raise_exception=True, 173 | ) 174 | self.assertTrue(test) 175 | 176 | def test_compute_rgb_from_sh_9(self): 177 | N_gaussians = 100 178 | sh_coeffs = torch.ones( 179 | N_gaussians, 180 | 3, 181 | 9, 182 | dtype=torch.float64, 183 | device=self.device, 184 | requires_grad=True, 185 | ) 186 | xyz = torch.rand( 187 | N_gaussians, 188 | 3, 189 | dtype=torch.float64, 190 | device=self.device, 191 | requires_grad=False, 192 | ) 193 | camera_T_world = torch.zeros( 194 | 4, 195 | 4, 196 | dtype=torch.float64, 197 | device=self.device, 198 | requires_grad=False, 199 | ) 200 | test = torch.autograd.gradcheck( 201 | PrecomputeRGBFromSH.apply, 202 | (sh_coeffs, xyz, camera_T_world), 203 | raise_exception=True, 204 | ) 205 | self.assertTrue(test) 206 | 207 | def test_compute_rgb_from_sh_16(self): 208 | N_gaussians = 100 209 | sh_coeffs = torch.ones( 210 | N_gaussians, 211 | 3, 212 | 16, 213 | dtype=torch.float64, 214 | device=self.device, 215 | requires_grad=True, 216 | ) 217 | xyz = torch.rand( 218 | N_gaussians, 219 | 3, 220 | dtype=torch.float64, 221 | device=self.device, 222 | requires_grad=False, 223 | ) 224 | camera_T_world = torch.zeros( 225 | 4, 226 | 4, 227 | dtype=torch.float64, 228 | device=self.device, 229 | requires_grad=False, 230 | ) 231 | test = torch.autograd.gradcheck( 232 | PrecomputeRGBFromSH.apply, 233 | (sh_coeffs, xyz, camera_T_world), 234 | raise_exception=True, 235 | ) 236 | self.assertTrue(test) 237 | 238 | 239 | if __name__ == "__main__": 240 | unittest.main() 241 | -------------------------------------------------------------------------------- /test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from splat_py.config import SplatConfig 5 | from splat_py.dataloader import ColmapData 6 | 7 | TEST_DATASET_PATH = "/home/joe/Downloads/garden" 8 | 9 | 10 | class TestColmapData(unittest.TestCase): 11 | """Test Colmap dataloader""" 12 | 13 | def setUp(self): 14 | self.colmap_directory_path = TEST_DATASET_PATH 15 | self.config = SplatConfig() 16 | self.device = torch.device("cpu") 17 | self.colmap_data = ColmapData( 18 | self.colmap_directory_path, self.device, downsample_factor=8, config=self.config 19 | ) 20 | 21 | def test_init(self): 22 | """Test Data Loading""" 23 | self.assertEqual(self.colmap_data.colmap_directory_path, self.colmap_directory_path) 24 | 25 | # values for garden dataset 26 | self.assertEqual(len(self.colmap_data.image_info), 185) 27 | self.assertEqual(len(self.colmap_data.xyz), 138766) 28 | self.assertEqual(len(self.colmap_data.rgb), 138766) 29 | self.assertEqual(len(self.colmap_data.cameras), 1) 30 | 31 | # test image data 32 | self.assertEqual(self.colmap_data.image_info[1].id, 1) 33 | self.assertEqual(self.colmap_data.image_info[1].camera_id, 1) 34 | self.assertEqual(self.colmap_data.image_info[1].name, "DSC07956.JPG") 35 | self.assertEqual(len(self.colmap_data.image_info[1].point3D_ids), 11193) 36 | 37 | def test_create_gaussians(self): 38 | """Test Gaussian Creation from colmap dataset""" 39 | 40 | gaussians = self.colmap_data.create_gaussians() 41 | self.assertEqual(gaussians.xyz.shape[0], 138766) 42 | self.assertEqual(gaussians.xyz.shape[1], 3) 43 | self.assertEqual(gaussians.rgb.shape[0], 138766) 44 | self.assertEqual(gaussians.rgb.shape[1], 3) 45 | self.assertEqual(gaussians.opacity.shape[0], 138766) 46 | self.assertEqual(gaussians.opacity.shape[1], 1) 47 | self.assertEqual(gaussians.scale.shape[0], 138766) 48 | self.assertEqual(gaussians.scale.shape[1], 3) 49 | self.assertEqual(gaussians.quaternion.shape[0], 138766) 50 | self.assertEqual(gaussians.quaternion.shape[1], 4) 51 | 52 | self.assertAlmostEqual(gaussians.xyz[0, 0].item(), 5.048415184) 53 | self.assertAlmostEqual(gaussians.xyz[0, 1].item(), 1.673997640) 54 | self.assertAlmostEqual(gaussians.xyz[0, 2].item(), -1.014126658) 55 | 56 | self.assertAlmostEqual(gaussians.rgb[0, 0].item(), 0.27803197503089905) 57 | self.assertAlmostEqual(gaussians.rgb[0, 1].item(), 0.48655596375465393) 58 | self.assertAlmostEqual(gaussians.rgb[0, 2].item(), 0.06950799375772476) 59 | 60 | self.assertAlmostEqual(gaussians.opacity[0, 0].item(), -1.3862943649) 61 | 62 | self.assertAlmostEqual(gaussians.scale[0, 0].item(), -3.722839117050171) 63 | self.assertAlmostEqual(gaussians.scale[0, 1].item(), -3.722839117050171) 64 | self.assertAlmostEqual(gaussians.scale[0, 2].item(), -3.722839117050171) 65 | 66 | self.assertAlmostEqual(gaussians.quaternion[0, 0].item(), 1.0) 67 | self.assertAlmostEqual(gaussians.quaternion[0, 1].item(), 0.0) 68 | self.assertAlmostEqual(gaussians.quaternion[0, 2].item(), 0.0) 69 | self.assertAlmostEqual(gaussians.quaternion[0, 3].item(), 0.0) 70 | 71 | def test_load_capture_info(self): 72 | """Test loading Images, Cameras""" 73 | images = self.colmap_data.get_images() 74 | self.assertEqual(len(images), 185) 75 | self.assertEqual(images[0].image.shape[0], 420) 76 | self.assertEqual(images[0].image.shape[1], 648) 77 | self.assertEqual(images[0].camera_id, 1) 78 | self.assertEqual(images[0].camera_T_world.shape[0], 4) 79 | self.assertEqual(images[0].camera_T_world.shape[1], 4) 80 | 81 | expected_camera_T_world = torch.tensor( 82 | [ 83 | [0.9999, 0.0089, 0.0073, -0.3283], 84 | [-0.0106, 0.9568, 0.2905, -1.9260], 85 | [-0.0044, -0.2906, 0.9568, 3.9581], 86 | [0.0000, 0.0000, 0.0000, 1.0000], 87 | ], 88 | dtype=torch.float32, 89 | ) 90 | 91 | self.assertTrue( 92 | torch.allclose(images[0].camera_T_world, expected_camera_T_world, atol=1e-4) 93 | ) 94 | 95 | cameras = self.colmap_data.get_cameras() 96 | self.assertEqual(len(cameras), 1) 97 | 98 | # using 8x downsample factor 99 | expected_K = torch.tensor( 100 | [[480.6123, 0.0, 324.1875], [0.0, 481.5445, 210.0625], [0.0, 0.0, 1.0]], 101 | dtype=torch.float32, 102 | ) 103 | self.assertTrue(torch.allclose(cameras[1].K, expected_K, atol=1e-6)) 104 | 105 | 106 | if __name__ == "__main__": 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /test/test_depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import unittest 4 | 5 | from splat_py.depth import render_depth 6 | from splat_py.utils import inverse_sigmoid_torch 7 | from gaussian_test_data import get_test_data 8 | 9 | 10 | class TestRenderDepth(unittest.TestCase): 11 | def setUp(self): 12 | self.assertTrue(torch.cuda.is_available()) 13 | self.device = torch.device("cuda") 14 | self.gaussians, self.camera, self.camera_T_world = get_test_data(self.device) 15 | self.gaussians.opacity = inverse_sigmoid_torch(self.gaussians.opacity) 16 | 17 | def test_rasterize_no_sh(self): 18 | near_thresh = 0.3 19 | cull_mask_padding = 10 20 | mh_dist = 3.0 21 | 22 | alpha_threshold = 0.2 23 | depth_image = render_depth( 24 | self.gaussians, 25 | alpha_threshold, 26 | self.camera_T_world, 27 | self.camera, 28 | near_thresh, 29 | cull_mask_padding, 30 | mh_dist, 31 | ) 32 | # near red gaussian center 33 | self.assertAlmostEqual(depth_image[340, 348].item(), 17.29551887512207, places=5) 34 | 35 | # overlap of red and blue gaussian, blue is in front of red 36 | self.assertAlmostEqual(depth_image[200, 348].item(), 13.205718040466309, places=5) 37 | -------------------------------------------------------------------------------- /test/test_projection.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from splat_py.cuda_autograd_functions import ( 5 | CameraPointProjection, 6 | ComputeSigmaWorld, 7 | ComputeProjectionJacobian, 8 | ComputeConic, 9 | ) 10 | from splat_py.utils import transform_points_torch 11 | 12 | from gaussian_test_data import get_test_data 13 | 14 | 15 | class ProjectionTest(unittest.TestCase): 16 | def setUp(self): 17 | self.assertTrue(torch.cuda.is_available()) 18 | self.device = torch.device("cuda") 19 | self.gaussians, self.camera, self.camera_T_world = get_test_data(self.device) 20 | 21 | def test_project_points(self): 22 | xyz_camera_frame = transform_points_torch(self.gaussians.xyz, self.camera_T_world) 23 | 24 | self.assertAlmostEqual(xyz_camera_frame[0, 0].item(), 0.6602, places=4) 25 | self.assertAlmostEqual(xyz_camera_frame[0, 1].item(), -1.1849998, places=4) 26 | self.assertAlmostEqual(xyz_camera_frame[0, 2].item(), -1.4546999, places=4) 27 | self.assertAlmostEqual(xyz_camera_frame[1, 0].item(), 3.7595997, places=4) 28 | self.assertAlmostEqual(xyz_camera_frame[1, 1].item(), 4.5586, places=4) 29 | self.assertAlmostEqual(xyz_camera_frame[1, 2].item(), 7.2283, places=4) 30 | 31 | uv = CameraPointProjection.apply(xyz_camera_frame, self.camera.K) 32 | 33 | self.assertEqual(uv.shape, (6, 2)) 34 | self.assertAlmostEqual(uv[0, 0].item(), 124.849106, places=4) 35 | self.assertAlmostEqual(uv[0, 1].item(), 573.9863, places=4) 36 | self.assertAlmostEqual(uv[1, 0].item(), 543.6526, places=4) 37 | self.assertAlmostEqual(uv[1, 1].item(), 498.57062, places=4) 38 | 39 | # perform frustrum culling 40 | # (TODO) move frustrum culling to function 41 | culling_mask = torch.zeros( 42 | xyz_camera_frame.shape[0], 43 | dtype=torch.bool, 44 | device=self.device, 45 | ) 46 | near_thresh = 0.3 47 | culling_mask = culling_mask | (xyz_camera_frame[:, 2] < near_thresh) 48 | culling_mask = ( 49 | culling_mask 50 | | (uv[:, 0] < 0) 51 | | (uv[:, 0] > self.camera.width) 52 | | (uv[:, 1] < 0) 53 | | (uv[:, 1] > self.camera.height) 54 | ) 55 | 56 | self.assertEqual(uv.shape, (6, 2)) 57 | self.assertEqual(xyz_camera_frame.shape, (6, 3)) 58 | self.assertEqual(culling_mask.shape, (6,)) 59 | 60 | self.assertTrue( 61 | torch.all( 62 | culling_mask 63 | == torch.tensor([True, True, True, False, False, False], device=self.device) 64 | ) 65 | ) 66 | 67 | def test_compute_sigma_world(self): 68 | sigma_world = ComputeSigmaWorld.apply(self.gaussians.quaternion, self.gaussians.scale) 69 | 70 | self.assertEqual(sigma_world.shape, (6, 3, 3)) 71 | # check first sigma_world 72 | self.assertAlmostEqual(sigma_world[0, 0, 0].item(), 0.0004, places=4) 73 | self.assertAlmostEqual(sigma_world[0, 0, 1].item(), 0.0, places=4) 74 | self.assertAlmostEqual(sigma_world[0, 0, 2].item(), 0.0, places=4) 75 | 76 | self.assertAlmostEqual(sigma_world[0, 1, 0].item(), 0.0, places=4) 77 | self.assertAlmostEqual(sigma_world[0, 1, 1].item(), 0.0009, places=4) 78 | self.assertAlmostEqual(sigma_world[0, 1, 2].item(), 0.0, places=4) 79 | 80 | self.assertAlmostEqual(sigma_world[0, 2, 0].item(), 0.0, places=4) 81 | self.assertAlmostEqual(sigma_world[0, 2, 1].item(), 0.0, places=4) 82 | self.assertAlmostEqual(sigma_world[0, 2, 2].item(), 0.0016, places=4) 83 | 84 | # sigma world 85 | self.assertAlmostEqual(sigma_world[4, 0, 0].item(), 0.01454808, places=4) 86 | self.assertAlmostEqual(sigma_world[4, 0, 1].item(), 0.01702517, places=4) 87 | self.assertAlmostEqual(sigma_world[4, 0, 2].item(), 0.07868834, places=4) 88 | self.assertAlmostEqual(sigma_world[4, 1, 0].item(), 0.01702517, places=4) 89 | self.assertAlmostEqual(sigma_world[4, 1, 1].item(), 0.4389012, places=4) 90 | self.assertAlmostEqual(sigma_world[4, 1, 2].item(), 1.1959752, places=4) 91 | self.assertAlmostEqual(sigma_world[4, 2, 0].item(), 0.07868834, places=4) 92 | self.assertAlmostEqual(sigma_world[4, 2, 1].item(), 1.1959752, places=4) 93 | self.assertAlmostEqual(sigma_world[4, 2, 2].item(), 3.5965507, places=4) 94 | 95 | def test_compute_projection_jacobian(self): 96 | xyz_camera_frame = transform_points_torch(self.gaussians.xyz, self.camera_T_world) 97 | 98 | jacobian = ComputeProjectionJacobian.apply(xyz_camera_frame, self.camera.K) 99 | 100 | self.assertEqual(jacobian.shape, (6, 2, 3)) 101 | self.assertAlmostEqual(jacobian[0, 0, 0].item(), -295.5936, places=4) 102 | self.assertAlmostEqual(jacobian[0, 0, 1].item(), 0.0, places=4) 103 | self.assertAlmostEqual(jacobian[0, 0, 2].item(), -134.1520, places=4) 104 | self.assertAlmostEqual(jacobian[0, 1, 0].item(), 0.0, places=4) 105 | self.assertAlmostEqual(jacobian[0, 1, 1].item(), -281.8451, places=4) 106 | self.assertAlmostEqual(jacobian[0, 1, 2].item(), 229.5912, places=4) 107 | 108 | def test_compute_conic(self): 109 | # compute inputs (tested in previous tests) 110 | sigma_world = ComputeSigmaWorld.apply(self.gaussians.quaternion, self.gaussians.scale) 111 | xyz_camera_frame = transform_points_torch(self.gaussians.xyz, self.camera_T_world) 112 | jacobian = ComputeProjectionJacobian.apply(xyz_camera_frame, self.camera.K) 113 | 114 | # compute conic 115 | conic = ComputeConic.apply(sigma_world, jacobian, self.camera_T_world) 116 | 117 | self.assertEqual(conic.shape, (6, 3)) 118 | self.assertAlmostEqual(conic[3, 0].item(), 664.28760, places=4) 119 | self.assertAlmostEqual(conic[3, 1].item(), 254.81781, places=4) 120 | self.assertAlmostEqual(conic[3, 2].item(), 5761.8906, places=4) 121 | 122 | 123 | if __name__ == "__main__": 124 | unittest.main() 125 | -------------------------------------------------------------------------------- /test/test_rasterize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import unittest 5 | 6 | from splat_py.rasterize import rasterize 7 | from splat_py.utils import inverse_sigmoid_torch 8 | 9 | from gaussian_test_data import get_test_data 10 | 11 | SAVE_DEBUG = False 12 | 13 | 14 | class TestRasterize(unittest.TestCase): 15 | def setUp(self): 16 | self.assertTrue(torch.cuda.is_available()) 17 | self.device = torch.device("cuda") 18 | self.gaussians, self.camera, self.camera_T_world = get_test_data(self.device) 19 | self.gaussians.opacity = inverse_sigmoid_torch(self.gaussians.opacity) 20 | 21 | def test_rasterize_no_sh(self): 22 | near_thresh = 0.3 23 | far_thresh = 100.0 24 | cull_mask_padding = 10 25 | mh_dist = 3.0 26 | use_sh_precompute = True 27 | 28 | background_rgb = torch.zeros(3, device=self.device, dtype=self.gaussians.rgb.dtype) 29 | image, _, _ = rasterize( 30 | self.gaussians, 31 | self.camera_T_world, 32 | self.camera, 33 | near_thresh, 34 | far_thresh, 35 | cull_mask_padding, 36 | mh_dist, 37 | use_sh_precompute, 38 | background_rgb, 39 | ) 40 | if SAVE_DEBUG: 41 | debug_image = image.clip(0, 1).detach().cpu().numpy() 42 | cv2.imwrite( 43 | "/tmp/test_rasterize_no_sh.png", (debug_image * 255).astype(np.uint8)[..., ::-1] 44 | ) 45 | 46 | # near red gaussian center 47 | self.assertAlmostEqual(image[340, 348, 0].item(), 0.47698545455932617, places=5) 48 | self.assertAlmostEqual(image[340, 348, 1].item(), 0.0, places=5) 49 | self.assertAlmostEqual(image[340, 348, 2].item(), 0.0, places=5) 50 | 51 | # overlap of red and blue gaussian, blue is in front of red 52 | self.assertAlmostEqual(image[200, 348, 0].item(), 0.03330837935209274, places=5) 53 | self.assertAlmostEqual(image[200, 348, 1].item(), 0.0, places=5) 54 | self.assertAlmostEqual(image[200, 348, 2].item(), 0.267561137676239, places=5) 55 | 56 | def test_rasterize_full_sh_use_precompute(self): 57 | near_thresh = 0.3 58 | far_thresh = 100.0 59 | cull_mask_padding = 10 60 | mh_dist = 3.0 61 | use_sh_precompute = True 62 | self.gaussians.sh = ( 63 | torch.ones((self.gaussians.xyz.shape[0], 3, 15), device=self.device) * 0.1 64 | ) 65 | background_rgb = torch.zeros(3, device=self.device, dtype=self.gaussians.rgb.dtype) 66 | image, _, _ = rasterize( 67 | self.gaussians, 68 | self.camera_T_world, 69 | self.camera, 70 | near_thresh, 71 | far_thresh, 72 | cull_mask_padding, 73 | mh_dist, 74 | use_sh_precompute, 75 | background_rgb, 76 | ) 77 | if SAVE_DEBUG: 78 | debug_image = image.clip(0, 1).detach().cpu().numpy() 79 | cv2.imwrite( 80 | "/tmp/test_rasterize_full_sh_use_precompute.png", 81 | (debug_image * 255).astype(np.uint8)[..., ::-1], 82 | ) 83 | 84 | # near red gaussian center 85 | self.assertAlmostEqual(image[340, 348, 0].item(), 0.5362688899040222, places=5) 86 | self.assertAlmostEqual(image[340, 348, 1].item(), 0.05928343906998634, places=5) 87 | self.assertAlmostEqual(image[340, 348, 2].item(), 0.05928343906998634, places=5) 88 | 89 | # overlap of red and blue gaussian, blue is in front of red 90 | self.assertAlmostEqual(image[200, 348, 0].item(), 0.10543855279684067, places=5) 91 | self.assertAlmostEqual(image[200, 348, 1].item(), 0.07212823629379272, places=5) 92 | self.assertAlmostEqual(image[200, 348, 2].item(), 0.3396894335746765, places=5) 93 | 94 | def test_rasterize_full_sh_use_per_pixel_viewdir(self): 95 | near_thresh = 0.3 96 | far_thresh = 100.0 97 | cull_mask_padding = 10 98 | mh_dist = 3.0 99 | use_sh_precompute = False 100 | self.gaussians.sh = ( 101 | torch.ones((self.gaussians.xyz.shape[0], 3, 15), device=self.device) * 0.1 102 | ) 103 | 104 | background_rgb = torch.zeros(3, device=self.device, dtype=self.gaussians.rgb.dtype) 105 | image, _, _ = rasterize( 106 | self.gaussians, 107 | self.camera_T_world, 108 | self.camera, 109 | near_thresh, 110 | far_thresh, 111 | cull_mask_padding, 112 | mh_dist, 113 | use_sh_precompute, 114 | background_rgb, 115 | ) 116 | if SAVE_DEBUG: 117 | debug_image = image.clip(0, 1).detach().cpu().numpy() 118 | cv2.imwrite( 119 | "/tmp/test_rasterize_full_sh_use_per_pixel_viewdir.png", 120 | (debug_image * 255).astype(np.uint8)[..., ::-1], 121 | ) 122 | 123 | # near red gaussian center 124 | self.assertAlmostEqual(image[340, 348, 0].item(), 0.5328576564788818, places=5) 125 | self.assertAlmostEqual(image[340, 348, 1].item(), 0.05587226152420044, places=5) 126 | self.assertAlmostEqual(image[340, 348, 2].item(), 0.05587226152420044, places=5) 127 | 128 | # overlap of red and blue gaussian, blue is in front of red 129 | self.assertAlmostEqual(image[200, 348, 0].item(), 0.06694115698337555, places=5) 130 | self.assertAlmostEqual(image[200, 348, 1].item(), 0.033630844205617905, places=5) 131 | self.assertAlmostEqual(image[200, 348, 2].item(), 0.30119192600250244, places=5) 132 | 133 | 134 | if __name__ == "__main__": 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /test/test_rasterize_autograd.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from splat_py.cuda_autograd_functions import RenderImage 5 | from splat_py.structs import Tiles, Camera 6 | from splat_py.tile_culling import get_splats 7 | from splat_py.utils import compute_rays_in_world_frame, transform_points_torch 8 | 9 | 10 | class TestRasterizeAutograd(unittest.TestCase): 11 | def setUp(self): 12 | self.assertTrue(torch.cuda.is_available()) 13 | self.device = torch.device("cuda") 14 | 15 | K = torch.tensor( 16 | [ 17 | [43.0, 0.0, 30.0], 18 | [0.0, 41.0, 20.0], 19 | [0.0, 0.0, 1.0], 20 | ], 21 | dtype=torch.float64, 22 | device=self.device, 23 | ) 24 | self.camera = Camera(60, 40, K) 25 | self.camera_T_world = torch.eye(4, dtype=torch.float64, device=self.device) 26 | self.rays = compute_rays_in_world_frame(self.camera, self.camera_T_world) 27 | 28 | self.xyz = torch.tensor( 29 | [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 30 | dtype=torch.float64, 31 | device=self.device, 32 | requires_grad=False, 33 | ) 34 | self.xyz_camera_frame = transform_points_torch(self.xyz, self.camera_T_world) 35 | 36 | self.uv = torch.tensor( 37 | [ 38 | [32.8523, 24.88553], 39 | [25.0, 25.0], 40 | [45.339926, 13.85983], 41 | ], 42 | dtype=torch.float64, 43 | device=self.device, 44 | requires_grad=True, 45 | ) 46 | self.conic = torch.tensor( 47 | [ 48 | [1.3287e03, 9.7362e02 * 2, 7.3605e02], 49 | [90.0, 20.0 * 2, 60.0], 50 | [776.215, -2464.463 * 2, 8276.755], 51 | ], 52 | dtype=torch.float64, 53 | device=self.device, 54 | requires_grad=True, 55 | ) 56 | self.tiles = Tiles(40, 60, self.device) 57 | self.sorted_gaussian_idx_by_splat_idx, self.splat_start_end_idx_by_tile_idx = get_splats( 58 | self.uv.to(torch.float), 59 | self.tiles, 60 | self.conic.to(torch.float), 61 | self.xyz_camera_frame.to(torch.float), 62 | mh_dist=3.0, 63 | ) 64 | 65 | self.opacity = torch.ones( 66 | self.uv.shape[0], 67 | 1, 68 | dtype=torch.float64, 69 | device=self.device, 70 | requires_grad=True, 71 | ) 72 | 73 | def test_rasterize_image_grad_SH_0(self): 74 | image_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 75 | rgb = ( 76 | torch.ones( 77 | self.uv.shape[0], 78 | 3, 79 | dtype=torch.float64, 80 | device=self.device, 81 | requires_grad=True, 82 | ) 83 | * 0.5 84 | ) 85 | rgb[0, 0] = 0.0 86 | rgb[1, 1] = 0.0 87 | 88 | background_rgb = ( 89 | torch.ones(3, dtype=torch.float64, device=self.device, requires_grad=False) * 0.5 90 | ) 91 | test = torch.autograd.gradcheck( 92 | RenderImage.apply, 93 | ( 94 | rgb, 95 | self.opacity, 96 | self.uv, 97 | self.conic, 98 | self.rays, 99 | self.splat_start_end_idx_by_tile_idx, 100 | self.sorted_gaussian_idx_by_splat_idx, 101 | image_size, 102 | background_rgb, 103 | ), 104 | raise_exception=True, 105 | ) 106 | self.assertTrue(test) 107 | 108 | def test_rasterize_image_grad_SH_0_no_background(self): 109 | image_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 110 | rgb = ( 111 | torch.ones( 112 | self.uv.shape[0], 113 | 3, 114 | dtype=torch.float64, 115 | device=self.device, 116 | requires_grad=True, 117 | ) 118 | * 0.5 119 | ) 120 | rgb[0, 0] = 0.0 121 | rgb[1, 1] = 0.0 122 | 123 | background_rgb = torch.zeros( 124 | 3, dtype=torch.float64, device=self.device, requires_grad=False 125 | ) 126 | test = torch.autograd.gradcheck( 127 | RenderImage.apply, 128 | ( 129 | rgb, 130 | self.opacity, 131 | self.uv, 132 | self.conic, 133 | self.rays, 134 | self.splat_start_end_idx_by_tile_idx, 135 | self.sorted_gaussian_idx_by_splat_idx, 136 | image_size, 137 | background_rgb, 138 | ), 139 | raise_exception=True, 140 | ) 141 | self.assertTrue(test) 142 | 143 | def test_rasterize_image_grad_SH_4(self): 144 | test_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 145 | sh_coeff_4 = ( 146 | torch.ones( 147 | self.uv.shape[0], 148 | 3, 149 | 4, 150 | dtype=torch.float64, 151 | device=self.device, 152 | requires_grad=True, 153 | ) 154 | * 0.5 155 | ) 156 | background_rgb = ( 157 | torch.ones(3, dtype=torch.float64, device=self.device, requires_grad=False) * 0.5 158 | ) 159 | test = torch.autograd.gradcheck( 160 | RenderImage.apply, 161 | ( 162 | sh_coeff_4, 163 | self.opacity, 164 | self.uv, 165 | self.conic, 166 | self.rays, 167 | self.splat_start_end_idx_by_tile_idx, 168 | self.sorted_gaussian_idx_by_splat_idx, 169 | test_size, 170 | background_rgb, 171 | ), 172 | raise_exception=True, 173 | ) 174 | self.assertTrue(test) 175 | 176 | def test_rasterize_image_grad_SH_4_no_background(self): 177 | test_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 178 | sh_coeff_4 = ( 179 | torch.ones( 180 | self.uv.shape[0], 181 | 3, 182 | 4, 183 | dtype=torch.float64, 184 | device=self.device, 185 | requires_grad=True, 186 | ) 187 | * 0.5 188 | ) 189 | background_rgb = torch.zeros( 190 | 3, dtype=torch.float64, device=self.device, requires_grad=False 191 | ) 192 | test = torch.autograd.gradcheck( 193 | RenderImage.apply, 194 | ( 195 | sh_coeff_4, 196 | self.opacity, 197 | self.uv, 198 | self.conic, 199 | self.rays, 200 | self.splat_start_end_idx_by_tile_idx, 201 | self.sorted_gaussian_idx_by_splat_idx, 202 | test_size, 203 | background_rgb, 204 | ), 205 | raise_exception=True, 206 | ) 207 | self.assertTrue(test) 208 | 209 | def test_rasterize_image_grad_SH_9(self): 210 | test_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 211 | sh_coeff_9 = ( 212 | torch.ones( 213 | self.uv.shape[0], 214 | 3, 215 | 9, 216 | dtype=torch.float64, 217 | device=self.device, 218 | requires_grad=True, 219 | ) 220 | * 0.5 221 | ) 222 | background_rgb = ( 223 | torch.ones(3, dtype=torch.float64, device=self.device, requires_grad=False) * 0.5 224 | ) 225 | test = torch.autograd.gradcheck( 226 | RenderImage.apply, 227 | ( 228 | sh_coeff_9, 229 | self.opacity, 230 | self.uv, 231 | self.conic, 232 | self.rays, 233 | self.splat_start_end_idx_by_tile_idx, 234 | self.sorted_gaussian_idx_by_splat_idx, 235 | test_size, 236 | background_rgb, 237 | ), 238 | raise_exception=True, 239 | ) 240 | self.assertTrue(test) 241 | 242 | def test_rasterize_image_grad_SH_9_no_background(self): 243 | test_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 244 | sh_coeff_9 = ( 245 | torch.ones( 246 | self.uv.shape[0], 247 | 3, 248 | 9, 249 | dtype=torch.float64, 250 | device=self.device, 251 | requires_grad=True, 252 | ) 253 | * 0.5 254 | ) 255 | background_rgb = torch.zeros( 256 | 3, dtype=torch.float64, device=self.device, requires_grad=False 257 | ) 258 | test = torch.autograd.gradcheck( 259 | RenderImage.apply, 260 | ( 261 | sh_coeff_9, 262 | self.opacity, 263 | self.uv, 264 | self.conic, 265 | self.rays, 266 | self.splat_start_end_idx_by_tile_idx, 267 | self.sorted_gaussian_idx_by_splat_idx, 268 | test_size, 269 | background_rgb, 270 | ), 271 | raise_exception=True, 272 | ) 273 | self.assertTrue(test) 274 | 275 | def test_rasterize_image_grad_SH_16(self): 276 | test_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 277 | sh_coeff_16 = ( 278 | torch.ones( 279 | self.uv.shape[0], 280 | 3, 281 | 16, 282 | dtype=torch.float64, 283 | device=self.device, 284 | requires_grad=True, 285 | ) 286 | * 0.5 287 | ) 288 | background_rgb = ( 289 | torch.ones(3, dtype=torch.float64, device=self.device, requires_grad=False) * 0.5 290 | ) 291 | test = torch.autograd.gradcheck( 292 | RenderImage.apply, 293 | ( 294 | sh_coeff_16, 295 | self.opacity, 296 | self.uv, 297 | self.conic, 298 | self.rays, 299 | self.splat_start_end_idx_by_tile_idx, 300 | self.sorted_gaussian_idx_by_splat_idx, 301 | test_size, 302 | background_rgb, 303 | ), 304 | raise_exception=True, 305 | atol=3e-5, 306 | ) 307 | self.assertTrue(test) 308 | 309 | def test_rasterize_image_grad_SH_16_no_background(self): 310 | test_size = torch.tensor([40, 60], dtype=torch.int, device=self.device) 311 | sh_coeff_16 = ( 312 | torch.ones( 313 | self.uv.shape[0], 314 | 3, 315 | 16, 316 | dtype=torch.float64, 317 | device=self.device, 318 | requires_grad=True, 319 | ) 320 | * 0.5 321 | ) 322 | background_rgb = torch.zeros( 323 | 3, dtype=torch.float64, device=self.device, requires_grad=False 324 | ) 325 | test = torch.autograd.gradcheck( 326 | RenderImage.apply, 327 | ( 328 | sh_coeff_16, 329 | self.opacity, 330 | self.uv, 331 | self.conic, 332 | self.rays, 333 | self.splat_start_end_idx_by_tile_idx, 334 | self.sorted_gaussian_idx_by_splat_idx, 335 | test_size, 336 | background_rgb, 337 | ), 338 | raise_exception=True, 339 | atol=3e-5, 340 | ) 341 | self.assertTrue(test) 342 | 343 | 344 | if __name__ == "__main__": 345 | unittest.main() 346 | -------------------------------------------------------------------------------- /test/test_structs.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from splat_py.structs import Tiles 5 | 6 | 7 | class TestTiles(unittest.TestCase): 8 | """Test tile generation""" 9 | 10 | def test_initialization(self): 11 | image_height = 1080 12 | image_width = 1920 13 | device = torch.device("cpu") 14 | 15 | tiles = Tiles(image_height, image_width, device) 16 | 17 | self.assertEqual(tiles.image_height, image_height) 18 | self.assertEqual(tiles.image_width, image_width) 19 | self.assertEqual(tiles.device, device) 20 | 21 | self.assertEqual(tiles.image_height_padded, 1088) 22 | self.assertEqual(tiles.image_width_padded, 1920) 23 | 24 | self.assertEqual(tiles.y_tiles_count, 68) 25 | self.assertEqual(tiles.x_tiles_count, 120) 26 | self.assertEqual(tiles.tile_count, 8160) 27 | 28 | 29 | if __name__ == "__main__": 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /test/test_tile_culling.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from splat_py.utils import transform_points_torch 5 | from splat_py.cuda_autograd_functions import ( 6 | CameraPointProjection, 7 | ComputeSigmaWorld, 8 | ComputeProjectionJacobian, 9 | ComputeConic, 10 | ) 11 | from splat_py.structs import Gaussians, Tiles 12 | from splat_py.tile_culling import ( 13 | get_splats, 14 | ) 15 | from gaussian_test_data import get_test_data 16 | 17 | 18 | class TestCulling(unittest.TestCase): 19 | def setUp(self): 20 | self.assertTrue(torch.cuda.is_available()) 21 | self.device = torch.device("cuda") 22 | self.gaussians, self.camera, self.camera_T_world = get_test_data(self.device) 23 | 24 | def test_tile_culling(self): 25 | near_thresh = 0.3 26 | cull_mask_padding = 10 27 | mh_dist = 3.0 28 | 29 | xyz_camera_frame = transform_points_torch(self.gaussians.xyz, self.camera_T_world) 30 | uv = CameraPointProjection.apply(xyz_camera_frame, self.camera.K) 31 | 32 | # perform frustrum culling 33 | culling_mask = torch.zeros( 34 | xyz_camera_frame.shape[0], 35 | dtype=torch.bool, 36 | device=self.gaussians.xyz.device, 37 | ) 38 | culling_mask = culling_mask | (xyz_camera_frame[:, 2] < near_thresh) 39 | culling_mask = ( 40 | culling_mask 41 | | (uv[:, 0] < -1 * cull_mask_padding) 42 | | (uv[:, 0] > self.camera.width + cull_mask_padding) 43 | | (uv[:, 1] < -1 * cull_mask_padding) 44 | | (uv[:, 1] > self.camera.height + cull_mask_padding) 45 | ) 46 | 47 | # cull gaussians outside of camera frustrum 48 | uv = uv[~culling_mask, :] 49 | xyz_camera_frame = xyz_camera_frame[~culling_mask, :] 50 | 51 | culled_gaussians = Gaussians( 52 | xyz=self.gaussians.xyz[~culling_mask, :], 53 | quaternion=self.gaussians.quaternion[~culling_mask, :], 54 | scale=self.gaussians.scale[~culling_mask, :], 55 | opacity=torch.sigmoid( 56 | self.gaussians.opacity[~culling_mask] 57 | ), # apply sigmoid activation to opacity 58 | rgb=self.gaussians.rgb[~culling_mask, :], 59 | ) 60 | 61 | sigma_world = ComputeSigmaWorld.apply(culled_gaussians.quaternion, culled_gaussians.scale) 62 | J = ComputeProjectionJacobian.apply(xyz_camera_frame, self.camera.K) 63 | conic = ComputeConic.apply(sigma_world, J, self.camera_T_world) 64 | 65 | # perform tile culling 66 | tiles = Tiles(self.camera.height, self.camera.width, uv.device) 67 | 68 | sorted_gaussian_idx_by_splat_idx, splat_start_end_idx_by_tile_idx = get_splats( 69 | uv, tiles, conic, xyz_camera_frame, mh_dist 70 | ) 71 | 72 | # fmt: off 73 | expected_sorted_gaussian_idx_by_splat_idx = torch.tensor( 74 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 75 | 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 76 | 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 77 | 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 78 | 2, 0, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 79 | 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 2, 1, 2, 1, 2, 1, 2, 80 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 81 | 0, 2, 0, 2, 0, 2, 0, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 82 | 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 2, 83 | 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 84 | 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 2, 0, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 85 | 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 86 | 0, 1, 2, 0, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 87 | 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 2, 0, 1, 2, 1, 2, 2, 88 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 89 | 2, 0, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 90 | 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 2, 0, 1, 2, 91 | 0, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 92 | 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 2, 2, 2, 93 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 94 | 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 95 | 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 96 | 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 97 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 98 | 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 99 | 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 100 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 101 | device=sorted_gaussian_idx_by_splat_idx.device, 102 | dtype=sorted_gaussian_idx_by_splat_idx.dtype 103 | ) 104 | # fmt: on 105 | self.assertTrue( 106 | torch.equal(sorted_gaussian_idx_by_splat_idx, expected_sorted_gaussian_idx_by_splat_idx) 107 | ) 108 | self.assertEqual(splat_start_end_idx_by_tile_idx.shape[0], 1201) 109 | 110 | 111 | if __name__ == "__main__": 112 | unittest.main() 113 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import unittest 3 | import torch 4 | 5 | from splat_py.utils import ( 6 | quaternion_to_rotation_torch, 7 | transform_points_torch, 8 | compute_rays, 9 | compute_rays_in_world_frame, 10 | ) 11 | from gaussian_test_data import get_test_camera, get_test_camera_T_world 12 | 13 | 14 | class TestUtils(unittest.TestCase): 15 | def test_quaternion_to_rotation_torch(self): 16 | q = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, math.sqrt(2) / 2, 0.0, math.sqrt(2) / 2]) 17 | q = q.reshape(-1, 4) 18 | R = quaternion_to_rotation_torch(q) 19 | 20 | self.assertEqual(R.shape, (2, 3, 3)) 21 | # transpose/inverse each rotation matrix in the 3D tensor 22 | # R * R_inv = I 23 | R_inv = torch.transpose(R, 1, 2) 24 | eye_tensor = torch.eye(3).repeat(2, 1, 1) 25 | self.assertTrue(torch.allclose(torch.bmm(R, R_inv), eye_tensor, atol=1e-6)) 26 | 27 | def test_transform_points_torch(self): 28 | pts = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) 29 | pts = pts.reshape(-1, 3) 30 | expected_pts = torch.tensor([4.0, 0.0, 4.0, 7.0, -3.0, 7.0, 10.0, -6.0, 10.0]) 31 | expected_pts = expected_pts.reshape(-1, 3) 32 | 33 | q = torch.tensor([0.0, math.sqrt(2) / 2, 0.0, math.sqrt(2) / 2]).unsqueeze(dim=0) 34 | transform = torch.eye(4) 35 | transform[:3, :3] = quaternion_to_rotation_torch(q) 36 | transform[:3, 3] = torch.tensor([1.0, 2.0, 3.0]) 37 | 38 | transformed_pts = transform_points_torch(pts, transform) 39 | self.assertEqual(transformed_pts.shape, expected_pts.shape) 40 | 41 | self.assertTrue(transformed_pts.allclose(expected_pts, atol=1e-6)) 42 | 43 | transform_inv = torch.inverse(transform.unsqueeze(dim=0)) 44 | transformed_back_original_pts = transform_points_torch(transformed_pts, transform_inv) 45 | self.assertTrue(transformed_back_original_pts.allclose(pts, atol=1e-6)) 46 | 47 | def test_compute_rays_camera_frame(self): 48 | # get test data 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | camera = get_test_camera(device) 51 | 52 | # compute rays 53 | rays = compute_rays(camera) 54 | self.assertEqual(rays.shape, (640 * 480, 3)) 55 | rays = rays.reshape(camera.height, camera.width, 3) 56 | self.assertEqual(rays.shape, (480, 640, 3)) 57 | 58 | # check some values 59 | self.assertAlmostEqual(rays[0, 0, 0].item(), -0.5403921008110046) 60 | self.assertAlmostEqual(rays[0, 0, 1].item(), -0.4250645041465759) 61 | self.assertAlmostEqual(rays[0, 0, 2].item(), 0.7261518836021423) 62 | 63 | self.assertAlmostEqual(rays[240, 320, 0].item(), 0.0) 64 | self.assertAlmostEqual(rays[240, 320, 1].item(), 0.0) 65 | self.assertAlmostEqual(rays[240, 320, 2].item(), 1.0) 66 | 67 | self.assertAlmostEqual(rays[0, 639, 0].item(), 0.5391948819160461) 68 | self.assertAlmostEqual(rays[0, 639, 1].item(), -0.425452321767807) 69 | self.assertAlmostEqual(rays[0, 639, 2].item(), 0.7268144488334656) 70 | 71 | def test_compute_rays_world_frame(self): 72 | # get test data 73 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 74 | camera = get_test_camera(device) 75 | camera_T_world = get_test_camera_T_world(device) 76 | rays = compute_rays_in_world_frame(camera, camera_T_world) 77 | self.assertEqual(rays.shape, (480, 640, 3)) 78 | 79 | # check some values 80 | self.assertAlmostEqual(rays[0, 0, 0].item(), -0.5390445590019226) 81 | self.assertAlmostEqual(rays[0, 0, 1].item(), -0.6224945187568665) 82 | self.assertAlmostEqual(rays[0, 0, 2].item(), 0.5673900842666626) 83 | 84 | self.assertAlmostEqual(rays[240, 320, 0].item(), -0.004399406723678112) 85 | self.assertAlmostEqual(rays[240, 320, 1].item(), -0.2905626893043518) 86 | self.assertAlmostEqual(rays[240, 320, 2].item(), 0.9568459391593933) 87 | 88 | self.assertAlmostEqual(rays[0, 639, 0].item(), 0.540492832660675) 89 | self.assertAlmostEqual(rays[0, 639, 1].item(), -0.6134769916534424) 90 | self.assertAlmostEqual(rays[0, 639, 2].item(), 0.5757721662521362) 91 | 92 | 93 | if __name__ == "__main__": 94 | unittest.main() 95 | --------------------------------------------------------------------------------