├── docs ├── .gitignore ├── requirements.txt ├── jaxsplat.gif ├── render │ ├── random-output.png │ └── index.md ├── Makefile ├── make.bat ├── conf.py └── index.md ├── lib ├── kernels │ ├── kernels.h │ ├── other.h │ ├── other.cu │ ├── backward.h │ ├── forward.h │ ├── helpers.h │ ├── forward.cu │ └── backward.cu ├── ffi.h ├── ffi.cu ├── common.h ├── ops.h └── ops.cu ├── examples ├── requirements.txt ├── single_image.py └── benchmark.py ├── .clang-format ├── .readthedocs.yaml ├── .gitignore ├── pyproject.toml ├── CMakeLists.txt ├── jaxsplat ├── _types.py ├── _project │ ├── impl.py │ ├── abstract.py │ ├── lowering.py │ └── __init__.py ├── _rasterize │ ├── impl.py │ ├── abstract.py │ ├── lowering.py │ └── __init__.py └── __init__.py ├── LICENSE.txt ├── README.md └── COPYRIGHT.txt /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _*/ 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | furo 3 | myst-parser 4 | -------------------------------------------------------------------------------- /docs/jaxsplat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yklcs/jaxsplat/HEAD/docs/jaxsplat.gif -------------------------------------------------------------------------------- /docs/render/random-output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yklcs/jaxsplat/HEAD/docs/render/random-output.png -------------------------------------------------------------------------------- /lib/kernels/kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "backward.h" // IWYU pragma: export 4 | #include "forward.h" // IWYU pragma: export 5 | #include "other.h" // IWYU pragma: export 6 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | optax 3 | imageio[pyav] 4 | jaxsplat @ git+https://github.com/yklcs/jaxsplat 5 | torch 6 | diff-gaussian-rasterization @ git+https://github.com/yklcs/diff-gaussian-rasterization 7 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | IndentWidth: 4 2 | AllowAllParametersOfDeclarationOnNextLine: false 3 | AllowAllArgumentsOnNextLine: false 4 | BinPackParameters: false 5 | BinPackArguments: false 6 | AlignAfterOpenBracket: BlockIndent 7 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.12" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | CMakeCache.txt 2 | CMakeFiles/ 3 | /Makefile 4 | cmake_install.cmake 5 | compile_commands.json 6 | install_manifest.txt 7 | build/ 8 | .cache/ 9 | *.so 10 | 11 | venv/ 12 | __pycache__/ 13 | *.py[cod] 14 | 15 | _gsplat/ 16 | test.png 17 | test.jpg 18 | out.png 19 | out.mp4 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["scikit-build-core"] 3 | build-backend = "scikit_build_core.build" 4 | 5 | [project] 6 | name = "jaxsplat" 7 | version = "0.1.0" 8 | readme = "README.md" 9 | license = {file = "LICENSE.txt"} 10 | 11 | [tool.setuptools] 12 | cmake.version = ">=3.24" 13 | -------------------------------------------------------------------------------- /lib/kernels/other.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace kernels { 6 | 7 | __global__ void 8 | tiled_memset(float *dst, size_t n_dst, const float *src, size_t n_src); 9 | 10 | __global__ void compute_cov2d_bounds( 11 | const unsigned num_points, 12 | const float *__restrict__ covs2d, 13 | float *__restrict__ conics, 14 | float *__restrict__ radii 15 | ); 16 | 17 | } // namespace kernels 18 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /lib/ffi.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace py = pybind11; 9 | 10 | template py::capsule encapsulate_function(T *fn) { 11 | return pybind11::capsule( 12 | reinterpret_cast(fn), 13 | "xla._CUSTOM_CALL_TARGET" 14 | ); 15 | } 16 | 17 | template 18 | const T *unpack_descriptor(const char *opaque, std::size_t opaque_len) { 19 | if (opaque_len != sizeof(T)) { 20 | throw std::runtime_error("Invalid opaque object size"); 21 | } 22 | return reinterpret_cast(opaque); 23 | } 24 | 25 | template py::bytes pack_descriptor(const T &descriptor) { 26 | const std::string str = 27 | std::string(reinterpret_cast(&descriptor), sizeof(T)); 28 | return py::bytes(str); 29 | } 30 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.24) 2 | 3 | set(CMAKE_CUDA_FLAGS -std=c++17) # must be set before CUDA detection 4 | project(_jaxsplat LANGUAGES CXX CUDA) 5 | 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CUDA_ARCHITECTURES native) 8 | 9 | set(PYBIND11_NEWPYTHON ON) 10 | find_package(pybind11 CONFIG REQUIRED) 11 | 12 | include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) 13 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 14 | 15 | pybind11_add_module( 16 | _jaxsplat 17 | ${CMAKE_CURRENT_LIST_DIR}/lib/ffi.cu 18 | ${CMAKE_CURRENT_LIST_DIR}/lib/ops.cu 19 | ${CMAKE_CURRENT_LIST_DIR}/lib/kernels/forward.cu 20 | ${CMAKE_CURRENT_LIST_DIR}/lib/kernels/backward.cu 21 | ${CMAKE_CURRENT_LIST_DIR}/lib/kernels/other.cu 22 | ) 23 | 24 | if(DEFINED SKBUILD) 25 | set(JAXSPLAT_INSTALL_DEST ${SKBUILD_PLATLIB_DIR}) 26 | else() 27 | set(JAXSPLAT_INSTALL_DEST ${CMAKE_CURRENT_LIST_DIR}) 28 | endif() 29 | 30 | message(STATUS "Installing _jaxsplat to ${JAXSPLAT_INSTALL_DEST}") 31 | install(TARGETS _jaxsplat LIBRARY DESTINATION ${JAXSPLAT_INSTALL_DEST}) 32 | -------------------------------------------------------------------------------- /jaxsplat/_types.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.interpreters import mlir 3 | from jax.interpreters.mlir import ir 4 | from jax.core import ShapedArray, canonicalize_shape 5 | import jax.numpy as jnp 6 | from jax.typing import DTypeLike 7 | from jax.dtypes import canonicalize_dtype 8 | 9 | 10 | class Type: 11 | shape: tuple[int, ...] 12 | dtype: jnp.dtype 13 | 14 | def __init__(self, shape: tuple[int, ...], dtype: DTypeLike): 15 | self.shape = canonicalize_shape(shape) 16 | self.dtype = jnp.dtype(dtype) 17 | 18 | def ir_type(self): 19 | return mlir.dtype_to_ir_type(self.dtype) 20 | 21 | def ir_tensor_type(self): 22 | return ir.RankedTensorType.get(self.shape, self.ir_type()) 23 | 24 | def layout(self): 25 | return tuple(range(len(self.shape) - 1, -1, -1)) 26 | 27 | def shaped_array(self): 28 | return ShapedArray(self.shape, self.dtype) 29 | 30 | def assert_(self, other: jax.Array): 31 | assert self.shape == other.shape and canonicalize_dtype( 32 | self.dtype 33 | ) == canonicalize_dtype(other.dtype) 34 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lucas Yunkyu Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /jaxsplat/_project/impl.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir, xla 2 | from jax.lib import xla_client 3 | from jax import core 4 | 5 | import functools 6 | 7 | import _jaxsplat 8 | from jaxsplat._project import lowering, abstract 9 | 10 | 11 | # register GPU XLA custom calls 12 | for name, value in _jaxsplat.registrations().items(): 13 | xla_client.register_custom_call_target(name, value, platform="gpu") 14 | 15 | 16 | # forward 17 | _project_fwd_p = core.Primitive("project_fwd") 18 | _project_fwd_p.multiple_results = True 19 | _project_fwd_p.def_impl(functools.partial(xla.apply_primitive, _project_fwd_p)) 20 | _project_fwd_p.def_abstract_eval(abstract._project_fwd_abs) 21 | 22 | mlir.register_lowering( 23 | prim=_project_fwd_p, 24 | rule=lowering._project_fwd_rule, 25 | platform="gpu", 26 | ) 27 | 28 | # backward 29 | _project_bwd_p = core.Primitive("project_bwd") 30 | _project_bwd_p.multiple_results = True 31 | _project_bwd_p.def_impl(functools.partial(xla.apply_primitive, _project_bwd_p)) 32 | _project_bwd_p.def_abstract_eval(abstract._project_bwd_abs) 33 | 34 | mlir.register_lowering( 35 | prim=_project_bwd_p, 36 | rule=lowering._project_bwd_rule, 37 | platform="gpu", 38 | ) 39 | -------------------------------------------------------------------------------- /jaxsplat/_rasterize/impl.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir, xla 2 | from jax.lib import xla_client 3 | from jax import core 4 | 5 | import functools 6 | 7 | import _jaxsplat 8 | from jaxsplat._rasterize import lowering, abstract 9 | 10 | 11 | # register GPU XLA custom calls 12 | for name, value in _jaxsplat.registrations().items(): 13 | xla_client.register_custom_call_target(name, value, platform="gpu") 14 | 15 | 16 | # forward 17 | _rasterize_fwd_p = core.Primitive("rasterize_fwd") 18 | _rasterize_fwd_p.multiple_results = True 19 | _rasterize_fwd_p.def_impl(functools.partial(xla.apply_primitive, _rasterize_fwd_p)) 20 | _rasterize_fwd_p.def_abstract_eval(abstract._rasterize_fwd_abs) 21 | 22 | mlir.register_lowering( 23 | prim=_rasterize_fwd_p, 24 | rule=lowering._rasterize_fwd_rule, 25 | platform="gpu", 26 | ) 27 | 28 | # backward 29 | _rasterize_bwd_p = core.Primitive("rasterize_bwd") 30 | _rasterize_bwd_p.multiple_results = True 31 | _rasterize_bwd_p.def_impl(functools.partial(xla.apply_primitive, _rasterize_bwd_p)) 32 | _rasterize_bwd_p.def_abstract_eval(abstract._rasterize_bwd_abs) 33 | 34 | mlir.register_lowering( 35 | prim=_rasterize_bwd_p, 36 | rule=lowering._rasterize_bwd_rule, 37 | platform="gpu", 38 | ) 39 | -------------------------------------------------------------------------------- /lib/kernels/other.cu: -------------------------------------------------------------------------------- 1 | #include "helpers.h" 2 | #include "other.h" 3 | 4 | #include 5 | #include 6 | 7 | namespace cg = cooperative_groups; 8 | 9 | __global__ void kernels::tiled_memset( 10 | float *dst, 11 | size_t n_dst, 12 | const float *src, 13 | size_t n_src 14 | ) { 15 | for (int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_dst; 16 | tid += blockDim.x * gridDim.x) { 17 | dst[tid] = src[tid % n_src]; 18 | } 19 | } 20 | 21 | __global__ void kernels::compute_cov2d_bounds( 22 | const unsigned num_points, 23 | const float *__restrict__ covs2d, 24 | float *__restrict__ conics, 25 | float *__restrict__ radii 26 | ) { 27 | unsigned row = cg::this_grid().thread_rank(); 28 | if (row >= num_points) { 29 | return; 30 | } 31 | int index = row * 3; 32 | float3 conic; 33 | float radius; 34 | float3 cov2d{ 35 | (float)covs2d[index], 36 | (float)covs2d[index + 1], 37 | (float)covs2d[index + 2] 38 | }; 39 | helpers::compute_cov2d_bounds(cov2d, conic, radius); 40 | conics[index] = conic.x; 41 | conics[index + 1] = conic.y; 42 | conics[index + 2] = conic.z; 43 | radii[row] = radius; 44 | } 45 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | project = "jaxsplat" 7 | copyright = "2024, Lucas Yunkyu Lee" 8 | author = "Lucas Yunkyu Lee" 9 | 10 | 11 | extensions = ["myst_parser", "sphinx.ext.autodoc", "sphinx.ext.napoleon"] 12 | 13 | autodoc_mock_imports = ["jax", "jaxlib", "jaxsplat._jaxsplat"] 14 | 15 | templates_path = ["_templates"] 16 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 17 | 18 | html_theme = "furo" 19 | html_static_path = ["_static"] 20 | 21 | html_theme_options = { 22 | "footer_icons": [ 23 | { 24 | "name": "GitHub", 25 | "url": "https://github.com/yklcs/jaxsplat", 26 | "html": """ 27 | 28 | 29 | 30 | """, 31 | "class": "", 32 | }, 33 | ], 34 | } 35 | python_maximum_signature_line_length = 10 36 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # jaxsplat 2 | 3 | ![](./jaxsplat.gif) 4 | 5 | jaxsplat is a port of 3D Gaussian Splatting to [JAX](https://github.com/google/jax). 6 | Fully differentiable, CUDA accelerated. 7 | 8 | ## Installation 9 | 10 | Requires a working CUDA toolchain to install. Simply `pip install`ing directly from source should build and install jaxsplat: 11 | 12 | ```shell 13 | $ python -m venv venv && . venv/bin/activate 14 | $ pip install git+https://github.com/yklcs/jaxsplat 15 | ``` 16 | 17 | ## Usage 18 | 19 | The primary function of jaxsplat is `jaxsplat.render`, which renders 3D Gaussians to a 2D image differentiably. 20 | View [the rendering API docs](./render/index) for more complete docs. 21 | 22 | ```python 23 | img = jaxsplat.render( 24 | means3d, 25 | scales, 26 | quats, 27 | colors, 28 | opacities, 29 | viewmat=viewmat, 30 | background=background, 31 | img_shape=img_shape, 32 | f=f, 33 | c=c, 34 | glob_scale=glob_scale, 35 | clip_thresh=clip_thresh, 36 | block_size=block_size, 37 | ) 38 | ``` 39 | 40 | ## Bibliography 41 | 42 | - [3D Gaussian Splatting for Real-Time Radiance Field Rendering](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) (Kerbl et al., SIGGRAPH 2023) 43 | - [gsplat](https://github.com/nerfstudio-project/gsplat) 44 | 45 | ```{eval-rst} 46 | .. toctree:: 47 | :maxdepth: 1 48 | :hidden: 49 | 50 | self 51 | GitHub 52 | 53 | .. toctree:: 54 | :maxdepth: 2 55 | :hidden: 56 | :caption: API Reference: 57 | 58 | render/index 59 | 60 | .. toctree:: 61 | :maxdepth: 2 62 | :hidden: 63 | :caption: Examples: 64 | 65 | Single image 66 | ``` 67 | -------------------------------------------------------------------------------- /lib/ffi.cu: -------------------------------------------------------------------------------- 1 | #include "ffi.h" 2 | #include "ops.h" 3 | 4 | py::dict registrations() { 5 | py::dict dict; 6 | 7 | dict["project_fwd"] = encapsulate_function(ops::project::fwd::xla); 8 | dict["project_bwd"] = encapsulate_function(ops::project::bwd::xla); 9 | dict["rasterize_fwd"] = encapsulate_function(ops::rasterize::fwd::xla); 10 | dict["rasterize_bwd"] = encapsulate_function(ops::rasterize::bwd::xla); 11 | 12 | return dict; 13 | } 14 | 15 | py::bytes make_descriptor( 16 | unsigned num_points, 17 | std::pair img_shape, 18 | std::pair f, 19 | std::pair c, 20 | float glob_scale, 21 | float clip_thresh, 22 | unsigned block_width 23 | ) { 24 | float4 intrins = {f.first, f.second, c.first, c.second}; 25 | 26 | // img_shape is in (H,W) 27 | dim3 img_shape_dim3 = {img_shape.second, img_shape.first, 1}; 28 | 29 | const unsigned block_dim_1d = block_width * block_width; 30 | const unsigned grid_dim_1d = (num_points + block_dim_1d - 1) / block_dim_1d; 31 | dim3 block_dim_2d = {block_width, block_width, 1}; 32 | dim3 grid_dim_2d = { 33 | (img_shape_dim3.x + block_width - 1) / block_width, 34 | (img_shape_dim3.y + block_width - 1) / block_width, 35 | 1 36 | }; 37 | 38 | ops::Descriptor desc = { 39 | num_points, 40 | img_shape_dim3, 41 | intrins, 42 | glob_scale, 43 | clip_thresh, 44 | block_width, 45 | grid_dim_1d, 46 | block_dim_1d, 47 | grid_dim_2d, 48 | block_dim_2d 49 | }; 50 | 51 | return pack_descriptor(desc); 52 | } 53 | 54 | PYBIND11_MODULE(_jaxsplat, m) { 55 | m.def("registrations", ®istrations); 56 | 57 | m.def( 58 | "make_descriptor", 59 | make_descriptor, 60 | py::arg("num_points"), 61 | py::arg("img_shape"), 62 | py::arg("f"), 63 | py::arg("c"), 64 | py::arg("glob_scale"), 65 | py::arg("clip_thresh"), 66 | py::arg("block_width") 67 | ); 68 | } 69 | -------------------------------------------------------------------------------- /jaxsplat/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jaxsplat._project import project 3 | from jaxsplat._rasterize import rasterize 4 | 5 | 6 | def render( 7 | means3d: jax.Array, 8 | scales: jax.Array, 9 | quats: jax.Array, 10 | colors: jax.Array, 11 | opacities: jax.Array, 12 | *, 13 | viewmat: jax.Array, 14 | background: jax.Array, 15 | img_shape: tuple[int, int], 16 | f: tuple[float, float], 17 | c: tuple[int, int], 18 | glob_scale: float, 19 | clip_thresh: float, 20 | block_size: int = 16, 21 | ) -> jax.Array: 22 | """ 23 | Renders 3D Gaussians to a 2D image differentiably. 24 | Output is differentiable w.r.t. all non-keyword-only arguments. 25 | 26 | Args: 27 | means3d (Array): (N, 3) array of 3D Gaussian means 28 | scales (Array): (N, 3) array of 3D Gaussian scales 29 | quats (Array): (N, 4) array of 3D Gaussian quaternions, must be normalized 30 | colors (Array): (N, 3) array of 3D Gaussian colors 31 | opacities (Array): (N, 1) array of 3D Gaussian opacities 32 | Keyword Args: 33 | viewmat (Array): (4, 4) array containing view matrix 34 | background (Array): (3,) array of background color 35 | img_shape (tuple[int, int]): Image shape in (H, W) 36 | f (tuple[float, float]): Focal lengths in (fx, fy) 37 | c (tuple[int, int]): Principal points in (cx, cy) 38 | glob_scale (float): Global scaling factor 39 | clip_thresh (float): Minimum z depth clipping threshold 40 | block_size (int): CUDA block size, 1 < block_size <= 16. 41 | Returns: 42 | Array: Rendered image 43 | """ 44 | (xys, depths, radii, conics, _num_tiles_hit, cum_tiles_hit) = project( 45 | means3d, 46 | scales, 47 | quats, 48 | viewmat, 49 | img_shape=img_shape, 50 | f=f, 51 | c=c, 52 | glob_scale=glob_scale, 53 | clip_thresh=clip_thresh, 54 | block_width=block_size, 55 | ) 56 | 57 | img = rasterize( 58 | colors, 59 | opacities, 60 | background, 61 | xys, 62 | depths, 63 | radii, 64 | conics, 65 | cum_tiles_hit, 66 | img_shape=img_shape, 67 | block_width=block_size, 68 | ) 69 | 70 | return img 71 | 72 | 73 | __all__ = ["render", "project", "rasterize"] 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jaxsplat 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/jaxsplat/badge/?version=latest)](https://jaxsplat.readthedocs.io/en/latest/?badge=latest) 4 | 5 | ![](./docs/jaxsplat.gif) 6 | 7 | A port of 3D Gaussian Splatting to JAX. 8 | Fully differentiable, CUDA accelerated. 9 | 10 | [Read documentation](https://jaxsplat.readthedocs.io) 11 | 12 | ## Installation 13 | 14 | Requires a working CUDA toolchain to install. 15 | Simply `pip install`ing directly from source should build and install jaxsplat: 16 | 17 | ```shell 18 | $ python -m venv venv && . venv/bin/activate 19 | $ pip install git+https://github.com/yklcs/jaxsplat 20 | ``` 21 | 22 | ## Usage 23 | 24 | The primary function of this library is `jaxsplat.render`: 25 | 26 | ```python 27 | img = jaxsplat.render( 28 | means3d, # jax.Array (N, 3) 29 | scales, # jax.Array (N, 3) 30 | quats, # jax.Array (N, 4) normalized 31 | colors, # jax.Array (N, 3) 32 | opacities, # jax.Array (N, 1) 33 | viewmat=viewmat, # jax.Array (4, 4) 34 | background=background, # jax.Array (3,) 35 | img_shape=img_shape, # tuple[int, int] = (H, W) 36 | f=f, # tuple[float, float] = (fx, fy) 37 | c=c, # tuple[int, int] = (cx, cy) 38 | glob_scale=glob_scale, # float 39 | clip_thresh=clip_thresh, # float 40 | block_size=block_size, # int <= 16 41 | ) 42 | ``` 43 | 44 | The rendered output is differentiable w.r.t. `means3d`, `scales`, `quats`, `colors`, and `opacities`. 45 | 46 | Alternatively, `jaxsplat.project` projects 3D Gaussians to 2D, and `jaxsplat.rasterize` sorts and rasterizes 2D Gaussians. 47 | `jaxsplat.render` successively calls `jaxsplat.project` and `jaxsplat.rasterize` under the hood. 48 | 49 | ## Examples 50 | 51 | See [/examples](./examples) for examples. 52 | These can be ran like the following: 53 | 54 | ```shell 55 | $ python -m venv venv && . venv/bin/activate 56 | $ pip install -r examples/requirements.txt 57 | 58 | # Train Gaussians on a single image 59 | $ python -m examples.single_image input.png 60 | ``` 61 | 62 | ## Method 63 | 64 | We use modified versions of [gsplat](https://github.com/nerfstudio-project/gsplat)'s kernels. 65 | The [original INRIA implementation](https://github.com/graphdeco-inria/diff-gaussian-rasterization) uses a custom license and contains dynamically shaped tensors which are harder to port to JAX/XLA. 66 | -------------------------------------------------------------------------------- /docs/render/index.md: -------------------------------------------------------------------------------- 1 | # Rendering 2 | 3 | The primary function of this library is `jaxsplat.render`: 4 | 5 | ```{eval-rst} 6 | .. autofunction:: jaxsplat.render 7 | ``` 8 | 9 | ## Basic usage 10 | 11 | Let's render random 3D Gaussians to an image. 12 | We start by importing packages: 13 | 14 | ```python 15 | import jaxsplat 16 | import jax.numpy as jnp 17 | import jax 18 | import imageio.v3 as iio 19 | ``` 20 | 21 | `jaxsplat.render` takes 5 array inputs describing the geometry and appearance of 3D Gaussians. 22 | Instead of means, 3D covariances, colors, and opacities, we use means, scales, quaternions, colors, and opacities. 23 | Let's initialize them randomly with 1000 Gaussians: 24 | 25 | ```python 26 | key = jax.random.key(0) 27 | key, *subkeys = jax.random.split(key, 6) 28 | 29 | num_points = 1000 30 | 31 | means3d = jax.random.uniform(subkeys[0], (num_points, 3), minval=-3, maxval=3) 32 | scales = jax.random.uniform(subkeys[1], (num_points, 3), maxval=0.5) 33 | quats = jax.random.normal(subkeys[2], (num_points, 4)) 34 | colors = jax.random.uniform(subkeys[3], (num_points, 3)) 35 | opacities = jax.random.uniform(subkeys[4], (num_points, 1)) 36 | ``` 37 | 38 | Our quaternions need to be normalized before using them with `jaxsplat.render`: 39 | 40 | ```python 41 | quats /= jnp.linalg.norm(quats, axis=-1, keepdims=True) 42 | ``` 43 | 44 | We then set the other parameters. Let's output a `(900, 1600)` image, use 90 degrees FoV, and set the principal point to the center of the viewport. 45 | We also set the global scale of Gaussians to 1 and clip any Gaussians closer than 0.01 to the camera: 46 | 47 | ```python 48 | viewmat = jnp.array([ 49 | [1, 0, 0, 0], 50 | [0, 1, 0, 0], 51 | [0, 0, 1, 8], 52 | [0, 0, 0, 1], 53 | ], dtype=jnp.float32) 54 | background = jnp.ones((3,), dtype=jnp.float32) 55 | W = 1600 56 | H = 900 57 | img_shape = H, W 58 | f = (W / 2, H / 2) 59 | c = (W / 2, H / 2) 60 | glob_scale = 1 61 | clip_thresh = 0.01 62 | ``` 63 | 64 | We're now ready to render and save the image to a file: 65 | 66 | ```python 67 | img = jaxsplat.render( 68 | means3d, 69 | scales, 70 | quats, 71 | colors, 72 | opacities, 73 | viewmat=viewmat, 74 | background=background, 75 | img_shape=img_shape, 76 | f=f, 77 | c=c, 78 | glob_scale=glob_scale, 79 | clip_thresh=clip_thresh, 80 | ) 81 | iio.imwrite("output.png", (img * 255).astype(jnp.uint8)) 82 | ``` 83 | 84 | We got our output! 85 | 86 | ![](./random-output.png) 87 | -------------------------------------------------------------------------------- /lib/kernels/backward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace kernels { 9 | 10 | // for f : R(n) -> R(m), J in R(m, n), 11 | // v is cotangent in R(m), e.g. dL/df in R(m), 12 | // compute vjp i.e. vT J -> R(n) 13 | __global__ void project_gaussians_bwd( 14 | const int num_points, const float3 *__restrict__ means3d, 15 | const float3 *__restrict__ scales, const float glob_scale, 16 | const float4 *__restrict__ quats, const float *__restrict__ viewmat, 17 | const float4 intrins, const dim3 img_size, const float *__restrict__ cov3d, 18 | const int *__restrict__ radii, const float3 *__restrict__ conics, 19 | const float *__restrict__ compensation, const float2 *__restrict__ v_xy, 20 | const float *__restrict__ v_depth, const float3 *__restrict__ v_conic, 21 | const float *__restrict__ v_compensation, float3 *__restrict__ v_cov2d, 22 | float *__restrict__ v_cov3d, float3 *__restrict__ v_mean3d, 23 | float3 *__restrict__ v_scale, float4 *__restrict__ v_quat); 24 | 25 | // compute jacobians of output image wrt binned and sorted gaussians 26 | __global__ void nd_rasterize_backward_kernel( 27 | const dim3 tile_bounds, const dim3 img_size, const unsigned channels, 28 | const int32_t *__restrict__ gaussians_ids_sorted, 29 | const int2 *__restrict__ tile_bins, const float2 *__restrict__ xys, 30 | const float3 *__restrict__ conics, const float *__restrict__ rgbs, 31 | const float *__restrict__ opacities, const float *__restrict__ background, 32 | const float *__restrict__ final_Ts, const int *__restrict__ final_index, 33 | const float *__restrict__ v_output, 34 | const float *__restrict__ v_output_alpha, float2 *__restrict__ v_xy, 35 | float2 *__restrict__ v_xy_abs, float3 *__restrict__ v_conic, 36 | float *__restrict__ v_rgb, float *__restrict__ v_opacity); 37 | 38 | __global__ void rasterize_bwd( 39 | const dim3 tile_bounds, const dim3 img_size, 40 | const int32_t *__restrict__ gaussian_ids_sorted, 41 | const int2 *__restrict__ tile_bins, const float2 *__restrict__ xys, 42 | const float3 *__restrict__ conics, const float3 *__restrict__ rgbs, 43 | const float *__restrict__ opacities, const float3 &__restrict__ background, 44 | const float *__restrict__ final_Ts, const int *__restrict__ final_index, 45 | const float3 *__restrict__ v_output, 46 | const float *__restrict__ v_output_alpha, float2 *__restrict__ v_xy, 47 | float2 *__restrict__ v_xy_abs, float3 *__restrict__ v_conic, 48 | float3 *__restrict__ v_rgb, float *__restrict__ v_opacity); 49 | 50 | __device__ void project_cov3d_ewa_vjp(const float3 &mean3d, const float *cov3d, 51 | const float *viewmat, const float fx, 52 | const float fy, const float3 &v_cov2d, 53 | float3 &v_mean3d, float *v_cov3d); 54 | 55 | __device__ void scale_rot_to_cov3d_vjp(const float3 scale, 56 | const float glob_scale, 57 | const float4 quat, const float *v_cov3d, 58 | float3 &v_scale, float4 &v_quat); 59 | 60 | } // namespace kernels 61 | -------------------------------------------------------------------------------- /lib/kernels/forward.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace kernels { 9 | 10 | // compute the 2d gaussian parameters from 3d gaussian parameters 11 | __global__ void project_gaussians_fwd( 12 | const int num_points, const float3 *__restrict__ means3d, 13 | const float3 *__restrict__ scales, const float glob_scale, 14 | const float4 *__restrict__ quats, const float *__restrict__ viewmat, 15 | const float4 intrins, const dim3 img_size, const dim3 tile_bounds, 16 | const unsigned block_width, const float clip_thresh, 17 | float *__restrict__ covs3d, float2 *__restrict__ xys, 18 | float *__restrict__ depths, int *__restrict__ radii, 19 | float3 *__restrict__ conics, float *__restrict__ compensation, 20 | int32_t *__restrict__ num_tiles_hit); 21 | 22 | // compute output color image from binned and sorted gaussians 23 | __global__ void 24 | rasterize_fwd(const dim3 tile_bounds, const dim3 img_size, 25 | const int32_t *__restrict__ gaussian_ids_sorted, 26 | const int2 *__restrict__ tile_bins, 27 | const float2 *__restrict__ xys, const float3 *__restrict__ conics, 28 | const float3 *__restrict__ colors, 29 | const float *__restrict__ opacities, float *__restrict__ final_Ts, 30 | int *__restrict__ final_index, float3 *__restrict__ out_img, 31 | const float3 &__restrict__ background); 32 | 33 | // compute output color image from binned and sorted gaussians 34 | __global__ void nd_rasterize_forward( 35 | const dim3 tile_bounds, const dim3 img_size, const unsigned channels, 36 | const int32_t *__restrict__ gaussian_ids_sorted, 37 | const int2 *__restrict__ tile_bins, const float2 *__restrict__ xys, 38 | const float3 *__restrict__ conics, const float *__restrict__ colors, 39 | const float *__restrict__ opacities, float *__restrict__ final_Ts, 40 | int *__restrict__ final_index, float *__restrict__ out_img, 41 | const float *__restrict__ background); 42 | 43 | // device helper to approximate projected 2d cov from 3d mean and cov 44 | __device__ void project_cov3d_ewa(const float3 &mean3d, const float *cov3d, 45 | const float *viewmat, const float fx, 46 | const float fy, const float tan_fovx, 47 | const float tan_fovy, float3 &cov2d, 48 | float &comp); 49 | 50 | // device helper to get 3D covariance from scale and quat parameters 51 | __device__ void scale_rot_to_cov3d(const float3 scale, const float glob_scale, 52 | const float4 quat, float *cov3d); 53 | 54 | __global__ void map_gaussian_to_intersects( 55 | const int num_points, const float2 *__restrict__ xys, 56 | const float *__restrict__ depths, const int *__restrict__ radii, 57 | const int32_t *__restrict__ cum_tiles_hit, const dim3 tile_bounds, 58 | const unsigned block_width, int64_t *__restrict__ isect_ids, 59 | int32_t *__restrict__ gaussian_ids); 60 | 61 | __global__ void get_tile_bin_edges(const int num_intersects, 62 | const int64_t *__restrict__ isect_ids_sorted, 63 | int2 *__restrict__ tile_bins); 64 | 65 | __global__ void 66 | rasterize_fwd(const dim3 tile_bounds, const dim3 img_size, 67 | const int32_t *__restrict__ gaussian_ids_sorted, 68 | const int2 *__restrict__ tile_bins, 69 | const float2 *__restrict__ xys, const float3 *__restrict__ conics, 70 | const float3 *__restrict__ colors, 71 | const float *__restrict__ opacities, float *__restrict__ final_Ts, 72 | int *__restrict__ final_index, float3 *__restrict__ out_img, 73 | const float3 &__restrict__ background); 74 | 75 | __global__ void nd_rasterize_forward( 76 | const dim3 tile_bounds, const dim3 img_size, const unsigned channels, 77 | const int32_t *__restrict__ gaussian_ids_sorted, 78 | const int2 *__restrict__ tile_bins, const float2 *__restrict__ xys, 79 | const float3 *__restrict__ conics, const float *__restrict__ colors, 80 | const float *__restrict__ opacities, float *__restrict__ final_Ts, 81 | int *__restrict__ final_index, float *__restrict__ out_img, 82 | const float *__restrict__ background); 83 | 84 | } // namespace kernels 85 | -------------------------------------------------------------------------------- /jaxsplat/_rasterize/abstract.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from jaxsplat._types import Type 5 | 6 | 7 | class RasterizeFwdTypes: 8 | def __init__( 9 | self, 10 | num_points: int, 11 | img_shape: tuple[int, int], 12 | ): 13 | self.in_colors = Type((num_points, 3), jnp.float32) 14 | self.in_opacities = Type((num_points, 1), jnp.float32) 15 | self.in_background = Type((3,), jnp.float32) 16 | self.in_xys = Type((num_points, 2), jnp.float32) 17 | self.in_depths = Type((num_points, 1), jnp.float32) 18 | self.in_radii = Type((num_points, 1), jnp.int32) 19 | self.in_conics = Type((num_points, 3), jnp.float32) 20 | self.in_cum_tiles_hit = Type((num_points, 1), jnp.uint32) 21 | 22 | self.out_final_Ts = Type((*img_shape, 1), jnp.float32) 23 | self.out_final_idx = Type((*img_shape, 1), jnp.int32) 24 | self.out_img = Type((*img_shape, 3), jnp.float32) 25 | 26 | 27 | def _rasterize_fwd_abs( 28 | colors: jax.Array, 29 | opacities: jax.Array, 30 | background: jax.Array, 31 | xys: jax.Array, 32 | depths: jax.Array, 33 | radii: jax.Array, 34 | conics: jax.Array, 35 | cum_tiles_hit: jax.Array, 36 | # 37 | num_points: int, 38 | img_shape: tuple[int, int], 39 | block_width: int, 40 | ): 41 | t = RasterizeFwdTypes( 42 | num_points, 43 | img_shape, 44 | ) 45 | 46 | t.in_colors.assert_(colors) 47 | t.in_opacities.assert_(opacities) 48 | t.in_background.assert_(background) 49 | t.in_xys.assert_(xys) 50 | t.in_depths.assert_(depths) 51 | t.in_radii.assert_(radii) 52 | t.in_conics.assert_(conics) 53 | t.in_cum_tiles_hit.assert_(cum_tiles_hit) 54 | 55 | return ( 56 | t.out_final_Ts.shaped_array(), 57 | t.out_final_idx.shaped_array(), 58 | t.out_img.shaped_array(), 59 | ) 60 | 61 | 62 | class RasterizeBwdTypes: 63 | def __init__( 64 | self, 65 | num_points: int, 66 | img_shape: tuple[int, int], 67 | ): 68 | self.in_colors = Type((num_points, 3), jnp.float32) 69 | self.in_opacities = Type((num_points, 1), jnp.float32) 70 | self.in_background = Type((3,), jnp.float32) 71 | self.in_xys = Type((num_points, 2), jnp.float32) 72 | self.in_depths = Type((num_points, 1), jnp.float32) 73 | self.in_radii = Type((num_points, 1), jnp.int32) 74 | self.in_conics = Type((num_points, 3), jnp.float32) 75 | self.in_cum_tiles_hit = Type((num_points, 1), jnp.uint32) 76 | self.in_final_Ts = Type((*img_shape, 1), jnp.float32) 77 | self.in_final_idx = Type((*img_shape, 1), jnp.int32) 78 | self.in_v_img = Type((*img_shape, 3), jnp.float32) 79 | self.in_v_img_alpha = Type((*img_shape, 1), jnp.float32) 80 | 81 | self.out_v_color = Type((num_points, 3), jnp.float32) 82 | self.out_v_opacity = Type((num_points, 1), jnp.float32) 83 | self.out_v_xy = Type((num_points, 2), jnp.float32) 84 | self.out_v_xy_abs = Type((num_points, 2), jnp.float32) 85 | self.out_v_conic = Type((num_points, 3), jnp.float32) 86 | 87 | 88 | def _rasterize_bwd_abs( 89 | colors: jax.Array, 90 | opacities: jax.Array, 91 | background: jax.Array, 92 | xys: jax.Array, 93 | depths: jax.Array, 94 | radii: jax.Array, 95 | conics: jax.Array, 96 | cum_tiles_hit: jax.Array, 97 | final_Ts: jax.Array, 98 | final_idx: jax.Array, 99 | v_img: jax.Array, 100 | v_img_alpha: jax.Array, 101 | # 102 | num_points: int, 103 | img_shape: tuple[int, int], 104 | block_width: int, 105 | ): 106 | t = RasterizeBwdTypes(num_points, img_shape) 107 | 108 | t.in_colors.assert_(colors) 109 | t.in_opacities.assert_(opacities) 110 | t.in_background.assert_(background) 111 | t.in_xys.assert_(xys) 112 | t.in_depths.assert_(depths) 113 | t.in_radii.assert_(radii) 114 | t.in_conics.assert_(conics) 115 | t.in_cum_tiles_hit.assert_(cum_tiles_hit) 116 | t.in_final_Ts.assert_(final_Ts) 117 | t.in_final_idx.assert_(final_idx) 118 | t.in_v_img.assert_(v_img) 119 | t.in_v_img_alpha.assert_(v_img_alpha) 120 | 121 | return ( 122 | t.out_v_color.shaped_array(), 123 | t.out_v_opacity.shaped_array(), 124 | t.out_v_xy.shaped_array(), 125 | t.out_v_xy_abs.shaped_array(), 126 | t.out_v_conic.shaped_array(), 127 | ) 128 | -------------------------------------------------------------------------------- /jaxsplat/_rasterize/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | from jaxlib.hlo_helpers import custom_call 4 | 5 | import _jaxsplat 6 | from jaxsplat._rasterize.abstract import ( 7 | RasterizeFwdTypes, 8 | RasterizeBwdTypes, 9 | ) 10 | 11 | 12 | def _rasterize_fwd_rule( 13 | ctx: mlir.LoweringRuleContext, 14 | # 15 | colors: ir.Value, 16 | opacities: ir.Value, 17 | background: ir.Value, 18 | xys: ir.Value, 19 | depths: ir.Value, 20 | radii: ir.Value, 21 | conics: ir.Value, 22 | cum_tiles_hit: ir.Value, 23 | # 24 | num_points: int, 25 | img_shape: tuple[int, int], 26 | block_width: int, 27 | ): 28 | opaque = _jaxsplat.make_descriptor( 29 | num_points=num_points, 30 | img_shape=img_shape, 31 | f=(0.0, 0.0), 32 | c=(0.0, 0.0), 33 | glob_scale=0.0, 34 | clip_thresh=0.0, 35 | block_width=block_width, 36 | ) 37 | 38 | t = RasterizeFwdTypes(num_points, img_shape) 39 | 40 | return custom_call( 41 | "rasterize_fwd", 42 | operands=[ 43 | colors, 44 | opacities, 45 | background, 46 | xys, 47 | depths, 48 | radii, 49 | conics, 50 | cum_tiles_hit, 51 | ], 52 | operand_layouts=[ 53 | t.in_colors.layout(), 54 | t.in_opacities.layout(), 55 | t.in_background.layout(), 56 | t.in_xys.layout(), 57 | t.in_depths.layout(), 58 | t.in_radii.layout(), 59 | t.in_conics.layout(), 60 | t.in_cum_tiles_hit.layout(), 61 | ], 62 | result_types=[ 63 | t.out_final_Ts.ir_tensor_type(), 64 | t.out_final_idx.ir_tensor_type(), 65 | t.out_img.ir_tensor_type(), 66 | ], 67 | result_layouts=[ 68 | t.out_final_Ts.layout(), 69 | t.out_final_idx.layout(), 70 | t.out_img.layout(), 71 | ], 72 | backend_config=opaque, 73 | ).results 74 | 75 | 76 | def _rasterize_bwd_rule( 77 | ctx: mlir.LoweringRuleContext, 78 | # 79 | colors: ir.Value, 80 | opacities: ir.Value, 81 | background: ir.Value, 82 | xys: ir.Value, 83 | depths: ir.Value, 84 | radii: ir.Value, 85 | conics: ir.Value, 86 | cum_tiles_hit: ir.Value, 87 | final_Ts: ir.Value, 88 | final_idx: ir.Value, 89 | v_img: ir.Value, 90 | v_img_alpha: ir.Value, 91 | # 92 | num_points: int, 93 | img_shape: tuple[int, int], 94 | block_width: int, 95 | ): 96 | opaque = _jaxsplat.make_descriptor( 97 | num_points=num_points, 98 | img_shape=img_shape, 99 | f=(0.0, 0.0), 100 | c=(0.0, 0.0), 101 | glob_scale=0.0, 102 | clip_thresh=0.0, 103 | block_width=block_width, 104 | ) 105 | 106 | t = RasterizeBwdTypes(num_points, img_shape) 107 | 108 | return custom_call( 109 | "rasterize_bwd", 110 | operands=[ 111 | colors, 112 | opacities, 113 | background, 114 | xys, 115 | depths, 116 | radii, 117 | conics, 118 | cum_tiles_hit, 119 | final_Ts, 120 | final_idx, 121 | v_img, 122 | v_img_alpha, 123 | ], 124 | operand_layouts=[ 125 | t.in_colors.layout(), 126 | t.in_opacities.layout(), 127 | t.in_background.layout(), 128 | t.in_xys.layout(), 129 | t.in_depths.layout(), 130 | t.in_radii.layout(), 131 | t.in_conics.layout(), 132 | t.in_cum_tiles_hit.layout(), 133 | t.in_final_Ts.layout(), 134 | t.in_final_idx.layout(), 135 | t.in_v_img.layout(), 136 | t.in_v_img_alpha.layout(), 137 | ], 138 | result_types=[ 139 | t.out_v_color.ir_tensor_type(), 140 | t.out_v_opacity.ir_tensor_type(), 141 | t.out_v_xy.ir_tensor_type(), 142 | t.out_v_xy_abs.ir_tensor_type(), 143 | t.out_v_conic.ir_tensor_type(), 144 | ], 145 | result_layouts=[ 146 | t.out_v_color.layout(), 147 | t.out_v_opacity.layout(), 148 | t.out_v_xy.layout(), 149 | t.out_v_xy_abs.layout(), 150 | t.out_v_conic.layout(), 151 | ], 152 | backend_config=opaque, 153 | ).results 154 | -------------------------------------------------------------------------------- /jaxsplat/_project/abstract.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from jaxsplat._types import Type 5 | 6 | 7 | class ProjectFwdTypes: 8 | def __init__(self, num_points: int): 9 | self.in_mean3ds = Type((num_points, 3), jnp.float32) 10 | self.in_scales = Type((num_points, 3), jnp.float32) 11 | self.in_quats = Type((num_points, 4), jnp.float32) 12 | self.in_viewmat = Type((4, 4), jnp.float32) 13 | 14 | self.out_cov3ds = Type((num_points, 6), jnp.float32) 15 | self.out_xys = Type((num_points, 2), jnp.float32) 16 | self.out_depths = Type((num_points, 1), jnp.float32) 17 | self.out_radii = Type((num_points, 1), jnp.int32) 18 | self.out_conics = Type((num_points, 3), jnp.float32) 19 | self.out_compensation = Type((num_points, 1), jnp.float32) 20 | self.out_num_tiles_hit = Type((num_points, 1), jnp.uint32) 21 | self.out_cum_tiles_hit = Type((num_points, 1), jnp.uint32) 22 | 23 | 24 | def _project_fwd_abs( 25 | mean3ds: jax.Array, 26 | scales: jax.Array, 27 | quats: jax.Array, 28 | viewmat: jax.Array, 29 | # 30 | num_points: int, 31 | img_shape: tuple[int, int], 32 | f: tuple[float, float], 33 | c: tuple[float, float], 34 | glob_scale: float, 35 | clip_thresh: float, 36 | block_width: int, 37 | ): 38 | t = ProjectFwdTypes(num_points) 39 | 40 | t.in_mean3ds.assert_(mean3ds) 41 | t.in_scales.assert_(scales) 42 | t.in_quats.assert_(quats) 43 | t.in_viewmat.assert_(viewmat) 44 | 45 | return ( 46 | t.out_cov3ds.shaped_array(), 47 | t.out_xys.shaped_array(), 48 | t.out_depths.shaped_array(), 49 | t.out_radii.shaped_array(), 50 | t.out_conics.shaped_array(), 51 | t.out_compensation.shaped_array(), 52 | t.out_num_tiles_hit.shaped_array(), 53 | t.out_cum_tiles_hit.shaped_array(), 54 | ) 55 | 56 | 57 | class ProjectBwdTypes: 58 | def __init__(self, num_points: int): 59 | self.in_mean3ds = Type((num_points, 3), jnp.float32) 60 | self.in_scales = Type((num_points, 3), jnp.float32) 61 | self.in_quats = Type((num_points, 4), jnp.float32) 62 | self.in_viewmat = Type((4, 4), jnp.float32) 63 | self.in_cov3ds = Type((num_points, 6), jnp.float32) 64 | self.in_xys = Type((num_points, 2), jnp.float32) 65 | self.in_radii = Type((num_points, 1), jnp.int32) 66 | self.in_conics = Type((num_points, 3), jnp.float32) 67 | self.in_compensation = Type((num_points, 1), jnp.float32) 68 | self.in_v_compensation = Type((num_points, 1), jnp.float32) 69 | self.in_v_xy = Type((num_points, 2), jnp.float32) 70 | self.in_v_depth = Type((num_points, 1), jnp.float32) 71 | self.in_v_conic = Type((num_points, 3), jnp.float32) 72 | 73 | self.out_v_mean3d = Type((num_points, 3), jnp.float32) 74 | self.out_v_scale = Type((num_points, 3), jnp.float32) 75 | self.out_v_quat = Type((num_points, 4), jnp.float32) 76 | self.out_v_cov2d = Type((num_points, 3), jnp.float32) 77 | self.out_v_cov3d = Type((num_points, 6), jnp.float32) 78 | 79 | 80 | def _project_bwd_abs( 81 | mean3ds: jax.Array, 82 | scales: jax.Array, 83 | quats: jax.Array, 84 | viewmat: jax.Array, 85 | cov3ds: jax.Array, 86 | xys: jax.Array, 87 | radii: jax.Array, 88 | conics: jax.Array, 89 | compensation: jax.Array, 90 | v_compensation: jax.Array, 91 | v_xy: jax.Array, 92 | v_depth: jax.Array, 93 | v_conic: jax.Array, 94 | # 95 | num_points: int, 96 | img_shape: tuple[int, int], 97 | f: tuple[float, float], 98 | c: tuple[float, float], 99 | glob_scale: float, 100 | clip_thresh: float, 101 | block_width: int, 102 | ): 103 | t = ProjectBwdTypes(num_points) 104 | 105 | t.in_mean3ds.assert_(mean3ds) 106 | t.in_scales.assert_(scales) 107 | t.in_quats.assert_(quats) 108 | t.in_viewmat.assert_(viewmat) 109 | t.in_cov3ds.assert_(cov3ds) 110 | t.in_xys.assert_(xys) 111 | t.in_radii.assert_(radii) 112 | t.in_conics.assert_(conics) 113 | t.in_compensation.assert_(compensation) 114 | t.in_v_compensation.assert_(v_compensation) 115 | t.in_v_xy.assert_(v_xy) 116 | t.in_v_depth.assert_(v_depth) 117 | t.in_v_conic.assert_(v_conic) 118 | 119 | return ( 120 | t.out_v_mean3d.shaped_array(), 121 | t.out_v_scale.shaped_array(), 122 | t.out_v_quat.shaped_array(), 123 | t.out_v_cov2d.shaped_array(), 124 | t.out_v_cov3d.shaped_array(), 125 | ) 126 | -------------------------------------------------------------------------------- /jaxsplat/_project/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | from jaxlib.hlo_helpers import custom_call 4 | 5 | import _jaxsplat 6 | from jaxsplat._project.abstract import ( 7 | ProjectFwdTypes, 8 | ProjectBwdTypes, 9 | ) 10 | 11 | 12 | def _project_fwd_rule( 13 | ctx: mlir.LoweringRuleContext, 14 | # 15 | mean3ds: ir.Value, 16 | scales: ir.Value, 17 | quats: ir.Value, 18 | viewmat: ir.Value, 19 | # 20 | num_points: int, 21 | img_shape: tuple[int, int], 22 | f: tuple[float, float], 23 | c: tuple[float, float], 24 | glob_scale: float, 25 | clip_thresh: float, 26 | block_width: int, 27 | ): 28 | opaque = _jaxsplat.make_descriptor( 29 | num_points=num_points, 30 | img_shape=img_shape, 31 | f=f, 32 | c=c, 33 | glob_scale=glob_scale, 34 | clip_thresh=clip_thresh, 35 | block_width=block_width, 36 | ) 37 | 38 | t = ProjectFwdTypes(num_points) 39 | 40 | return custom_call( 41 | "project_fwd", 42 | operands=[ 43 | mean3ds, 44 | scales, 45 | quats, 46 | viewmat, 47 | ], 48 | operand_layouts=[ 49 | t.in_mean3ds.layout(), 50 | t.in_scales.layout(), 51 | t.in_quats.layout(), 52 | t.in_viewmat.layout(), 53 | ], 54 | result_types=[ 55 | t.out_cov3ds.ir_tensor_type(), 56 | t.out_xys.ir_tensor_type(), 57 | t.out_depths.ir_tensor_type(), 58 | t.out_radii.ir_tensor_type(), 59 | t.out_conics.ir_tensor_type(), 60 | t.out_compensation.ir_tensor_type(), 61 | t.out_num_tiles_hit.ir_tensor_type(), 62 | t.out_cum_tiles_hit.ir_tensor_type(), 63 | ], 64 | result_layouts=[ 65 | t.out_cov3ds.layout(), 66 | t.out_xys.layout(), 67 | t.out_depths.layout(), 68 | t.out_radii.layout(), 69 | t.out_conics.layout(), 70 | t.out_compensation.layout(), 71 | t.out_num_tiles_hit.layout(), 72 | t.out_cum_tiles_hit.layout(), 73 | ], 74 | backend_config=opaque, 75 | ).results 76 | 77 | 78 | def _project_bwd_rule( 79 | ctx: mlir.LoweringRuleContext, 80 | # 81 | mean3ds: ir.Value, 82 | scales: ir.Value, 83 | quats: ir.Value, 84 | viewmat: ir.Value, 85 | cov3ds: ir.Value, 86 | xys: ir.Value, 87 | radii: ir.Value, 88 | conics: ir.Value, 89 | compensation: ir.Value, 90 | v_compensation: ir.Value, 91 | v_xy: ir.Value, 92 | v_depth: ir.Value, 93 | v_conic: ir.Value, 94 | # 95 | num_points: int, 96 | img_shape: tuple[int, int], 97 | f: tuple[float, float], 98 | c: tuple[float, float], 99 | glob_scale: float, 100 | clip_thresh: float, 101 | block_width: int, 102 | ): 103 | opaque = _jaxsplat.make_descriptor( 104 | num_points=num_points, 105 | img_shape=img_shape, 106 | f=f, 107 | c=c, 108 | glob_scale=glob_scale, 109 | clip_thresh=clip_thresh, 110 | block_width=block_width, 111 | ) 112 | 113 | t = ProjectBwdTypes(num_points) 114 | 115 | return custom_call( 116 | "project_bwd", 117 | operands=[ 118 | mean3ds, 119 | scales, 120 | quats, 121 | viewmat, 122 | cov3ds, 123 | xys, 124 | radii, 125 | conics, 126 | compensation, 127 | v_compensation, 128 | v_xy, 129 | v_depth, 130 | v_conic, 131 | ], 132 | operand_layouts=[ 133 | t.in_mean3ds.layout(), 134 | t.in_scales.layout(), 135 | t.in_quats.layout(), 136 | t.in_viewmat.layout(), 137 | t.in_cov3ds.layout(), 138 | t.in_xys.layout(), 139 | t.in_radii.layout(), 140 | t.in_conics.layout(), 141 | t.in_compensation.layout(), 142 | t.in_v_compensation.layout(), 143 | t.in_v_xy.layout(), 144 | t.in_v_depth.layout(), 145 | t.in_v_conic.layout(), 146 | ], 147 | result_types=[ 148 | t.out_v_mean3d.ir_tensor_type(), 149 | t.out_v_scale.ir_tensor_type(), 150 | t.out_v_quat.ir_tensor_type(), 151 | t.out_v_cov2d.ir_tensor_type(), 152 | t.out_v_cov3d.ir_tensor_type(), 153 | ], 154 | result_layouts=[ 155 | t.out_v_mean3d.layout(), 156 | t.out_v_scale.layout(), 157 | t.out_v_quat.layout(), 158 | t.out_v_cov2d.layout(), 159 | t.out_v_cov3d.layout(), 160 | ], 161 | backend_config=opaque, 162 | ).results 163 | -------------------------------------------------------------------------------- /lib/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | constexpr unsigned MAX_GRID_DIM = 256; 10 | 11 | #define CUDA_THROW_IF_ERR(err) \ 12 | do { \ 13 | cuda_throw_if_err((err), __FILE__, __LINE__); \ 14 | } while (false) 15 | 16 | inline void cuda_throw_if_err(cudaError_t error, const char *file, int line) { 17 | if (error != cudaSuccess) { 18 | std::cerr << "Encountered CUDA error in " << file << ":" << line 19 | << std::endl; 20 | std::cerr << cudaGetErrorString(error) << std::endl; 21 | throw std::runtime_error(cudaGetErrorString(error)); 22 | } 23 | } 24 | 25 | template struct mat3 { 26 | T m[3][3]; 27 | 28 | // Fills diags with val, rest with 0. 29 | inline __device__ mat3(T m) { 30 | this->m[0][0] = m; 31 | this->m[0][1] = static_cast(0.f); 32 | this->m[0][2] = static_cast(0.f); 33 | this->m[1][0] = static_cast(0.f); 34 | this->m[1][1] = m; 35 | this->m[1][2] = static_cast(0.f); 36 | this->m[2][0] = static_cast(0.f); 37 | this->m[2][1] = static_cast(0.f); 38 | this->m[2][2] = m; 39 | } 40 | 41 | inline __device__ 42 | mat3(T m00, T m10, T m20, T m01, T m11, T m21, T m02, T m12, T m22) { 43 | this->m[0][0] = m00; 44 | this->m[0][1] = m10; 45 | this->m[0][2] = m20; 46 | this->m[1][0] = m01; 47 | this->m[1][1] = m11; 48 | this->m[1][2] = m21; 49 | this->m[2][0] = m02; 50 | this->m[2][1] = m12; 51 | this->m[2][2] = m22; 52 | } 53 | 54 | inline __device__ T *operator[](const size_t idx) { return m[idx]; } 55 | inline __device__ T const *operator[](const size_t idx) const { 56 | return m[idx]; 57 | } 58 | 59 | inline __device__ mat3 operator*(mat3 const &rhs) const { 60 | return mat3( 61 | m[0][0] * rhs[0][0] + m[1][0] * rhs[0][1] + m[2][0] * rhs[0][2], 62 | m[0][1] * rhs[0][0] + m[1][1] * rhs[0][1] + m[2][1] * rhs[0][2], 63 | m[0][2] * rhs[0][0] + m[1][2] * rhs[0][1] + m[2][2] * rhs[0][2], 64 | m[0][0] * rhs[1][0] + m[1][0] * rhs[1][1] + m[2][0] * rhs[1][2], 65 | m[0][1] * rhs[1][0] + m[1][1] * rhs[1][1] + m[2][1] * rhs[1][2], 66 | m[0][2] * rhs[1][0] + m[1][2] * rhs[1][1] + m[2][2] * rhs[1][2], 67 | m[0][0] * rhs[2][0] + m[1][0] * rhs[2][1] + m[2][0] * rhs[2][2], 68 | m[0][1] * rhs[2][0] + m[1][1] * rhs[2][1] + m[2][1] * rhs[2][2], 69 | m[0][2] * rhs[2][0] + m[1][2] * rhs[2][1] + m[2][2] * rhs[2][2] 70 | ); 71 | } 72 | 73 | inline __device__ mat3 operator+(mat3 const &rhs) const { 74 | return mat3( 75 | m[0][0] + rhs[0][0], 76 | m[0][1] + rhs[0][1], 77 | m[0][2] + rhs[0][2], 78 | m[1][0] + rhs[1][0], 79 | m[1][1] + rhs[1][1], 80 | m[1][2] + rhs[1][2], 81 | m[2][0] + rhs[2][0], 82 | m[2][1] + rhs[2][1], 83 | m[2][2] + rhs[2][2] 84 | ); 85 | } 86 | 87 | inline __device__ mat3 transpose() const { 88 | return mat3( 89 | m[0][0], 90 | m[1][0], 91 | m[2][0], 92 | m[0][1], 93 | m[1][1], 94 | m[2][1], 95 | m[0][2], 96 | m[1][2], 97 | m[2][2] 98 | ); 99 | } 100 | }; 101 | 102 | template struct mat2 { 103 | T m[2][2]; 104 | 105 | // Fills diags with val, rest with 0. 106 | inline __device__ mat2(T val) { 107 | m[0][0] = val; 108 | m[0][1] = static_cast(0.f); 109 | m[1][0] = static_cast(0.f); 110 | m[1][1] = val; 111 | } 112 | 113 | inline __device__ mat2(T m00, T m10, T m01, T m11) { 114 | this->m[0][0] = m00; 115 | this->m[0][1] = m10; 116 | this->m[1][0] = m01; 117 | this->m[1][1] = m11; 118 | } 119 | 120 | inline __device__ T *operator[](const size_t idx) { return m[idx]; } 121 | inline __device__ T const *operator[](const size_t idx) const { 122 | return m[idx]; 123 | } 124 | 125 | inline __device__ mat2 operator-() const { 126 | return mat2(-m[0][0], -m[0][1], -m[1][0], -m[1][1]); 127 | } 128 | 129 | inline __device__ mat2 operator*(mat2 const &rhs) const { 130 | return mat2( 131 | m[0][0] * rhs[0][0] + m[1][0] * rhs[0][1], 132 | m[0][1] * rhs[0][0] + m[1][1] * rhs[0][1], 133 | m[0][0] * rhs[1][0] + m[1][0] * rhs[1][1], 134 | m[0][1] * rhs[1][0] + m[1][1] * rhs[1][1] 135 | ); 136 | } 137 | }; 138 | -------------------------------------------------------------------------------- /jaxsplat/_rasterize/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from typing import TypedDict 4 | from dataclasses import dataclass 5 | from functools import partial 6 | 7 | from jaxsplat._rasterize import impl 8 | 9 | 10 | @jax.tree_util.register_pytree_node_class 11 | @dataclass(frozen=True, kw_only=True) 12 | class RasterizeDescriptor: 13 | num_points: int 14 | img_shape: tuple[int, int] 15 | block_width: int 16 | 17 | def tree_flatten(self): 18 | children = () 19 | aux = ( 20 | self.num_points, 21 | self.img_shape, 22 | self.block_width, 23 | ) 24 | return children, aux 25 | 26 | @classmethod 27 | def tree_unflatten(cls, aux, children): 28 | ( 29 | num_points, 30 | img_shape, 31 | block_width, 32 | ) = aux 33 | return cls( 34 | num_points=num_points, 35 | img_shape=img_shape, 36 | block_width=block_width, 37 | ) 38 | 39 | 40 | def rasterize( 41 | colors: jax.Array, 42 | opacities: jax.Array, 43 | background: jax.Array, 44 | xys: jax.Array, 45 | depths: jax.Array, 46 | radii: jax.Array, 47 | conics: jax.Array, 48 | cum_tiles_hit: jax.Array, 49 | *, 50 | img_shape: tuple[int, int], 51 | block_width: int, 52 | ) -> jax.Array: 53 | desc = RasterizeDescriptor( 54 | num_points=colors.shape[0], img_shape=img_shape, block_width=block_width 55 | ) 56 | 57 | (img, _img_alpha) = _rasterize( 58 | desc, 59 | colors, 60 | opacities, 61 | background, 62 | xys, 63 | depths, 64 | radii, 65 | conics, 66 | cum_tiles_hit, 67 | ) 68 | return img 69 | 70 | 71 | @partial(jax.custom_vjp, nondiff_argnums=(0,)) 72 | def _rasterize( 73 | desc: RasterizeDescriptor, 74 | colors: jax.Array, 75 | opacities: jax.Array, 76 | background: jax.Array, 77 | xys: jax.Array, 78 | depths: jax.Array, 79 | radii: jax.Array, 80 | conics: jax.Array, 81 | cum_tiles_hit: jax.Array, 82 | ): 83 | primals, _ = _rasterize_fwd( 84 | desc, 85 | colors, 86 | opacities, 87 | background, 88 | xys, 89 | depths, 90 | radii, 91 | conics, 92 | cum_tiles_hit, 93 | ) 94 | 95 | return primals 96 | 97 | 98 | class RasterizeResiduals(TypedDict): 99 | colors: jax.Array 100 | opacities: jax.Array 101 | background: jax.Array 102 | xys: jax.Array 103 | depths: jax.Array 104 | radii: jax.Array 105 | conics: jax.Array 106 | cum_tiles_hit: jax.Array 107 | final_Ts: jax.Array 108 | final_idx: jax.Array 109 | 110 | 111 | def _rasterize_fwd( 112 | desc: RasterizeDescriptor, 113 | colors: jax.Array, 114 | opacities: jax.Array, 115 | background: jax.Array, 116 | xys: jax.Array, 117 | depths: jax.Array, 118 | radii: jax.Array, 119 | conics: jax.Array, 120 | cum_tiles_hit: jax.Array, 121 | ): 122 | (final_Ts, final_idx, img) = impl._rasterize_fwd_p.bind( 123 | colors, 124 | opacities, 125 | background, 126 | xys, 127 | depths, 128 | radii, 129 | conics, 130 | cum_tiles_hit, 131 | num_points=desc.num_points, 132 | img_shape=desc.img_shape, 133 | block_width=desc.block_width, 134 | ) 135 | 136 | img_alpha = 1 - final_Ts 137 | primals = (img, img_alpha) 138 | 139 | residuals: RasterizeResiduals = { 140 | "colors": colors, 141 | "opacities": opacities, 142 | "background": background, 143 | "xys": xys, 144 | "depths": depths, 145 | "radii": radii, 146 | "conics": conics, 147 | "cum_tiles_hit": cum_tiles_hit, 148 | "final_Ts": final_Ts, 149 | "final_idx": final_idx, 150 | } 151 | 152 | return primals, residuals 153 | 154 | 155 | def _rasterize_bwd( 156 | desc: RasterizeDescriptor, 157 | residuals: RasterizeResiduals, 158 | cotangents, 159 | ): 160 | (v_img, v_img_alpha) = cotangents 161 | 162 | ( 163 | v_colors, 164 | v_opacity, 165 | v_xy, 166 | _v_xy_abs, 167 | v_conic, 168 | ) = impl._rasterize_bwd_p.bind( 169 | residuals["colors"], 170 | residuals["opacities"], 171 | residuals["background"], 172 | residuals["xys"], 173 | residuals["depths"], 174 | residuals["radii"], 175 | residuals["conics"], 176 | residuals["cum_tiles_hit"], 177 | residuals["final_Ts"], 178 | residuals["final_idx"], 179 | v_img, 180 | v_img_alpha, 181 | # 182 | num_points=desc.num_points, 183 | img_shape=desc.img_shape, 184 | block_width=desc.block_width, 185 | ) 186 | 187 | return (v_colors, v_opacity, None, v_xy, None, None, v_conic, None) 188 | 189 | 190 | _rasterize.defvjp(_rasterize_fwd, _rasterize_bwd) 191 | 192 | __all__ = ["rasterize"] 193 | -------------------------------------------------------------------------------- /lib/ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace ops { 8 | 9 | struct Descriptor { 10 | unsigned num_points; 11 | dim3 img_shape; 12 | 13 | float4 intrins; 14 | float glob_scale; 15 | float clip_thresh; 16 | 17 | unsigned block_width; 18 | unsigned grid_dim_1d; 19 | unsigned block_dim_1d; 20 | dim3 grid_dim_2d; 21 | dim3 block_dim_2d; 22 | }; 23 | 24 | void cumsum( 25 | cudaStream_t stream, 26 | const int32_t *input, 27 | int32_t *output, 28 | const int num_items 29 | ); 30 | 31 | void sort_and_bin( 32 | cudaStream_t stream, 33 | const Descriptor &d, 34 | unsigned num_intersects, 35 | const float2 *xys, 36 | const float *depths, 37 | const int *radii, 38 | const int *cum_tiles_hit, 39 | int *gaussian_ids_sorted, 40 | int2 *tile_bins 41 | ); 42 | 43 | namespace project::fwd { 44 | 45 | void xla( 46 | cudaStream_t stream, 47 | void **buffers, 48 | const char *opaque, 49 | std::size_t opaque_len 50 | ); 51 | struct Tensors; 52 | Tensors 53 | unpack_tensors(cudaStream_t stream, const Descriptor &d, void **buffers); 54 | 55 | } // namespace project::fwd 56 | 57 | namespace project::bwd { 58 | 59 | void xla( 60 | cudaStream_t stream, 61 | void **buffers, 62 | const char *opaque, 63 | std::size_t opaque_len 64 | ); 65 | struct Tensors; 66 | Tensors 67 | unpack_tensors(cudaStream_t stream, const Descriptor &d, void **buffers); 68 | 69 | } // namespace project::bwd 70 | 71 | namespace rasterize::fwd { 72 | 73 | void xla( 74 | cudaStream_t stream, 75 | void **buffers, 76 | const char *opaque, 77 | std::size_t opaque_len 78 | ); 79 | struct Tensors; 80 | Tensors 81 | unpack_tensors(cudaStream_t stream, const Descriptor &d, void **buffers); 82 | 83 | } // namespace rasterize::fwd 84 | 85 | namespace rasterize::bwd { 86 | 87 | void xla( 88 | cudaStream_t stream, 89 | void **buffers, 90 | const char *opaque, 91 | std::size_t opaque_len 92 | ); 93 | struct Tensors; 94 | Tensors 95 | unpack_tensors(cudaStream_t stream, const Descriptor &d, void **buffers); 96 | 97 | } // namespace rasterize::bwd 98 | 99 | struct ops::project::fwd::Tensors { 100 | struct In { 101 | // geometry 102 | const float3 *mean3ds; 103 | const float3 *scales; 104 | const float4 *quats; 105 | const float *viewmat; 106 | } in; 107 | struct Out { 108 | // projection 109 | float *cov3ds; 110 | float2 *xys; 111 | float *depths; 112 | int *radii; 113 | float3 *conics; 114 | float *compensation; 115 | int *num_tiles_hit; 116 | int *cum_tiles_hit; 117 | } out; 118 | }; 119 | 120 | struct project::bwd::Tensors { 121 | struct In { 122 | // geometry 123 | const float3 *mean3ds; 124 | const float3 *scales; 125 | const float4 *quats; 126 | const float *viewmat; 127 | 128 | // projection output 129 | const float *cov3ds; 130 | const float2 *xys; 131 | const int *radii; 132 | const float3 *conics; 133 | const float *compensation; 134 | 135 | const float *v_compensation; 136 | const float2 *v_xy; 137 | const float *v_depth; 138 | const float3 *v_conic; 139 | } in; 140 | struct Out { 141 | // geometry 142 | float3 *v_mean3d; 143 | float3 *v_scale; 144 | float4 *v_quat; 145 | 146 | // projection 147 | float3 *v_cov2d; 148 | float *v_cov3d; 149 | } out; 150 | }; 151 | 152 | struct rasterize::fwd::Tensors { 153 | struct In { 154 | // appearance 155 | const float3 *colors; 156 | const float *opacities; 157 | const float3 *background; 158 | 159 | // projection output 160 | const float2 *xys; 161 | const float *depths; 162 | const int *radii; 163 | const float3 *conics; 164 | const int *cum_tiles_hit; 165 | } in; 166 | struct Out { 167 | // rasterization output 168 | float *final_Ts; 169 | int *final_idx; 170 | float3 *out_img; 171 | } out; 172 | 173 | int num_intersects; 174 | int *gaussian_ids_sorted; 175 | int2 *tile_bins; 176 | }; 177 | 178 | struct rasterize::bwd::Tensors { 179 | struct In { 180 | // appearance 181 | const float3 *colors; 182 | const float *opacities; 183 | const float3 *background; 184 | 185 | // projection output 186 | const float2 *xys; 187 | const float *depths; 188 | const int *radii; 189 | const float3 *conics; 190 | const int *cum_tiles_hit; 191 | 192 | // rasterization output 193 | const float *final_Ts; 194 | const int *final_idx; 195 | 196 | // vjps 197 | const float3 *v_out_img; 198 | const float *v_out_img_alpha; 199 | } in; 200 | struct Out { 201 | // appearance 202 | float3 *v_colors; 203 | float *v_opacity; 204 | 205 | // projection 206 | float2 *v_xy; 207 | float2 *v_xy_abs; 208 | float3 *v_conic; 209 | } out; 210 | 211 | int num_intersects; 212 | int *gaussian_ids_sorted; 213 | int2 *tile_bins; 214 | }; 215 | 216 | } // namespace ops 217 | -------------------------------------------------------------------------------- /jaxsplat/_project/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from typing import TypedDict 4 | from dataclasses import dataclass 5 | from functools import partial 6 | 7 | from jaxsplat._project import impl 8 | 9 | 10 | @jax.tree_util.register_pytree_node_class 11 | @dataclass(frozen=True, kw_only=True) 12 | class ProjectDescriptor: 13 | num_points: int 14 | img_shape: tuple[int, int] 15 | f: tuple[float, float] 16 | c: tuple[float, float] 17 | glob_scale: float 18 | clip_thresh: float 19 | block_width: int 20 | 21 | def tree_flatten(self): 22 | children = () 23 | aux = ( 24 | self.num_points, 25 | self.img_shape, 26 | self.f, 27 | self.c, 28 | self.glob_scale, 29 | self.clip_thresh, 30 | self.block_width, 31 | ) 32 | return children, aux 33 | 34 | @classmethod 35 | def tree_unflatten(cls, aux, children): 36 | ( 37 | num_points, 38 | img_shape, 39 | f, 40 | c, 41 | glob_scale, 42 | clip_thresh, 43 | block_width, 44 | ) = aux 45 | return cls( 46 | num_points=num_points, 47 | img_shape=img_shape, 48 | f=f, 49 | c=c, 50 | glob_scale=glob_scale, 51 | clip_thresh=clip_thresh, 52 | block_width=block_width, 53 | ) 54 | 55 | 56 | def project( 57 | mean3ds: jax.Array, 58 | scales: jax.Array, 59 | quats: jax.Array, 60 | viewmat: jax.Array, 61 | *, 62 | img_shape: tuple[int, int], 63 | f: tuple[float, float], 64 | c: tuple[float, float], 65 | glob_scale: float, 66 | clip_thresh: float, 67 | block_width: int, 68 | ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: 69 | desc = ProjectDescriptor( 70 | num_points=mean3ds.shape[0], 71 | glob_scale=glob_scale, 72 | f=f, 73 | c=c, 74 | img_shape=img_shape, 75 | block_width=block_width, 76 | clip_thresh=clip_thresh, 77 | ) 78 | (xys, depths, radii, conics, num_tiles_hit, cum_tiles_hit, _compensation) = ( 79 | _project( 80 | desc, 81 | mean3ds, 82 | scales, 83 | quats, 84 | viewmat, 85 | ) 86 | ) 87 | return (xys, depths, radii, conics, num_tiles_hit, cum_tiles_hit) 88 | 89 | 90 | @partial(jax.custom_vjp, nondiff_argnums=(0,)) 91 | def _project( 92 | desc: ProjectDescriptor, 93 | mean3ds: jax.Array, 94 | scales: jax.Array, 95 | quats: jax.Array, 96 | viewmat: jax.Array, 97 | ): 98 | primals, _ = _project_fwd(desc, mean3ds, scales, quats, viewmat) 99 | 100 | return primals 101 | 102 | 103 | class ProjectResiduals(TypedDict): 104 | mean3ds: jax.Array 105 | scales: jax.Array 106 | quats: jax.Array 107 | viewmat: jax.Array 108 | cov3ds: jax.Array 109 | xys: jax.Array 110 | radii: jax.Array 111 | conics: jax.Array 112 | compensation: jax.Array 113 | 114 | 115 | def _project_fwd( 116 | desc: ProjectDescriptor, 117 | mean3ds: jax.Array, 118 | scales: jax.Array, 119 | quats: jax.Array, 120 | viewmat: jax.Array, 121 | ): 122 | (cov3ds, xys, depths, radii, conics, compensation, num_tiles_hit, cum_tiles_hit) = ( 123 | impl._project_fwd_p.bind( 124 | mean3ds, 125 | scales, 126 | quats, 127 | viewmat, 128 | num_points=desc.num_points, 129 | img_shape=desc.img_shape, 130 | f=desc.f, 131 | c=desc.c, 132 | glob_scale=desc.glob_scale, 133 | clip_thresh=desc.clip_thresh, 134 | block_width=desc.block_width, 135 | ) 136 | ) 137 | 138 | primals = (xys, depths, radii, conics, num_tiles_hit, cum_tiles_hit, compensation) 139 | 140 | residuals: ProjectResiduals = { 141 | "mean3ds": mean3ds, 142 | "scales": scales, 143 | "quats": quats, 144 | "viewmat": viewmat, 145 | "cov3ds": cov3ds, 146 | "xys": xys, 147 | "radii": radii, 148 | "conics": conics, 149 | "compensation": compensation, 150 | } 151 | 152 | return primals, residuals 153 | 154 | 155 | def _project_bwd( 156 | desc: ProjectDescriptor, 157 | residuals: ProjectResiduals, 158 | cotangents, 159 | ): 160 | ( 161 | v_xy, 162 | v_depth, 163 | _v_radii, 164 | v_conic, 165 | _v_num_tiles_hit, 166 | _v_cum_tiles_hit, 167 | v_compensation, 168 | ) = cotangents 169 | 170 | ( 171 | v_mean3d, 172 | v_scale, 173 | v_quat, 174 | v_cov2d, 175 | v_cov3d, 176 | ) = impl._project_bwd_p.bind( 177 | residuals["mean3ds"], 178 | residuals["scales"], 179 | residuals["quats"], 180 | residuals["viewmat"], 181 | residuals["cov3ds"], 182 | residuals["xys"], 183 | residuals["radii"], 184 | residuals["conics"], 185 | residuals["compensation"], 186 | v_compensation, 187 | v_xy, 188 | v_depth, 189 | v_conic, 190 | # 191 | num_points=desc.num_points, 192 | img_shape=desc.img_shape, 193 | f=desc.f, 194 | c=desc.c, 195 | glob_scale=desc.glob_scale, 196 | clip_thresh=desc.clip_thresh, 197 | block_width=desc.block_width, 198 | ) 199 | 200 | return ( 201 | v_mean3d, 202 | v_scale, 203 | v_quat, 204 | None, 205 | ) 206 | 207 | 208 | _project.defvjp(_project_fwd, _project_bwd) 209 | 210 | __all__ = ["project"] 211 | -------------------------------------------------------------------------------- /examples/single_image.py: -------------------------------------------------------------------------------- 1 | import jaxsplat 2 | import jax 3 | import jax.numpy as jnp 4 | import imageio.v3 as iio 5 | import optax 6 | import argparse 7 | import time 8 | 9 | 10 | def main( 11 | iterations: int, 12 | num_points: int, 13 | lr: float, 14 | gt_path: str, 15 | out_img_path: str, 16 | out_vid_path: str, 17 | ): 18 | gt = jnp.array(iio.imread(gt_path)).astype(jnp.float32)[..., :3] / 255 19 | 20 | key = jax.random.key(0) 21 | params, coeffs = init(key, num_points, gt.shape[:2]) 22 | 23 | optimizer = optax.adam(lr) 24 | optimizer_state = optimizer.init(params) 25 | 26 | def loss_fn(params): 27 | output = render_fn(params, coeffs) 28 | loss = jnp.mean(jnp.square(output - gt)) 29 | return loss 30 | 31 | # @jax.jit 32 | def train_step( 33 | params, 34 | optimizer_state: optax.OptState, 35 | ): 36 | loss, grads = jax.value_and_grad(loss_fn)(params) 37 | updates, optimizer_state = optimizer.update(grads, optimizer_state) 38 | params = optax.apply_updates(params, updates) 39 | 40 | return params, optimizer_state, loss 41 | 42 | log_every = 50 43 | with iio.imopen(out_vid_path, "w", plugin="pyav") as video: 44 | video.init_video_stream("h264") 45 | 46 | cum_time = 0 47 | cum_time_split = 0 48 | for i in range(iterations): 49 | img = (render_fn(params, coeffs) * 255).astype(jnp.uint8) 50 | video.write_frame(img) 51 | 52 | start = time.perf_counter() 53 | params, optimizer_state, loss = train_step(params, optimizer_state) 54 | end = time.perf_counter() 55 | 56 | cum_time += end - start 57 | cum_time_split += end - start 58 | 59 | if i % log_every == 0: 60 | print( 61 | f"iter {i} loss {loss:.4f}, {cum_time_split/log_every*1000:.3f}ms avg per step" 62 | ) 63 | cum_time_split = 0 64 | print( 65 | f"done training in {cum_time:.3f}s ({cum_time/iterations*1000:.3f}ms avg per step)" 66 | ) 67 | 68 | out = render_fn(params, coeffs) 69 | iio.imwrite(out_img_path, (out * 255).astype(jnp.uint8)) 70 | 71 | 72 | def init(key, num_points, img_shape): 73 | key, subkey = jax.random.split(key) 74 | means3d = jax.random.uniform( 75 | subkey, 76 | (num_points, 3), 77 | minval=jnp.array([-6, -6, -1]), 78 | maxval=jnp.array([6, 6, 1]), 79 | dtype=jnp.float32, 80 | ) 81 | 82 | key, subkey = jax.random.split(key) 83 | scales = jax.random.uniform( 84 | subkey, (num_points, 3), dtype=jnp.float32, minval=0, maxval=0.5 85 | ) 86 | 87 | key, subkey = jax.random.split(key) 88 | u, v, w = jax.random.uniform(subkey, (3, num_points, 1)) 89 | quats = jnp.hstack( 90 | [ 91 | jnp.sqrt(1 - u) * jnp.sin(2 * jnp.pi * v), 92 | jnp.sqrt(1 - u) * jnp.cos(2 * jnp.pi * v), 93 | jnp.sqrt(u) * jnp.sin(2 * jnp.pi * w), 94 | jnp.sqrt(u) * jnp.cos(2 * jnp.pi * w), 95 | ] 96 | ) 97 | 98 | viewmat = jnp.array( 99 | [ 100 | [1.0, 0.0, 0.0, 0.0], 101 | [0.0, 1.0, 0.0, 0.0], 102 | [0.0, 0.0, 1.0, 8.0], 103 | [0.0, 0.0, 0.0, 1.0], 104 | ] 105 | ) 106 | 107 | key, subkey = jax.random.split(key) 108 | colors = jax.random.uniform(subkey, (num_points, 3), dtype=jnp.float32) 109 | 110 | key, subkey = jax.random.split(key) 111 | opacities = jax.random.uniform(subkey, (num_points, 1), minval=0.5) 112 | 113 | background = jnp.array([0, 0, 0], dtype=jnp.float32) 114 | 115 | H, W = img_shape 116 | fx, fy = W / 2, H / 2 117 | cx, cy = W / 2, H / 2 118 | glob_scale = 1 119 | clip_thresh = 0.01 120 | block_size = 16 121 | 122 | return ( 123 | { 124 | "means3d": means3d, 125 | "scales": scales, 126 | "quats": quats, 127 | "colors": colors, 128 | "opacities": opacities, 129 | }, 130 | { 131 | "viewmat": viewmat, 132 | "background": background, 133 | "img_shape": img_shape, 134 | "f": (fx, fy), 135 | "c": (cx, cy), 136 | "glob_scale": glob_scale, 137 | "clip_thresh": clip_thresh, 138 | "block_size": block_size, 139 | }, 140 | ) 141 | 142 | 143 | def render_fn(params, coeffs): 144 | means3d = params["means3d"] 145 | quats = params["quats"] / (jnp.linalg.norm(params["quats"], axis=-1, keepdims=True)) 146 | scales = params["scales"] 147 | colors = jax.nn.sigmoid(params["colors"]) 148 | opacities = jax.nn.sigmoid(params["opacities"]) 149 | 150 | img = jaxsplat.render( 151 | means3d=means3d, 152 | scales=scales, 153 | quats=quats, 154 | colors=colors, 155 | opacities=opacities, 156 | viewmat=coeffs["viewmat"], 157 | background=coeffs["background"], 158 | img_shape=coeffs["img_shape"], 159 | f=coeffs["f"], 160 | c=coeffs["c"], 161 | glob_scale=coeffs["glob_scale"], 162 | clip_thresh=coeffs["clip_thresh"], 163 | block_size=coeffs["block_size"], 164 | ) 165 | 166 | return img 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser( 171 | prog="python -m examples.single_image", 172 | description="Fits 3D Gaussians to single 2D image", 173 | ) 174 | parser.add_argument("input") 175 | parser.add_argument("--iters", type=int, default=1000) 176 | parser.add_argument("--num_points", type=int, default=50_000) 177 | parser.add_argument("--lr", type=float, default=0.01) 178 | parser.add_argument("--out_image", default="out.png") 179 | parser.add_argument("--out_video", default="out.mp4") 180 | 181 | args = parser.parse_args() 182 | main( 183 | args.iters, args.num_points, args.lr, args.input, args.out_image, args.out_video 184 | ) 185 | -------------------------------------------------------------------------------- /lib/kernels/helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "common.h" 4 | 5 | #include 6 | 7 | namespace helpers { 8 | 9 | inline __device__ void get_bbox(const float2 center, const float2 dims, 10 | const dim3 img_size, uint2 &bb_min, 11 | uint2 &bb_max) { 12 | // get bounding box with center and dims, within bounds 13 | // bounding box coords returned in tile coords, inclusive min, exclusive max 14 | // clamp between 0 and tile bounds 15 | bb_min.x = min(max(0, (int)(center.x - dims.x)), img_size.x); 16 | bb_max.x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x); 17 | bb_min.y = min(max(0, (int)(center.y - dims.y)), img_size.y); 18 | bb_max.y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y); 19 | } 20 | 21 | inline __device__ void get_tile_bbox(const float2 pix_center, 22 | const float pix_radius, 23 | const dim3 tile_bounds, uint2 &tile_min, 24 | uint2 &tile_max, const int block_size) { 25 | // gets gaussian dimensions in tile space, i.e. the span of a gaussian in 26 | // tile_grid (image divided into tiles) 27 | float2 tile_center = {pix_center.x / (float)block_size, 28 | pix_center.y / (float)block_size}; 29 | float2 tile_radius = {pix_radius / (float)block_size, 30 | pix_radius / (float)block_size}; 31 | get_bbox(tile_center, tile_radius, tile_bounds, tile_min, tile_max); 32 | } 33 | 34 | inline __device__ bool compute_cov2d_bounds(const float3 cov2d, float3 &conic, 35 | float &radius) { 36 | // find eigenvalues of 2d covariance matrix 37 | // expects upper triangular values of cov matrix as float3 38 | // then compute the radius and conic dimensions 39 | // the conic is the inverse cov2d matrix, represented here with upper 40 | // triangular values. 41 | float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; 42 | if (det == 0.f) 43 | return false; 44 | float inv_det = 1.f / det; 45 | 46 | // inverse of 2x2 cov2d matrix 47 | conic.x = cov2d.z * inv_det; 48 | conic.y = -cov2d.y * inv_det; 49 | conic.z = cov2d.x * inv_det; 50 | 51 | float b = 0.5f * (cov2d.x + cov2d.z); 52 | float v1 = b + sqrt(max(0.1f, b * b - det)); 53 | float v2 = b - sqrt(max(0.1f, b * b - det)); 54 | // take 3 sigma of covariance 55 | radius = ceil(3.f * sqrt(max(v1, v2))); 56 | return true; 57 | } 58 | 59 | // compute vjp from df/d_conic to df/c_cov2d 60 | inline __device__ void cov2d_to_conic_vjp(const float3 &conic, 61 | const float3 &v_conic, 62 | float3 &v_cov2d) { 63 | // conic = inverse cov2d 64 | // df/d_cov2d = -conic * df/d_conic * conic 65 | mat2 X = mat2(conic.x, conic.y, conic.y, conic.z); 66 | mat2 G = mat2(v_conic.x, v_conic.y / 2.f, v_conic.y / 2.f, v_conic.z); 67 | mat2 v_Sigma = -X * G * X; 68 | v_cov2d.x = v_Sigma[0][0]; 69 | v_cov2d.y = v_Sigma[1][0] + v_Sigma[0][1]; 70 | v_cov2d.z = v_Sigma[1][1]; 71 | } 72 | 73 | inline __device__ void cov2d_to_compensation_vjp(const float compensation, 74 | const float3 &conic, 75 | const float v_compensation, 76 | float3 &v_cov2d) { 77 | // comp = sqrt(det(cov2d - 0.3 I) / det(cov2d)) 78 | // conic = inverse(cov2d) 79 | // df / d_cov2d = df / d comp * 0.5 / comp * [ d comp^2 / d cov2d ] 80 | // d comp^2 / d cov2d = (1 - comp^2) * conic - 0.3 I * det(conic) 81 | float inv_det = conic.x * conic.z - conic.y * conic.y; 82 | float one_minus_sqr_comp = 1 - compensation * compensation; 83 | float v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6); 84 | v_cov2d.x += v_sqr_comp * (one_minus_sqr_comp * conic.x - 0.3 * inv_det); 85 | v_cov2d.y += 2 * v_sqr_comp * (one_minus_sqr_comp * conic.y); 86 | v_cov2d.z += v_sqr_comp * (one_minus_sqr_comp * conic.z - 0.3 * inv_det); 87 | } 88 | 89 | // helper for applying R^T * p for a ROW MAJOR 4x3 matrix [R, t], ignoring t 90 | inline __device__ float3 transform_4x3_rot_only_transposed(const float *mat, 91 | const float3 p) { 92 | float3 out = { 93 | mat[0] * p.x + mat[4] * p.y + mat[8] * p.z, 94 | mat[1] * p.x + mat[5] * p.y + mat[9] * p.z, 95 | mat[2] * p.x + mat[6] * p.y + mat[10] * p.z, 96 | }; 97 | return out; 98 | } 99 | 100 | // helper for applying R * p + T, expect mat to be ROW MAJOR 101 | inline __device__ float3 transform_4x3(const float *mat, const float3 p) { 102 | float3 out = { 103 | mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], 104 | mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], 105 | mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], 106 | }; 107 | return out; 108 | } 109 | 110 | // helper to apply 4x4 transform to 3d vector, return homo coords 111 | // expects mat to be ROW MAJOR 112 | inline __device__ float4 transform_4x4(const float *mat, const float3 p) { 113 | float4 out = { 114 | mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], 115 | mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], 116 | mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], 117 | mat[12] * p.x + mat[13] * p.y + mat[14] * p.z + mat[15], 118 | }; 119 | return out; 120 | } 121 | 122 | inline __device__ float2 project_pix(const float2 fxfy, const float3 p_view, 123 | const float2 pp) { 124 | float rw = 1.f / (p_view.z + 1e-6f); 125 | float2 p_proj = {p_view.x * rw, p_view.y * rw}; 126 | float2 p_pix = {p_proj.x * fxfy.x + pp.x, p_proj.y * fxfy.y + pp.y}; 127 | return p_pix; 128 | } 129 | 130 | // given v_xy_pix, get v_xyz 131 | inline __device__ float3 project_pix_vjp(const float2 fxfy, const float3 p_view, 132 | const float2 v_xy) { 133 | float rw = 1.f / (p_view.z + 1e-6f); 134 | float2 v_proj = {fxfy.x * v_xy.x, fxfy.y * v_xy.y}; 135 | float3 v_view = {v_proj.x * rw, v_proj.y * rw, 136 | -(v_proj.x * p_view.x + v_proj.y * p_view.y) * rw * rw}; 137 | return v_view; 138 | } 139 | 140 | inline __device__ mat3 quat_to_rotmat(const float4 quat) { 141 | // quat to rotation matrix 142 | float w = quat.x; 143 | float x = quat.y; 144 | float y = quat.z; 145 | float z = quat.w; 146 | 147 | // glm matrices are column-major 148 | return mat3(1.f - 2.f * (y * y + z * z), 2.f * (x * y + w * z), 149 | 2.f * (x * z - w * y), 2.f * (x * y - w * z), 150 | 1.f - 2.f * (x * x + z * z), 2.f * (y * z + w * x), 151 | 2.f * (x * z + w * y), 2.f * (y * z - w * x), 152 | 1.f - 2.f * (x * x + y * y)); 153 | } 154 | 155 | inline __device__ float4 quat_to_rotmat_vjp(const float4 quat, 156 | const mat3 v_R) { 157 | float w = quat.x; 158 | float x = quat.y; 159 | float y = quat.z; 160 | float z = quat.w; 161 | 162 | float4 v_quat; 163 | // v_R is COLUMN MAJOR 164 | // w element stored in x field 165 | v_quat.x = 166 | 2.f * ( 167 | // v_quat.w = 2.f * ( 168 | x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + 169 | z * (v_R[0][1] - v_R[1][0])); 170 | // x element in y field 171 | v_quat.y = 172 | 2.f * 173 | ( 174 | // v_quat.x = 2.f * ( 175 | -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + 176 | z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1])); 177 | // y element in z field 178 | v_quat.z = 179 | 2.f * 180 | ( 181 | // v_quat.y = 2.f * ( 182 | x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + 183 | z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2])); 184 | // z element in w field 185 | v_quat.w = 186 | 2.f * 187 | ( 188 | // v_quat.z = 2.f * ( 189 | x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - 190 | 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0])); 191 | return v_quat; 192 | } 193 | 194 | inline __device__ mat3 scale_to_mat(const float3 scale, 195 | const float glob_scale) { 196 | mat3 S = mat3(1.f); 197 | S[0][0] = glob_scale * scale.x; 198 | S[1][1] = glob_scale * scale.y; 199 | S[2][2] = glob_scale * scale.z; 200 | return S; 201 | } 202 | 203 | // device helper for culling near points 204 | inline __device__ bool clip_near_plane(const float3 p, const float *viewmat, 205 | float3 &p_view, float thresh) { 206 | p_view = transform_4x3(viewmat, p); 207 | if (p_view.z <= thresh) { 208 | return true; 209 | } 210 | return false; 211 | } 212 | 213 | } // namespace helpers 214 | -------------------------------------------------------------------------------- /examples/benchmark.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import torch 5 | import imageio.v3 as iio 6 | 7 | import time 8 | import argparse 9 | 10 | import gsplat 11 | import jaxsplat 12 | import diff_gaussian_rasterization 13 | 14 | 15 | def main( 16 | iterations: int, 17 | num_points: int, 18 | ): 19 | coeffs_jax = init_coeffs() 20 | experiments: dict[str, Experiment] = { 21 | "jaxsplat": JaxsplatExperiment(coeffs_jax), 22 | "gsplat": GsplatExperiment(coeffs_jax), 23 | "diff_gaussian_rasterization": DiffGaussianRasterizationExperiment(coeffs_jax), 24 | } 25 | 26 | for name, experiment in experiments.items(): 27 | key = jax.random.key(0) 28 | params_jax = init_params(key, num_points) 29 | _, img = experiment.run(params_jax) 30 | 31 | iio.imwrite(f"test-{name}.png", (img * 255).astype(jnp.uint8)) 32 | 33 | for i in range(iterations): 34 | key, subkey = jax.random.split(key) 35 | params_jax = init_params(subkey, num_points) 36 | _, img = experiment.run(params_jax) 37 | print(f"{name} avg {experiment.total_time/iterations*1000:.4f}ms") 38 | 39 | 40 | class Experiment: 41 | _start_time: float 42 | total_time: float = 0 43 | 44 | def _start(self): 45 | self._start_time = time.perf_counter() 46 | 47 | def _end(self) -> float: 48 | end = time.perf_counter() 49 | return end - self._start_time 50 | 51 | def run(self, params_jax: dict[str, jax.Array]) -> tuple[float, jax.Array]: ... 52 | 53 | 54 | class JaxsplatExperiment(Experiment): 55 | _coeffs: dict 56 | 57 | def __init__(self, coeffs_jax): 58 | self._coeffs = coeffs_jax 59 | self.render = self.renderer() 60 | 61 | def run(self, params_jax): 62 | self._start() 63 | img = self.render(params_jax).block_until_ready() 64 | delta = self._end() 65 | 66 | self.total_time += delta 67 | return delta, img 68 | 69 | def renderer(self): 70 | viewmat = self._coeffs["viewmat"] 71 | background = self._coeffs["background"] 72 | img_shape = self._coeffs["img_shape"] 73 | f = self._coeffs["f"] 74 | c = self._coeffs["c"] 75 | glob_scale = self._coeffs["glob_scale"] 76 | clip_thresh = self._coeffs["clip_thresh"] 77 | block_size = self._coeffs["block_size"] 78 | 79 | def render(params: dict[str, jax.Array]) -> jax.Array: 80 | img = jaxsplat.render( 81 | means3d=params["means3d"], 82 | scales=params["scales"], 83 | quats=params["quats"], 84 | colors=params["colors"], 85 | opacities=params["opacities"], 86 | viewmat=viewmat, 87 | background=background, 88 | img_shape=img_shape, 89 | f=f, 90 | c=c, 91 | glob_scale=glob_scale, 92 | clip_thresh=clip_thresh, 93 | block_size=block_size, 94 | ) 95 | return img 96 | 97 | return render 98 | 99 | 100 | class GsplatExperiment(Experiment): 101 | _coeffs: dict 102 | 103 | def __init__(self, coeffs_jax): 104 | self._coeffs = jax_to_torch_dict(coeffs_jax) 105 | self.render = self.renderer() 106 | 107 | def run(self, params_jax): 108 | params = jax_to_torch_dict(params_jax) 109 | self._start() 110 | img = self.render(params) 111 | torch.cuda.synchronize() 112 | delta = self._end() 113 | 114 | self.total_time += delta 115 | return delta, jnp.asarray(img.cpu()) 116 | 117 | def renderer(self): 118 | glob_scale = self._coeffs["glob_scale"] 119 | viewmat = self._coeffs["viewmat"] 120 | img_height = self._coeffs["img_shape"][0] 121 | img_width = self._coeffs["img_shape"][1] 122 | fx = self._coeffs["f"][0] 123 | fy = self._coeffs["f"][1] 124 | cx = self._coeffs["c"][0] 125 | cy = self._coeffs["c"][1] 126 | clip_thresh = self._coeffs["clip_thresh"] 127 | block_width = self._coeffs["block_size"] 128 | background = self._coeffs["background"] 129 | 130 | def render(params: dict[str, torch.Tensor]) -> torch.Tensor: 131 | (xys, depths, radii, conics, compensation, num_tiles_hit, cov3ds) = ( 132 | gsplat.project_gaussians( 133 | means3d=params["means3d"], 134 | scales=params["scales"], 135 | glob_scale=glob_scale, 136 | quats=params["quats"], 137 | viewmat=viewmat, 138 | img_height=img_height, 139 | img_width=img_width, 140 | fx=fx, 141 | fy=fy, 142 | cx=cx, 143 | cy=cy, 144 | clip_thresh=clip_thresh, 145 | block_width=block_width, 146 | ) 147 | ) 148 | 149 | img = gsplat.rasterize_gaussians( 150 | xys=xys, 151 | depths=depths, 152 | radii=radii, 153 | conics=conics, 154 | num_tiles_hit=num_tiles_hit, 155 | colors=params["colors"], 156 | opacity=params["opacities"], 157 | img_height=img_height, 158 | img_width=img_width, 159 | block_width=block_width, 160 | background=background, 161 | return_alpha=False, 162 | ) 163 | 164 | return img 165 | 166 | return render 167 | 168 | 169 | class DiffGaussianRasterizationExperiment(Experiment): 170 | _coeffs: dict 171 | _settings: diff_gaussian_rasterization.GaussianRasterizationSettings 172 | 173 | def __init__(self, coeffs_jax): 174 | self._coeffs = jax_to_torch_dict(coeffs_jax) 175 | h, w = self._coeffs["img_shape"] 176 | fx, fy = self._coeffs["f"] 177 | far, near = 1000, self._coeffs["clip_thresh"] 178 | viewmat = self._coeffs["viewmat"].T 179 | projmat = torch.tensor( 180 | [ 181 | [2 * fx / w, 0, 0, 0], 182 | [0, 2 * fy / h, 0, 0], 183 | [0, 0, (far + near) / (far - near), 1], 184 | [0, 0, -2 * far * near / (far - near), 0], 185 | ] 186 | ).cuda() 187 | self._settings = diff_gaussian_rasterization.GaussianRasterizationSettings( 188 | image_height=h, 189 | image_width=w, 190 | tanfovx=0.5 * w / fx, 191 | tanfovy=0.5 * h / fy, 192 | bg=self._coeffs["background"], 193 | scale_modifier=self._coeffs["glob_scale"], 194 | viewmatrix=viewmat, 195 | projmatrix=viewmat @ projmat, 196 | sh_degree=0, 197 | campos=viewmat[:3, 3], 198 | prefiltered=False, 199 | debug=False, 200 | ) 201 | self.render = self.renderer() 202 | 203 | def run(self, params_jax): 204 | params = jax_to_torch_dict(params_jax) 205 | self._start() 206 | img = self.render(params) 207 | torch.cuda.synchronize() 208 | delta = self._end() 209 | 210 | self.total_time += delta 211 | return delta, jnp.asarray(img.permute(1, 2, 0).cpu()) 212 | 213 | def renderer(self): 214 | settings = self._settings 215 | 216 | def render( 217 | params: dict[str, torch.Tensor], 218 | ) -> torch.Tensor: 219 | img, _ = diff_gaussian_rasterization.rasterize_gaussians( 220 | means3D=params["means3d"], 221 | means2D=torch.zeros_like(params["means3d"]), 222 | sh=torch.Tensor([]), 223 | colors_precomp=params["colors"], 224 | opacities=params["opacities"], 225 | scales=params["scales"], 226 | rotations=params["quats"], 227 | cov3Ds_precomp=torch.Tensor([]), 228 | raster_settings=settings, 229 | ) # type: ignore 230 | 231 | return img 232 | 233 | return render 234 | 235 | 236 | def jax_to_torch(array_jax: jax.Array) -> torch.Tensor: 237 | array_np = np.asarray(array_jax) 238 | tensor = torch.from_numpy(array_np.copy()).cuda() 239 | return tensor 240 | 241 | 242 | def jax_to_torch_dict(dict_jax) -> dict: 243 | return { 244 | k: (jax_to_torch(v) if isinstance(v, jax.Array) else v) 245 | for k, v in dict_jax.items() 246 | } 247 | 248 | 249 | def init_params(key, num_points: int) -> dict[str, jax.Array]: 250 | subkeys = jax.random.split(key, 5) 251 | 252 | means3d = jax.random.uniform(subkeys[0], (num_points, 3), minval=-3, maxval=3) 253 | scales = jax.random.uniform(subkeys[1], (num_points, 3), maxval=0.5) 254 | quats = jax.random.normal(subkeys[2], (num_points, 4)) 255 | quats /= jnp.linalg.norm(quats, axis=-1, keepdims=True) 256 | colors = jax.random.uniform(subkeys[3], (num_points, 3)) 257 | opacities = jax.random.uniform(subkeys[4], (num_points, 1)) 258 | 259 | return { 260 | "means3d": means3d, 261 | "scales": scales, 262 | "quats": quats, 263 | "colors": colors, 264 | "opacities": opacities, 265 | } 266 | 267 | 268 | def init_coeffs(): 269 | viewmat = jnp.array( 270 | [ 271 | [1, 0, 0, 0], 272 | [0, 1, 0, 0], 273 | [0, 0, 1, 8], 274 | [0, 0, 0, 1], 275 | ], 276 | dtype=jnp.float32, 277 | ) 278 | W, H = 1600, 900 279 | 280 | return { 281 | "viewmat": viewmat, 282 | "background": jnp.ones((3,), dtype=jnp.float32), 283 | "img_shape": (H, W), 284 | "f": (W / 2, H / 2), 285 | "c": (W / 2, H / 2), 286 | "glob_scale": 1.0, 287 | "clip_thresh": 0.01, 288 | "block_size": 16, 289 | } 290 | 291 | 292 | if __name__ == "__main__": 293 | parser = argparse.ArgumentParser( 294 | prog="python -m examples.benchmark", 295 | description="Benchmarks jaxsplat and other methods", 296 | ) 297 | parser.add_argument("--iters", type=int, default=100) 298 | parser.add_argument("--num_points", type=int, default=50_000) 299 | 300 | args = parser.parse_args() 301 | main(args.iters, args.num_points) 302 | -------------------------------------------------------------------------------- /COPYRIGHT.txt: -------------------------------------------------------------------------------- 1 | This project is released under the MIT license. 2 | 3 | CUDA kernels for jaxsplat were modified from gsplat 4 | , originally released under 5 | Apache License 2.0. Changes include removing GLM and refactoring. The original 6 | license text of gsplat is as follows: 7 | 8 | 9 | Apache License 10 | Version 2.0, January 2004 11 | http://www.apache.org/licenses/ 12 | 13 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 14 | 15 | 1. Definitions. 16 | 17 | "License" shall mean the terms and conditions for use, reproduction, 18 | and distribution as defined by Sections 1 through 9 of this document. 19 | 20 | "Licensor" shall mean the copyright owner or entity authorized by 21 | the copyright owner that is granting the License. 22 | 23 | "Legal Entity" shall mean the union of the acting entity and all 24 | other entities that control, are controlled by, or are under common 25 | control with that entity. For the purposes of this definition, 26 | "control" means (i) the power, direct or indirect, to cause the 27 | direction or management of such entity, whether by contract or 28 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 29 | outstanding shares, or (iii) beneficial ownership of such entity. 30 | 31 | "You" (or "Your") shall mean an individual or Legal Entity 32 | exercising permissions granted by this License. 33 | 34 | "Source" form shall mean the preferred form for making modifications, 35 | including but not limited to software source code, documentation 36 | source, and configuration files. 37 | 38 | "Object" form shall mean any form resulting from mechanical 39 | transformation or translation of a Source form, including but 40 | not limited to compiled object code, generated documentation, 41 | and conversions to other media types. 42 | 43 | "Work" shall mean the work of authorship, whether in Source or 44 | Object form, made available under the License, as indicated by a 45 | copyright notice that is included in or attached to the work 46 | (an example is provided in the Appendix below). 47 | 48 | "Derivative Works" shall mean any work, whether in Source or Object 49 | form, that is based on (or derived from) the Work and for which the 50 | editorial revisions, annotations, elaborations, or other modifications 51 | represent, as a whole, an original work of authorship. For the purposes 52 | of this License, Derivative Works shall not include works that remain 53 | separable from, or merely link (or bind by name) to the interfaces of, 54 | the Work and Derivative Works thereof. 55 | 56 | "Contribution" shall mean any work of authorship, including 57 | the original version of the Work and any modifications or additions 58 | to that Work or Derivative Works thereof, that is intentionally 59 | submitted to Licensor for inclusion in the Work by the copyright owner 60 | or by an individual or Legal Entity authorized to submit on behalf of 61 | the copyright owner. For the purposes of this definition, "submitted" 62 | means any form of electronic, verbal, or written communication sent 63 | to the Licensor or its representatives, including but not limited to 64 | communication on electronic mailing lists, source code control systems, 65 | and issue tracking systems that are managed by, or on behalf of, the 66 | Licensor for the purpose of discussing and improving the Work, but 67 | excluding communication that is conspicuously marked or otherwise 68 | designated in writing by the copyright owner as "Not a Contribution." 69 | 70 | "Contributor" shall mean Licensor and any individual or Legal Entity 71 | on behalf of whom a Contribution has been received by Licensor and 72 | subsequently incorporated within the Work. 73 | 74 | 2. Grant of Copyright License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | copyright license to reproduce, prepare Derivative Works of, 78 | publicly display, publicly perform, sublicense, and distribute the 79 | Work and such Derivative Works in Source or Object form. 80 | 81 | 3. Grant of Patent License. Subject to the terms and conditions of 82 | this License, each Contributor hereby grants to You a perpetual, 83 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 84 | (except as stated in this section) patent license to make, have made, 85 | use, offer to sell, sell, import, and otherwise transfer the Work, 86 | where such license applies only to those patent claims licensable 87 | by such Contributor that are necessarily infringed by their 88 | Contribution(s) alone or by combination of their Contribution(s) 89 | with the Work to which such Contribution(s) was submitted. If You 90 | institute patent litigation against any entity (including a 91 | cross-claim or counterclaim in a lawsuit) alleging that the Work 92 | or a Contribution incorporated within the Work constitutes direct 93 | or contributory patent infringement, then any patent licenses 94 | granted to You under this License for that Work shall terminate 95 | as of the date such litigation is filed. 96 | 97 | 4. Redistribution. You may reproduce and distribute copies of the 98 | Work or Derivative Works thereof in any medium, with or without 99 | modifications, and in Source or Object form, provided that You 100 | meet the following conditions: 101 | 102 | (a) You must give any other recipients of the Work or 103 | Derivative Works a copy of this License; and 104 | 105 | (b) You must cause any modified files to carry prominent notices 106 | stating that You changed the files; and 107 | 108 | (c) You must retain, in the Source form of any Derivative Works 109 | that You distribute, all copyright, patent, trademark, and 110 | attribution notices from the Source form of the Work, 111 | excluding those notices that do not pertain to any part of 112 | the Derivative Works; and 113 | 114 | (d) If the Work includes a "NOTICE" text file as part of its 115 | distribution, then any Derivative Works that You distribute must 116 | include a readable copy of the attribution notices contained 117 | within such NOTICE file, excluding those notices that do not 118 | pertain to any part of the Derivative Works, in at least one 119 | of the following places: within a NOTICE text file distributed 120 | as part of the Derivative Works; within the Source form or 121 | documentation, if provided along with the Derivative Works; or, 122 | within a display generated by the Derivative Works, if and 123 | wherever such third-party notices normally appear. The contents 124 | of the NOTICE file are for informational purposes only and 125 | do not modify the License. You may add Your own attribution 126 | notices within Derivative Works that You distribute, alongside 127 | or as an addendum to the NOTICE text from the Work, provided 128 | that such additional attribution notices cannot be construed 129 | as modifying the License. 130 | 131 | You may add Your own copyright statement to Your modifications and 132 | may provide additional or different license terms and conditions 133 | for use, reproduction, or distribution of Your modifications, or 134 | for any such Derivative Works as a whole, provided Your use, 135 | reproduction, and distribution of the Work otherwise complies with 136 | the conditions stated in this License. 137 | 138 | 5. Submission of Contributions. Unless You explicitly state otherwise, 139 | any Contribution intentionally submitted for inclusion in the Work 140 | by You to the Licensor shall be under the terms and conditions of 141 | this License, without any additional terms or conditions. 142 | Notwithstanding the above, nothing herein shall supersede or modify 143 | the terms of any separate license agreement you may have executed 144 | with Licensor regarding such Contributions. 145 | 146 | 6. Trademarks. This License does not grant permission to use the trade 147 | names, trademarks, service marks, or product names of the Licensor, 148 | except as required for reasonable and customary use in describing the 149 | origin of the Work and reproducing the content of the NOTICE file. 150 | 151 | 7. Disclaimer of Warranty. Unless required by applicable law or 152 | agreed to in writing, Licensor provides the Work (and each 153 | Contributor provides its Contributions) on an "AS IS" BASIS, 154 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 155 | implied, including, without limitation, any warranties or conditions 156 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 157 | PARTICULAR PURPOSE. You are solely responsible for determining the 158 | appropriateness of using or redistributing the Work and assume any 159 | risks associated with Your exercise of permissions under this License. 160 | 161 | 8. Limitation of Liability. In no event and under no legal theory, 162 | whether in tort (including negligence), contract, or otherwise, 163 | unless required by applicable law (such as deliberate and grossly 164 | negligent acts) or agreed to in writing, shall any Contributor be 165 | liable to You for damages, including any direct, indirect, special, 166 | incidental, or consequential damages of any character arising as a 167 | result of this License or out of the use or inability to use the 168 | Work (including but not limited to damages for loss of goodwill, 169 | work stoppage, computer failure or malfunction, or any and all 170 | other commercial damages or losses), even if such Contributor 171 | has been advised of the possibility of such damages. 172 | 173 | 9. Accepting Warranty or Additional Liability. While redistributing 174 | the Work or Derivative Works thereof, You may choose to offer, 175 | and charge a fee for, acceptance of support, warranty, indemnity, 176 | or other liability obligations and/or rights consistent with this 177 | License. However, in accepting such obligations, You may act only 178 | on Your own behalf and on Your sole responsibility, not on behalf 179 | of any other Contributor, and only if You agree to indemnify, 180 | defend, and hold each Contributor harmless for any liability 181 | incurred by, or claims asserted against, such Contributor by reason 182 | of your accepting any such warranty or additional liability. 183 | 184 | END OF TERMS AND CONDITIONS 185 | 186 | APPENDIX: How to apply the Apache License to your work. 187 | 188 | To apply the Apache License to your work, attach the following 189 | boilerplate notice, with the fields enclosed by brackets "[]" 190 | replaced with your own identifying information. (Don't include 191 | the brackets!) The text should be enclosed in the appropriate 192 | comment syntax for the file format. We also recommend that a 193 | file or class name and description of purpose be included on the 194 | same "printed page" as the copyright notice for easier 195 | identification within third-party archives. 196 | 197 | Copyright [yyyy] [name of copyright owner] 198 | 199 | Licensed under the Apache License, Version 2.0 (the "License"); 200 | you may not use this file except in compliance with the License. 201 | You may obtain a copy of the License at 202 | 203 | http://www.apache.org/licenses/LICENSE-2.0 204 | 205 | Unless required by applicable law or agreed to in writing, software 206 | distributed under the License is distributed on an "AS IS" BASIS, 207 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 208 | See the License for the specific language governing permissions and 209 | limitations under the License. 210 | -------------------------------------------------------------------------------- /lib/kernels/forward.cu: -------------------------------------------------------------------------------- 1 | #include "forward.h" 2 | #include "helpers.h" 3 | 4 | #include 5 | #include 6 | 7 | namespace cg = cooperative_groups; 8 | 9 | namespace kernels { 10 | 11 | // kernel function for projecting each gaussian on device 12 | // each thread processes one gaussian 13 | __global__ void project_gaussians_fwd( 14 | const int num_points, 15 | const float3 *__restrict__ means3d, 16 | const float3 *__restrict__ scales, 17 | const float glob_scale, 18 | const float4 *__restrict__ quats, 19 | const float *__restrict__ viewmat, 20 | const float4 intrins, 21 | const dim3 img_size, 22 | const dim3 tile_bounds, 23 | const unsigned block_width, 24 | const float clip_thresh, 25 | float *__restrict__ covs3d, 26 | float2 *__restrict__ xys, 27 | float *__restrict__ depths, 28 | int *__restrict__ radii, 29 | float3 *__restrict__ conics, 30 | float *__restrict__ compensation, 31 | int32_t *__restrict__ num_tiles_hit 32 | ) { 33 | unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid 34 | if (idx >= num_points) { 35 | return; 36 | } 37 | radii[idx] = 0; 38 | num_tiles_hit[idx] = 0; 39 | 40 | float3 p_world = means3d[idx]; 41 | // printf("p_world %d %.2f %.2f %.2f\n", idx, p_world.x, p_world.y, 42 | // p_world.z); 43 | float3 p_view; 44 | if (helpers::clip_near_plane(p_world, viewmat, p_view, clip_thresh)) { 45 | // printf("%d is out of frustum z %.2f, returning\n", idx, p_view.z); 46 | return; 47 | } 48 | // printf("p_view %d %.2f %.2f %.2f\n", idx, p_view.x, p_view.y, p_view.z); 49 | 50 | // compute the projected covariance 51 | float3 scale = scales[idx]; 52 | float4 quat = quats[idx]; 53 | // printf("%d scale %.2f %.2f %.2f\n", idx, scale.x, scale.y, scale.z); 54 | // printf( 55 | // "%d quat %.2f %.2f %.2f %.2f\n", 56 | // idx, 57 | // quat.w, 58 | // quat.x, 59 | // quat.y, 60 | // quat.z 61 | // ); 62 | float *cur_cov3d = &(covs3d[6 * idx]); 63 | scale_rot_to_cov3d(scale, glob_scale, quat, cur_cov3d); 64 | 65 | // project to 2d with ewa approximation 66 | float fx = intrins.x; 67 | float fy = intrins.y; 68 | float cx = intrins.z; 69 | float cy = intrins.w; 70 | float tan_fovx = 0.5 * img_size.x / fx; 71 | float tan_fovy = 0.5 * img_size.y / fy; 72 | float3 cov2d; 73 | float comp; 74 | project_cov3d_ewa( 75 | p_world, 76 | cur_cov3d, 77 | viewmat, 78 | fx, 79 | fy, 80 | tan_fovx, 81 | tan_fovy, 82 | cov2d, 83 | comp 84 | ); 85 | // printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z); 86 | 87 | float3 conic; 88 | float radius; 89 | bool ok = helpers::compute_cov2d_bounds(cov2d, conic, radius); 90 | if (!ok) 91 | return; // zero determinant 92 | // printf("conic %d %.2f %.2f %.2f\n", idx, conic.x, conic.y, conic.z); 93 | conics[idx] = conic; 94 | 95 | // compute the projected mean 96 | float2 center = helpers::project_pix({fx, fy}, p_view, {cx, cy}); 97 | uint2 tile_min, tile_max; 98 | helpers::get_tile_bbox( 99 | center, 100 | radius, 101 | tile_bounds, 102 | tile_min, 103 | tile_max, 104 | block_width 105 | ); 106 | int32_t tile_area = (tile_max.x - tile_min.x) * (tile_max.y - tile_min.y); 107 | if (tile_area <= 0) { 108 | // printf("%d point bbox outside of bounds\n", idx); 109 | return; 110 | } 111 | 112 | num_tiles_hit[idx] = tile_area; 113 | depths[idx] = p_view.z; 114 | radii[idx] = (int)radius; 115 | xys[idx] = center; 116 | compensation[idx] = comp; 117 | // printf( 118 | // "point %d x %.2f y %.2f z %.2f, radius %d, # tiles %d, tile_min %d " 119 | // "%d, tile_max %d %d\n", 120 | // idx, 121 | // center.x, 122 | // center.y, 123 | // depths[idx], 124 | // radii[idx], 125 | // tile_area, 126 | // tile_min.x, 127 | // tile_min.y, 128 | // tile_max.x, 129 | // tile_max.y 130 | // ); 131 | } 132 | 133 | // kernel to map each intersection from tile ID and depth to a gaussian 134 | // writes output to isect_ids and gaussian_ids 135 | __global__ void map_gaussian_to_intersects( 136 | const int num_points, 137 | const float2 *__restrict__ xys, 138 | const float *__restrict__ depths, 139 | const int *__restrict__ radii, 140 | const int32_t *__restrict__ cum_tiles_hit, 141 | const dim3 tile_bounds, 142 | const unsigned block_width, 143 | int64_t *__restrict__ isect_ids, 144 | int32_t *__restrict__ gaussian_ids 145 | ) { 146 | unsigned idx = cg::this_grid().thread_rank(); 147 | if (idx >= num_points) 148 | return; 149 | if (radii[idx] <= 0) 150 | return; 151 | // get the tile bbox for gaussian 152 | uint2 tile_min, tile_max; 153 | float2 center = xys[idx]; 154 | helpers::get_tile_bbox( 155 | center, 156 | radii[idx], 157 | tile_bounds, 158 | tile_min, 159 | tile_max, 160 | block_width 161 | ); 162 | // printf("point %d, %d radius, min %d %d, max %d %d\n", idx, radii[idx], 163 | // tile_min.x, tile_min.y, tile_max.x, tile_max.y); 164 | 165 | // update the intersection info for all tiles this gaussian hits 166 | int32_t cur_idx = (idx == 0) ? 0 : cum_tiles_hit[idx - 1]; 167 | // printf("point %d starting at %d\n", idx, cur_idx); 168 | int64_t depth_id = (int64_t) * (int32_t *)&(depths[idx]); 169 | for (int i = tile_min.y; i < tile_max.y; ++i) { 170 | for (int j = tile_min.x; j < tile_max.x; ++j) { 171 | // isect_id is tile ID and depth as int32 172 | int64_t tile_id = i * tile_bounds.x + j; // tile within image 173 | isect_ids[cur_idx] = (tile_id << 32) | depth_id; // tile | depth id 174 | gaussian_ids[cur_idx] = idx; // 3D gaussian id 175 | ++cur_idx; // handles gaussians that hit more than one tile 176 | } 177 | } 178 | // printf("point %d ending at %d\n", idx, cur_idx); 179 | } 180 | 181 | // kernel to map sorted intersection IDs to tile bins 182 | // expect that intersection IDs are sorted by increasing tile ID 183 | // i.e. intersections of a tile are in contiguous chunks 184 | __global__ void get_tile_bin_edges( 185 | const int num_intersects, 186 | const int64_t *__restrict__ isect_ids_sorted, 187 | int2 *__restrict__ tile_bins 188 | ) { 189 | unsigned idx = cg::this_grid().thread_rank(); 190 | if (idx >= num_intersects) 191 | return; 192 | // save the indices where the tile_id changes 193 | int32_t cur_tile_idx = (int32_t)(isect_ids_sorted[idx] >> 32); 194 | if (idx == 0 || idx == num_intersects - 1) { 195 | if (idx == 0) 196 | tile_bins[cur_tile_idx].x = 0; 197 | if (idx == num_intersects - 1) 198 | tile_bins[cur_tile_idx].y = num_intersects; 199 | } 200 | if (idx == 0) 201 | return; 202 | int32_t prev_tile_idx = (int32_t)(isect_ids_sorted[idx - 1] >> 32); 203 | if (prev_tile_idx != cur_tile_idx) { 204 | tile_bins[prev_tile_idx].y = idx; 205 | tile_bins[cur_tile_idx].x = idx; 206 | return; 207 | } 208 | } 209 | 210 | // kernel function for rasterizing each tile 211 | // each thread treats a single pixel 212 | // each thread group uses the same gaussian data in a tile 213 | __global__ void nd_rasterize_forward( 214 | const dim3 tile_bounds, 215 | const dim3 img_size, 216 | const unsigned channels, 217 | const int32_t *__restrict__ gaussian_ids_sorted, 218 | const int2 *__restrict__ tile_bins, 219 | const float2 *__restrict__ xys, 220 | const float3 *__restrict__ conics, 221 | const float *__restrict__ colors, 222 | const float *__restrict__ opacities, 223 | float *__restrict__ final_Ts, 224 | int *__restrict__ final_index, 225 | float *__restrict__ out_img, 226 | const float *__restrict__ background 227 | ) { 228 | auto block = cg::this_thread_block(); 229 | int32_t tile_id = 230 | block.group_index().y * tile_bounds.x + block.group_index().x; 231 | unsigned i = 232 | block.group_index().y * block.group_dim().y + block.thread_index().y; 233 | unsigned j = 234 | block.group_index().x * block.group_dim().x + block.thread_index().x; 235 | 236 | float px = (float)j + 0.5; 237 | float py = (float)i + 0.5; 238 | int32_t pix_id = i * img_size.x + j; 239 | 240 | // keep not rasterizing threads around for reading data 241 | bool inside = (i < img_size.y && j < img_size.x); 242 | bool done = !inside; 243 | 244 | int2 range = tile_bins[tile_id]; 245 | const int block_size = block.size(); 246 | int num_batches = (range.y - range.x + block_size - 1) / block_size; 247 | 248 | extern __shared__ int s[]; 249 | int32_t *id_batch = (int32_t *)s; 250 | float3 *xy_opacity_batch = (float3 *)&id_batch[block_size]; 251 | float3 *conic_batch = (float3 *)&xy_opacity_batch[block_size]; 252 | __half *color_out_batch = (__half *)&conic_batch[block_size]; 253 | #pragma unroll 254 | for (int c = 0; c < channels; ++c) 255 | color_out_batch[block.thread_rank() * channels + c] = __float2half(0.f); 256 | 257 | // current visibility left to render 258 | float T = 1.f; 259 | // index of most recent gaussian to write to this thread's pixel 260 | int cur_idx = 0; 261 | 262 | // collect and process batches of gaussians 263 | // each thread loads one gaussian at a time before rasterizing its 264 | // designated pixel 265 | int tr = block.thread_rank(); 266 | __half *pix_out = &color_out_batch[block.thread_rank() * channels]; 267 | 268 | for (int b = 0; b < num_batches; ++b) { 269 | // resync all threads before beginning next batch 270 | // end early if entire tile is done 271 | if (__syncthreads_count(done) >= block_size) { 272 | break; 273 | } 274 | // each thread fetch 1 gaussian from front to back 275 | 276 | int batch_start = range.x + block_size * b; 277 | int idx = batch_start + tr; 278 | if (idx < range.y) { 279 | int32_t g_id = gaussian_ids_sorted[idx]; 280 | id_batch[tr] = g_id; 281 | const float2 xy = xys[g_id]; 282 | const float opac = opacities[g_id]; 283 | xy_opacity_batch[tr] = {xy.x, xy.y, opac}; 284 | conic_batch[tr] = conics[g_id]; 285 | } 286 | 287 | // wait for other threads to collect the gaussians in batch 288 | block.sync(); 289 | 290 | int batch_size = min(block_size, range.y - batch_start); 291 | for (int t = 0; (t < batch_size) && !done; ++t) { 292 | const float3 conic = conic_batch[t]; 293 | const float3 xy_opac = xy_opacity_batch[t]; 294 | const float opac = xy_opac.z; 295 | const float2 delta = {xy_opac.x - px, xy_opac.y - py}; 296 | const float sigma = 0.5f * (conic.x * delta.x * delta.x + 297 | conic.z * delta.y * delta.y) + 298 | conic.y * delta.x * delta.y; 299 | const float alpha = min(0.999f, opac * __expf(-sigma)); 300 | if (sigma < 0.f || alpha < 1.f / 255.f) { 301 | continue; 302 | } 303 | 304 | const float next_T = T * (1.f - alpha); 305 | if (next_T <= 1e-4f) { 306 | // we want to render the last gaussian that contributes and note 307 | // that here idx > range.x so we don't underflow 308 | done = true; 309 | break; 310 | } 311 | 312 | int32_t g = id_batch[t]; 313 | const float vis = alpha * T; 314 | #pragma unroll 315 | for (int c = 0; c < channels; ++c) { 316 | pix_out[c] = __hadd( 317 | pix_out[c], 318 | __float2half(colors[channels * g + c] * vis) 319 | ); 320 | } 321 | T = next_T; 322 | cur_idx = batch_start + t; 323 | } 324 | } 325 | 326 | if (inside) { 327 | // add background 328 | final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel 329 | final_index[pix_id] = 330 | cur_idx; // index of in bin of last gaussian in this pixel 331 | #pragma unroll 332 | for (int c = 0; c < channels; ++c) { 333 | out_img[pix_id * channels + c] = 334 | __half2float(pix_out[c]) + T * background[c]; 335 | } 336 | } 337 | } 338 | 339 | __global__ void rasterize_fwd( 340 | const dim3 tile_bounds, 341 | const dim3 img_size, 342 | const int32_t *__restrict__ gaussian_ids_sorted, 343 | const int2 *__restrict__ tile_bins, 344 | const float2 *__restrict__ xys, 345 | const float3 *__restrict__ conics, 346 | const float3 *__restrict__ colors, 347 | const float *__restrict__ opacities, 348 | float *__restrict__ final_Ts, 349 | int *__restrict__ final_index, 350 | float3 *__restrict__ out_img, 351 | const float3 &__restrict__ background 352 | ) { 353 | // each thread draws one pixel, but also timeshares caching gaussians in a 354 | // shared tile 355 | 356 | auto block = cg::this_thread_block(); 357 | int32_t tile_id = 358 | block.group_index().y * tile_bounds.x + block.group_index().x; 359 | unsigned i = 360 | block.group_index().y * block.group_dim().y + block.thread_index().y; 361 | unsigned j = 362 | block.group_index().x * block.group_dim().x + block.thread_index().x; 363 | 364 | float px = (float)j + 0.5; 365 | float py = (float)i + 0.5; 366 | int32_t pix_id = i * img_size.x + j; 367 | 368 | // return if out of bounds 369 | // keep not rasterizing threads around for reading data 370 | bool inside = (i < img_size.y && j < img_size.x); 371 | bool done = !inside; 372 | 373 | // have all threads in tile process the same gaussians in batches 374 | // first collect gaussians between range.x and range.y in batches 375 | // which gaussians to look through in this tile 376 | int2 range = tile_bins[tile_id]; 377 | const int block_size = block.size(); 378 | int num_batches = (range.y - range.x + block_size - 1) / block_size; 379 | 380 | __shared__ int32_t id_batch[MAX_GRID_DIM]; 381 | __shared__ float3 xy_opacity_batch[MAX_GRID_DIM]; 382 | __shared__ float3 conic_batch[MAX_GRID_DIM]; 383 | 384 | // current visibility left to render 385 | float T = 1.f; 386 | // index of most recent gaussian to write to this thread's pixel 387 | int cur_idx = 0; 388 | 389 | // collect and process batches of gaussians 390 | // each thread loads one gaussian at a time before rasterizing its 391 | // designated pixel 392 | int tr = block.thread_rank(); 393 | float3 pix_out = {0.f, 0.f, 0.f}; 394 | for (int b = 0; b < num_batches; ++b) { 395 | // resync all threads before beginning next batch 396 | // end early if entire tile is done 397 | if (__syncthreads_count(done) >= block_size) { 398 | break; 399 | } 400 | 401 | // each thread fetch 1 gaussian from front to back 402 | // index of gaussian to load 403 | int batch_start = range.x + block_size * b; 404 | int idx = batch_start + tr; 405 | if (idx < range.y) { 406 | int32_t g_id = gaussian_ids_sorted[idx]; 407 | id_batch[tr] = g_id; 408 | const float2 xy = xys[g_id]; 409 | const float opac = opacities[g_id]; 410 | xy_opacity_batch[tr] = {xy.x, xy.y, opac}; 411 | conic_batch[tr] = conics[g_id]; 412 | } 413 | 414 | // wait for other threads to collect the gaussians in batch 415 | block.sync(); 416 | 417 | // process gaussians in the current batch for this pixel 418 | int batch_size = min(block_size, range.y - batch_start); 419 | for (int t = 0; (t < batch_size) && !done; ++t) { 420 | const float3 conic = conic_batch[t]; 421 | const float3 xy_opac = xy_opacity_batch[t]; 422 | const float opac = xy_opac.z; 423 | const float2 delta = {xy_opac.x - px, xy_opac.y - py}; 424 | const float sigma = 0.5f * (conic.x * delta.x * delta.x + 425 | conic.z * delta.y * delta.y) + 426 | conic.y * delta.x * delta.y; 427 | const float alpha = min(0.999f, opac * __expf(-sigma)); 428 | if (sigma < 0.f || alpha < 1.f / 255.f) { 429 | continue; 430 | } 431 | 432 | const float next_T = T * (1.f - alpha); 433 | if (next_T <= 1e-4f) { // this pixel is done 434 | // we want to render the last gaussian that contributes and note 435 | // that here idx > range.x so we don't underflow 436 | done = true; 437 | break; 438 | } 439 | 440 | int32_t g = id_batch[t]; 441 | const float vis = alpha * T; 442 | const float3 c = colors[g]; 443 | pix_out.x = pix_out.x + c.x * vis; 444 | pix_out.y = pix_out.y + c.y * vis; 445 | pix_out.z = pix_out.z + c.z * vis; 446 | T = next_T; 447 | cur_idx = batch_start + t; 448 | } 449 | } 450 | 451 | if (inside) { 452 | // add background 453 | final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel 454 | final_index[pix_id] = 455 | cur_idx; // index of in bin of last gaussian in this pixel 456 | float3 final_color; 457 | final_color.x = pix_out.x + T * background.x; 458 | final_color.y = pix_out.y + T * background.y; 459 | final_color.z = pix_out.z + T * background.z; 460 | out_img[pix_id] = final_color; 461 | } 462 | } 463 | 464 | // device helper to approximate projected 2d cov from 3d mean and cov 465 | __device__ void project_cov3d_ewa( 466 | const float3 &__restrict__ mean3d, 467 | const float *__restrict__ cov3d, 468 | const float *__restrict__ viewmat, 469 | const float fx, 470 | const float fy, 471 | const float tan_fovx, 472 | const float tan_fovy, 473 | float3 &cov2d, 474 | float &compensation 475 | ) { 476 | // clip the 477 | // we expect row major matrices as input, glm uses column major 478 | // upper 3x3 submatrix 479 | mat3 W = mat3( 480 | viewmat[0], 481 | viewmat[4], 482 | viewmat[8], 483 | viewmat[1], 484 | viewmat[5], 485 | viewmat[9], 486 | viewmat[2], 487 | viewmat[6], 488 | viewmat[10] 489 | ); 490 | float3 p = {viewmat[3], viewmat[7], viewmat[11]}; 491 | float3 t = { 492 | W[0][0] * mean3d.x + W[1][0] * mean3d.y + W[2][0] * mean3d.z + p.x, 493 | W[0][1] * mean3d.x + W[1][1] * mean3d.y + W[2][1] * mean3d.z + p.y, 494 | W[0][2] * mean3d.x + W[1][2] * mean3d.y + W[2][2] * mean3d.z + p.z, 495 | }; 496 | 497 | // clip so that the covariance 498 | float lim_x = 1.3f * tan_fovx; 499 | float lim_y = 1.3f * tan_fovy; 500 | t.x = t.z * min(lim_x, max(-lim_x, t.x / t.z)); 501 | t.y = t.z * min(lim_y, max(-lim_y, t.y / t.z)); 502 | 503 | float rz = 1.f / t.z; 504 | float rz2 = rz * rz; 505 | 506 | // column major 507 | // we only care about the top 2x2 submatrix 508 | mat3 J = mat3( 509 | fx * rz, 510 | 0.f, 511 | 0.f, 512 | 0.f, 513 | fy * rz, 514 | 0.f, 515 | -fx * t.x * rz2, 516 | -fy * t.y * rz2, 517 | 0.f 518 | ); 519 | mat3 T = J * W; 520 | 521 | mat3 V = mat3( 522 | cov3d[0], 523 | cov3d[1], 524 | cov3d[2], 525 | cov3d[1], 526 | cov3d[3], 527 | cov3d[4], 528 | cov3d[2], 529 | cov3d[4], 530 | cov3d[5] 531 | ); 532 | 533 | mat3 cov = T * V * T.transpose(); 534 | 535 | // add a little blur along axes and save upper triangular elements 536 | // and compute the density compensation factor due to the blurs 537 | float c00 = cov[0][0], c11 = cov[1][1], c01 = cov[0][1]; 538 | float det_orig = c00 * c11 - c01 * c01; 539 | cov2d.x = c00 + 0.3f; 540 | cov2d.y = c01; 541 | cov2d.z = c11 + 0.3f; 542 | float det_blur = cov2d.x * cov2d.z - cov2d.y * cov2d.y; 543 | compensation = sqrt(max(0.f, det_orig / det_blur)); 544 | } 545 | 546 | // device helper to get 3D covariance from scale and quat parameters 547 | __device__ void scale_rot_to_cov3d( 548 | const float3 scale, 549 | const float glob_scale, 550 | const float4 quat, 551 | float *cov3d 552 | ) { 553 | // printf("quat %.2f %.2f %.2f %.2f\n", quat.x, quat.y, quat.z, quat.w); 554 | mat3 R = helpers::quat_to_rotmat(quat); 555 | // printf("R %.2f %.2f %.2f\n", R[0][0], R[1][1], R[2][2]); 556 | mat3 S = helpers::scale_to_mat(scale, glob_scale); 557 | mat3 M = R * S; 558 | 559 | mat3 tmp = M * M.transpose(); 560 | 561 | // save upper right because symmetric 562 | cov3d[0] = tmp[0][0]; 563 | cov3d[1] = tmp[0][1]; 564 | cov3d[2] = tmp[0][2]; 565 | cov3d[3] = tmp[1][1]; 566 | cov3d[4] = tmp[1][2]; 567 | cov3d[5] = tmp[2][2]; 568 | } 569 | 570 | } // namespace kernels 571 | -------------------------------------------------------------------------------- /lib/kernels/backward.cu: -------------------------------------------------------------------------------- 1 | #include "backward.h" 2 | #include "common.h" 3 | #include "helpers.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace cg = cooperative_groups; 10 | 11 | namespace kernels { 12 | 13 | inline __device__ void warpSum3(float3 &val, cg::thread_block_tile<32> &tile) { 14 | val.x = cg::reduce(tile, val.x, cg::plus()); 15 | val.y = cg::reduce(tile, val.y, cg::plus()); 16 | val.z = cg::reduce(tile, val.z, cg::plus()); 17 | } 18 | 19 | inline __device__ void warpSum2(float2 &val, cg::thread_block_tile<32> &tile) { 20 | val.x = cg::reduce(tile, val.x, cg::plus()); 21 | val.y = cg::reduce(tile, val.y, cg::plus()); 22 | } 23 | 24 | inline __device__ void warpSum(float &val, cg::thread_block_tile<32> &tile) { 25 | val = cg::reduce(tile, val, cg::plus()); 26 | } 27 | __global__ void nd_rasterize_backward_kernel( 28 | const dim3 tile_bounds, 29 | const dim3 img_size, 30 | const unsigned channels, 31 | const int32_t *__restrict__ gaussians_ids_sorted, 32 | const int2 *__restrict__ tile_bins, 33 | const float2 *__restrict__ xys, 34 | const float3 *__restrict__ conics, 35 | const float *__restrict__ rgbs, 36 | const float *__restrict__ opacities, 37 | const float *__restrict__ background, 38 | const float *__restrict__ final_Ts, 39 | const int *__restrict__ final_index, 40 | const float *__restrict__ v_output, 41 | const float *__restrict__ v_output_alpha, 42 | float2 *__restrict__ v_xy, 43 | float2 *__restrict__ v_xy_abs, 44 | float3 *__restrict__ v_conic, 45 | float *__restrict__ v_rgb, 46 | float *__restrict__ v_opacity 47 | ) { 48 | auto block = cg::this_thread_block(); 49 | const int tr = block.thread_rank(); 50 | int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x; 51 | unsigned i = blockIdx.y * blockDim.y + threadIdx.y; 52 | unsigned j = blockIdx.x * blockDim.x + threadIdx.x; 53 | float px = (float)j + 0.5; 54 | float py = (float)i + 0.5; 55 | const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1); 56 | 57 | // keep not rasterizing threads around for reading data 58 | const bool inside = (i < img_size.y && j < img_size.x); 59 | // which gaussians get gradients for this pixel 60 | const int2 range = tile_bins[tile_id]; 61 | // df/d_out for this pixel 62 | const float *v_out = &(v_output[channels * pix_id]); 63 | const float v_out_alpha = v_output_alpha[pix_id]; 64 | // this is the T AFTER the last gaussian in this pixel 65 | float T_final = final_Ts[pix_id]; 66 | float T = T_final; 67 | // the contribution from gaussians behind the current one 68 | 69 | extern __shared__ half workspace[]; 70 | 71 | half *S = (half *)(&workspace[channels * tr]); 72 | #pragma unroll 73 | for (int c = 0; c < channels; ++c) { 74 | S[c] = __float2half(0.f); 75 | } 76 | const int bin_final = inside ? final_index[pix_id] : 0; 77 | cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); 78 | const int warp_bin_final = cg::reduce(warp, bin_final, cg::greater()); 79 | for (int idx = warp_bin_final - 1; idx >= range.x; --idx) { 80 | int valid = inside && idx < bin_final; 81 | const int32_t g = gaussians_ids_sorted[idx]; 82 | const float3 conic = conics[g]; 83 | const float2 center = xys[g]; 84 | const float2 delta = {center.x - px, center.y - py}; 85 | const float sigma = 86 | 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + 87 | conic.y * delta.x * delta.y; 88 | valid &= (sigma >= 0.f); 89 | const float opac = opacities[g]; 90 | const float vis = __expf(-sigma); 91 | const float alpha = min(0.99f, opac * vis); 92 | valid &= (alpha >= 1.f / 255.f); 93 | if (!warp.any(valid)) { 94 | continue; 95 | } 96 | float v_alpha = 0.f; 97 | float3 v_conic_local = {0.f, 0.f, 0.f}; 98 | float2 v_xy_local = {0.f, 0.f}; 99 | float2 v_xy_abs_local = {0.f, 0.f}; 100 | float v_opacity_local = 0.f; 101 | if (valid) { 102 | // compute the current T for this gaussian 103 | const float ra = 1.f / (1.f - alpha); 104 | T *= ra; 105 | // update v_rgb for this gaussian 106 | const float fac = alpha * T; 107 | for (int c = 0; c < channels; ++c) { 108 | // gradient wrt rgb 109 | atomicAdd(&(v_rgb[channels * g + c]), fac * v_out[c]); 110 | // contribution from this pixel 111 | v_alpha += 112 | (rgbs[channels * g + c] * T - __half2float(S[c]) * ra) * 113 | v_out[c]; 114 | // contribution from background pixel 115 | v_alpha += -T_final * ra * background[c] * v_out[c]; 116 | // update the running sum 117 | S[c] = __hadd(S[c], __float2half(rgbs[channels * g + c] * fac)); 118 | } 119 | v_alpha += T_final * ra * v_out_alpha; 120 | const float v_sigma = -opac * vis * v_alpha; 121 | v_conic_local = { 122 | 0.5f * v_sigma * delta.x * delta.x, 123 | v_sigma * delta.x * delta.y, 124 | 0.5f * v_sigma * delta.y * delta.y 125 | }; 126 | v_xy_local = { 127 | v_sigma * (conic.x * delta.x + conic.y * delta.y), 128 | v_sigma * (conic.y * delta.x + conic.z * delta.y) 129 | }; 130 | v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)}; 131 | v_opacity_local = vis * v_alpha; 132 | } 133 | warpSum3(v_conic_local, warp); 134 | warpSum2(v_xy_local, warp); 135 | warpSum2(v_xy_abs_local, warp); 136 | warpSum(v_opacity_local, warp); 137 | if (warp.thread_rank() == 0) { 138 | float *v_conic_ptr = (float *)(v_conic); 139 | float *v_xy_ptr = (float *)(v_xy); 140 | float *v_xy_abs_ptr = (float *)(v_xy_abs); 141 | atomicAdd(v_conic_ptr + 3 * g + 0, v_conic_local.x); 142 | atomicAdd(v_conic_ptr + 3 * g + 1, v_conic_local.y); 143 | atomicAdd(v_conic_ptr + 3 * g + 2, v_conic_local.z); 144 | atomicAdd(v_xy_ptr + 2 * g + 0, v_xy_local.x); 145 | atomicAdd(v_xy_ptr + 2 * g + 1, v_xy_local.y); 146 | atomicAdd(v_xy_abs_ptr + 2 * g + 0, v_xy_abs_local.x); 147 | atomicAdd(v_xy_abs_ptr + 2 * g + 1, v_xy_abs_local.y); 148 | atomicAdd(v_opacity + g, v_opacity_local); 149 | } 150 | } 151 | } 152 | 153 | __global__ void rasterize_bwd( 154 | const dim3 tile_bounds, 155 | const dim3 img_size, 156 | const int32_t *__restrict__ gaussian_ids_sorted, 157 | const int2 *__restrict__ tile_bins, 158 | const float2 *__restrict__ xys, 159 | const float3 *__restrict__ conics, 160 | const float3 *__restrict__ rgbs, 161 | const float *__restrict__ opacities, 162 | const float3 &__restrict__ background, 163 | const float *__restrict__ final_Ts, 164 | const int *__restrict__ final_index, 165 | const float3 *__restrict__ v_output, 166 | const float *__restrict__ v_output_alpha, 167 | float2 *__restrict__ v_xy, 168 | float2 *__restrict__ v_xy_abs, 169 | float3 *__restrict__ v_conic, 170 | float3 *__restrict__ v_rgb, 171 | float *__restrict__ v_opacity 172 | ) { 173 | auto block = cg::this_thread_block(); 174 | int32_t tile_id = 175 | block.group_index().y * tile_bounds.x + block.group_index().x; 176 | unsigned i = 177 | block.group_index().y * block.group_dim().y + block.thread_index().y; 178 | unsigned j = 179 | block.group_index().x * block.group_dim().x + block.thread_index().x; 180 | 181 | const float px = (float)j + 0.5; 182 | const float py = (float)i + 0.5; 183 | // clamp this value to the last pixel 184 | const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1); 185 | 186 | // keep not rasterizing threads around for reading data 187 | const bool inside = (i < img_size.y && j < img_size.x); 188 | 189 | // this is the T AFTER the last gaussian in this pixel 190 | float T_final = final_Ts[pix_id]; 191 | float T = T_final; 192 | // the contribution from gaussians behind the current one 193 | float3 buffer = {0.f, 0.f, 0.f}; 194 | // index of last gaussian to contribute to this pixel 195 | const int bin_final = inside ? final_index[pix_id] : 0; 196 | 197 | // have all threads in tile process the same gaussians in batches 198 | // first collect gaussians between range.x and range.y in batches 199 | // which gaussians to look through in this tile 200 | const int2 range = tile_bins[tile_id]; 201 | const int block_size = block.size(); 202 | const int num_batches = (range.y - range.x + block_size - 1) / block_size; 203 | 204 | __shared__ int32_t id_batch[MAX_GRID_DIM]; 205 | __shared__ float3 xy_opacity_batch[MAX_GRID_DIM]; 206 | __shared__ float3 conic_batch[MAX_GRID_DIM]; 207 | __shared__ float3 rgbs_batch[MAX_GRID_DIM]; 208 | 209 | // df/d_out for this pixel 210 | const float3 v_out = v_output[pix_id]; 211 | const float v_out_alpha = v_output_alpha[pix_id]; 212 | 213 | // collect and process batches of gaussians 214 | // each thread loads one gaussian at a time before rasterizing 215 | const int tr = block.thread_rank(); 216 | cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); 217 | const int warp_bin_final = cg::reduce(warp, bin_final, cg::greater()); 218 | for (int b = 0; b < num_batches; ++b) { 219 | // resync all threads before writing next batch of shared mem 220 | block.sync(); 221 | 222 | // each thread fetch 1 gaussian from back to front 223 | // 0 index will be furthest back in batch 224 | // index of gaussian to load 225 | // batch end is the index of the last gaussian in the batch 226 | const int batch_end = range.y - 1 - block_size * b; 227 | int batch_size = min(block_size, batch_end + 1 - range.x); 228 | const int idx = batch_end - tr; 229 | if (idx >= range.x) { 230 | int32_t g_id = gaussian_ids_sorted[idx]; 231 | id_batch[tr] = g_id; 232 | const float2 xy = xys[g_id]; 233 | const float opac = opacities[g_id]; 234 | xy_opacity_batch[tr] = {xy.x, xy.y, opac}; 235 | conic_batch[tr] = conics[g_id]; 236 | rgbs_batch[tr] = rgbs[g_id]; 237 | } 238 | // wait for other threads to collect the gaussians in batch 239 | block.sync(); 240 | // process gaussians in the current batch for this pixel 241 | // 0 index is the furthest back gaussian in the batch 242 | for (int t = max(0, batch_end - warp_bin_final); t < batch_size; ++t) { 243 | int valid = inside; 244 | if (batch_end - t > bin_final) { 245 | valid = 0; 246 | } 247 | float alpha; 248 | float opac; 249 | float2 delta; 250 | float3 conic; 251 | float vis; 252 | if (valid) { 253 | conic = conic_batch[t]; 254 | float3 xy_opac = xy_opacity_batch[t]; 255 | opac = xy_opac.z; 256 | delta = {xy_opac.x - px, xy_opac.y - py}; 257 | float sigma = 0.5f * (conic.x * delta.x * delta.x + 258 | conic.z * delta.y * delta.y) + 259 | conic.y * delta.x * delta.y; 260 | vis = __expf(-sigma); 261 | alpha = min(0.99f, opac * vis); 262 | if (sigma < 0.f || alpha < 1.f / 255.f) { 263 | valid = 0; 264 | } 265 | } 266 | // if all threads are inactive in this warp, skip this loop 267 | if (!warp.any(valid)) { 268 | continue; 269 | } 270 | float3 v_rgb_local = {0.f, 0.f, 0.f}; 271 | float3 v_conic_local = {0.f, 0.f, 0.f}; 272 | float2 v_xy_local = {0.f, 0.f}; 273 | float2 v_xy_abs_local = {0.f, 0.f}; 274 | float v_opacity_local = 0.f; 275 | // initialize everything to 0, only set if the lane is valid 276 | if (valid) { 277 | // compute the current T for this gaussian 278 | float ra = 1.f / (1.f - alpha); 279 | T *= ra; 280 | // update v_rgb for this gaussian 281 | const float fac = alpha * T; 282 | float v_alpha = 0.f; 283 | v_rgb_local = {fac * v_out.x, fac * v_out.y, fac * v_out.z}; 284 | 285 | const float3 rgb = rgbs_batch[t]; 286 | // contribution from this pixel 287 | v_alpha += (rgb.x * T - buffer.x * ra) * v_out.x; 288 | v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y; 289 | v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z; 290 | 291 | v_alpha += T_final * ra * v_out_alpha; 292 | // contribution from background pixel 293 | v_alpha += -T_final * ra * background.x * v_out.x; 294 | v_alpha += -T_final * ra * background.y * v_out.y; 295 | v_alpha += -T_final * ra * background.z * v_out.z; 296 | // update the running sum 297 | buffer.x += rgb.x * fac; 298 | buffer.y += rgb.y * fac; 299 | buffer.z += rgb.z * fac; 300 | 301 | const float v_sigma = -opac * vis * v_alpha; 302 | v_conic_local = { 303 | 0.5f * v_sigma * delta.x * delta.x, 304 | v_sigma * delta.x * delta.y, 305 | 0.5f * v_sigma * delta.y * delta.y 306 | }; 307 | v_xy_local = { 308 | v_sigma * (conic.x * delta.x + conic.y * delta.y), 309 | v_sigma * (conic.y * delta.x + conic.z * delta.y) 310 | }; 311 | v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)}; 312 | v_opacity_local = vis * v_alpha; 313 | } 314 | warpSum3(v_rgb_local, warp); 315 | warpSum3(v_conic_local, warp); 316 | warpSum2(v_xy_local, warp); 317 | warpSum2(v_xy_abs_local, warp); 318 | warpSum(v_opacity_local, warp); 319 | if (warp.thread_rank() == 0) { 320 | int32_t g = id_batch[t]; 321 | float *v_rgb_ptr = (float *)(v_rgb); 322 | atomicAdd(v_rgb_ptr + 3 * g + 0, v_rgb_local.x); 323 | atomicAdd(v_rgb_ptr + 3 * g + 1, v_rgb_local.y); 324 | atomicAdd(v_rgb_ptr + 3 * g + 2, v_rgb_local.z); 325 | 326 | float *v_conic_ptr = (float *)(v_conic); 327 | atomicAdd(v_conic_ptr + 3 * g + 0, v_conic_local.x); 328 | atomicAdd(v_conic_ptr + 3 * g + 1, v_conic_local.y); 329 | atomicAdd(v_conic_ptr + 3 * g + 2, v_conic_local.z); 330 | 331 | float *v_xy_ptr = (float *)(v_xy); 332 | atomicAdd(v_xy_ptr + 2 * g + 0, v_xy_local.x); 333 | atomicAdd(v_xy_ptr + 2 * g + 1, v_xy_local.y); 334 | 335 | float *v_xy_abs_ptr = (float *)(v_xy_abs); 336 | atomicAdd(v_xy_abs_ptr + 2 * g + 0, v_xy_abs_local.x); 337 | atomicAdd(v_xy_abs_ptr + 2 * g + 1, v_xy_abs_local.y); 338 | 339 | atomicAdd(v_opacity + g, v_opacity_local); 340 | } 341 | } 342 | } 343 | } 344 | 345 | __global__ void project_gaussians_bwd( 346 | const int num_points, 347 | const float3 *__restrict__ means3d, 348 | const float3 *__restrict__ scales, 349 | const float glob_scale, 350 | const float4 *__restrict__ quats, 351 | const float *__restrict__ viewmat, 352 | const float4 intrins, 353 | const dim3 img_size, 354 | const float *__restrict__ cov3d, 355 | const int *__restrict__ radii, 356 | const float3 *__restrict__ conics, 357 | const float *__restrict__ compensation, 358 | const float2 *__restrict__ v_xy, 359 | const float *__restrict__ v_depth, 360 | const float3 *__restrict__ v_conic, 361 | const float *__restrict__ v_compensation, 362 | float3 *__restrict__ v_cov2d, 363 | float *__restrict__ v_cov3d, 364 | float3 *__restrict__ v_mean3d, 365 | float3 *__restrict__ v_scale, 366 | float4 *__restrict__ v_quat 367 | ) { 368 | unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid 369 | if (idx >= num_points || radii[idx] <= 0) { 370 | return; 371 | } 372 | float3 p_world = means3d[idx]; 373 | float fx = intrins.x; 374 | float fy = intrins.y; 375 | float3 p_view = helpers::transform_4x3(viewmat, p_world); 376 | // get v_mean3d from v_xy 377 | v_mean3d[idx] = helpers::transform_4x3_rot_only_transposed( 378 | viewmat, 379 | helpers::project_pix_vjp({fx, fy}, p_view, v_xy[idx]) 380 | ); 381 | 382 | // get z gradient contribution to mean3d gradient 383 | // z = viemwat[8] * mean3d.x + viewmat[9] * mean3d.y + viewmat[10] * 384 | // mean3d.z + viewmat[11] 385 | float v_z = v_depth[idx]; 386 | v_mean3d[idx].x += viewmat[8] * v_z; 387 | v_mean3d[idx].y += viewmat[9] * v_z; 388 | v_mean3d[idx].z += viewmat[10] * v_z; 389 | 390 | // get v_cov2d 391 | helpers::cov2d_to_conic_vjp(conics[idx], v_conic[idx], v_cov2d[idx]); 392 | helpers::cov2d_to_compensation_vjp( 393 | compensation[idx], 394 | conics[idx], 395 | v_compensation[idx], 396 | v_cov2d[idx] 397 | ); 398 | // get v_cov3d (and v_mean3d contribution) 399 | project_cov3d_ewa_vjp( 400 | p_world, 401 | &(cov3d[6 * idx]), 402 | viewmat, 403 | fx, 404 | fy, 405 | v_cov2d[idx], 406 | v_mean3d[idx], 407 | &(v_cov3d[6 * idx]) 408 | ); 409 | // get v_scale and v_quat 410 | scale_rot_to_cov3d_vjp( 411 | scales[idx], 412 | glob_scale, 413 | quats[idx], 414 | &(v_cov3d[6 * idx]), 415 | v_scale[idx], 416 | v_quat[idx] 417 | ); 418 | } 419 | 420 | // output space: 2D covariance, input space: cov3d 421 | __device__ void project_cov3d_ewa_vjp( 422 | const float3 &__restrict__ mean3d, 423 | const float *__restrict__ cov3d, 424 | const float *__restrict__ viewmat, 425 | const float fx, 426 | const float fy, 427 | const float3 &__restrict__ v_cov2d, 428 | float3 &__restrict__ v_mean3d, 429 | float *__restrict__ v_cov3d 430 | ) { 431 | // viewmat is row major, glm is column major 432 | // upper 3x3 submatrix 433 | mat3 W = { 434 | viewmat[0], 435 | viewmat[4], 436 | viewmat[8], 437 | viewmat[1], 438 | viewmat[5], 439 | viewmat[9], 440 | viewmat[2], 441 | viewmat[6], 442 | viewmat[10] 443 | }; 444 | 445 | float3 p = {viewmat[3], viewmat[7], viewmat[11]}; 446 | float3 t = { 447 | W[0][0] * mean3d.x + W[1][0] * mean3d.y + W[2][0] * mean3d.z + p.x, 448 | W[0][1] * mean3d.x + W[1][1] * mean3d.y + W[2][1] * mean3d.z + p.y, 449 | W[0][2] * mean3d.x + W[1][2] * mean3d.y + W[2][2] * mean3d.z + p.z, 450 | }; 451 | float rz = 1.f / t.z; 452 | float rz2 = rz * rz; 453 | 454 | // column major 455 | // we only care about the top 2x2 submatrix 456 | mat3 J = { 457 | fx * rz, 458 | 0.f, 459 | 0.f, 460 | 0.f, 461 | fy * rz, 462 | 0.f, 463 | -fx * t.x * rz2, 464 | -fy * t.y * rz2, 465 | 0.f 466 | }; 467 | mat3 V = { 468 | cov3d[0], 469 | cov3d[1], 470 | cov3d[2], 471 | cov3d[1], 472 | cov3d[3], 473 | cov3d[4], 474 | cov3d[2], 475 | cov3d[4], 476 | cov3d[5] 477 | }; 478 | // cov = T * V * Tt; G = df/dcov = v_cov 479 | // -> d/dV = Tt * G * T 480 | // -> df/dT = G * T * Vt + Gt * T * V 481 | mat3 v_cov = { 482 | v_cov2d.x, 483 | 0.5f * v_cov2d.y, 484 | 0.f, 485 | 0.5f * v_cov2d.y, 486 | v_cov2d.z, 487 | 0.f, 488 | 0.f, 489 | 0.f, 490 | 0.f 491 | }; 492 | 493 | mat3 T = J * W; 494 | mat3 Tt = T.transpose(); 495 | mat3 Vt = V.transpose(); 496 | mat3 v_V = Tt * v_cov * T; 497 | mat3 v_T = v_cov * T * Vt + v_cov.transpose() * T * V; 498 | 499 | // vjp of cov3d parameters 500 | // v_cov3d_i = v_V : dV/d_cov3d_i 501 | // where : is frobenius inner product 502 | v_cov3d[0] = v_V[0][0]; 503 | v_cov3d[1] = v_V[0][1] + v_V[1][0]; 504 | v_cov3d[2] = v_V[0][2] + v_V[2][0]; 505 | v_cov3d[3] = v_V[1][1]; 506 | v_cov3d[4] = v_V[1][2] + v_V[2][1]; 507 | v_cov3d[5] = v_V[2][2]; 508 | 509 | // compute df/d_mean3d 510 | // T = J * W 511 | mat3 v_J = v_T * W.transpose(); 512 | float rz3 = rz2 * rz; 513 | float3 v_t = { 514 | -fx * rz2 * v_J[2][0], 515 | -fy * rz2 * v_J[2][1], 516 | -fx * rz2 * v_J[0][0] + 2.f * fx * t.x * rz3 * v_J[2][0] - 517 | fy * rz2 * v_J[1][1] + 2.f * fy * t.y * rz3 * v_J[2][1] 518 | }; 519 | // printf("v_t %.2f %.2f %.2f\n", v_t[0], v_t[1], v_t[2]); 520 | // printf("W %.2f %.2f %.2f\n", W[0][0], W[0][1], W[0][2]); 521 | v_mean3d.x += v_t.x * W[0][0] + v_t.y * W[0][1] + v_t.z * W[0][2]; 522 | v_mean3d.y += v_t.x * W[1][0] + v_t.y * W[1][1] + v_t.z * W[1][2]; 523 | v_mean3d.z += v_t.x * W[2][0] + v_t.y * W[2][1] + v_t.z * W[2][2]; 524 | } 525 | 526 | // given cotangent v in output space (e.g. d_L/d_cov3d) in R(6) 527 | // compute vJp for scale and rotation 528 | __device__ void scale_rot_to_cov3d_vjp( 529 | const float3 scale, 530 | const float glob_scale, 531 | const float4 quat, 532 | const float *__restrict__ v_cov3d, 533 | float3 &__restrict__ v_scale, 534 | float4 &__restrict__ v_quat 535 | ) { 536 | // cov3d is upper triangular elements of matrix 537 | // off-diagonal elements count grads from both ij and ji elements, 538 | // must halve when expanding back into symmetric matrix 539 | mat3 v_V = { 540 | v_cov3d[0], 541 | 0.5f * v_cov3d[1], 542 | 0.5f * v_cov3d[2], 543 | 0.5f * v_cov3d[1], 544 | v_cov3d[3], 545 | 0.5f * v_cov3d[4], 546 | 0.5f * v_cov3d[2], 547 | 0.5f * v_cov3d[4], 548 | v_cov3d[5] 549 | }; 550 | mat3 R = helpers::quat_to_rotmat(quat); 551 | mat3 S = helpers::scale_to_mat(scale, glob_scale); 552 | mat3 M = R * S; 553 | // https://math.stackexchange.com/a/3850121 554 | // for D = W * X, G = df/dD 555 | // df/dW = G * XT, df/dX = WT * G 556 | mat3 v_M = v_V * M * 2.f; 557 | // glm::mat3 v_S = glm::transpose(R) * v_M; 558 | v_scale.x = 559 | (R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]) * 560 | glob_scale; 561 | v_scale.y = 562 | (R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]) * 563 | glob_scale; 564 | v_scale.z = 565 | (R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]) * 566 | glob_scale; 567 | 568 | mat3 v_R = v_M * S; 569 | v_quat = helpers::quat_to_rotmat_vjp(quat, v_R); 570 | } 571 | 572 | } // namespace kernels 573 | -------------------------------------------------------------------------------- /lib/ops.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "ffi.h" 3 | #include "kernels/kernels.h" 4 | #include "ops.h" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | void ops::project::fwd::xla( 13 | cudaStream_t stream, 14 | void **buffers, 15 | const char *opaque, 16 | std::size_t opaque_len 17 | ) { 18 | cudaError_t cuda_err; 19 | const auto &d = *unpack_descriptor(opaque, opaque_len); 20 | 21 | const auto tensors = unpack_tensors(stream, d, buffers); 22 | cudaStreamSynchronize(stream); 23 | 24 | kernels:: 25 | project_gaussians_fwd<<>>( 26 | // in 27 | d.num_points, 28 | tensors.in.mean3ds, 29 | tensors.in.scales, 30 | d.glob_scale, 31 | tensors.in.quats, 32 | tensors.in.viewmat, 33 | d.intrins, 34 | d.img_shape, 35 | d.grid_dim_2d, 36 | d.block_width, 37 | d.clip_thresh, 38 | 39 | // out 40 | tensors.out.cov3ds, 41 | tensors.out.xys, 42 | tensors.out.depths, 43 | tensors.out.radii, 44 | tensors.out.conics, 45 | tensors.out.compensation, 46 | tensors.out.num_tiles_hit 47 | ); 48 | cuda_err = cudaGetLastError(); 49 | CUDA_THROW_IF_ERR(cuda_err); 50 | cudaStreamSynchronize(stream); 51 | 52 | cumsum( 53 | stream, 54 | tensors.out.num_tiles_hit, 55 | tensors.out.cum_tiles_hit, 56 | d.num_points 57 | ); 58 | cuda_err = cudaGetLastError(); 59 | CUDA_THROW_IF_ERR(cuda_err); 60 | cudaStreamSynchronize(stream); 61 | } 62 | 63 | void ops::project::bwd::xla( 64 | cudaStream_t stream, 65 | void **buffers, 66 | const char *opaque, 67 | std::size_t opaque_len 68 | ) { 69 | cudaError_t cuda_err; 70 | const auto &d = *unpack_descriptor(opaque, opaque_len); 71 | 72 | const auto &tensors = unpack_tensors(stream, d, buffers); 73 | cudaStreamSynchronize(stream); 74 | 75 | kernels:: 76 | project_gaussians_bwd<<>>( 77 | // in 78 | d.num_points, 79 | tensors.in.mean3ds, 80 | tensors.in.scales, 81 | d.glob_scale, 82 | tensors.in.quats, 83 | tensors.in.viewmat, 84 | d.intrins, 85 | d.img_shape, 86 | tensors.in.cov3ds, 87 | tensors.in.radii, 88 | tensors.in.conics, 89 | tensors.in.compensation, 90 | tensors.in.v_xy, 91 | tensors.in.v_depth, 92 | tensors.in.v_conic, 93 | tensors.in.v_compensation, 94 | 95 | // out 96 | tensors.out.v_cov2d, 97 | tensors.out.v_cov3d, 98 | tensors.out.v_mean3d, 99 | tensors.out.v_scale, 100 | tensors.out.v_quat 101 | ); 102 | cuda_err = cudaGetLastError(); 103 | CUDA_THROW_IF_ERR(cuda_err); 104 | cudaStreamSynchronize(stream); 105 | } 106 | 107 | void ops::rasterize::fwd::xla( 108 | cudaStream_t stream, 109 | void **buffers, 110 | const char *opaque, 111 | std::size_t opaque_len 112 | ) { 113 | cudaError_t cuda_err; 114 | const auto &d = *unpack_descriptor(opaque, opaque_len); 115 | 116 | const auto tensors = unpack_tensors(stream, d, buffers); 117 | cudaStreamSynchronize(stream); 118 | 119 | if (tensors.num_intersects == 0) { 120 | kernels::tiled_memset<<>>( 121 | reinterpret_cast(tensors.out.out_img), 122 | d.img_shape.x * d.img_shape.y * 3, 123 | reinterpret_cast(tensors.in.background), 124 | 3 125 | ); 126 | return; 127 | } 128 | 129 | sort_and_bin( 130 | stream, 131 | d, 132 | tensors.num_intersects, 133 | tensors.in.xys, 134 | tensors.in.depths, 135 | tensors.in.radii, 136 | tensors.in.cum_tiles_hit, 137 | tensors.gaussian_ids_sorted, 138 | tensors.tile_bins 139 | ); 140 | cuda_err = cudaGetLastError(); 141 | CUDA_THROW_IF_ERR(cuda_err); 142 | cudaStreamSynchronize(stream); 143 | 144 | kernels::rasterize_fwd<<>>( 145 | // in 146 | d.grid_dim_2d, 147 | d.img_shape, 148 | tensors.gaussian_ids_sorted, 149 | tensors.tile_bins, 150 | tensors.in.xys, 151 | tensors.in.conics, 152 | tensors.in.colors, 153 | tensors.in.opacities, 154 | 155 | // out 156 | tensors.out.final_Ts, 157 | tensors.out.final_idx, 158 | tensors.out.out_img, 159 | *tensors.in.background 160 | ); 161 | cuda_err = cudaGetLastError(); 162 | CUDA_THROW_IF_ERR(cuda_err); 163 | cudaStreamSynchronize(stream); 164 | 165 | cuda_err = cudaFreeAsync(tensors.gaussian_ids_sorted, stream); 166 | CUDA_THROW_IF_ERR(cuda_err); 167 | cuda_err = cudaFreeAsync(tensors.tile_bins, stream); 168 | CUDA_THROW_IF_ERR(cuda_err); 169 | cudaStreamSynchronize(stream); 170 | }; 171 | 172 | void ops::rasterize::bwd::xla( 173 | cudaStream_t stream, 174 | void **buffers, 175 | const char *opaque, 176 | std::size_t opaque_len 177 | ) { 178 | cudaError_t cuda_err; 179 | const auto &d = *unpack_descriptor(opaque, opaque_len); 180 | 181 | const auto tensors = unpack_tensors(stream, d, buffers); 182 | cudaStreamSynchronize(stream); 183 | 184 | if (tensors.num_intersects == 0) { 185 | return; 186 | } 187 | 188 | sort_and_bin( 189 | stream, 190 | d, 191 | tensors.num_intersects, 192 | tensors.in.xys, 193 | tensors.in.depths, 194 | tensors.in.radii, 195 | tensors.in.cum_tiles_hit, 196 | tensors.gaussian_ids_sorted, 197 | tensors.tile_bins 198 | ); 199 | cuda_err = cudaGetLastError(); 200 | CUDA_THROW_IF_ERR(cuda_err); 201 | cudaStreamSynchronize(stream); 202 | 203 | kernels::rasterize_bwd<<>>( 204 | // in 205 | d.grid_dim_2d, 206 | d.img_shape, 207 | tensors.gaussian_ids_sorted, 208 | tensors.tile_bins, 209 | tensors.in.xys, 210 | tensors.in.conics, 211 | tensors.in.colors, 212 | tensors.in.opacities, 213 | *tensors.in.background, 214 | tensors.in.final_Ts, 215 | tensors.in.final_idx, 216 | tensors.in.v_out_img, 217 | tensors.in.v_out_img_alpha, 218 | 219 | // out 220 | tensors.out.v_xy, 221 | tensors.out.v_xy_abs, 222 | tensors.out.v_conic, 223 | tensors.out.v_colors, 224 | tensors.out.v_opacity 225 | ); 226 | cuda_err = cudaGetLastError(); 227 | CUDA_THROW_IF_ERR(cuda_err); 228 | cudaStreamSynchronize(stream); 229 | 230 | cuda_err = cudaFreeAsync(tensors.gaussian_ids_sorted, stream); 231 | CUDA_THROW_IF_ERR(cuda_err); 232 | cuda_err = cudaFreeAsync(tensors.tile_bins, stream); 233 | CUDA_THROW_IF_ERR(cuda_err); 234 | cudaStreamSynchronize(stream); 235 | }; 236 | 237 | ops::project::fwd::Tensors ops::project::fwd::unpack_tensors( 238 | cudaStream_t stream, 239 | const Descriptor &d, 240 | void **buffers 241 | ) { 242 | Tensors tensors; 243 | cudaError_t cuda_err; 244 | std::size_t idx = 0; 245 | 246 | tensors.in.mean3ds = static_cast(buffers[idx++]); 247 | tensors.in.scales = static_cast(buffers[idx++]); 248 | tensors.in.quats = static_cast(buffers[idx++]); 249 | tensors.in.viewmat = static_cast(buffers[idx++]); 250 | 251 | tensors.out.cov3ds = static_cast(buffers[idx++]); 252 | tensors.out.xys = static_cast(buffers[idx++]); 253 | tensors.out.depths = static_cast(buffers[idx++]); 254 | tensors.out.radii = static_cast(buffers[idx++]); 255 | tensors.out.conics = static_cast(buffers[idx++]); 256 | tensors.out.compensation = static_cast(buffers[idx++]); 257 | tensors.out.num_tiles_hit = static_cast(buffers[idx++]); 258 | tensors.out.cum_tiles_hit = static_cast(buffers[idx++]); 259 | 260 | cuda_err = cudaMemsetAsync( 261 | tensors.out.cov3ds, 262 | 0, 263 | sizeof(*tensors.out.cov3ds) * d.num_points * 6, 264 | stream 265 | ); 266 | CUDA_THROW_IF_ERR(cuda_err); 267 | 268 | cuda_err = cudaMemsetAsync( 269 | tensors.out.xys, 270 | 0, 271 | sizeof(*tensors.out.xys) * d.num_points, 272 | stream 273 | ); 274 | CUDA_THROW_IF_ERR(cuda_err); 275 | 276 | cuda_err = cudaMemsetAsync( 277 | tensors.out.depths, 278 | 0, 279 | sizeof(*tensors.out.depths) * d.num_points, 280 | stream 281 | ); 282 | CUDA_THROW_IF_ERR(cuda_err); 283 | 284 | cuda_err = cudaMemsetAsync( 285 | tensors.out.radii, 286 | 0, 287 | sizeof(*tensors.out.radii) * d.num_points, 288 | stream 289 | ); 290 | CUDA_THROW_IF_ERR(cuda_err); 291 | 292 | cuda_err = cudaMemsetAsync( 293 | tensors.out.conics, 294 | 0, 295 | sizeof(*tensors.out.conics) * d.num_points, 296 | stream 297 | ); 298 | CUDA_THROW_IF_ERR(cuda_err); 299 | 300 | cuda_err = cudaMemsetAsync( 301 | tensors.out.compensation, 302 | 0, 303 | sizeof(*tensors.out.compensation) * d.num_points, 304 | stream 305 | ); 306 | CUDA_THROW_IF_ERR(cuda_err); 307 | 308 | cuda_err = cudaMemsetAsync( 309 | tensors.out.num_tiles_hit, 310 | 0, 311 | sizeof(*tensors.out.num_tiles_hit) * d.num_points, 312 | stream 313 | ); 314 | CUDA_THROW_IF_ERR(cuda_err); 315 | 316 | cuda_err = cudaMemsetAsync( 317 | tensors.out.cum_tiles_hit, 318 | 0, 319 | sizeof(*tensors.out.cum_tiles_hit) * d.num_points, 320 | stream 321 | ); 322 | CUDA_THROW_IF_ERR(cuda_err); 323 | 324 | return tensors; 325 | } 326 | 327 | ops::project::bwd::Tensors ops::project::bwd::unpack_tensors( 328 | cudaStream_t stream, 329 | const Descriptor &d, 330 | void **buffers 331 | ) { 332 | Tensors tensors; 333 | cudaError_t cuda_err; 334 | std::size_t idx = 0; 335 | 336 | tensors.in.mean3ds = static_cast(buffers[idx++]); 337 | tensors.in.scales = static_cast(buffers[idx++]); 338 | tensors.in.quats = static_cast(buffers[idx++]); 339 | tensors.in.viewmat = static_cast(buffers[idx++]); 340 | tensors.in.cov3ds = static_cast(buffers[idx++]); 341 | tensors.in.xys = static_cast(buffers[idx++]); 342 | tensors.in.radii = static_cast(buffers[idx++]); 343 | tensors.in.conics = static_cast(buffers[idx++]); 344 | tensors.in.compensation = static_cast(buffers[idx++]); 345 | tensors.in.v_compensation = static_cast(buffers[idx++]); 346 | tensors.in.v_xy = static_cast(buffers[idx++]); 347 | tensors.in.v_depth = static_cast(buffers[idx++]); 348 | tensors.in.v_conic = static_cast(buffers[idx++]); 349 | 350 | tensors.out.v_mean3d = static_cast(buffers[idx++]); 351 | tensors.out.v_scale = static_cast(buffers[idx++]); 352 | tensors.out.v_quat = static_cast(buffers[idx++]); 353 | tensors.out.v_cov2d = static_cast(buffers[idx++]); 354 | tensors.out.v_cov3d = static_cast(buffers[idx++]); 355 | 356 | cuda_err = cudaMemsetAsync( 357 | tensors.out.v_cov2d, 358 | 0, 359 | sizeof(*tensors.out.v_cov2d) * d.num_points, 360 | stream 361 | ); 362 | CUDA_THROW_IF_ERR(cuda_err); 363 | 364 | cuda_err = cudaMemsetAsync( 365 | tensors.out.v_cov3d, 366 | 0, 367 | sizeof(*tensors.out.v_cov3d) * d.num_points * 6, 368 | stream 369 | ); 370 | CUDA_THROW_IF_ERR(cuda_err); 371 | 372 | cuda_err = cudaMemsetAsync( 373 | tensors.out.v_mean3d, 374 | 0, 375 | sizeof(*tensors.out.v_mean3d) * d.num_points, 376 | stream 377 | ); 378 | CUDA_THROW_IF_ERR(cuda_err); 379 | 380 | cuda_err = cudaMemsetAsync( 381 | tensors.out.v_scale, 382 | 0, 383 | sizeof(*tensors.out.v_scale) * d.num_points, 384 | stream 385 | ); 386 | CUDA_THROW_IF_ERR(cuda_err); 387 | 388 | cuda_err = cudaMemsetAsync( 389 | tensors.out.v_quat, 390 | 0, 391 | sizeof(*tensors.out.v_quat) * d.num_points, 392 | stream 393 | ); 394 | CUDA_THROW_IF_ERR(cuda_err); 395 | 396 | return tensors; 397 | } 398 | 399 | ops::rasterize::fwd::Tensors ops::rasterize::fwd::unpack_tensors( 400 | cudaStream_t stream, 401 | const Descriptor &d, 402 | void **buffers 403 | ) { 404 | Tensors tensors; 405 | cudaError_t cuda_err; 406 | std::size_t idx = 0; 407 | 408 | tensors.in.colors = static_cast(buffers[idx++]); 409 | tensors.in.opacities = static_cast(buffers[idx++]); 410 | tensors.in.background = static_cast(buffers[idx++]); 411 | tensors.in.xys = static_cast(buffers[idx++]); 412 | tensors.in.depths = static_cast(buffers[idx++]); 413 | tensors.in.radii = static_cast(buffers[idx++]); 414 | tensors.in.conics = static_cast(buffers[idx++]); 415 | tensors.in.cum_tiles_hit = static_cast(buffers[idx++]); 416 | 417 | tensors.out.final_Ts = static_cast(buffers[idx++]); 418 | tensors.out.final_idx = static_cast(buffers[idx++]); 419 | tensors.out.out_img = static_cast(buffers[idx++]); 420 | 421 | tensors.gaussian_ids_sorted = nullptr; 422 | tensors.tile_bins = nullptr; 423 | 424 | tensors.num_intersects = 0; 425 | cuda_err = cudaMemcpyAsync( 426 | &tensors.num_intersects, 427 | tensors.in.cum_tiles_hit + d.num_points - 1, 428 | sizeof(tensors.num_intersects), 429 | cudaMemcpyKind::cudaMemcpyDeviceToHost, 430 | stream 431 | ); 432 | CUDA_THROW_IF_ERR(cuda_err); 433 | 434 | const auto img_size = d.img_shape.x * d.img_shape.y; 435 | 436 | cuda_err = cudaMemsetAsync( 437 | tensors.out.final_Ts, 438 | 0, 439 | sizeof(*tensors.out.final_Ts) * img_size, 440 | stream 441 | ); 442 | CUDA_THROW_IF_ERR(cuda_err); 443 | 444 | cuda_err = cudaMemsetAsync( 445 | tensors.out.final_idx, 446 | 0, 447 | sizeof(*tensors.out.final_idx) * img_size, 448 | stream 449 | ); 450 | CUDA_THROW_IF_ERR(cuda_err); 451 | 452 | cuda_err = cudaMemsetAsync( 453 | tensors.out.out_img, 454 | 0, 455 | sizeof(*tensors.out.out_img) * img_size, 456 | stream 457 | ); 458 | CUDA_THROW_IF_ERR(cuda_err); 459 | 460 | cuda_err = cudaMallocAsync( 461 | &tensors.gaussian_ids_sorted, 462 | sizeof(*tensors.gaussian_ids_sorted) * tensors.num_intersects, 463 | stream 464 | ); 465 | CUDA_THROW_IF_ERR(cuda_err); 466 | 467 | cuda_err = cudaMemsetAsync( 468 | tensors.gaussian_ids_sorted, 469 | 0, 470 | sizeof(*tensors.gaussian_ids_sorted) * tensors.num_intersects, 471 | stream 472 | ); 473 | CUDA_THROW_IF_ERR(cuda_err); 474 | 475 | cuda_err = cudaMallocAsync( 476 | &tensors.tile_bins, 477 | sizeof(*tensors.tile_bins) * d.grid_dim_2d.x * d.grid_dim_2d.y, 478 | stream 479 | ); 480 | CUDA_THROW_IF_ERR(cuda_err); 481 | 482 | cuda_err = cudaMemsetAsync( 483 | tensors.tile_bins, 484 | 0, 485 | sizeof(*tensors.tile_bins) * d.grid_dim_2d.x * d.grid_dim_2d.y, 486 | stream 487 | ); 488 | CUDA_THROW_IF_ERR(cuda_err); 489 | 490 | return tensors; 491 | } 492 | 493 | ops::rasterize::bwd::Tensors ops::rasterize::bwd::unpack_tensors( 494 | cudaStream_t stream, 495 | const Descriptor &d, 496 | void **buffers 497 | ) { 498 | Tensors tensors; 499 | cudaError_t cuda_err; 500 | std::size_t idx = 0; 501 | 502 | tensors.in.colors = static_cast(buffers[idx++]); 503 | tensors.in.opacities = static_cast(buffers[idx++]); 504 | tensors.in.background = static_cast(buffers[idx++]); 505 | tensors.in.xys = static_cast(buffers[idx++]); 506 | tensors.in.depths = static_cast(buffers[idx++]); 507 | tensors.in.radii = static_cast(buffers[idx++]); 508 | tensors.in.conics = static_cast(buffers[idx++]); 509 | tensors.in.cum_tiles_hit = static_cast(buffers[idx++]); 510 | tensors.in.final_Ts = static_cast(buffers[idx++]); 511 | tensors.in.final_idx = static_cast(buffers[idx++]); 512 | tensors.in.v_out_img = static_cast(buffers[idx++]); 513 | tensors.in.v_out_img_alpha = static_cast(buffers[idx++]); 514 | 515 | tensors.out.v_colors = static_cast(buffers[idx++]); 516 | tensors.out.v_opacity = static_cast(buffers[idx++]); 517 | tensors.out.v_xy = static_cast(buffers[idx++]); 518 | tensors.out.v_xy_abs = static_cast(buffers[idx++]); 519 | tensors.out.v_conic = static_cast(buffers[idx++]); 520 | 521 | tensors.gaussian_ids_sorted = nullptr; 522 | tensors.tile_bins = nullptr; 523 | 524 | tensors.num_intersects = 0; 525 | cuda_err = cudaMemcpyAsync( 526 | &tensors.num_intersects, 527 | tensors.in.cum_tiles_hit + d.num_points - 1, 528 | sizeof(tensors.num_intersects), 529 | cudaMemcpyKind::cudaMemcpyDeviceToHost, 530 | stream 531 | ); 532 | CUDA_THROW_IF_ERR(cuda_err); 533 | 534 | cuda_err = cudaMemsetAsync( 535 | tensors.out.v_xy, 536 | 0, 537 | sizeof(*tensors.out.v_xy) * d.num_points, 538 | stream 539 | ); 540 | CUDA_THROW_IF_ERR(cuda_err); 541 | 542 | cuda_err = cudaMemsetAsync( 543 | tensors.out.v_xy_abs, 544 | 0, 545 | sizeof(*tensors.out.v_xy_abs) * d.num_points, 546 | stream 547 | ); 548 | CUDA_THROW_IF_ERR(cuda_err); 549 | 550 | cuda_err = cudaMemsetAsync( 551 | tensors.out.v_conic, 552 | 0, 553 | sizeof(*tensors.out.v_conic) * d.num_points, 554 | stream 555 | ); 556 | CUDA_THROW_IF_ERR(cuda_err); 557 | 558 | cuda_err = cudaMemsetAsync( 559 | tensors.out.v_colors, 560 | 0, 561 | sizeof(*tensors.out.v_colors) * d.num_points, 562 | stream 563 | ); 564 | CUDA_THROW_IF_ERR(cuda_err); 565 | 566 | cuda_err = cudaMemsetAsync( 567 | tensors.out.v_opacity, 568 | 0, 569 | sizeof(*tensors.out.v_opacity) * d.num_points, 570 | stream 571 | ); 572 | CUDA_THROW_IF_ERR(cuda_err); 573 | 574 | cuda_err = cudaMallocAsync( 575 | &tensors.gaussian_ids_sorted, 576 | sizeof(*tensors.gaussian_ids_sorted) * tensors.num_intersects, 577 | stream 578 | ); 579 | CUDA_THROW_IF_ERR(cuda_err); 580 | 581 | cuda_err = cudaMemsetAsync( 582 | tensors.gaussian_ids_sorted, 583 | 0, 584 | sizeof(*tensors.gaussian_ids_sorted) * tensors.num_intersects, 585 | stream 586 | ); 587 | CUDA_THROW_IF_ERR(cuda_err); 588 | 589 | cuda_err = cudaMallocAsync( 590 | &tensors.tile_bins, 591 | sizeof(*tensors.tile_bins) * d.grid_dim_2d.x * d.grid_dim_2d.y, 592 | stream 593 | ); 594 | CUDA_THROW_IF_ERR(cuda_err); 595 | 596 | cuda_err = cudaMemsetAsync( 597 | tensors.tile_bins, 598 | 0, 599 | sizeof(*tensors.tile_bins) * d.grid_dim_2d.x * d.grid_dim_2d.y, 600 | stream 601 | ); 602 | CUDA_THROW_IF_ERR(cuda_err); 603 | 604 | return tensors; 605 | } 606 | 607 | void ops::cumsum( 608 | cudaStream_t stream, 609 | const int32_t *input, 610 | int32_t *output, 611 | const int num_items 612 | ) { 613 | cudaError_t cuda_err; 614 | 615 | void *sum_ws = nullptr; 616 | size_t sum_ws_bytes; 617 | 618 | cuda_err = cub::DeviceScan::InclusiveSum( 619 | sum_ws, 620 | sum_ws_bytes, 621 | input, 622 | output, 623 | num_items, 624 | stream 625 | ); 626 | CUDA_THROW_IF_ERR(cuda_err); 627 | 628 | cuda_err = cudaMalloc(&sum_ws, sum_ws_bytes); 629 | CUDA_THROW_IF_ERR(cuda_err); 630 | 631 | cuda_err = cub::DeviceScan::InclusiveSum( 632 | sum_ws, 633 | sum_ws_bytes, 634 | input, 635 | output, 636 | num_items, 637 | stream 638 | ); 639 | CUDA_THROW_IF_ERR(cuda_err); 640 | 641 | cuda_err = cudaFree(sum_ws); 642 | CUDA_THROW_IF_ERR(cuda_err); 643 | } 644 | 645 | void ops::sort_and_bin( 646 | cudaStream_t stream, 647 | const Descriptor &d, 648 | unsigned num_intersects, 649 | const float2 *xys, 650 | const float *depths, 651 | const int *radii, 652 | const int *cum_tiles_hit, 653 | int *gaussian_ids_sorted, 654 | int2 *tile_bins 655 | ) { 656 | cudaError_t cuda_err; 657 | 658 | std::int32_t *gaussian_ids_unsorted; 659 | std::int64_t *isect_ids_unsorted; 660 | std::int64_t *isect_ids_sorted; 661 | 662 | cuda_err = cudaMalloc( 663 | &gaussian_ids_unsorted, 664 | num_intersects * sizeof(*gaussian_ids_unsorted) 665 | ); 666 | CUDA_THROW_IF_ERR(cuda_err); 667 | 668 | cuda_err = cudaMemsetAsync( 669 | gaussian_ids_unsorted, 670 | 0, 671 | num_intersects * sizeof(*gaussian_ids_unsorted), 672 | stream 673 | ); 674 | CUDA_THROW_IF_ERR(cuda_err); 675 | 676 | cuda_err = cudaMalloc( 677 | &isect_ids_unsorted, 678 | num_intersects * sizeof(*isect_ids_unsorted) 679 | ); 680 | CUDA_THROW_IF_ERR(cuda_err); 681 | 682 | cuda_err = cudaMemsetAsync( 683 | isect_ids_unsorted, 684 | 0, 685 | num_intersects * sizeof(*isect_ids_unsorted), 686 | stream 687 | ); 688 | CUDA_THROW_IF_ERR(cuda_err); 689 | 690 | cuda_err = cudaMalloc( 691 | &isect_ids_sorted, 692 | num_intersects * sizeof(*isect_ids_sorted) 693 | ); 694 | CUDA_THROW_IF_ERR(cuda_err); 695 | 696 | cuda_err = cudaMemsetAsync( 697 | isect_ids_sorted, 698 | 0, 699 | num_intersects * sizeof(*isect_ids_sorted), 700 | stream 701 | ); 702 | CUDA_THROW_IF_ERR(cuda_err); 703 | 704 | kernels::map_gaussian_to_intersects<<< 705 | d.grid_dim_1d, 706 | d.block_dim_1d, 707 | 0, 708 | stream>>>( 709 | d.num_points, 710 | xys, 711 | depths, 712 | radii, 713 | cum_tiles_hit, 714 | d.grid_dim_2d, 715 | d.block_width, 716 | isect_ids_unsorted, 717 | gaussian_ids_unsorted 718 | ); 719 | cuda_err = cudaGetLastError(); 720 | CUDA_THROW_IF_ERR(cuda_err); 721 | 722 | // sort intersections by ascending tile ID and depth with RadixSort 723 | int32_t max_tile_id = (int32_t)(d.grid_dim_2d.x * d.grid_dim_2d.y); 724 | int msb = 32 - __builtin_clz(max_tile_id) + 1; 725 | // allocate workspace memory 726 | void *sort_ws = nullptr; 727 | size_t sort_ws_bytes; 728 | cuda_err = cub::DeviceRadixSort::SortPairs( 729 | sort_ws, 730 | sort_ws_bytes, 731 | isect_ids_unsorted, 732 | isect_ids_sorted, 733 | gaussian_ids_unsorted, 734 | gaussian_ids_sorted, 735 | num_intersects, 736 | 0, 737 | 32 + msb, 738 | stream 739 | ); 740 | CUDA_THROW_IF_ERR(cuda_err); 741 | 742 | cuda_err = cudaMalloc(&sort_ws, sort_ws_bytes); 743 | CUDA_THROW_IF_ERR(cuda_err); 744 | 745 | cuda_err = cub::DeviceRadixSort::SortPairs( 746 | sort_ws, 747 | sort_ws_bytes, 748 | isect_ids_unsorted, 749 | isect_ids_sorted, 750 | gaussian_ids_unsorted, 751 | gaussian_ids_sorted, 752 | num_intersects, 753 | 0, 754 | 32 + msb, 755 | stream 756 | ); 757 | CUDA_THROW_IF_ERR(cuda_err); 758 | 759 | cuda_err = cudaFree(sort_ws); 760 | CUDA_THROW_IF_ERR(cuda_err); 761 | 762 | // printf( 763 | // "%d %d\n", 764 | // (d.num_intersects + d.block_dim_1d - 1) / d.block_dim_1d, 765 | // d.block_dim_1d 766 | // ); 767 | kernels::get_tile_bin_edges<<< 768 | (num_intersects + d.block_dim_1d - 1) / d.block_dim_1d, 769 | d.block_dim_1d, 770 | 0, 771 | stream>>>(num_intersects, isect_ids_sorted, tile_bins); 772 | cuda_err = cudaGetLastError(); 773 | CUDA_THROW_IF_ERR(cuda_err); 774 | 775 | // int test[4]; 776 | // cudaMemcpy(test, tile_bins, sizeof(int) * 4, cudaMemcpyDefault); 777 | // printf("tile_bins_test %d %d %d %d\n", test[0], test[1], test[2], 778 | // test[3]); 779 | 780 | cudaStreamSynchronize(stream); 781 | 782 | // free intermediate work spaces 783 | cuda_err = cudaFree(isect_ids_unsorted); 784 | CUDA_THROW_IF_ERR(cuda_err); 785 | 786 | cuda_err = cudaFree(isect_ids_sorted); 787 | CUDA_THROW_IF_ERR(cuda_err); 788 | 789 | cuda_err = cudaFree(gaussian_ids_unsorted); 790 | CUDA_THROW_IF_ERR(cuda_err); 791 | 792 | cudaStreamSynchronize(stream); 793 | } 794 | --------------------------------------------------------------------------------