├── .gitattributes ├── pyproject.toml ├── deps ├── jax-tcnn │ ├── src │ │ └── jaxtcnn │ │ │ ├── __init__.py │ │ │ └── hashgrid_tcnn │ │ │ ├── __init__.py │ │ │ ├── impl.py │ │ │ ├── lowering.py │ │ │ └── abstract.py │ ├── pyproject.toml │ ├── CMakeLists.txt │ ├── lib │ │ ├── impl │ │ │ ├── tcnnutils.h │ │ │ └── hashgrid.cu │ │ └── ffi.cc │ ├── default.nix │ └── setup.py ├── volume-rendering-jax │ ├── pyproject.toml │ ├── src │ │ └── volrendjax │ │ │ ├── morton3d │ │ │ ├── __init__.py │ │ │ ├── impl.py │ │ │ ├── abstract.py │ │ │ └── lowering.py │ │ │ ├── __init__.py │ │ │ ├── packbits │ │ │ ├── impl.py │ │ │ ├── __init__.py │ │ │ ├── abstract.py │ │ │ └── lowering.py │ │ │ ├── marching │ │ │ ├── impl.py │ │ │ ├── abstract.py │ │ │ ├── __init__.py │ │ │ └── lowering.py │ │ │ └── integrating │ │ │ ├── abstract.py │ │ │ ├── impl.py │ │ │ ├── __init__.py │ │ │ └── lowering.py │ ├── CMakeLists.txt │ ├── default.nix │ ├── lib │ │ ├── impl │ │ │ ├── packbits.cu │ │ │ └── volrend.h │ │ └── ffi.cc │ └── setup.py ├── spherical-harmonics-encoding-jax │ ├── pyproject.toml │ ├── lib │ │ ├── impl │ │ │ └── spherical_harmonics_encoding.h │ │ └── ffi.cc │ ├── CMakeLists.txt │ ├── default.nix │ ├── setup.py │ └── src │ │ └── shjax │ │ └── __init__.py ├── serde-helper │ ├── default.nix │ └── serde.h ├── colmap-locked │ └── default.nix ├── pyimgui │ └── default.nix ├── nvfetcher.toml ├── tyro │ └── default.nix ├── dearpygui │ └── default.nix ├── pycolmap │ ├── default.nix │ └── expose-bundle-adjustment-function.patch ├── default.nix ├── _sources │ ├── generated.nix │ └── generated.json └── tiny-cuda-nn │ └── default.nix ├── models ├── renderers │ └── __init__.py └── imagefit.py ├── CITATION.cff ├── utils ├── _constants.py ├── sfm.py ├── __main__.py └── args.py ├── .gitignore ├── app ├── nerf │ ├── __main__.py │ ├── _utils.py │ └── test.py └── imagefit.py ├── flake.lock ├── flake.nix └── LICENSE /.gitattributes: -------------------------------------------------------------------------------- 1 | deps/** linguist-vendored=false 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pyright] 2 | include = [ 3 | "app", 4 | "models", 5 | "utils", 6 | ] 7 | -------------------------------------------------------------------------------- /deps/jax-tcnn/src/jaxtcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .hashgrid_tcnn import HashGridMetadata, hashgrid_encode 2 | 3 | 4 | __all__ = [ 5 | "HashGridMetadata", 6 | "hashgrid_encode" 7 | ] 8 | -------------------------------------------------------------------------------- /models/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cuda import render_image_inference, render_rays_train 2 | 3 | 4 | __all__ = [ 5 | "render_rays_train", 6 | "render_image_inference", 7 | ] 8 | -------------------------------------------------------------------------------- /deps/jax-tcnn/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "cmake", 4 | "ninja", 5 | "pybind11>=2.6", 6 | "setuptools>=42", 7 | "setuptools_scm[toml]>=3.4", 8 | "wheel", 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "cmake", 4 | "ninja", 5 | "pybind11>=2.6", 6 | "setuptools>=42", 7 | "setuptools_scm[toml]>=3.4", 8 | "wheel", 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | -------------------------------------------------------------------------------- /deps/spherical-harmonics-encoding-jax/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "cmake", 4 | "ninja", 5 | "pybind11>=2.6", 6 | "setuptools>=42", 7 | "setuptools_scm[toml]>=3.4", 8 | "wheel", 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/morton3d/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from . import impl 4 | 5 | 6 | def morton3d(xyzs: jax.Array): 7 | return impl.morton3d_p.bind(xyzs) 8 | 9 | def morton3d_invert(idcs: jax.Array): 10 | return impl.morton3d_invert_p.bind(idcs) 11 | -------------------------------------------------------------------------------- /deps/serde-helper/default.nix: -------------------------------------------------------------------------------- 1 | { version, stdenvNoCC }: stdenvNoCC.mkDerivation { 2 | pname = "serde-helper"; 3 | inherit version; 4 | src = ./serde.h; 5 | 6 | dontUnpack = true; 7 | 8 | installPhase = '' 9 | install -Dvm644 $src $out/include/serde-helper/serde.h 10 | ''; 11 | } 12 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Zhang" 5 | given-names: "Gaoyang" 6 | - family-names: "Chen" 7 | given-names: "Yingxi" 8 | title: "jaxngp" 9 | date-released: 2023-05-31 10 | url: "https://github.com/blurgyy/jaxngp" 11 | -------------------------------------------------------------------------------- /utils/_constants.py: -------------------------------------------------------------------------------- 1 | from colorama import Fore, Style 2 | 3 | 4 | _tqdm_format = "SBRIGHT{desc}RESET: HI{percentage:3.0f}%RESET {n_fmt}/{total_fmt} [{elapsed} 4 | #include 5 | 6 | namespace shjax { 7 | 8 | struct SphericalHarmonicsEncodingDescriptor { 9 | std::uint32_t n; 10 | std::uint32_t degree; 11 | }; 12 | 13 | void spherical_harmonics_encoding_cuda_f32( 14 | cudaStream_t stream, 15 | void **buffers, 16 | const char *opaque, 17 | std::size_t opaque_len 18 | ); 19 | 20 | } // namespace shjax 21 | -------------------------------------------------------------------------------- /deps/colmap-locked/default.nix: -------------------------------------------------------------------------------- 1 | { source, lib, config, colmap 2 | , flann 3 | , metis 4 | }: 5 | 6 | let 7 | cudaSupport = config.cudaSupport or false; 8 | in 9 | 10 | colmap.overrideAttrs (o: { 11 | pname = if cudaSupport 12 | then source.pname + "-cuda" 13 | else source.pname; 14 | inherit (source) version src; 15 | buildInputs = o.buildInputs ++ [ 16 | flann 17 | metis 18 | ]; 19 | cmakeFlags = o.cmakeFlags ++ (lib.optional 20 | cudaSupport 21 | "-DCMAKE_CUDA_ARCHITECTURES=all-major" 22 | ); 23 | }) 24 | -------------------------------------------------------------------------------- /deps/pyimgui/default.nix: -------------------------------------------------------------------------------- 1 | { source, lib, buildPythonPackage 2 | 3 | , setuptools-scm 4 | 5 | , click 6 | , cython 7 | , glfw 8 | , pyopengl 9 | , wheel 10 | }: 11 | 12 | buildPythonPackage { 13 | inherit (source) pname version src; 14 | format = "pyproject"; 15 | 16 | nativeBuildInputs = [ setuptools-scm ]; 17 | buildInputs = [ wheel ]; 18 | propagatedBuildInputs = [ click cython glfw pyopengl ]; 19 | 20 | pythonImportsCheck = [ "imgui" ]; 21 | 22 | meta = { 23 | homepage = "https://github.com/pyimgui/pyimgui"; 24 | description = "Cython-based Python bindings for dear imgui"; 25 | license = lib.licenses.bsd3; 26 | }; 27 | } 28 | -------------------------------------------------------------------------------- /deps/nvfetcher.toml: -------------------------------------------------------------------------------- 1 | [colmap-locked] 2 | src.github = "colmap/colmap" 3 | fetch.github = "colmap/colmap" 4 | git.fetchSubmodules = true 5 | 6 | [dearpygui] 7 | src.github = "hoffstadt/DearPyGui" 8 | fetch.github = "hoffstadt/DearPyGui" 9 | git.fetchSubmodules = true 10 | 11 | [pycolmap] 12 | src.github = "colmap/pycolmap" 13 | fetch.github = "colmap/pycolmap" 14 | git.fetchSubmodules = true 15 | 16 | [pyimgui] 17 | src.pypi = "imgui" 18 | fetch.pypi = "imgui" 19 | 20 | [tiny-cuda-nn] 21 | src.github = "NVlabs/tiny-cuda-nn" 22 | fetch.github = "NVlabs/tiny-cuda-nn" 23 | git.fetchSubmodules = true 24 | 25 | [tyro] 26 | src.pypi = "tyro" 27 | fetch.pypi = "tyro" 28 | -------------------------------------------------------------------------------- /deps/spherical-harmonics-encoding-jax/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.23) 2 | project(spherical_harmonics_encoding_jax LANGUAGES CXX CUDA) 3 | set(CMAKE_CUDA_ARCHITECTURES "all") 4 | 5 | message(STATUS "Using CMake version " ${CMAKE_VERSION}) 6 | 7 | find_package(Python COMPONENTS Interpreter Development REQUIRED) 8 | find_package(pybind11 CONFIG REQUIRED) 9 | find_package(fmt REQUIRED) 10 | 11 | include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) 12 | 13 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 14 | pybind11_add_module( 15 | cudaops 16 | ${CMAKE_CURRENT_LIST_DIR}/lib/impl/spherical_harmonics_encoding.cu 17 | ${CMAKE_CURRENT_LIST_DIR}/lib/ffi.cc 18 | ) 19 | target_link_libraries(cudaops PRIVATE fmt::fmt) 20 | 21 | install(TARGETS cudaops DESTINATION shjax) 22 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/packbits/impl.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | from jax.interpreters import mlir, xla 5 | from jax.lib import xla_client 6 | 7 | from . import abstract, lowering 8 | from .. import volrendutils_cuda 9 | 10 | 11 | # register GPU XLA custom calls 12 | for name, value in volrendutils_cuda.get_packbits_registrations().items(): 13 | xla_client.register_custom_call_target(name, value, platform="gpu") 14 | 15 | packbits_p = jax.core.Primitive("packbits🎱") 16 | packbits_p.multiple_results = True 17 | packbits_p.def_impl(functools.partial(xla.apply_primitive, packbits_p)) 18 | packbits_p.def_abstract_eval(abstract.pack_density_into_bits_abstract) 19 | 20 | mlir.register_lowering( 21 | prim=packbits_p, 22 | rule=lowering.packbits_lowering_rule, 23 | platform="gpu", 24 | ) 25 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.23) 2 | project(volume_rendering_jax LANGUAGES CXX CUDA) 3 | set(CMAKE_CUDA_ARCHITECTURES "all") 4 | 5 | message(STATUS "Using CMake version " ${CMAKE_VERSION}) 6 | 7 | find_package(Python COMPONENTS Interpreter Development REQUIRED) 8 | find_package(pybind11 CONFIG REQUIRED) 9 | find_package(fmt REQUIRED) 10 | 11 | include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) 12 | 13 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 14 | pybind11_add_module( 15 | volrendutils_cuda 16 | ${CMAKE_CURRENT_LIST_DIR}/lib/impl/packbits.cu 17 | ${CMAKE_CURRENT_LIST_DIR}/lib/impl/marching.cu 18 | ${CMAKE_CURRENT_LIST_DIR}/lib/impl/integrating.cu 19 | ${CMAKE_CURRENT_LIST_DIR}/lib/ffi.cc 20 | ) 21 | target_link_libraries(volrendutils_cuda PRIVATE fmt::fmt) 22 | 23 | install(TARGETS volrendutils_cuda DESTINATION volrendjax) 24 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/packbits/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | from . import impl 6 | 7 | 8 | def packbits( 9 | density_threshold: float, 10 | density_grid: jax.Array, 11 | ) -> Tuple[jax.Array, jax.Array]: 12 | """ 13 | Pack the given `density_grid` into a compact representation of type uint8, where each bit is 14 | high if its corresponding density grid cell's density is larger than `density_threshold`, low 15 | otherwise. 16 | 17 | Inputs: 18 | density_threshold `broadcastable to [N]` 19 | density_grid `[N]` 20 | 21 | Returns: 22 | occ_mask `[N] bool`: boolean mask that indicates whether this grid is occupied 23 | occ_bitfield `[N//8]` 24 | """ 25 | return impl.packbits_p.bind( 26 | jnp.broadcast_to(density_threshold, density_grid.shape), 27 | density_grid, 28 | ) 29 | -------------------------------------------------------------------------------- /deps/tyro/default.nix: -------------------------------------------------------------------------------- 1 | # with import {}; 2 | 3 | { source, lib, buildPythonPackage 4 | 5 | , poetry-core 6 | , backports-cached-property 7 | 8 | , colorama 9 | , frozendict 10 | , pyyaml 11 | , typing-extensions 12 | 13 | , docstring-parser 14 | , rich 15 | , shtab 16 | }: 17 | 18 | buildPythonPackage { 19 | inherit (source) pname version src; 20 | format = "pyproject"; 21 | 22 | nativeBuildInputs = [ 23 | poetry-core 24 | backports-cached-property 25 | ]; 26 | 27 | buildInputs = [ 28 | colorama 29 | frozendict 30 | typing-extensions 31 | ]; 32 | 33 | propagatedBuildInputs = [ 34 | docstring-parser 35 | pyyaml 36 | rich 37 | shtab 38 | ]; 39 | 40 | pythonImportsCheck = [ "tyro" ]; 41 | 42 | meta = { 43 | homepage = "https://github.com/brentyi/tyro"; 44 | description = "Strongly typed, zero-effort CLI interfaces & config objects"; 45 | license = lib.licenses.mit; 46 | }; 47 | } 48 | -------------------------------------------------------------------------------- /deps/dearpygui/default.nix: -------------------------------------------------------------------------------- 1 | { source, lib, buildPythonPackage 2 | 3 | , cmake 4 | , setuptools-scm 5 | 6 | , libglvnd 7 | , libxcrypt 8 | , xorg 9 | 10 | , pillow 11 | }: 12 | 13 | buildPythonPackage { 14 | inherit (source) pname version src; 15 | 16 | nativeBuildInputs = [ 17 | cmake 18 | setuptools-scm 19 | ]; 20 | dontUseCmakeConfigure = true; 21 | 22 | preBuild = '' 23 | export MAKEFLAGS="''${MAKEFLAGS:+''${MAKEFLAGS} }-j$NIX_BUILD_CORES" 24 | ''; 25 | 26 | buildInputs = with xorg; [ 27 | libX11 28 | libXcursor 29 | libXi 30 | libXinerama 31 | libXrandr 32 | ] ++ [ 33 | libglvnd 34 | libxcrypt 35 | ]; 36 | 37 | propagatedBuildInputs = [ pillow ]; 38 | 39 | doCheck = false; 40 | 41 | pythonImportsCheck = [ 42 | "dearpygui" 43 | "dearpygui.dearpygui" 44 | ]; 45 | 46 | meta = { 47 | homepage = "https://github.com/hoffstadt/DearPyGui"; 48 | description = "A fast and powerful Graphical User Interface Toolkit for Python with minimal dependencies"; 49 | license = lib.licenses.mit; 50 | }; 51 | } 52 | -------------------------------------------------------------------------------- /deps/jax-tcnn/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.23) 2 | project(volume_rendering_jax LANGUAGES CXX CUDA) 3 | # use `cmake -DCMAKE_CUDA_ARCHITECTURES=61;62;75` to build for compute capabilities 61, 62, and 75 4 | # set(CMAKE_CUDA_ARCHITECTURES "all") 5 | message(STATUS "Enabled CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") 6 | 7 | message(STATUS "Using CMake version " ${CMAKE_VERSION}) 8 | 9 | find_package(Python COMPONENTS Interpreter Development REQUIRED) 10 | find_package(pybind11 CONFIG REQUIRED) 11 | find_package(fmt REQUIRED) 12 | 13 | include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) 14 | 15 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 16 | pybind11_add_module( 17 | tcnnutils 18 | ${CMAKE_CURRENT_LIST_DIR}/lib/impl/hashgrid.cu 19 | ${CMAKE_CURRENT_LIST_DIR}/lib/ffi.cc 20 | ) 21 | 22 | # e.g. `cmake -DTCNN_MIN_GPU_ARCH=61` 23 | message(STATUS "TCNN_MIN_GPU_ARCH=${TCNN_MIN_GPU_ARCH}") 24 | target_compile_definitions(tcnnutils PUBLIC -DTCNN_MIN_GPU_ARCH=${TCNN_MIN_GPU_ARCH}) 25 | 26 | target_link_libraries(tcnnutils PRIVATE tiny-cuda-nn fmt::fmt) 27 | 28 | install(TARGETS tcnnutils DESTINATION jaxtcnn) 29 | -------------------------------------------------------------------------------- /deps/jax-tcnn/lib/impl/tcnnutils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #ifdef __CUDACC__ 5 | #include 6 | #endif 7 | 8 | 9 | namespace jaxtcnn { 10 | 11 | struct HashGridDescriptor { 12 | // number of input coordinates to be encoded 13 | std::uint32_t n_coords; 14 | 15 | // number of levels (2), "n_levels" in tcnn 16 | std::uint32_t L; 17 | 18 | // number of features that each level should output (2), "n_features_per_level" in tcnn 19 | std::uint32_t F; 20 | 21 | // coarsest resolution (16), "base_resolution" in tcnn 22 | std::uint32_t N_min; 23 | 24 | // scale factor between consecutive levels 25 | // float const per_level_scale() const { 26 | // return std::exp2f(std::log2f(this->N_max) - std::log2f(this->N_min)) / (this->L - 1); 27 | // } 28 | float per_level_scale; 29 | }; 30 | 31 | void hashgrid_encode( 32 | cudaStream_t stream, 33 | void **buffers, 34 | const char *opaque, 35 | std::size_t opaque_len 36 | ); 37 | 38 | void hashgrid_encode_backward( 39 | cudaStream_t stream, 40 | void **buffers, 41 | const char *opaque, 42 | std::size_t opaque_len 43 | ); 44 | 45 | } // namespace jaxtcnn 46 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/morton3d/impl.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | from jax.interpreters import mlir, xla 5 | from jax.lib import xla_client 6 | 7 | from . import abstract, lowering 8 | from .. import volrendutils_cuda 9 | 10 | 11 | # register GPU XLA custom calls 12 | for name, value in volrendutils_cuda.get_morton3d_registrations().items(): 13 | xla_client.register_custom_call_target(name, value, platform="gpu") 14 | 15 | morton3d_p = jax.core.Primitive("morton3d⚡") 16 | morton3d_p.multiple_results = False 17 | morton3d_p.def_impl(functools.partial(xla.apply_primitive, morton3d_p)) 18 | morton3d_p.def_abstract_eval(abstract.morton3d_abstract) 19 | 20 | morton3d_invert_p = jax.core.Primitive("morton3d⚡invert") 21 | morton3d_invert_p.multiple_results = False 22 | morton3d_invert_p.def_impl(functools.partial(xla.apply_primitive, morton3d_invert_p)) 23 | morton3d_invert_p.def_abstract_eval(abstract.morton3d_invert_abstract) 24 | 25 | # register mlir lowering rules 26 | mlir.register_lowering( 27 | prim=morton3d_p, 28 | rule=lowering.morton3d_lowering_rule, 29 | platform="gpu", 30 | ) 31 | mlir.register_lowering( 32 | prim=morton3d_invert_p, 33 | rule=lowering.morton3d_invert_lowering_rule, 34 | platform="gpu", 35 | ) 36 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/morton3d/abstract.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | # jit rules 6 | def morton3d_abstract( 7 | # input array 8 | xyzs: jax.ShapedArray, 9 | ): 10 | length, _ = xyzs.shape 11 | 12 | dtype = jax.dtypes.canonicalize_dtype(xyzs.dtype) 13 | if dtype != jnp.uint32: 14 | raise NotImplementedError( 15 | "morton3d is only implemented for input coordinates of type `jnp.uint32`, got {}".format( 16 | dtype, 17 | ) 18 | ) 19 | 20 | out_shapes = { 21 | "idcs": (length,), 22 | } 23 | 24 | return jax.ShapedArray(shape=out_shapes["idcs"], dtype=jnp.uint32) 25 | 26 | 27 | def morton3d_invert_abstract( 28 | # input array 29 | idcs: jax.ShapedArray, 30 | ): 31 | length, = idcs.shape 32 | 33 | dtype = jax.dtypes.canonicalize_dtype(idcs.dtype) 34 | if dtype != jnp.uint32: 35 | raise NotImplementedError( 36 | "morton3d_invert is only implemented for input indices of type `jnp.uint32`, got {}".format( 37 | dtype, 38 | ) 39 | ) 40 | 41 | out_shapes = { 42 | "xyzs": (length, 3), 43 | } 44 | 45 | return jax.ShapedArray(shape=out_shapes["xyzs"], dtype=jnp.uint32) 46 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/packbits/abstract.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | # jit rules 7 | def pack_density_into_bits_abstract( 8 | # input array 9 | density_threshold: jax.ShapedArray, 10 | density_grid: jax.ShapedArray, 11 | ): 12 | chex.assert_rank([density_threshold, density_grid], 1) 13 | chex.assert_shape(density_threshold, density_grid.shape) 14 | n_bits = density_grid.shape[0] 15 | if n_bits % 8 != 0: 16 | raise ValueError( 17 | "pack_density_into_bits expects size of density grid to be divisible by 8, got {}".format( 18 | n_bits, 19 | ) 20 | ) 21 | n_bytes = n_bits // 8 22 | 23 | dtype = jax.dtypes.canonicalize_dtype(density_grid.dtype) 24 | if dtype != jnp.float32: 25 | raise NotImplementedError( 26 | "pack_density_into_bits is only implemented for densities of `jnp.float32` type, got {}".format( 27 | dtype, 28 | ) 29 | ) 30 | 31 | out_shapes = { 32 | "occupied_mask": (n_bits,), 33 | "occupancy_bitfield": (n_bytes,), 34 | } 35 | return ( 36 | jax.ShapedArray(out_shapes["occupied_mask"], jnp.bool_), 37 | jax.ShapedArray(out_shapes["occupancy_bitfield"], jnp.uint8), 38 | ) 39 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/marching/impl.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | from jax.interpreters import mlir, xla 5 | from jax.lib import xla_client 6 | 7 | from . import abstract, lowering 8 | from .. import volrendutils_cuda 9 | 10 | 11 | # register GPU XLA custom calls 12 | for name, value in volrendutils_cuda.get_marching_registrations().items(): 13 | xla_client.register_custom_call_target(name, value, platform="gpu") 14 | 15 | march_rays_p = jax.core.Primitive("march_rays🗧") 16 | march_rays_p.multiple_results = True 17 | march_rays_p.def_impl(functools.partial(xla.apply_primitive, march_rays_p)) 18 | march_rays_p.def_abstract_eval(abstract.march_rays_abstract) 19 | 20 | march_rays_inference_p = jax.core.Primitive("march_rays_inference🗧") 21 | march_rays_inference_p.multiple_results = True 22 | march_rays_inference_p.def_impl(functools.partial(xla.apply_primitive, march_rays_inference_p)) 23 | march_rays_inference_p.def_abstract_eval(abstract.march_rays_inference_abstract) 24 | 25 | # register mlir lowering rules 26 | mlir.register_lowering( 27 | prim=march_rays_p, 28 | rule=lowering.march_rays_lowering_rule, 29 | platform="gpu", 30 | ) 31 | mlir.register_lowering( 32 | prim=march_rays_inference_p, 33 | rule=lowering.march_rays_inference_lowering_rule, 34 | platform="gpu", 35 | ) 36 | -------------------------------------------------------------------------------- /deps/pycolmap/default.nix: -------------------------------------------------------------------------------- 1 | { source, config, lib, buildPythonPackage 2 | 3 | , cmake 4 | , setuptools-scm 5 | 6 | , cudatoolkit 7 | , boost17x 8 | , ceres-solver 9 | , colmap-locked 10 | , eigen 11 | , flann 12 | , freeimage 13 | , libGLU 14 | , metis 15 | , glew 16 | , qt5 17 | }: 18 | 19 | let 20 | cudaSupport = config.cudaSupport or false; 21 | in 22 | 23 | buildPythonPackage { 24 | pname = if cudaSupport 25 | then source.pname + "-cuda" 26 | else source.pname; 27 | inherit (source) version src; 28 | 29 | patches = [ 30 | ./expose-bundle-adjustment-function.patch 31 | ]; 32 | 33 | nativeBuildInputs = [ 34 | cmake 35 | setuptools-scm 36 | qt5.wrapQtAppsHook 37 | ]; 38 | dontUseCmakeConfigure = true; 39 | cmakeFlags = [ 40 | "-DCUDA_ENABLED=ON" 41 | "-DCUDA_NVCC_FLAGS=--std=c++14" 42 | ]; 43 | 44 | buildInputs = [ 45 | colmap-locked 46 | cudatoolkit 47 | boost17x 48 | ceres-solver 49 | eigen 50 | flann 51 | freeimage 52 | libGLU 53 | metis 54 | glew 55 | qt5.qtbase 56 | ]; 57 | 58 | preBuild = '' 59 | export MAKEFLAGS="''${MAKEFLAGS:+''${MAKEFLAGS} }-j$NIX_BUILD_CORES" 60 | ''; 61 | 62 | meta = { 63 | homepage = "https://github.com/colmap/pycolmap"; 64 | description = "Python bindings for COLMAP"; 65 | license = lib.licenses.bsd3; 66 | }; 67 | } 68 | -------------------------------------------------------------------------------- /deps/serde-helper/serde.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #undef __noinline__ // REF: 4 | #include 5 | #include 6 | #include 7 | #include 8 | #define __noinline__ noinline 9 | 10 | #include 11 | 12 | template 13 | std::enable_if_t< 14 | sizeof(To) == sizeof(From) && 15 | std::is_trivially_copyable_v && 16 | std::is_trivially_copyable_v, 17 | To> 18 | // constexpr support needs compiler magic 19 | bit_cast(const From& src) noexcept 20 | { 21 | static_assert(std::is_trivially_constructible_v, 22 | "This implementation additionally requires " 23 | "destination type to be trivially constructible"); 24 | 25 | To dst; 26 | std::memcpy(&dst, &src, sizeof(To)); 27 | return dst; 28 | } 29 | 30 | template 31 | std::string serialize(T const &descriptor) { 32 | return std::string{bit_cast(&descriptor), sizeof(T)}; 33 | } 34 | 35 | template 36 | T const *deserialize(char const *opaque, std::size_t opaque_len) { 37 | if (opaque_len != sizeof(T)) { 38 | throw std::runtime_error(fmt::format("deserialize: Invalid opaque object size, expected {}, got {}", sizeof(T), opaque_len)); 39 | } 40 | return bit_cast(opaque); 41 | } 42 | -------------------------------------------------------------------------------- /app/nerf/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Annotated 4 | from typing_extensions import assert_never 5 | 6 | import tyro 7 | 8 | from utils.args import NeRFTrainingArgs,NeRFTestingArgs,NeRFGUIArgs 9 | from utils import common 10 | 11 | 12 | 13 | CmdTrain = Annotated[ 14 | NeRFTrainingArgs, 15 | tyro.conf.subcommand( 16 | name="train", 17 | prefix_name=False, 18 | ), 19 | ] 20 | CmdTest = Annotated[ 21 | NeRFTestingArgs, 22 | tyro.conf.subcommand( 23 | name="test", 24 | prefix_name=False, 25 | ), 26 | ] 27 | CmdGui = Annotated[ 28 | NeRFGUIArgs, 29 | tyro.conf.subcommand( 30 | name="gui", 31 | prefix_name=False, 32 | ), 33 | ] 34 | 35 | 36 | MainArgsType = CmdTrain | CmdTest | CmdGui 37 | 38 | 39 | def main(args: MainArgsType): 40 | logger = common.setup_logging("nerf") 41 | KEY = common.set_deterministic(args.common.seed) 42 | 43 | if isinstance(args, NeRFTrainingArgs): 44 | from app.nerf.train import train 45 | return train(KEY, args, logger) 46 | elif isinstance(args, NeRFTestingArgs): 47 | from app.nerf.test import test 48 | return test(KEY, args, logger) 49 | elif isinstance(args, NeRFGUIArgs): 50 | from app.nerf.gui import GuiWindow 51 | return GuiWindow(KEY, args, logger) 52 | else: 53 | assert_never(args) 54 | 55 | 56 | if __name__ == "__main__": 57 | args = tyro.cli(MainArgsType) 58 | exit(main(args)) 59 | -------------------------------------------------------------------------------- /deps/default.nix: -------------------------------------------------------------------------------- 1 | let 2 | filterAttrs = predicate: attrs: with builtins; listToAttrs (filter 3 | (v: v != null) 4 | (attrValues (mapAttrs 5 | (name: value: if (predicate name value) then { inherit name value; } else null) 6 | attrs) 7 | ) 8 | ); 9 | mapPackage = basedir: fn: with builtins; 10 | mapAttrs (name: _: fn name) 11 | (filterAttrs 12 | (name: type: type == "directory" && name != "_sources") 13 | (readDir basedir)); 14 | # Compute capability, used for building tiny-cuda-nn 15 | # NOTE: removing unused compute capabilities will build faster, GPUs and their compute 16 | # capabilities can be found at: 17 | # 18 | # All the compute capabilities since `5.0`. REF: 19 | cudaCapabilities = [ 20 | # /nix/store/hsbbv8a72hwjrka8igd7hk66skvi03rp-cudatoolkit-11.7.0-unsplit/bin/nvcc --list-gpu-arch 21 | "3.5" 22 | "3.7" 23 | "5.0" 24 | "5.2" 25 | "5.3" 26 | "6.0" 27 | "6.1" 28 | "6.2" 29 | "7.0" 30 | "7.2" 31 | "7.5" 32 | "8.0" 33 | "8.6" 34 | "8.7" 35 | ## the two compute capabilities below require newer nvcc (this environment uses CUDA 11.7) 36 | # "8.9" 37 | # "9.0" 38 | ]; 39 | in { 40 | inherit filterAttrs; 41 | packages = pkgs: mapPackage ./. (name: pkgs.${name}); 42 | overlay = final: prev: mapPackage ./. (name: let 43 | generated = final.callPackage ./_sources/generated.nix {}; 44 | package = import ./${name}; 45 | args = with builtins; intersectAttrs (functionArgs package) { 46 | inherit generated cudaCapabilities; 47 | version = "0.1.0"; 48 | source = generated.${name}; 49 | buildSharedLib = false; 50 | }; 51 | in 52 | final.python3.pkgs.callPackage package args 53 | ); 54 | } 55 | -------------------------------------------------------------------------------- /deps/_sources/generated.nix: -------------------------------------------------------------------------------- 1 | # This file was generated by nvfetcher, please do not modify it manually. 2 | { fetchgit, fetchurl, fetchFromGitHub, dockerTools }: 3 | { 4 | colmap-locked = { 5 | pname = "colmap-locked"; 6 | version = "3.8"; 7 | src = fetchFromGitHub ({ 8 | owner = "colmap"; 9 | repo = "colmap"; 10 | rev = "3.8"; 11 | fetchSubmodules = true; 12 | sha256 = "sha256-ArWQBRuWRkRXnNs154pxTgrGcZyMH6doG/R89LC/0Ms="; 13 | }); 14 | }; 15 | dearpygui = { 16 | pname = "dearpygui"; 17 | version = "v1.9.1"; 18 | src = fetchFromGitHub ({ 19 | owner = "hoffstadt"; 20 | repo = "DearPyGui"; 21 | rev = "v1.9.1"; 22 | fetchSubmodules = true; 23 | sha256 = "sha256-Af1jhQYT0CYNFMWihAtP6jRNYKm3XKEu3brFOPSGCnk="; 24 | }); 25 | }; 26 | pycolmap = { 27 | pname = "pycolmap"; 28 | version = "v0.4.0"; 29 | src = fetchFromGitHub ({ 30 | owner = "colmap"; 31 | repo = "pycolmap"; 32 | rev = "v0.4.0"; 33 | fetchSubmodules = true; 34 | sha256 = "sha256-W3d+uHZXkH1/QlER1HV8t1MOBOrHIXYsVeYv1zbsbW4="; 35 | }); 36 | }; 37 | pyimgui = { 38 | pname = "pyimgui"; 39 | version = "2.0.0"; 40 | src = fetchurl { 41 | url = "https://pypi.io/packages/source/i/imgui/imgui-2.0.0.tar.gz"; 42 | sha256 = "sha256-L7247tO429fqmK+eTBxlgrC8TalColjeFjM9jGU9Z+E="; 43 | }; 44 | }; 45 | tiny-cuda-nn = { 46 | pname = "tiny-cuda-nn"; 47 | version = "v1.6"; 48 | src = fetchFromGitHub ({ 49 | owner = "NVlabs"; 50 | repo = "tiny-cuda-nn"; 51 | rev = "v1.6"; 52 | fetchSubmodules = true; 53 | sha256 = "sha256-qW6Fk2GB71fvZSsfu+mykabSxEKvaikZ/pQQZUycOy0="; 54 | }); 55 | }; 56 | tyro = { 57 | pname = "tyro"; 58 | version = "0.5.3"; 59 | src = fetchurl { 60 | url = "https://pypi.io/packages/source/t/tyro/tyro-0.5.3.tar.gz"; 61 | sha256 = "sha256-ygdNkRr4bjDDHioXoMWPZ1c0IamIkqGyvAuvJx3Bhis="; 62 | }; 63 | }; 64 | } 65 | -------------------------------------------------------------------------------- /deps/spherical-harmonics-encoding-jax/lib/ffi.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "impl/spherical_harmonics_encoding.h" 9 | 10 | namespace shjax { 11 | 12 | template 13 | pybind11::bytes to_pybind11_bytes(T const &descriptor) { 14 | return pybind11::bytes(serialize(descriptor)); 15 | } 16 | 17 | template 18 | pybind11::capsule encapsulate_function(T *fn) { 19 | return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); 20 | } 21 | 22 | // expose gpu function 23 | namespace { 24 | 25 | pybind11::dict get_registrations() { 26 | pybind11::dict dict; 27 | 28 | dict["spherical_harmonics_encoding_cuda_f32"] = encapsulate_function(spherical_harmonics_encoding_cuda_f32); 29 | 30 | return dict; 31 | } 32 | 33 | PYBIND11_MODULE(cudaops, m) { 34 | m.def("get_registrations", &get_registrations); 35 | m.def("make_spherical_harmonics_encoding_descriptor", 36 | [](std::uint32_t n, std::uint8_t degree) { 37 | if (degree < 1 || degree > 8) { 38 | throw std::runtime_error(fmt::format("Spherical harmonics encoding supports degrees 1 to 8 (inclusive) but got {}", degree)); 39 | } 40 | return to_pybind11_bytes(SphericalHarmonicsEncodingDescriptor{ 41 | .n = n, 42 | .degree = degree, 43 | }); 44 | }, 45 | "Description of the data passed to the spherical harmonics encoding function.\n\n" 46 | "Args:\n" 47 | " n: specifies how many inputs are to be encoded\n" 48 | " degree: specifies the highest degree of the output encoding, supports integers 1..8\n" 49 | "\n" 50 | "Returns:\n" 51 | " Serialized bytes that can be passed as the opaque parameter for spherical_harmonics_encoding_cuda" 52 | ); 53 | }; 54 | 55 | } // namespace 56 | 57 | } // namespace shjax 58 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/packbits/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | 4 | from .. import volrendutils_cuda 5 | 6 | try: 7 | from jaxlib.mhlo_helpers import custom_call 8 | except ModuleNotFoundError: 9 | # A more recent jaxlib would have `hlo_helpers` instead of `mhlo_helpers` 10 | # 11 | from jaxlib.hlo_helpers import custom_call 12 | 13 | 14 | # helper function for mapping given shapes to their default mlir layouts 15 | def default_layouts(*shapes): 16 | return [range(len(shape) - 1, -1, -1) for shape in shapes] 17 | 18 | 19 | def packbits_lowering_rule( 20 | ctx: mlir.LoweringRule, 21 | 22 | # input array 23 | density_threshold: ir.Value, 24 | density_grid: ir.Value, 25 | ): 26 | n_bits = ir.RankedTensorType(density_grid.type).shape[0] 27 | n_bytes = n_bits // 8 28 | 29 | opaque = volrendutils_cuda.make_packbits_descriptor(n_bytes) 30 | 31 | shapes = { 32 | "in.density_threshold": (n_bits,), 33 | "in.density_grid": (n_bits,), 34 | 35 | "out.occupied_mask": (n_bits,), 36 | "out.occupancy_bitfield": (n_bytes,), 37 | } 38 | 39 | return custom_call( 40 | call_target_name="pack_density_into_bits", 41 | out_types = [ 42 | ir.RankedTensorType.get(shapes["out.occupied_mask"], ir.IntegerType.get_signless(1)), 43 | ir.RankedTensorType.get(shapes["out.occupancy_bitfield"], ir.IntegerType.get_unsigned(8)), 44 | ], 45 | operands=[ 46 | density_threshold, 47 | density_grid, 48 | ], 49 | backend_config=opaque, 50 | operand_layouts=default_layouts( 51 | shapes["in.density_threshold"], 52 | shapes["in.density_grid"], 53 | ), 54 | result_layouts=default_layouts( 55 | shapes["out.occupied_mask"], 56 | shapes["out.occupancy_bitfield"], 57 | ), 58 | ) 59 | -------------------------------------------------------------------------------- /deps/jax-tcnn/lib/ffi.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "impl/tcnnutils.h" 9 | 10 | namespace jaxtcnn { 11 | 12 | template 13 | pybind11::bytes to_pybind11_bytes(T const &descriptor) { 14 | return pybind11::bytes(serialize(descriptor)); 15 | } 16 | 17 | template 18 | pybind11::capsule encapsulate_function(T *fn) { 19 | return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); 20 | } 21 | 22 | // expose gpu function 23 | namespace { 24 | 25 | pybind11::dict get_hashgrid_registrations() { 26 | pybind11::dict dict; 27 | dict["hashgrid_encode"] = encapsulate_function(hashgrid_encode); 28 | dict["hashgrid_encode_backward"] = encapsulate_function(hashgrid_encode_backward); 29 | return dict; 30 | } 31 | 32 | PYBIND11_MODULE(tcnnutils, m) { 33 | m.def("get_hashgrid_registrations", &get_hashgrid_registrations); 34 | m.def("make_hashgrid_descriptor", 35 | [](std::uint32_t const n_coords 36 | , std::uint32_t const L 37 | , std::uint32_t const F 38 | , std::uint32_t const N_min 39 | , float const per_level_scale) { 40 | return to_pybind11_bytes(HashGridDescriptor{ 41 | .n_coords = n_coords, 42 | .L = L, 43 | .F = F, 44 | .N_min = N_min, 45 | .per_level_scale = per_level_scale 46 | }); 47 | }, 48 | "Static arguments passed to the `pack_density_into_bits` function.\n\n" 49 | "Args:\n" 50 | " n_coords: number of input coordinates to be encoded\n" 51 | " L: number of levels, 'n_levels' in tcnn\n" 52 | " F: number of features that each level should output (2), 'n_features_per_level' in tcnn\n" 53 | " N_min: coarsest resolution (16), 'base_resolution' in tcnn\n" 54 | " per_level_scale: scale factor between consecutive levels\n" 55 | ); 56 | }; 57 | 58 | } // namespace 59 | 60 | } // namespace jax_tcnn 61 | -------------------------------------------------------------------------------- /deps/spherical-harmonics-encoding-jax/default.nix: -------------------------------------------------------------------------------- 1 | { lib, version, symlinkJoin, buildPythonPackage 2 | 3 | , setuptools-scm 4 | , cmake 5 | , ninja 6 | , pybind11 7 | , fmt 8 | 9 | , serde-helper 10 | , cudatoolkit 11 | , python3 12 | , chex 13 | , jax 14 | , jaxlib 15 | }: 16 | 17 | let 18 | cudatoolkit-unsplit = symlinkJoin { 19 | name = "${cudatoolkit.name}-unsplit"; 20 | paths = [ cudatoolkit.out cudatoolkit.lib ]; 21 | }; 22 | fmt-unsplit = symlinkJoin { 23 | name = "${fmt.name}-unsplit"; 24 | paths = [ fmt.out fmt.dev ]; 25 | }; 26 | in 27 | 28 | buildPythonPackage rec { 29 | pname = "spherical-harmonics-encoding-jax"; 30 | inherit version; 31 | src = ./.; 32 | 33 | format = "pyproject"; 34 | 35 | CUDA_HOME = cudatoolkit-unsplit; 36 | 37 | nativeBuildInputs = [ 38 | cmake 39 | ninja 40 | pybind11 41 | setuptools-scm 42 | ]; 43 | dontUseCmakeConfigure = true; 44 | 45 | buildInputs = [ 46 | serde-helper 47 | cudatoolkit-unsplit 48 | fmt-unsplit 49 | ]; 50 | 51 | propagatedBuildInputs = [ 52 | chex 53 | jax 54 | jaxlib 55 | ]; 56 | 57 | doCheck = false; 58 | 59 | preFixup = '' 60 | patchelf --set-rpath "${lib.makeLibraryPath buildInputs}" $out/lib/python${python3.pythonVersion}/site-packages/shjax/*.so 61 | ''; 62 | 63 | pythonImportsCheck = [ "shjax" ]; 64 | 65 | # development 66 | dot_clangd = '' 67 | CompileFlags: # Tweak the parse settings 68 | Add: 69 | - "-Wall" # enable more warnings 70 | - "-Wshadow" # warn if a local declared variable shadows a global one 71 | - "-std=c++20" # use cpp20 standard (std::bit_cast needs this) 72 | - "-I${serde-helper}/include" 73 | - "-I${cudatoolkit-unsplit}/include" 74 | - "-I${fmt.dev}/include" 75 | - "-I${pybind11}/include" 76 | - "-I${python3}/include/python${python3.pythonVersion}" 77 | - "--cuda-path=${cudatoolkit-unsplit}" 78 | Remove: "-W*" # strip all other warning-related flags 79 | Compiler: "clang++" # Change argv[0] of compile flags to clang++ 80 | 81 | # vim: ft=yaml: 82 | ''; 83 | shellHook = '' 84 | echo "use \`echo \$dot_clangd >.clangd\` for development" 85 | [[ "$-" == *i* ]] && exec "$SHELL" 86 | ''; 87 | } 88 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/default.nix: -------------------------------------------------------------------------------- 1 | { lib, version, symlinkJoin, buildPythonPackage 2 | 3 | , setuptools-scm 4 | , cmake 5 | , ninja 6 | , pybind11 7 | , fmt 8 | 9 | , serde-helper 10 | , cudatoolkit 11 | , python3 12 | , chex 13 | , jax 14 | , jaxlib 15 | }: 16 | 17 | let 18 | cudatoolkit-unsplit = symlinkJoin { 19 | name = "${cudatoolkit.name}-unsplit"; 20 | paths = [ cudatoolkit.out cudatoolkit.lib ]; 21 | }; 22 | fmt-unsplit = symlinkJoin { 23 | name = "fmtlib"; 24 | # libfmt.so resides in the "out" output and is set into RPATH of the python extension 25 | paths = [ fmt.dev fmt.out ]; 26 | }; 27 | in 28 | 29 | buildPythonPackage rec { 30 | pname = "volume-rendering-jax"; 31 | inherit version; 32 | src = ./.; 33 | 34 | format = "pyproject"; 35 | 36 | CUDA_HOME = cudatoolkit-unsplit; 37 | 38 | nativeBuildInputs = [ 39 | cmake 40 | ninja 41 | pybind11 42 | setuptools-scm 43 | ]; 44 | dontUseCmakeConfigure = true; 45 | 46 | buildInputs = [ 47 | serde-helper 48 | cudatoolkit-unsplit 49 | fmt-unsplit 50 | ]; 51 | 52 | propagatedBuildInputs = [ 53 | chex 54 | jax 55 | jaxlib 56 | ]; 57 | 58 | preFixup = '' 59 | patchelf --set-rpath "${lib.makeLibraryPath buildInputs}" $out/lib/python${python3.pythonVersion}/site-packages/volrendjax/*.so 60 | ''; 61 | 62 | doCheck = false; 63 | 64 | pythonImportsCheck = [ "volrendjax" ]; 65 | 66 | # development 67 | dot_clangd = '' 68 | CompileFlags: # Tweak the parse settings 69 | Add: 70 | - "-Wall" # enable more warnings 71 | - "-Wshadow" # warn if a local declared variable shadows a global one 72 | - "-std=c++20" # use cpp20 standard (std::bit_cast needs this) 73 | - "-I${serde-helper}/include" 74 | - "-I${cudatoolkit-unsplit}/include" 75 | - "-I${fmt.dev}/include" 76 | - "-I${pybind11}/include" 77 | - "-I${python3}/include/python${python3.pythonVersion}" 78 | - "--cuda-path=${cudatoolkit-unsplit}" 79 | Remove: "-W*" # strip all other warning-related flags 80 | Compiler: "clang++" # Change argv[0] of compile flags to clang++ 81 | 82 | # vim: ft=yaml: 83 | ''; 84 | shellHook = '' 85 | echo "use \`echo \$dot_clangd >.clangd\` for development" 86 | [[ "$-" == *i* ]] && exec "$SHELL" 87 | ''; 88 | } 89 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/morton3d/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | 4 | from .. import volrendutils_cuda 5 | 6 | try: 7 | from jaxlib.mhlo_helpers import custom_call 8 | except ModuleNotFoundError: 9 | # A more recent jaxlib would have `hlo_helpers` instead of `mhlo_helpers` 10 | # 11 | from jaxlib.hlo_helpers import custom_call 12 | 13 | 14 | # helper function for mapping given shapes to their default mlir layouts 15 | def default_layouts(*shapes): 16 | return [range(len(shape) - 1, -1, -1) for shape in shapes] 17 | 18 | 19 | def morton3d_lowering_rule( 20 | ctx: mlir.LoweringRule, 21 | 22 | # input array 23 | xyzs: ir.Value, 24 | ): 25 | length, _ = ir.RankedTensorType(xyzs.type).shape 26 | 27 | opaque = volrendutils_cuda.make_morton3d_descriptor(length) 28 | 29 | shapes = { 30 | "in.xyzs": (length, 3), 31 | 32 | "out.idcs": (length,), 33 | } 34 | 35 | return [custom_call( 36 | call_target_name="morton3d", 37 | out_types=[ 38 | ir.RankedTensorType.get(shapes["out.idcs"], ir.IntegerType.get_unsigned(32)), 39 | ], 40 | operands=[ 41 | xyzs, 42 | ], 43 | backend_config=opaque, 44 | operand_layouts=default_layouts( 45 | shapes["in.xyzs"], 46 | ), 47 | result_layouts=default_layouts( 48 | shapes["out.idcs"], 49 | ), 50 | )] 51 | 52 | 53 | def morton3d_invert_lowering_rule( 54 | ctx: mlir.LoweringRule, 55 | 56 | # input array 57 | idcs: ir.Value, 58 | ): 59 | length, = ir.RankedTensorType(idcs.type).shape 60 | 61 | opaque = volrendutils_cuda.make_morton3d_descriptor(length) 62 | 63 | shapes = { 64 | "in.idcs": (length,), 65 | 66 | "out.xyzs": (length, 3), 67 | } 68 | 69 | return [custom_call( 70 | call_target_name="morton3d_invert", 71 | out_types=[ 72 | ir.RankedTensorType.get(shapes["out.xyzs"], ir.IntegerType.get_unsigned(32)), 73 | ], 74 | operands=[ 75 | idcs, 76 | ], 77 | backend_config=opaque, 78 | operand_layouts=default_layouts( 79 | shapes["in.idcs"], 80 | ), 81 | result_layouts=default_layouts( 82 | shapes["out.xyzs"], 83 | ), 84 | )] 85 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "locked": { 5 | "lastModified": 1676283394, 6 | "narHash": "sha256-XX2f9c3iySLCw54rJ/CZs+ZK6IQy7GXNY4nSOyu2QG4=", 7 | "owner": "numtide", 8 | "repo": "flake-utils", 9 | "rev": "3db36a8b464d0c4532ba1c7dda728f4576d6d073", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "numtide", 14 | "repo": "flake-utils", 15 | "rev": "3db36a8b464d0c4532ba1c7dda728f4576d6d073", 16 | "type": "github" 17 | } 18 | }, 19 | "nixgl": { 20 | "inputs": { 21 | "flake-utils": [ 22 | "flake-utils" 23 | ], 24 | "nixpkgs": [ 25 | "nixpkgs" 26 | ] 27 | }, 28 | "locked": { 29 | "lastModified": 1676383589, 30 | "narHash": "sha256-KCkWZXCjH+C4Kn7fUGSrEl5btk+sERHhZueSsvVbPWc=", 31 | "owner": "guibou", 32 | "repo": "nixgl", 33 | "rev": "c917918ab9ebeee27b0dd657263d3f57ba6bb8ad", 34 | "type": "github" 35 | }, 36 | "original": { 37 | "owner": "guibou", 38 | "repo": "nixgl", 39 | "rev": "c917918ab9ebeee27b0dd657263d3f57ba6bb8ad", 40 | "type": "github" 41 | } 42 | }, 43 | "nixpkgs": { 44 | "locked": { 45 | "lastModified": 1669833724, 46 | "narHash": "sha256-/HEZNyGbnQecrgJnfE8d0WC5c1xuPSD2LUpB6YXlg4c=", 47 | "owner": "nixos", 48 | "repo": "nixpkgs", 49 | "rev": "4d2b37a84fad1091b9de401eb450aae66f1a741e", 50 | "type": "github" 51 | }, 52 | "original": { 53 | "owner": "nixos", 54 | "ref": "22.11", 55 | "repo": "nixpkgs", 56 | "type": "github" 57 | } 58 | }, 59 | "nixpkgs-with-nvidia-driver-fix": { 60 | "locked": { 61 | "lastModified": 1679654343, 62 | "narHash": "sha256-T2v0cmTWstqhUDhU/lgMpY5YnCmtoaKZ7oSTIM93ERk=", 63 | "owner": "nixos", 64 | "repo": "nixpkgs", 65 | "rev": "90c44c959814b11009fe3abb33344fa6f5e6d290", 66 | "type": "github" 67 | }, 68 | "original": { 69 | "owner": "nixos", 70 | "ref": "pull/222762/head", 71 | "repo": "nixpkgs", 72 | "type": "github" 73 | } 74 | }, 75 | "root": { 76 | "inputs": { 77 | "flake-utils": "flake-utils", 78 | "nixgl": "nixgl", 79 | "nixpkgs": "nixpkgs", 80 | "nixpkgs-with-nvidia-driver-fix": "nixpkgs-with-nvidia-driver-fix" 81 | } 82 | } 83 | }, 84 | "root": "root", 85 | "version": 7 86 | } 87 | -------------------------------------------------------------------------------- /models/imagefit.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import chex 4 | import flax.linen as nn 5 | from flax.linen.dtypes import Dtype 6 | import jax 7 | 8 | from models.encoders import FrequencyEncoder, HashGridEncoder 9 | 10 | 11 | class ImageFitter(nn.Module): 12 | encoding: Literal["hashgrid", "frequency"] 13 | encoding_dtype: Dtype 14 | 15 | @nn.compact 16 | def __call__(self, uv: jax.Array) -> jax.Array: 17 | """ 18 | Inputs: 19 | uv [..., 2]: coordinates in $\\R^2$ (normalized in range [0, 1]). 20 | 21 | Returns: 22 | rgb [..., 3]: predicted color for each input uv coordinate (normalized in range [0, 1]). 23 | """ 24 | chex.assert_axis_dimension(uv, -1, 2) 25 | 26 | if self.encoding == "hashgrid": 27 | # [..., L*F] 28 | x = HashGridEncoder( 29 | dim=2, 30 | L=16, 31 | # ~1Mi entries per level 32 | T=2**20, 33 | F=2, 34 | N_min=16, 35 | N_max=2**19, # 524288 36 | param_dtype=self.encoding_dtype, 37 | )(uv) 38 | elif self.encoding == "frequency": 39 | # [..., dim*L] 40 | x = FrequencyEncoder(dim=2, L=10)(uv) 41 | else: 42 | raise ValueError("Unexpected encoding type '{}'".format(self.encoding)) 43 | 44 | DenseLayer = lambda dim, name: nn.Dense( 45 | features=dim, 46 | name=name, 47 | # the paper uses glorot initialization, in practice glorot initialization converges 48 | # to a better result than kaiming initialization, though the gap is small. 49 | # TODO: 50 | # experiment with initializers (or not) 51 | kernel_init=nn.initializers.lecun_normal(), 52 | bias_init=nn.initializers.zeros, 53 | param_dtype=x.dtype 54 | ) 55 | # feed to the MLP 56 | x = DenseLayer(128, name="linear1")(x) 57 | x = nn.relu(x) 58 | x = DenseLayer(128, name="linear2")(x) 59 | x = nn.relu(x) 60 | 61 | if self.encoding == "frequency": 62 | x = nn.relu(DenseLayer(256, name="linear3")(x)) 63 | x = nn.relu(DenseLayer(512, name="linear4")(x)) 64 | x = nn.relu(DenseLayer(512, name="linear5")(x)) 65 | x = nn.relu(DenseLayer(512, name="linear6")(x)) 66 | x = nn.relu(DenseLayer(512, name="linear7")(x)) 67 | 68 | x = nn.Dense(3, name="color_predictor", param_dtype=x.dtype)(x) 69 | rgb = nn.sigmoid(x) 70 | 71 | return rgb 72 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/lib/impl/packbits.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "volrend.h" 4 | 5 | namespace volrendjax { 6 | 7 | namespace { 8 | 9 | __global__ void pack_bits_kernel( 10 | // inputs 11 | /// static 12 | std::uint32_t const n_bytes 13 | 14 | /// array 15 | , float const * const __restrict__ density_threshold 16 | , float const * const __restrict__ density_grid 17 | 18 | // output 19 | , bool * const __restrict__ occupied_mask 20 | , std::uint8_t * const __restrict__ occupancy_bitfield 21 | ) { 22 | std::uint32_t const i = blockIdx.x * blockDim.x + threadIdx.x; 23 | if (i >= n_bytes) { return; } 24 | 25 | std::uint8_t byte = (std::uint8_t)0x00; 26 | 27 | #pragma unroll 28 | for (std::uint8_t idx = 0; idx < 8; ++idx) { 29 | bool const predicate = (density_grid[i*8+idx] > density_threshold[i*8+idx]); 30 | occupied_mask[i*8+idx] = predicate; 31 | byte |= predicate ? ((std::uint8_t)0x01 << idx) : (std::uint8_t)0x00; 32 | } 33 | occupancy_bitfield[i] = byte; 34 | } 35 | 36 | void pack_bits_launcher(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { 37 | // buffer indexing helper 38 | std::uint32_t __buffer_idx = 0; 39 | auto const next_buffer = [&]() { return buffers[__buffer_idx++]; }; 40 | 41 | // inputs 42 | /// static 43 | PackbitsDescriptor const &desc = *deserialize(opaque, opaque_len); 44 | 45 | /// array 46 | float const * const __restrict__ density_threshold = static_cast(next_buffer()); 47 | float const * const __restrict__ density_grid = static_cast(next_buffer()); 48 | 49 | // output 50 | bool * const __restrict__ occupied_mask = static_cast(next_buffer()); 51 | std::uint8_t * const __restrict__ occupancy_bitfield = static_cast(next_buffer()); 52 | 53 | // kernel launch 54 | std::uint32_t static constexpr blockSize = 512; 55 | std::uint32_t const numBlocks = (desc.n_bytes + blockSize - 1) / blockSize; 56 | pack_bits_kernel<<>>( 57 | // inputs 58 | /// static 59 | desc.n_bytes 60 | 61 | /// array 62 | , density_threshold 63 | , density_grid 64 | 65 | /// output 66 | , occupied_mask 67 | , occupancy_bitfield 68 | ); 69 | 70 | // abort on error 71 | CUDA_CHECK_THROW(cudaGetLastError()); 72 | } 73 | 74 | } 75 | 76 | void pack_density_into_bits( 77 | cudaStream_t stream, 78 | void **buffers, 79 | const char *opaque, 80 | std::size_t opaque_len 81 | ) { 82 | pack_bits_launcher(stream, buffers, opaque, opaque_len); 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /deps/jax-tcnn/default.nix: -------------------------------------------------------------------------------- 1 | { lib, version, symlinkJoin, linkFarm, cudaCapabilities, buildPythonPackage 2 | 3 | , setuptools-scm 4 | , cmake 5 | , ninja 6 | , pybind11 7 | , fmt 8 | 9 | , serde-helper 10 | , cudatoolkit 11 | , tiny-cuda-nn 12 | , nlohmann_json 13 | , python3 14 | , chex 15 | , jax 16 | , jaxlib 17 | }: 18 | 19 | let 20 | dropDot = x: builtins.replaceStrings ["."] [""] x; 21 | minGpuArch = let 22 | min = lhs: rhs: if (builtins.compareVersions lhs rhs) < 0 23 | then lhs 24 | else rhs; 25 | in dropDot (builtins.foldl' min "998244353" cudaCapabilities); 26 | 27 | cudatoolkit-unsplit = symlinkJoin { 28 | name = "${cudatoolkit.name}-unsplit"; 29 | paths = [ cudatoolkit.out cudatoolkit.lib ]; 30 | }; 31 | fmt-unsplit = symlinkJoin { 32 | name = "${fmt.name}-unsplit"; 33 | paths = [ fmt.out fmt.dev ]; 34 | }; 35 | nlohmann_json-symlinked = linkFarm "${nlohmann_json.name}-symlinked" [ 36 | { name = "include/json"; path = "${nlohmann_json}/include/nlohmann"; } 37 | { name = "include/nlohmann"; path = "${nlohmann_json}/include/nlohmann"; } 38 | ]; 39 | in 40 | 41 | buildPythonPackage rec { 42 | pname = "jax-tcnn"; 43 | inherit version; 44 | src = ./.; 45 | 46 | format = "pyproject"; 47 | 48 | CUDA_HOME = cudatoolkit-unsplit; 49 | 50 | nativeBuildInputs = [ 51 | cmake 52 | ninja 53 | pybind11 54 | setuptools-scm 55 | ]; 56 | dontUseCmakeConfigure = true; 57 | cmakeFlags = [ 58 | "-DTCNN_MIN_GPU_ARCH=${minGpuArch}" 59 | "-DCMAKE_CUDA_ARCHITECTURES=${lib.concatStringsSep ";" (map dropDot cudaCapabilities)}" 60 | ]; 61 | 62 | buildInputs = [ 63 | cudatoolkit-unsplit 64 | fmt-unsplit 65 | serde-helper 66 | tiny-cuda-nn 67 | nlohmann_json-symlinked 68 | ]; 69 | 70 | propagatedBuildInputs = [ 71 | chex 72 | jax 73 | jaxlib 74 | ]; 75 | 76 | preFixup = '' 77 | patchelf --set-rpath "${lib.makeLibraryPath buildInputs}/lib" $out/lib/python${python3.pythonVersion}/site-packages/jaxtcnn/*.so 78 | ''; 79 | 80 | doCheck = false; 81 | 82 | pythonImportsCheck = [ "jaxtcnn" ]; 83 | 84 | # development 85 | dot_clangd = '' 86 | CompileFlags: # Tweak the parse settings 87 | Add: 88 | - "-Wall" # enable more warnings 89 | - "-Wshadow" # warn if a local declared variable shadows a global one 90 | - "-std=c++20" # use cpp20 standard (std::bit_cast needs this) 91 | - "-DTCNN_MIN_GPU_ARCH=${minGpuArch}" 92 | - "-I${tiny-cuda-nn}/include" 93 | - "-I${nlohmann_json-symlinked}/include" 94 | - "-I${serde-helper}/include" 95 | - "-I${cudatoolkit-unsplit}/include" 96 | - "-I${fmt.dev}/include" 97 | - "-I${pybind11}/include" 98 | - "-I${python3}/include/python${python3.pythonVersion}" 99 | - "--cuda-path=${cudatoolkit-unsplit}" 100 | Remove: "-W*" # strip all other warning-related flags 101 | Compiler: "clang++" # Change argv[0] of compile flags to clang++ 102 | 103 | # vim: ft=yaml: 104 | ''; 105 | shellHook = '' 106 | echo "use \`echo \$dot_clangd >.clangd\` for development" 107 | [[ "$-" == *i* ]] && exec "$SHELL" 108 | ''; 109 | } 110 | -------------------------------------------------------------------------------- /deps/jax-tcnn/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | 6 | from setuptools import Extension, find_packages, setup 7 | from setuptools.command.build_ext import build_ext 8 | 9 | HERE = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | class CMakeBuildExt(build_ext): 13 | def build_extensions(self): 14 | # First: configure CMake build 15 | import platform 16 | import sys 17 | import sysconfig 18 | 19 | import pybind11 20 | 21 | # Work out the relevant Python paths to pass to CMake, adapted from the 22 | # PyTorch build system 23 | if platform.system() == "Windows": 24 | cmake_python_library = "{}/libs/python{}.lib".format( 25 | sysconfig.get_config_var("prefix"), 26 | sysconfig.get_config_var("VERSION"), 27 | ) 28 | if not os.path.exists(cmake_python_library): 29 | cmake_python_library = "{}/libs/python{}.lib".format( 30 | sys.base_prefix, 31 | sysconfig.get_config_var("VERSION"), 32 | ) 33 | else: 34 | cmake_python_library = "{}/{}".format( 35 | sysconfig.get_config_var("LIBDIR"), 36 | sysconfig.get_config_var("INSTSONAME"), 37 | ) 38 | cmake_python_include_dir = sysconfig.get_path("include") 39 | 40 | install_dir = os.path.abspath( 41 | os.path.dirname(self.get_ext_fullpath("dummy")) 42 | ) 43 | os.makedirs(install_dir, exist_ok=True) 44 | cmake_args = [ 45 | "-DCMAKE_INSTALL_PREFIX={}".format(install_dir), 46 | "-DPython_EXECUTABLE={}".format(sys.executable), 47 | "-DPython_LIBRARIES={}".format(cmake_python_library), 48 | "-DPython_INCLUDE_DIRS={}".format(cmake_python_include_dir), 49 | "-DCMAKE_BUILD_TYPE={}".format( 50 | "Debug" if self.debug else "Release" 51 | ), 52 | "-DCMAKE_PREFIX_PATH={}".format(pybind11.get_cmake_dir()), 53 | "-G Ninja", 54 | ] + os.environ["cmakeFlags"].split() 55 | os.makedirs(self.build_temp, exist_ok=True) 56 | subprocess.check_call( 57 | ["cmake", HERE] + cmake_args, cwd=self.build_temp 58 | ) 59 | 60 | # Build all the extensions 61 | super().build_extensions() 62 | 63 | # Finally run install 64 | subprocess.check_call( 65 | ["cmake", "--build", ".", "--target", "install"], 66 | cwd=self.build_temp, 67 | ) 68 | 69 | def build_extension(self, ext): 70 | target_name = ext.name.split(".")[-1] 71 | subprocess.check_call( 72 | ["cmake", "--build", ".", "--target", target_name], 73 | cwd=self.build_temp, 74 | ) 75 | 76 | extensions = [ 77 | Extension( 78 | "jax_tcnn.tcnnutils", # Python dotted name, whose final component should be a buildable target defined in CMakeLists.txt 79 | [ # source paths, relative to this setup.py file 80 | "lib/ffi.cc", 81 | "lib/impl/hashgrid.cu", 82 | ], 83 | ), 84 | ] 85 | 86 | setup( 87 | name="jax-tcnn", 88 | author="blurgyy", 89 | package_dir={"": "src"}, 90 | packages=find_packages("src"), 91 | include_package_data=True, 92 | install_requires=["jax", "jaxlib", "chex"], 93 | ext_modules=extensions, 94 | cmdclass={"build_ext": CMakeBuildExt}, 95 | ) 96 | -------------------------------------------------------------------------------- /deps/spherical-harmonics-encoding-jax/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | 6 | from setuptools import Extension, find_packages, setup 7 | from setuptools.command.build_ext import build_ext 8 | 9 | HERE = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | class CMakeBuildExt(build_ext): 13 | def build_extensions(self): 14 | # First: configure CMake build 15 | import platform 16 | import sys 17 | import sysconfig 18 | 19 | import pybind11 20 | 21 | # Work out the relevant Python paths to pass to CMake, adapted from the 22 | # PyTorch build system 23 | if platform.system() == "Windows": 24 | cmake_python_library = "{}/libs/python{}.lib".format( 25 | sysconfig.get_config_var("prefix"), 26 | sysconfig.get_config_var("VERSION"), 27 | ) 28 | if not os.path.exists(cmake_python_library): 29 | cmake_python_library = "{}/libs/python{}.lib".format( 30 | sys.base_prefix, 31 | sysconfig.get_config_var("VERSION"), 32 | ) 33 | else: 34 | cmake_python_library = "{}/{}".format( 35 | sysconfig.get_config_var("LIBDIR"), 36 | sysconfig.get_config_var("INSTSONAME"), 37 | ) 38 | cmake_python_include_dir = sysconfig.get_path("include") 39 | 40 | install_dir = os.path.abspath( 41 | os.path.dirname(self.get_ext_fullpath("dummy")) 42 | ) 43 | os.makedirs(install_dir, exist_ok=True) 44 | cmake_args = [ 45 | "-DCMAKE_INSTALL_PREFIX={}".format(install_dir), 46 | "-DPython_EXECUTABLE={}".format(sys.executable), 47 | "-DPython_LIBRARIES={}".format(cmake_python_library), 48 | "-DPython_INCLUDE_DIRS={}".format(cmake_python_include_dir), 49 | "-DCMAKE_BUILD_TYPE={}".format( 50 | "Debug" if self.debug else "Release" 51 | ), 52 | "-DCMAKE_PREFIX_PATH={}".format(pybind11.get_cmake_dir()), 53 | "-G Ninja", 54 | ] 55 | os.makedirs(self.build_temp, exist_ok=True) 56 | subprocess.check_call( 57 | ["cmake", HERE] + cmake_args, cwd=self.build_temp 58 | ) 59 | 60 | # Build all the extensions 61 | super().build_extensions() 62 | 63 | # Finally run install 64 | subprocess.check_call( 65 | ["cmake", "--build", ".", "--target", "install"], 66 | cwd=self.build_temp, 67 | ) 68 | 69 | def build_extension(self, ext): 70 | target_name = ext.name.split(".")[-1] 71 | subprocess.check_call( 72 | ["cmake", "--build", ".", "--target", target_name], 73 | cwd=self.build_temp, 74 | ) 75 | 76 | extensions = [ 77 | Extension( 78 | "shjax.cudaops", # Python dotted name, whose final component should be a buildable target defined in CMakeLists.txt 79 | [ # source paths, relative to this setup.py file 80 | "lib/ffi.cc", 81 | "lib/impl/spherical_harmonics_encoding.cu", 82 | ], 83 | ) 84 | ] 85 | 86 | setup( 87 | name="spherical-harmonics-encoding-jax", 88 | author="blurgyy", 89 | package_dir={"": "src"}, 90 | packages=find_packages("src"), 91 | include_package_data=True, 92 | install_requires=["jax", "jaxlib", "chex"], 93 | ext_modules=extensions, 94 | cmdclass={"build_ext": CMakeBuildExt}, 95 | ) 96 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | 6 | from setuptools import Extension, find_packages, setup 7 | from setuptools.command.build_ext import build_ext 8 | 9 | HERE = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | class CMakeBuildExt(build_ext): 13 | def build_extensions(self): 14 | # First: configure CMake build 15 | import platform 16 | import sys 17 | import sysconfig 18 | 19 | import pybind11 20 | 21 | # Work out the relevant Python paths to pass to CMake, adapted from the 22 | # PyTorch build system 23 | if platform.system() == "Windows": 24 | cmake_python_library = "{}/libs/python{}.lib".format( 25 | sysconfig.get_config_var("prefix"), 26 | sysconfig.get_config_var("VERSION"), 27 | ) 28 | if not os.path.exists(cmake_python_library): 29 | cmake_python_library = "{}/libs/python{}.lib".format( 30 | sys.base_prefix, 31 | sysconfig.get_config_var("VERSION"), 32 | ) 33 | else: 34 | cmake_python_library = "{}/{}".format( 35 | sysconfig.get_config_var("LIBDIR"), 36 | sysconfig.get_config_var("INSTSONAME"), 37 | ) 38 | cmake_python_include_dir = sysconfig.get_path("include") 39 | 40 | install_dir = os.path.abspath( 41 | os.path.dirname(self.get_ext_fullpath("dummy")) 42 | ) 43 | os.makedirs(install_dir, exist_ok=True) 44 | cmake_args = [ 45 | "-DCMAKE_INSTALL_PREFIX={}".format(install_dir), 46 | "-DPython_EXECUTABLE={}".format(sys.executable), 47 | "-DPython_LIBRARIES={}".format(cmake_python_library), 48 | "-DPython_INCLUDE_DIRS={}".format(cmake_python_include_dir), 49 | "-DCMAKE_BUILD_TYPE={}".format( 50 | "Debug" if self.debug else "Release" 51 | ), 52 | "-DCMAKE_PREFIX_PATH={}".format(pybind11.get_cmake_dir()), 53 | "-G Ninja", 54 | ] 55 | os.makedirs(self.build_temp, exist_ok=True) 56 | subprocess.check_call( 57 | ["cmake", HERE] + cmake_args, cwd=self.build_temp 58 | ) 59 | 60 | # Build all the extensions 61 | super().build_extensions() 62 | 63 | # Finally run install 64 | subprocess.check_call( 65 | ["cmake", "--build", ".", "--target", "install"], 66 | cwd=self.build_temp, 67 | ) 68 | 69 | def build_extension(self, ext): 70 | target_name = ext.name.split(".")[-1] 71 | subprocess.check_call( 72 | ["cmake", "--build", ".", "--target", target_name], 73 | cwd=self.build_temp, 74 | ) 75 | 76 | extensions = [ 77 | Extension( 78 | "volrendjax.volrendutils_cuda", # Python dotted name, whose final component should be a buildable target defined in CMakeLists.txt 79 | [ # source paths, relative to this setup.py file 80 | "lib/ffi.cc", 81 | "lib/impl/packbits.cu", 82 | "lib/impl/marching.cu", 83 | "lib/impl/integrating.cu", 84 | ], 85 | ), 86 | ] 87 | 88 | setup( 89 | name="volume-rendering-jax", 90 | author="blurgyy", 91 | package_dir={"": "src"}, 92 | packages=find_packages("src"), 93 | include_package_data=True, 94 | install_requires=["jax", "jaxlib", "chex"], 95 | ext_modules=extensions, 96 | cmdclass={"build_ext": CMakeBuildExt}, 97 | ) 98 | -------------------------------------------------------------------------------- /deps/tiny-cuda-nn/default.nix: -------------------------------------------------------------------------------- 1 | # adapted from 2 | 3 | { source, cudaCapabilities 4 | 5 | , buildSharedLib ? false 6 | 7 | , cmake 8 | , cudaPackages 9 | , lib 10 | , ninja 11 | , stdenv 12 | , symlinkJoin 13 | , which 14 | }: 15 | let 16 | cuda-common-redist = with cudaPackages; [ 17 | libcublas # cublas_v2.h 18 | libcusolver # cusolverDn.h 19 | libcusparse # cusparse.h 20 | ]; 21 | 22 | cuda-native-redist = symlinkJoin { 23 | name = "cuda-redist"; 24 | paths = with cudaPackages; [ 25 | cuda_cudart # cuda_runtime.h 26 | cuda_nvcc 27 | ] ++ cuda-common-redist; 28 | }; 29 | 30 | cuda-redist = symlinkJoin { 31 | name = "cuda-redist"; 32 | paths = cuda-common-redist; 33 | }; 34 | in 35 | stdenv.mkDerivation (finalAttrs: rec { 36 | inherit (source) pname version src; 37 | 38 | outputs = [ "out" "dev" ]; 39 | 40 | nativeBuildInputs = [ 41 | cmake 42 | cuda-native-redist 43 | ninja 44 | which 45 | ]; 46 | 47 | # build a shared library for faster development 48 | postPatch = lib.optionalString buildSharedLib '' 49 | sed -E \ 50 | -e 's/BUILD_SHARED_LIBS OFF/BUILD_SHARED_LIBS ON/g' \ 51 | -e 's/STATIC/SHARED/g' \ 52 | -i CMakeLists.txt 53 | ''; 54 | 55 | # by default tcnn builds a static library, but that's too slow 56 | cmakeFlags = [ 57 | "-DTCNN_BUILD_EXAMPLES=OFF" 58 | "-DTCNN_BUILD_BENCHMARK=OFF" 59 | ]; 60 | 61 | buildInputs = [ 62 | cuda-redist 63 | ]; 64 | 65 | # NOTE: We cannot use pythonImportsCheck for this module because it uses torch to immediately 66 | # initialize CUDA and GPU access is not allowed in the nix build environment. 67 | # NOTE: There are no tests for the C++ library or the python bindings, so we just skip the check 68 | # phase -- we're not missing anything. 69 | doCheck = false; 70 | 71 | preConfigure = let 72 | dropDot = x: builtins.replaceStrings ["."] [""] x; 73 | in '' 74 | export TCNN_CUDA_ARCHITECTURES=${ 75 | lib.concatStringsSep "\\;" (map dropDot cudaCapabilities) 76 | } 77 | export CUDA_HOME=${cuda-native-redist} 78 | export LIBRARY_PATH=${cuda-native-redist}/lib/stubs:$LIBRARY_PATH 79 | ''; 80 | 81 | installPhase = '' 82 | runHook preInstall 83 | 84 | # install headers 85 | mkdir -p $dev/include 86 | cp -vr ../include/* $dev/include 87 | cp -vr ../dependencies/* $dev/include 88 | 89 | # install built library 90 | mkdir -p $out/lib $out 91 | cp -v libtiny-cuda-nn.${if buildSharedLib then "so" else "a"} $out/lib/ 92 | 93 | runHook postInstall 94 | ''; 95 | # Fixes: 96 | # > RPATH of binary /nix/store/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx-tiny-cuda-nn-v1.6/lib/libtiny-cuda-nn.so contains a forbidden reference to /build/ 97 | # REF: 98 | preFixup = lib.optionalString buildSharedLib '' 99 | patchelf --set-rpath ${lib.makeLibraryPath buildInputs} $out/lib/libtiny-cuda-nn.so 100 | ''; 101 | 102 | passthru = { 103 | inherit cudaPackages; 104 | }; 105 | 106 | meta = with lib; { 107 | description = "Lightning fast C++/CUDA neural network framework"; 108 | homepage = "https://github.com/NVlabs/tiny-cuda-nn"; 109 | license = licenses.bsd3; 110 | maintainers = with maintainers; [ connorbaker ]; 111 | platforms = platforms.linux; 112 | }; 113 | }) 114 | -------------------------------------------------------------------------------- /deps/_sources/generated.json: -------------------------------------------------------------------------------- 1 | { 2 | "colmap-locked": { 3 | "cargoLocks": null, 4 | "date": null, 5 | "extract": null, 6 | "name": "colmap-locked", 7 | "passthru": null, 8 | "pinned": false, 9 | "src": { 10 | "deepClone": false, 11 | "fetchSubmodules": true, 12 | "leaveDotGit": false, 13 | "name": null, 14 | "owner": "colmap", 15 | "repo": "colmap", 16 | "rev": "3.8", 17 | "sha256": "sha256-ArWQBRuWRkRXnNs154pxTgrGcZyMH6doG/R89LC/0Ms=", 18 | "type": "github" 19 | }, 20 | "version": "3.8" 21 | }, 22 | "dearpygui": { 23 | "cargoLocks": null, 24 | "date": null, 25 | "extract": null, 26 | "name": "dearpygui", 27 | "passthru": null, 28 | "pinned": false, 29 | "src": { 30 | "deepClone": false, 31 | "fetchSubmodules": true, 32 | "leaveDotGit": false, 33 | "name": null, 34 | "owner": "hoffstadt", 35 | "repo": "DearPyGui", 36 | "rev": "v1.9.1", 37 | "sha256": "sha256-Af1jhQYT0CYNFMWihAtP6jRNYKm3XKEu3brFOPSGCnk=", 38 | "type": "github" 39 | }, 40 | "version": "v1.9.1" 41 | }, 42 | "pycolmap": { 43 | "cargoLocks": null, 44 | "date": null, 45 | "extract": null, 46 | "name": "pycolmap", 47 | "passthru": null, 48 | "pinned": false, 49 | "src": { 50 | "deepClone": false, 51 | "fetchSubmodules": true, 52 | "leaveDotGit": false, 53 | "name": null, 54 | "owner": "colmap", 55 | "repo": "pycolmap", 56 | "rev": "v0.4.0", 57 | "sha256": "sha256-W3d+uHZXkH1/QlER1HV8t1MOBOrHIXYsVeYv1zbsbW4=", 58 | "type": "github" 59 | }, 60 | "version": "v0.4.0" 61 | }, 62 | "pyimgui": { 63 | "cargoLocks": null, 64 | "date": null, 65 | "extract": null, 66 | "name": "pyimgui", 67 | "passthru": null, 68 | "pinned": false, 69 | "src": { 70 | "name": null, 71 | "sha256": "sha256-L7247tO429fqmK+eTBxlgrC8TalColjeFjM9jGU9Z+E=", 72 | "type": "url", 73 | "url": "https://pypi.io/packages/source/i/imgui/imgui-2.0.0.tar.gz" 74 | }, 75 | "version": "2.0.0" 76 | }, 77 | "tiny-cuda-nn": { 78 | "cargoLocks": null, 79 | "date": null, 80 | "extract": null, 81 | "name": "tiny-cuda-nn", 82 | "passthru": null, 83 | "pinned": false, 84 | "src": { 85 | "deepClone": false, 86 | "fetchSubmodules": true, 87 | "leaveDotGit": false, 88 | "name": null, 89 | "owner": "NVlabs", 90 | "repo": "tiny-cuda-nn", 91 | "rev": "v1.6", 92 | "sha256": "sha256-qW6Fk2GB71fvZSsfu+mykabSxEKvaikZ/pQQZUycOy0=", 93 | "type": "github" 94 | }, 95 | "version": "v1.6" 96 | }, 97 | "tyro": { 98 | "cargoLocks": null, 99 | "date": null, 100 | "extract": null, 101 | "name": "tyro", 102 | "passthru": null, 103 | "pinned": false, 104 | "src": { 105 | "name": null, 106 | "sha256": "sha256-ygdNkRr4bjDDHioXoMWPZ1c0IamIkqGyvAuvJx3Bhis=", 107 | "type": "url", 108 | "url": "https://pypi.io/packages/source/t/tyro/tyro-0.5.3.tar.gz" 109 | }, 110 | "version": "0.5.3" 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /deps/jax-tcnn/src/jaxtcnn/hashgrid_tcnn/impl.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import functools 3 | from typing import Tuple 4 | 5 | import chex 6 | import jax 7 | from jax.interpreters import mlir, xla 8 | from jax.lib import xla_client 9 | 10 | from . import abstract, lowering 11 | from .. import tcnnutils 12 | 13 | 14 | # register GPU XLA custom calls 15 | for name, value in tcnnutils.get_hashgrid_registrations().items(): 16 | xla_client.register_custom_call_target(name, value, platform="gpu") 17 | 18 | 19 | # primitives 20 | hashgrid_encode_p = jax.core.Primitive("hashgrid🏁") 21 | hashgrid_encode_p.multiple_results = True 22 | hashgrid_encode_p.def_impl(functools.partial(xla.apply_primitive, hashgrid_encode_p)) 23 | hashgrid_encode_p.def_abstract_eval(abstract.hashgrid_encode_abstract) 24 | 25 | hashgrid_encode_backward_p = jax.core.Primitive("hashgrid🏁backward") 26 | hashgrid_encode_backward_p.multiple_results = True 27 | hashgrid_encode_backward_p.def_impl(functools.partial(xla.apply_primitive, hashgrid_encode_backward_p)) 28 | hashgrid_encode_backward_p.def_abstract_eval(abstract.hashgrid_encode_backward_abstract) 29 | 30 | 31 | # lowering rules 32 | mlir.register_lowering( 33 | prim=hashgrid_encode_p, 34 | rule=lowering.hashgrid_encode_lowering_rule, 35 | platform="gpu", 36 | ) 37 | mlir.register_lowering( 38 | prim=hashgrid_encode_backward_p, 39 | rule=lowering.hashgrid_encode_backward_lowering_rule, 40 | platform="gpu", 41 | ) 42 | 43 | 44 | @jax.tree_util.register_pytree_node_class 45 | @dataclasses.dataclass(frozen=True, kw_only=True) 46 | class HashGridMetadata: 47 | # number of levels, "n_levels" in tcnn 48 | L: int 49 | 50 | # number of features that each level should output, "n_features_per_level" in tcnn 51 | F: int 52 | 53 | # coarsest resolution, "base_resolution" in tcnn 54 | N_min: int 55 | 56 | # scale factor between consecutive levels 57 | per_level_scale: float 58 | 59 | def tree_flatten(self): 60 | children = () 61 | aux = (self.L, self.F, self.N_min, self.per_level_scale) 62 | return children, aux 63 | 64 | @classmethod 65 | def tree_unflatten(cls, aux, children): 66 | L, F, N_min, per_level_scale = aux 67 | return cls( 68 | L=L, 69 | F=F, 70 | N_min=N_min, 71 | per_level_scale=per_level_scale, 72 | ) 73 | 74 | 75 | @functools.partial(jax.custom_vjp, nondiff_argnums=[0]) 76 | def __hashgrid_encode( 77 | desc: HashGridMetadata, 78 | offset_table_data: jax.Array, 79 | coords_rm: jax.Array, 80 | params: jax.Array, 81 | ): 82 | encoded_coords_rm, dy_dcoords_rm = hashgrid_encode_p.bind( 83 | offset_table_data, 84 | coords_rm, 85 | params, 86 | L=desc.L, 87 | F=desc.F, 88 | N_min=desc.N_min, 89 | per_level_scale=desc.per_level_scale, 90 | ) 91 | return encoded_coords_rm, dy_dcoords_rm 92 | 93 | 94 | def __hashgrid_encode_fwd( 95 | desc: HashGridMetadata, 96 | offset_table_data: jax.Array, 97 | coords_rm: jax.Array, 98 | params: jax.Array, 99 | ): 100 | primal_outputs = __hashgrid_encode( 101 | desc=desc, 102 | offset_table_data=offset_table_data, 103 | coords_rm=coords_rm, 104 | params=params, 105 | ) 106 | encoded_coords_rm, dy_dcoords_rm = primal_outputs 107 | aux = { 108 | "in.offset_table_data": offset_table_data, 109 | "in.coords_rm": coords_rm, 110 | "in.params": params, 111 | "out.dy_dcoords_rm": dy_dcoords_rm, 112 | } 113 | return primal_outputs, aux 114 | 115 | 116 | def __hashgrid_encode_bwd(desc: HashGridMetadata, aux, grads): 117 | dL_dy_rm, _ = grads 118 | dL_dparams, dL_dcoords_rm = hashgrid_encode_backward_p.bind( 119 | aux["in.offset_table_data"], 120 | aux["in.coords_rm"], 121 | aux["in.params"], 122 | dL_dy_rm, 123 | aux["out.dy_dcoords_rm"], 124 | 125 | L=desc.L, 126 | F=desc.F, 127 | N_min=desc.N_min, 128 | per_level_scale=desc.per_level_scale, 129 | ) 130 | return None, dL_dcoords_rm, dL_dparams 131 | 132 | 133 | __hashgrid_encode.defvjp( 134 | fwd=__hashgrid_encode_fwd, 135 | bwd=__hashgrid_encode_bwd, 136 | ) 137 | -------------------------------------------------------------------------------- /deps/jax-tcnn/src/jaxtcnn/hashgrid_tcnn/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | 4 | from .. import tcnnutils 5 | 6 | try: 7 | from jaxlib.mhlo_helpers import custom_call 8 | except ModuleNotFoundError: 9 | # A more recent jaxlib would have `hlo_helpers` instead of `mhlo_helpers` 10 | # 11 | from jaxlib.hlo_helpers import custom_call 12 | 13 | 14 | # helper function for mapping given shapes to their default mlir layouts 15 | def default_layouts(*shapes): 16 | return [range(len(shape) - 1, -1, -1) for shape in shapes] 17 | 18 | 19 | def hashgrid_encode_lowering_rule( 20 | ctx: mlir.LoweringRule, 21 | 22 | # arrays 23 | offset_table_data: ir.Value, 24 | coords_rm: ir.Value, 25 | params: ir.Value, 26 | 27 | # static args 28 | L: int, 29 | F: int, 30 | N_min: int, 31 | per_level_scale: float, 32 | ): 33 | dim, n_coords = ir.RankedTensorType(coords_rm.type).shape 34 | n_params, _ = ir.RankedTensorType(params.type).shape 35 | 36 | opaque = tcnnutils.make_hashgrid_descriptor( 37 | n_coords, 38 | L, 39 | F, 40 | N_min, 41 | per_level_scale, 42 | ) 43 | 44 | shapes = { 45 | "in.offset_table_data": (L + 1,), 46 | "in.coords_rm": (dim, n_coords), 47 | "in.params": (n_params, F), 48 | 49 | "out.encoded_coords_rm": (L * F, n_coords), 50 | "out.dy_dcoords_rm": (dim * L * F, n_coords), 51 | } 52 | 53 | return custom_call( 54 | call_target_name="hashgrid_encode", 55 | out_types=[ 56 | ir.RankedTensorType.get(shapes["out.encoded_coords_rm"], ir.F32Type.get()), 57 | ir.RankedTensorType.get(shapes["out.dy_dcoords_rm"], ir.F32Type.get()), 58 | ], 59 | operands=[ 60 | offset_table_data, 61 | coords_rm, 62 | params, 63 | ], 64 | backend_config=opaque, 65 | operand_layouts=default_layouts( 66 | shapes["in.offset_table_data"], 67 | shapes["in.coords_rm"], 68 | shapes["in.params"], 69 | ), 70 | result_layouts=default_layouts( 71 | shapes["out.encoded_coords_rm"], 72 | shapes["out.dy_dcoords_rm"], 73 | ), 74 | ) 75 | 76 | 77 | def hashgrid_encode_backward_lowering_rule( 78 | ctx: mlir.LoweringRule, 79 | 80 | offset_table_data: ir.Value, 81 | coords_rm: ir.Value, 82 | params: ir.Value, # only for determining shape of dL_dparams 83 | dL_dy_rm: ir.Value, 84 | dy_dcoords_rm: ir.Value, 85 | 86 | # static args 87 | L: int, 88 | F: int, 89 | N_min: int, 90 | per_level_scale: float, 91 | ): 92 | dim, n_coords = ir.RankedTensorType(coords_rm.type).shape 93 | n_params, _ = ir.RankedTensorType(params.type).shape 94 | 95 | opaque = tcnnutils.make_hashgrid_descriptor( 96 | n_coords, 97 | L, 98 | F, 99 | N_min, 100 | per_level_scale, 101 | ) 102 | 103 | shapes = { 104 | "in.offset_table_data": (L + 1,), 105 | "in.coords_rm": (dim, n_coords), 106 | # "in.params": (n_params, F), 107 | "in.dL_dy_rm": (L * F, n_coords), 108 | "in.dy_dcoords_rm": (dim * L * F, n_coords), 109 | 110 | "out.dL_dparams": (n_params, F), 111 | "out.dL_dcoords_rm": (dim, n_coords), 112 | } 113 | 114 | return custom_call( 115 | call_target_name="hashgrid_encode_backward", 116 | out_types=[ 117 | ir.RankedTensorType.get(shapes["out.dL_dparams"], ir.F32Type.get()), 118 | ir.RankedTensorType.get(shapes["out.dL_dcoords_rm"], ir.F32Type.get()), 119 | ], 120 | operands=[ 121 | offset_table_data, 122 | coords_rm, 123 | dL_dy_rm, 124 | dy_dcoords_rm 125 | ], 126 | backend_config=opaque, 127 | result_layouts=default_layouts( 128 | shapes["out.dL_dparams"], 129 | shapes["out.dL_dcoords_rm"], 130 | ), 131 | operand_layouts=default_layouts( 132 | shapes["in.offset_table_data"], 133 | shapes["in.coords_rm"], 134 | shapes["in.dL_dy_rm"], 135 | shapes["in.dy_dcoords_rm"], 136 | ), 137 | ) 138 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/integrating/abstract.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | # jit rules 7 | def integrate_rays_abstract( 8 | rays_sample_startidx: jax.Array, 9 | rays_n_samples: jax.Array, 10 | 11 | bgs: jax.Array, 12 | dss: jax.Array, 13 | z_vals: jax.Array, 14 | drgbs: jax.Array, 15 | ): 16 | (n_rays,), (total_samples,) = rays_sample_startidx.shape, dss.shape 17 | 18 | chex.assert_shape([rays_sample_startidx, rays_n_samples], (n_rays,)) 19 | chex.assert_shape(bgs, (n_rays, 3)) 20 | chex.assert_shape(z_vals, (total_samples,)) 21 | chex.assert_shape(drgbs, (total_samples, 4)) 22 | 23 | dtype = jax.dtypes.canonicalize_dtype(drgbs.dtype) 24 | if dtype != jnp.float32: 25 | raise NotImplementedError( 26 | "integrate_rays is only implemented for input prediction (density, color) of `jnp.float32` type, got {}".format( 27 | dtype, 28 | ) 29 | ) 30 | 31 | shapes = { 32 | "helper.measured_batch_size": (1,), 33 | 34 | "out.final_rgbds": (n_rays, 4), 35 | "out.final_opacities": (n_rays,), 36 | } 37 | 38 | return ( 39 | jax.ShapedArray(shape=shapes["helper.measured_batch_size"], dtype=jnp.uint32), 40 | 41 | jax.ShapedArray(shape=shapes["out.final_rgbds"], dtype=jnp.float32), 42 | jax.ShapedArray(shape=shapes["out.final_opacities"], dtype=jnp.float32), 43 | ) 44 | 45 | def integrate_rays_backward_abstract( 46 | rays_sample_startidx: jax.Array, 47 | rays_n_samples: jax.Array, 48 | 49 | # original inputs 50 | bgs: jax.Array, 51 | dss: jax.Array, 52 | z_vals: jax.Array, 53 | drgbs: jax.Array, 54 | 55 | # original outputs 56 | final_rgbds: jax.Array, 57 | final_opacities: jax.Array, 58 | 59 | # gradient inputs 60 | dL_dfinal_rgbds: jax.Array, 61 | 62 | # static argument 63 | near_distance: float, 64 | ): 65 | (n_rays,), (total_samples,) = rays_sample_startidx.shape, dss.shape 66 | 67 | chex.assert_shape([rays_sample_startidx, rays_n_samples, final_opacities], (n_rays,)) 68 | chex.assert_shape(bgs, (n_rays, 3)) 69 | chex.assert_shape(z_vals, (total_samples,)) 70 | chex.assert_shape(drgbs, (total_samples, 4)) 71 | chex.assert_shape([final_rgbds, dL_dfinal_rgbds], (n_rays, 4)) 72 | 73 | chex.assert_scalar_non_negative(near_distance) 74 | 75 | dtype = jax.dtypes.canonicalize_dtype(drgbs.dtype) 76 | if dtype != jnp.float32: 77 | raise NotImplementedError( 78 | "integrate_rays is only implemented for input color of `jnp.float32` type, got {}".format( 79 | dtype, 80 | ) 81 | ) 82 | 83 | out_shapes = { 84 | "dL_dbgs": (n_rays, 3), 85 | "dL_dz_vals": (total_samples,), 86 | "dL_ddrgbs": (total_samples, 4), 87 | } 88 | 89 | return ( 90 | jax.ShapedArray(shape=out_shapes["dL_dbgs"], dtype=jnp.float32), 91 | jax.ShapedArray(shape=out_shapes["dL_dz_vals"], dtype=jnp.float32), 92 | jax.ShapedArray(shape=out_shapes["dL_ddrgbs"], dtype=jnp.float32), 93 | ) 94 | 95 | 96 | def integrate_rays_inference_abstract( 97 | rays_bg: jax.ShapedArray, 98 | rays_rgbd: jax.ShapedArray, 99 | rays_T: jax.ShapedArray, 100 | 101 | n_samples: jax.ShapedArray, 102 | indices: jax.ShapedArray, 103 | dss: jax.ShapedArray, 104 | z_vals: jax.ShapedArray, 105 | drgbs: jax.ShapedArray, 106 | ): 107 | (n_total_rays, _), (n_rays, march_steps_cap) = rays_rgbd.shape, dss.shape 108 | 109 | chex.assert_shape(rays_bg, (n_total_rays, 3)) 110 | chex.assert_shape(rays_rgbd, (n_total_rays, 4)) 111 | chex.assert_shape(rays_T, (n_total_rays,)) 112 | chex.assert_shape([n_samples, indices], (n_rays,)) 113 | chex.assert_shape([dss, z_vals], (n_rays, march_steps_cap)) 114 | chex.assert_shape(drgbs, (n_rays, march_steps_cap, 4)) 115 | 116 | out_shapes = { 117 | "terminate_cnt": (1,), 118 | "terminated": (n_rays,), 119 | "rays_rgbd": (n_rays, 4), 120 | "rays_T": (n_rays,), 121 | } 122 | 123 | return ( 124 | jax.ShapedArray(shape=out_shapes["terminate_cnt"], dtype=jnp.uint32), 125 | jax.ShapedArray(shape=out_shapes["terminated"], dtype=jnp.bool_), 126 | jax.ShapedArray(shape=out_shapes["rays_rgbd"], dtype=jnp.float32), 127 | jax.ShapedArray(shape=out_shapes["rays_T"], dtype=jnp.float32), 128 | ) 129 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/marching/abstract.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | # jit rules 7 | def march_rays_abstract( 8 | # arrays 9 | rays_o: jax.ShapedArray, 10 | rays_d: jax.ShapedArray, 11 | t_starts: jax.ShapedArray, 12 | t_ends: jax.ShapedArray, 13 | noises: jax.ShapedArray, 14 | occupancy_bitfield: jax.ShapedArray, 15 | 16 | # static args 17 | total_samples: int, 18 | diagonal_n_steps: int, 19 | K: int, 20 | G: int, 21 | bound: float, 22 | stepsize_portion: float, 23 | ): 24 | n_rays, _ = rays_o.shape 25 | 26 | chex.assert_shape([rays_o, rays_d], (n_rays, 3)) 27 | chex.assert_shape([t_starts, t_ends, noises], (n_rays,)) 28 | 29 | chex.assert_shape(occupancy_bitfield, (K*G*G*G//8,)) 30 | chex.assert_type(occupancy_bitfield, jnp.uint8) 31 | 32 | chex.assert_scalar_positive(total_samples) 33 | chex.assert_scalar_positive(diagonal_n_steps) 34 | chex.assert_scalar_positive(K) 35 | chex.assert_scalar_positive(G) 36 | chex.assert_scalar_positive(bound) 37 | chex.assert_scalar_non_negative(stepsize_portion) 38 | 39 | dtype = jax.dtypes.canonicalize_dtype(rays_o.dtype) 40 | if dtype != jnp.float32: 41 | raise NotImplementedError( 42 | "march_rays is only implemented for input coordinates of `jnp.float32` type, got {}".format( 43 | dtype, 44 | ) 45 | ) 46 | 47 | shapes = { 48 | "helper.next_sample_write_location": (1,), 49 | "helper.number_of_exceeded_samples": (1,), 50 | "helper.ray_is_valid": (n_rays,), 51 | 52 | "out.rays_n_samples": (n_rays,), 53 | "out.rays_sample_startidx": (n_rays,), 54 | "out.idcs": (total_samples,), 55 | "out.xyzs": (total_samples, 3), 56 | "out.dirs": (total_samples, 3), 57 | "out.dss": (total_samples,), 58 | "out.z_vals": (total_samples,), 59 | } 60 | 61 | return ( 62 | jax.ShapedArray(shape=shapes["helper.next_sample_write_location"], dtype=jnp.uint32), 63 | jax.ShapedArray(shape=shapes["helper.number_of_exceeded_samples"], dtype=jnp.uint32), 64 | jax.ShapedArray(shape=shapes["helper.ray_is_valid"], dtype=jnp.bool_), 65 | jax.ShapedArray(shape=shapes["out.rays_n_samples"], dtype=jnp.uint32), 66 | jax.ShapedArray(shape=shapes["out.rays_sample_startidx"], dtype=jnp.uint32), 67 | jax.ShapedArray(shape=shapes["out.idcs"], dtype=jnp.uint32), 68 | jax.ShapedArray(shape=shapes["out.xyzs"], dtype=jnp.float32), 69 | jax.ShapedArray(shape=shapes["out.dirs"], dtype=jnp.float32), 70 | jax.ShapedArray(shape=shapes["out.dss"], dtype=jnp.float32), 71 | jax.ShapedArray(shape=shapes["out.z_vals"], dtype=jnp.float32), 72 | ) 73 | 74 | 75 | def march_rays_inference_abstract( 76 | # arrays 77 | rays_o: jax.ShapedArray, 78 | rays_d: jax.ShapedArray, 79 | t_starts: jax.ShapedArray, 80 | t_ends: jax.ShapedArray, 81 | occupancy_bitfield: jax.ShapedArray, 82 | next_ray_index: jax.ShapedArray, 83 | terminated: jax.ShapedArray, 84 | indices_in: jax.ShapedArray, 85 | 86 | # static args 87 | diagonal_n_steps: int, 88 | K: int, 89 | G: int, 90 | march_steps_cap: int, 91 | bound: float, 92 | stepsize_portion: float, 93 | ): 94 | (n_total_rays, _), (n_rays,) = rays_o.shape, terminated.shape 95 | 96 | chex.assert_shape([rays_o, rays_d], (n_total_rays, 3)) 97 | chex.assert_shape([t_starts, t_ends], (n_total_rays,)) 98 | chex.assert_shape(occupancy_bitfield, (K*G*G*G//8,)) 99 | chex.assert_type(occupancy_bitfield, jnp.uint8) 100 | chex.assert_shape(next_ray_index, (1,)) 101 | chex.assert_shape([terminated, indices_in], (n_rays,)) 102 | 103 | out_shapes = { 104 | "next_ray_index": (1,), 105 | "indices_out": (n_rays,), 106 | "n_samples": (n_rays,), 107 | "t_starts": (n_rays,), 108 | "xyzs": (n_rays, march_steps_cap, 3), 109 | "dss": (n_rays, march_steps_cap), 110 | "z_vals": (n_rays, march_steps_cap), 111 | } 112 | 113 | return ( 114 | jax.ShapedArray(shape=out_shapes["next_ray_index"], dtype=jnp.uint32), 115 | jax.ShapedArray(shape=out_shapes["indices_out"], dtype=jnp.uint32), 116 | jax.ShapedArray(shape=out_shapes["n_samples"], dtype=jnp.uint32), 117 | jax.ShapedArray(shape=out_shapes["t_starts"], dtype=jnp.float32), 118 | jax.ShapedArray(shape=out_shapes["xyzs"], dtype=jnp.float32), 119 | jax.ShapedArray(shape=out_shapes["dss"], dtype=jnp.float32), 120 | jax.ShapedArray(shape=out_shapes["z_vals"], dtype=jnp.float32), 121 | ) 122 | -------------------------------------------------------------------------------- /deps/jax-tcnn/src/jaxtcnn/hashgrid_tcnn/abstract.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | def hashgrid_encode_abstract( 7 | # arrays 8 | offset_table_data: jax.Array, 9 | coords_rm: jax.Array, 10 | params: jax.Array, 11 | 12 | # static args 13 | L: int, 14 | F: int, 15 | N_min: int, 16 | per_level_scale: float, 17 | ): 18 | dim, n_coords = coords_rm.shape 19 | if dim != 3: 20 | raise NotImplementedError( 21 | "hashgrid encoding is only implemented for 3D coordinates, expected input coordinates to have shape ({}, n_coords), but got shape {}".format( 22 | dim, coords_rm.shape 23 | ) 24 | ) 25 | 26 | n_params, _ = params.shape 27 | 28 | chex.assert_shape(offset_table_data, (L + 1,)) 29 | chex.assert_shape(coords_rm, (dim, n_coords)) 30 | chex.assert_shape(params, (n_params, F)) 31 | 32 | chex.assert_scalar(L) 33 | chex.assert_scalar(F) 34 | chex.assert_scalar(N_min) 35 | chex.assert_scalar(per_level_scale) 36 | chex.assert_type([L, F, N_min], int) 37 | chex.assert_type(per_level_scale, float) 38 | 39 | offset_dtype = jax.dtypes.canonicalize_dtype(offset_table_data.dtype) 40 | if offset_dtype != jnp.uint32: 41 | raise RuntimeError( 42 | "hashgrid encoding expects `offset_table_data` (a prefix sum of the hash table sizes of each level) to be of type uint32, got {}".format(offset_dtype) 43 | ) 44 | 45 | coord_dtype = jax.dtypes.canonicalize_dtype(coords_rm.dtype) 46 | if coord_dtype != jnp.float32: 47 | raise NotImplementedError( 48 | "hashgrid encoding is only implemented for input coordinates of type float32, got {}".format( 49 | coord_dtype 50 | ) 51 | ) 52 | 53 | param_dtype = jax.dtypes.canonicalize_dtype(params.dtype) 54 | if param_dtype != jnp.float32: 55 | raise NotImplementedError( 56 | "hashgrid encoding is only implemented for parameters of type float32, got {}".format( 57 | param_dtype 58 | ) 59 | ) 60 | 61 | out_shapes = { 62 | "encoded_coords_rm": (L * F, n_coords), 63 | "dy_dcoords_rm": (dim * L * F, n_coords), 64 | } 65 | 66 | return ( 67 | jax.ShapedArray(shape=out_shapes["encoded_coords_rm"], dtype=jnp.float32), 68 | jax.ShapedArray(shape=out_shapes["dy_dcoords_rm"], dtype=jnp.float32), 69 | ) 70 | 71 | 72 | def hashgrid_encode_backward_abstract( 73 | offset_table_data: jax.ShapedArray, 74 | coords_rm: jax.ShapedArray, 75 | params: jax.ShapedArray, # only for determining shape of dL_dparams 76 | dL_dy_rm: jax.ShapedArray, 77 | dy_dcoords_rm: jax.ShapedArray, 78 | 79 | # static args 80 | L: int, 81 | F: int, 82 | N_min: int, 83 | per_level_scale: float, 84 | ): 85 | dim, n_coords = coords_rm.shape 86 | if dim != 3: 87 | raise NotImplementedError( 88 | "hashgrid encoding is only implemented for 3D coordinates, expected input coordinates to have shape ({}, n_coords), but got shape {}".format( 89 | dim, coords_rm.shape 90 | ) 91 | ) 92 | 93 | n_params, _ = params.shape 94 | 95 | chex.assert_shape(offset_table_data, (L + 1,)) 96 | chex.assert_shape(coords_rm, (dim, n_coords)) 97 | chex.assert_shape(params, (n_params, F)) 98 | chex.assert_shape(dL_dy_rm, (L*F, n_coords)) 99 | chex.assert_shape(dy_dcoords_rm, (dim*L*F, n_coords)) 100 | 101 | chex.assert_scalar(L) 102 | chex.assert_scalar(F) 103 | chex.assert_scalar(N_min) 104 | chex.assert_scalar(per_level_scale) 105 | chex.assert_type([L, F, N_min], int) 106 | chex.assert_type(per_level_scale, float) 107 | 108 | offset_dtype = jax.dtypes.canonicalize_dtype(offset_table_data.dtype) 109 | if offset_dtype != jnp.uint32: 110 | raise RuntimeError( 111 | "hashgrid encoding expects `offset_table_data` (a prefix sum of the hash table sizes of each level) to be of type uint32, got {}".format(offset_dtype) 112 | ) 113 | 114 | coord_dtype = jax.dtypes.canonicalize_dtype(coords_rm.dtype) 115 | if coord_dtype != jnp.float32: 116 | raise NotImplementedError( 117 | "hashgrid encoding is only implemented for input coordinates of type float32, got {}".format( 118 | coord_dtype 119 | ) 120 | ) 121 | 122 | param_dtype = jax.dtypes.canonicalize_dtype(params.dtype) 123 | if param_dtype != jnp.float32: 124 | raise NotImplementedError( 125 | "hashgrid encoding is only implemented for parameters of type float32, got {}".format( 126 | param_dtype 127 | ) 128 | ) 129 | 130 | out_shapes = { 131 | "dL_dparams": (n_params, F), 132 | "dL_dcoords_rm": (dim, n_coords), 133 | } 134 | 135 | return ( 136 | jax.ShapedArray(shape=out_shapes["dL_dparams"], dtype=jnp.float32), 137 | jax.ShapedArray(shape=out_shapes["dL_dcoords_rm"], dtype=jnp.float32), 138 | ) 139 | -------------------------------------------------------------------------------- /utils/sfm.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict 3 | from typing_extensions import get_args 4 | 5 | import pycolmap 6 | 7 | from .common import mkValueError 8 | from .types import CameraModelType, ColmapMatcherType 9 | 10 | 11 | def extract_features( 12 | images_dir: Path, 13 | db_path: Path, 14 | camera_model: CameraModelType, 15 | ): 16 | images_dir, db_path = Path(images_dir), Path(db_path) 17 | pycolmap.extract_features( 18 | database_path=db_path, 19 | image_path=images_dir, 20 | # REF: 21 | # 22 | # 23 | camera_mode="SINGLE", 24 | # REF: 25 | # 26 | camera_model=camera_model, 27 | reader_options=pycolmap.ImageReaderOptions( 28 | # camera_model="OPENCV", # NOTE: this is obsolete, see camera_model above 29 | # single_camera=True, # NOTE: this is obsolete, see camera_mode above 30 | ), 31 | sift_options=pycolmap.SiftExtractionOptions( 32 | estimate_affine_shape=True, 33 | domain_size_pooling=True, 34 | ), 35 | ) 36 | 37 | 38 | def match_features( 39 | matcher: ColmapMatcherType, 40 | db_path: Path, 41 | ): 42 | db_path = Path(db_path) 43 | if matcher not in get_args(ColmapMatcherType): 44 | raise mkValueError( 45 | desc="colmap matcher", 46 | value=matcher, 47 | type=ColmapMatcherType, 48 | ) 49 | match_fn = getattr(pycolmap, "match_{}".format(matcher.lower())) 50 | return match_fn( 51 | database_path=db_path, 52 | sift_options=pycolmap.SiftMatchingOptions( 53 | guided_matching=True, 54 | ), 55 | ) 56 | 57 | 58 | def sparse_reconstruction( 59 | images_dir: Path, 60 | sparse_reconstructions_dir: Path, 61 | db_path: Path, 62 | matcher: ColmapMatcherType, 63 | ) -> Dict[int, pycolmap.Reconstruction]: 64 | images_dir, sparse_reconstructions_dir = Path(images_dir), Path(sparse_reconstructions_dir) 65 | mapping_options = pycolmap.IncrementalMapperOptions( 66 | # principal point estimation is an ill-posed problem in general (`False` is already the 67 | # default, setting to False here explicitly works as a reminder to self) 68 | ba_refine_principal_point=False, 69 | # :src/colmap/util/option_manager.cc:ModifyForExtremeQuality 70 | ba_local_max_num_iterations=40, 71 | ba_local_max_refinements=3, 72 | ba_global_max_num_iterations=100, 73 | # below 3 options are for individual/video data, for internet photos, they should be left 74 | # default 75 | # :src/colmap/util/option_manager.cc:ModifyForVideoData,ModifyForIndividualData 76 | min_focal_length_ratio=0.1, 77 | max_focal_length_ratio=10, 78 | max_extra_param=1e15, 79 | ) 80 | if matcher == "Sequential": 81 | # :src/colmap/util/option_manager.cc:ModifyForVideoData 82 | mapping_options.ba_global_images_ratio = 1.4 83 | mapping_options.ba_global_points_ratio = 1.4 84 | maps = pycolmap.incremental_mapping( 85 | database_path=db_path, 86 | image_path=images_dir, 87 | output_path=sparse_reconstructions_dir, 88 | options=mapping_options, 89 | ) 90 | return maps 91 | 92 | 93 | def colmap_bundle_adjustment( 94 | sparse_reconstruction_dir: Path, 95 | max_num_iterations: int, 96 | ) -> pycolmap.Reconstruction: 97 | sparse_reconstruction_dir = Path(sparse_reconstruction_dir) 98 | ba_options = { 99 | "refine_principal_point": True, 100 | "solver_options": { 101 | "max_num_iterations": max_num_iterations, 102 | }, 103 | } 104 | recon = pycolmap.bundle_adjustment( 105 | input_path=sparse_reconstruction_dir, 106 | output_path=sparse_reconstruction_dir, 107 | options=ba_options, 108 | ) 109 | return recon 110 | 111 | 112 | def undistort( 113 | images_dir: Path, 114 | sparse_reconstruction_dir: Path, 115 | undistorted_images_dir: Path, 116 | ): 117 | images_dir, sparse_reconstruction_dir, undistorted_images_dir = ( 118 | Path(images_dir), 119 | Path(sparse_reconstruction_dir), 120 | Path(undistorted_images_dir), 121 | ) 122 | pycolmap.undistort_images( 123 | output_path=undistorted_images_dir, 124 | input_path=sparse_reconstruction_dir, 125 | image_path=images_dir, 126 | ) 127 | 128 | 129 | def export_text_format_model( 130 | sparse_reconstruction_dir: Path, 131 | text_model_dir: Path, 132 | ): 133 | sparse_reconstruction_dir, text_model_dir = ( 134 | Path(sparse_reconstruction_dir), 135 | Path(text_model_dir), 136 | ) 137 | text_model_dir.mkdir(parents=True, exist_ok=True) 138 | reconstruction = pycolmap.Reconstruction(sparse_reconstruction_dir) 139 | reconstruction.write_text(text_model_dir.as_posix()) 140 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/integrating/impl.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Tuple 3 | 4 | import jax 5 | from jax.interpreters import mlir, xla 6 | from jax.lib import xla_client 7 | 8 | from . import abstract, lowering 9 | from .. import volrendutils_cuda 10 | 11 | 12 | # register GPU XLA custom calls 13 | for name, value in volrendutils_cuda.get_integrating_registrations().items(): 14 | xla_client.register_custom_call_target(name, value, platform="gpu") 15 | 16 | 17 | # primitives 18 | integrate_rays_p = jax.core.Primitive("integrate_rays🎨") 19 | integrate_rays_p.multiple_results = True 20 | integrate_rays_p.def_impl(functools.partial(xla.apply_primitive, integrate_rays_p)) 21 | integrate_rays_p.def_abstract_eval(abstract.integrate_rays_abstract) 22 | 23 | integrate_rays_bwd_p = jax.core.Primitive("integrate_rays🎨backward") 24 | integrate_rays_bwd_p.multiple_results = True 25 | integrate_rays_bwd_p.def_impl(functools.partial(xla.apply_primitive, integrate_rays_bwd_p)) 26 | integrate_rays_bwd_p.def_abstract_eval(abstract.integrate_rays_backward_abstract) 27 | 28 | integrate_rays_inference_p = jax.core.Primitive("integrate_rays🎨inference") 29 | integrate_rays_inference_p.multiple_results = True 30 | integrate_rays_inference_p.def_impl(functools.partial(xla.apply_primitive, integrate_rays_inference_p)) 31 | integrate_rays_inference_p.def_abstract_eval(abstract.integrate_rays_inference_abstract) 32 | 33 | # register mlir lowering rules 34 | mlir.register_lowering( 35 | prim=integrate_rays_p, 36 | rule=lowering.integrate_rays_lowering_rule, 37 | platform="gpu", 38 | ) 39 | mlir.register_lowering( 40 | prim=integrate_rays_bwd_p, 41 | rule=lowering.integrate_rays_backward_lowring_rule, 42 | platform="gpu", 43 | ) 44 | mlir.register_lowering( 45 | prim=integrate_rays_inference_p, 46 | rule=lowering.integrate_rays_inference_lowering_rule, 47 | platform="gpu", 48 | ) 49 | 50 | @jax.custom_vjp 51 | def __integrate_rays( 52 | near_distance: float, 53 | rays_sample_startidx: jax.Array, 54 | rays_n_samples: jax.Array, 55 | bgs: jax.Array, 56 | dss: jax.Array, 57 | z_vals: jax.Array, 58 | drgbs: jax.Array, 59 | ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: 60 | bgs = jax.numpy.broadcast_to(bgs, (rays_sample_startidx.shape[0], 3)) 61 | 62 | measured_batch_size, final_rgbds, final_opacities = integrate_rays_p.bind( 63 | rays_sample_startidx, 64 | rays_n_samples, 65 | bgs, 66 | dss, 67 | z_vals, 68 | drgbs, 69 | ) 70 | 71 | return measured_batch_size, final_rgbds, final_opacities 72 | 73 | def __fwd_integrate_rays( 74 | near_distance: float, 75 | rays_sample_startidx: jax.Array, 76 | rays_n_samples: jax.Array, 77 | bgs: jax.Array, 78 | dss: jax.Array, 79 | z_vals: jax.Array, 80 | drgbs: jax.Array, 81 | ): 82 | bgs = jax.numpy.broadcast_to(bgs, (rays_sample_startidx.shape[0], 3)) 83 | 84 | primal_outputs = __integrate_rays( 85 | near_distance=near_distance, 86 | rays_sample_startidx=rays_sample_startidx, 87 | rays_n_samples=rays_n_samples, 88 | bgs=bgs, 89 | dss=dss, 90 | z_vals=z_vals, 91 | drgbs=drgbs, 92 | ) 93 | measured_batch_size, final_rgbds, final_opacities = primal_outputs 94 | aux = { 95 | "in.near_distance": near_distance, 96 | "in.rays_sample_startidx": rays_sample_startidx, 97 | "in.rays_n_samples": rays_n_samples, 98 | "in.bgs": bgs, 99 | "in.dss": dss, 100 | "in.z_vals": z_vals, 101 | "in.drgbs": drgbs, 102 | 103 | "out.measured_batch_size": measured_batch_size, 104 | "out.final_rgbds": final_rgbds, 105 | "out.final_opacities": final_opacities, 106 | } 107 | return primal_outputs, aux 108 | 109 | def __bwd_integrate_rays(aux, grads): 110 | _, dL_dfinal_rgbds, dL_dfinal_opacities = grads # dL_dfinal_rgbds should be zeros everywhere 111 | dL_dbgs, dL_dz_vals, dL_ddrgbs = integrate_rays_bwd_p.bind( 112 | aux["in.rays_sample_startidx"], 113 | aux["in.rays_n_samples"], 114 | aux["in.bgs"], 115 | aux["in.dss"], 116 | aux["in.z_vals"], 117 | aux["in.drgbs"], 118 | 119 | aux["out.final_rgbds"], 120 | aux["out.final_opacities"], 121 | 122 | dL_dfinal_rgbds, 123 | 124 | near_distance=aux["in.near_distance"], 125 | ) 126 | return ( 127 | # The first primal input is `near_distance`, a static argument, return no gradient for it. 128 | None, 129 | # The next 2 primal inputs are integer-valued arrays (`rays_sample_startidx`, 130 | # `rays_n_samples`), return no gradient for them. 131 | # REF: 132 | # : 133 | # Wherever we used to use nondiff_argnums for array values, we should just pass those as 134 | # regular arguments. In the bwd rule, we need to produce values for them, but we can just 135 | # produce `None` values to indicate there’s no corresponding gradient value. 136 | None, None, 137 | # 4-th primal input is `dss`, no gradient 138 | None, 139 | # gradients for background colors, z_vals and model predictions (densites and rgbs) 140 | dL_dbgs, dL_dz_vals, dL_ddrgbs 141 | ) 142 | 143 | __integrate_rays.defvjp( 144 | fwd=__fwd_integrate_rays, 145 | bwd=__bwd_integrate_rays, 146 | ) 147 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/integrating/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | 5 | from . import impl 6 | 7 | 8 | # this function is a wrapper on top of `__integrate_rays` which has custom vjp (wrapping the 9 | # `__integrate_rays` function because the @jax.custom_vjp decorator makes the decorated function's 10 | # docstring invisible to LSPs). 11 | def integrate_rays( 12 | near_distance: float, 13 | rays_sample_startidx: jax.Array, 14 | rays_n_samples: jax.Array, 15 | bgs: jax.Array, 16 | dss: jax.Array, 17 | z_vals: jax.Array, 18 | drgbs: jax.Array, 19 | ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: 20 | """ 21 | Inputs: 22 | near_distance `float`: camera's near distance, samples behind the camera's near plane with 23 | non-negligible introduce a penalty on their densities\n" 24 | 25 | rays_sample_startidx `[n_rays]`: i-th element is the index of the first sample in z_vals, 26 | densities, and rgbs of the i-th ray 27 | rays_n_samples `[n_rays]`: i-th element is the number of samples for the i-th ray 28 | 29 | bgs `[n_rays, 3]`: background colors of each ray 30 | dss [total_samples]: it means `ds`s, the notation `ds` comes from the article "Local and 31 | global illumination in the volume rendering integral" written by Nelson 32 | Max and Min Chen, 2005. The product of `ds[i]` and `densities[i]` 33 | represents the probability of the ray terminates anywhere between 34 | `z_vals[i]` and `z_vals[i]+ds[i]`. 35 | Note that `ds[i]` is _not_ the same as `z_vals[i+1]-z_vals[i]` (though 36 | they may equal), because: (1) if empty spaces are skipped during ray 37 | marching, `z_vals[i+1]-z_vals[i]` may be very large, in which case it's 38 | no longer appropriate to assume the density is constant along this 39 | large segment; (2) `z_vals[i+1]` is not defined for the last sample. 40 | z_vals [total_samples]: z_vals[i] is the distance of the i-th sample from the camera 41 | drgbs [total_samples, 4]: density (1) and rgb (3) values along a ray 42 | 43 | Returns: 44 | measured_batch_size `uint`: total number of samples that got composited into output 45 | final_rgbds `[n_rays, 4]`: integrated ray colors and estimated depths according to input 46 | densities and rgbs. 47 | final_opacities `[n_rays]`: accumulated opacities along each ray 48 | """ 49 | measured_batch_size, final_rgbds, final_opacities = impl.__integrate_rays( 50 | near_distance=near_distance, 51 | rays_sample_startidx=rays_sample_startidx, 52 | rays_n_samples=rays_n_samples, 53 | bgs=bgs, 54 | dss=dss, 55 | z_vals=z_vals, 56 | drgbs=drgbs, 57 | ) 58 | 59 | return measured_batch_size[0], final_rgbds, final_opacities 60 | 61 | 62 | def integrate_rays_inference( 63 | rays_bg: jax.Array, 64 | rays_rgbd: jax.Array, 65 | rays_T: jax.Array, 66 | 67 | n_samples: jax.Array, 68 | indices: jax.Array, 69 | dss: jax.Array, 70 | z_vals: jax.Array, 71 | drgbs: jax.Array, 72 | ): 73 | """ 74 | Inputs: 75 | rays_bg `float` `[n_total_rays, 3]`: normalized background color of each ray in question 76 | rays_rgbd `float` `[n_total_rays, 4]`: target array to write rendered colors and estimated 77 | depths to 78 | rays_T `float` `[n_total_rays]`: accumulated transmittance of each ray 79 | 80 | n_samples `uint32` `[n_rays]`: output of ray marching, specifies how many samples are 81 | generated for this ray at this iteration 82 | indices `uint32` `[n_rays]`: values are in range [0, n_total_rays), specifies the location 83 | in `rays_bg`, `rays_rgbd`, `rays_T`, and `rays_depth` 84 | corresponding to this ray 85 | dss `float` `[n_rays, march_steps_cap]`: each sample's `ds` 86 | z_vals `float` `[n_rays, march_steps_cap]`: each sample's distance to its ray origin 87 | drgbs `float` `[n_rays, march_steps_cap, 4]`: predicted density (1) and RGB (3) values from a NeRF model 88 | 89 | Returns: 90 | terminate_cnt `uint32`: number of rays that terminated this iteration 91 | terminated `bool` `[n_rays]`: a binary mask, the i-th location being True means the i-th ray 92 | has terminated 93 | rays_rgbd `float` `[n_total_rays, 3]`: the input `rays_rgbd` with ray colors and estimated 94 | depths updated 95 | rays_T `float` `[n_total_rays]`: the input `rays_T` with transmittance values updated 96 | """ 97 | terminate_cnt, terminated, rays_rgbd_out, rays_T_out = impl.integrate_rays_inference_p.bind( 98 | rays_bg, 99 | rays_rgbd, 100 | rays_T, 101 | 102 | n_samples, 103 | indices, 104 | dss, 105 | z_vals, 106 | drgbs, 107 | ) 108 | rays_rgbd = rays_rgbd.at[indices].set(rays_rgbd_out) 109 | rays_T = rays_T.at[indices].set(rays_T_out) 110 | return terminate_cnt[0], terminated, rays_rgbd, rays_T 111 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "JAX implementation of instant-ngp (NeRF part)"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:nixos/nixpkgs/22.11"; 6 | nixpkgs-with-nvidia-driver-fix.url = "github:nixos/nixpkgs/pull/222762/head"; 7 | flake-utils.url = "github:numtide/flake-utils/3db36a8b464d0c4532ba1c7dda728f4576d6d073"; 8 | nixgl = { 9 | url = "github:guibou/nixgl/c917918ab9ebeee27b0dd657263d3f57ba6bb8ad"; 10 | inputs = { 11 | nixpkgs.follows = "nixpkgs"; 12 | flake-utils.follows = "flake-utils"; 13 | }; 14 | }; 15 | }; 16 | 17 | outputs = inputs@{ self, nixpkgs, flake-utils, ... }: let 18 | deps = import ./deps; 19 | in flake-utils.lib.eachSystem [ "x86_64-linux" "aarch64-linux" ] (system: let 20 | inherit (nixpkgs) lib; 21 | basePkgs = import nixpkgs { 22 | inherit system; 23 | overlays = [ 24 | self.overlays.default 25 | ]; 26 | }; 27 | in { 28 | devShells = let 29 | pyVer = "310"; 30 | py = "python${pyVer}"; 31 | jaxOverlays = final: prev: { 32 | # avoid rebuilding opencv4 with cuda for tensorflow-datasets 33 | opencv4 = prev.opencv4.override { 34 | enableCuda = false; 35 | }; 36 | ${py} = prev.${py}.override { 37 | packageOverrides = finalScope: prevScope: { 38 | jax = prevScope.jax.overridePythonAttrs (o: { doCheck = false; }); 39 | jaxlib = prevScope.jaxlib-bin; 40 | flax = prevScope.flax.overridePythonAttrs (o: { 41 | buildInputs = o.buildInputs ++ [ prevScope.pyyaml ]; 42 | doCheck = false; 43 | }); 44 | tensorflow = prevScope.tensorflow.override { 45 | # we only use tensorflow-datasets for data loading, it does not need to be built 46 | # with cuda support (building with cuda support is too demanding). 47 | cudaSupport = false; 48 | }; 49 | }; 50 | }; 51 | }; 52 | overlays = [ 53 | inputs.nixgl.overlays.default 54 | self.overlays.default 55 | jaxOverlays 56 | ]; 57 | cudaPkgs = import nixpkgs { 58 | inherit system overlays; 59 | config = { 60 | allowUnfree = true; 61 | cudaSupport = true; 62 | packageOverrides = pkgs: { 63 | linuxPackages = (import inputs.nixpkgs-with-nvidia-driver-fix {}).linuxPackages; 64 | }; 65 | }; 66 | }; 67 | cpuPkgs = import nixpkgs { 68 | inherit system overlays; 69 | config = { 70 | allowUnfree = true; 71 | cudaSupport = false; # NOTE: disable cuda for cpu env 72 | }; 73 | }; 74 | mkPythonDeps = { pp, extraPackages }: with pp; [ 75 | ipython 76 | tqdm 77 | icecream 78 | pillow 79 | ipdb 80 | colorama 81 | imageio 82 | ffmpeg-python 83 | pydantic 84 | natsort 85 | GitPython 86 | 87 | pkgs.dearpygui 88 | pkgs.pycolmap 89 | pkgs.tyro 90 | 91 | tensorflow 92 | keras 93 | jaxlib-bin 94 | jax 95 | optax 96 | flax 97 | 98 | pillow 99 | matplotlib 100 | ] ++ extraPackages; 101 | commonShellHook = '' 102 | export PYTHONBREAKPOINT=ipdb.set_trace 103 | export PYTHONDONTWRITEBYTECODE=1 104 | export PYTHONUNBUFFERED=1 105 | [[ "$-" == *i* ]] && exec "$SHELL" 106 | ''; 107 | in rec { 108 | default = cudaDevShell; 109 | cudaDevShell = let # impure 110 | isWsl = builtins.pathExists /usr/lib/wsl/lib; 111 | in cudaPkgs.mkShell { 112 | name = "cuda"; 113 | buildInputs = [ 114 | cudaPkgs.colmap-locked 115 | cudaPkgs.ffmpeg 116 | (cudaPkgs.${py}.withPackages (pp: mkPythonDeps { 117 | inherit pp; 118 | extraPackages = with pp; [ 119 | pkgs.spherical-harmonics-encoding-jax 120 | pkgs.volume-rendering-jax 121 | pkgs.jax-tcnn 122 | ]; 123 | })) 124 | ]; 125 | # REF: 126 | # 127 | XLA_FLAGS = with builtins; let 128 | nvidiaDriverVersion = 129 | head (match ".*Module ([0-9\\.]+) .*" (readFile /proc/driver/nvidia/version)); 130 | nvidiaDriverVersionMajor = lib.toInt (head (splitVersion nvidiaDriverVersion)); 131 | in lib.optionalString 132 | (!isWsl && nvidiaDriverVersionMajor <= 470) 133 | "--xla_gpu_force_compilation_parallelism=1"; 134 | shellHook = '' 135 | source <(sed -Ee '/\$@/d' ${lib.getExe cudaPkgs.nixgl.nixGLIntel}) 136 | '' + (if isWsl 137 | then ''export LD_LIBRARY_PATH=/usr/lib/wsl/lib''${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}'' 138 | else ''source <(sed -Ee '/\$@/d' ${lib.getExe cudaPkgs.nixgl.auto.nixGLNvidia}*)'' 139 | ) + "\n" + commonShellHook; 140 | }; 141 | 142 | cpuDevShell = cpuPkgs.mkShell { 143 | name = "cpu"; 144 | buildInputs = [ 145 | cpuPkgs.colmap-locked 146 | cpuPkgs.ffmpeg 147 | (cpuPkgs.${py}.withPackages (pp: mkPythonDeps { 148 | inherit pp; 149 | extraPackages = []; 150 | })) 151 | ]; 152 | shellHook = '' 153 | '' + commonShellHook; 154 | }; 155 | }; 156 | packages = deps.packages basePkgs; 157 | }) // { 158 | overlays.default = deps.overlay; 159 | }; 160 | } 161 | -------------------------------------------------------------------------------- /utils/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from dataclasses import dataclass 4 | from functools import partial, reduce 5 | import os 6 | from pathlib import Path 7 | from typing import Annotated, List 8 | from typing_extensions import assert_never 9 | 10 | from PIL import Image 11 | import numpy as np 12 | import tyro 13 | 14 | from utils.common import setup_logging 15 | from utils.data import ( 16 | add_border, 17 | blend_rgba_image_array, 18 | create_scene_from_single_camera_image_collection, 19 | create_scene_from_video, 20 | psnr, 21 | side_by_side, 22 | ) 23 | from utils.types import RGBColor, SceneCreationOptions 24 | 25 | 26 | @dataclass(frozen=True, kw_only=True) 27 | class Concatenate: 28 | image_paths: tyro.conf.Positional[List[Path]] 29 | # output image save path, the path will be overwritten with a warning 30 | out: Path 31 | # if specified, concatenate vertically instead of horizontally 32 | vertical: bool=False 33 | # gap between adjacent images, in pixels 34 | gap: int=0 35 | # border in pixels 36 | border: int=0 37 | bg: RGBColor=(1.0, 1.0, 1.0) 38 | 39 | 40 | @dataclass(frozen=True, kw_only=True) 41 | class Metrics: 42 | gt: Path 43 | image_paths: tyro.conf.Positional[List[Path]] 44 | psnr: bool=True 45 | bg: RGBColor=(1.0, 1.0, 1.0) 46 | 47 | 48 | @dataclass(frozen=True, kw_only=True) 49 | class CreateScene: 50 | # path to a video or a directory of image collection 51 | src: tyro.conf.Positional[Path] 52 | 53 | # where to write the images and transforms_{train,val,test}.json 54 | root_dir: Path 55 | 56 | # how many frames to extract per second, only required when src is a video 57 | fps: int | None=None 58 | 59 | scene_opts: tyro.conf.OmitArgPrefixes[SceneCreationOptions] 60 | 61 | CmdCat = Annotated[ 62 | Concatenate, 63 | tyro.conf.subcommand( 64 | name="cat", 65 | prefix_name=False, 66 | description="concatenate images horizontally or vertically", 67 | ), 68 | ] 69 | CmdMetrics = Annotated[ 70 | Metrics, 71 | tyro.conf.subcommand( 72 | name="metrics", 73 | prefix_name=False, 74 | description="compute metrics between images", 75 | ), 76 | ] 77 | CmdCreateScene = Annotated[ 78 | CreateScene, 79 | tyro.conf.subcommand( 80 | name="create-scene", 81 | prefix_name=False, 82 | description="create a instant-ngp-compatible scene from a video or a directory of images", 83 | ), 84 | ] 85 | 86 | 87 | Args = CmdCat | CmdCreateScene | CmdMetrics 88 | 89 | 90 | def main(args: Args): 91 | logger = setup_logging("utils", level="DEBUG") 92 | if isinstance(args, Concatenate): 93 | if args.out.is_dir(): 94 | logger.error("output path '{}' is a directory".format(args.out)) 95 | exit(1) 96 | if args.out.exists(): 97 | logger.warn("output path '{}' exists and will be overwritten".format(args.out)) 98 | if not os.access(args.out, os.W_OK): 99 | logger.error("output path '{}' is readonly".format(args.out)) 100 | exit(2) 101 | if args.out.suffix.lower() not in map(lambda x: "." + x, ["jpg", "jpeg", "png", "tif", "tiff", "bmp", "webp"]): 102 | logger.warn("the file extension '{}' might not be supported".format(args.out.suffix)) 103 | 104 | images = list(map( 105 | lambda img: blend_rgba_image_array(img, bg=args.bg) if img.shape[-1] == 4 else img, 106 | map(np.asarray, map(Image.open, args.image_paths)), 107 | )) 108 | height, width = images[0].shape[:2] 109 | oimg = reduce( 110 | partial( 111 | side_by_side, 112 | height=(None if args.vertical else height), 113 | width=(width if args.vertical else None), 114 | vertical=args.vertical, 115 | gap=args.gap, 116 | ), 117 | images, 118 | ) 119 | oimg = add_border(oimg, border_pixels=args.border) 120 | logger.info("saving image ...") 121 | Image.fromarray(np.asarray(oimg)).save(args.out) 122 | logger.info("image ({}x{}) saved to '{}'".format(oimg.shape[1], oimg.shape[0], args.out)) 123 | 124 | elif isinstance(args, Metrics): 125 | gt_image = np.asarray(Image.open(args.gt)) 126 | if gt_image.shape[-1] == 4: 127 | gt_image = blend_rgba_image_array(gt_image, bg=args.bg) 128 | images = list(map( 129 | lambda img: blend_rgba_image_array(img, bg=args.bg) if img.shape[-1] == 4 else img, 130 | map(np.asarray, map(Image.open, args.image_paths)), 131 | )) 132 | for impath, img in zip(args.image_paths, images): 133 | if args.psnr: 134 | logger.info("psnr={} ({})".format(psnr(gt_image, img), impath)) 135 | 136 | elif isinstance(args, CreateScene): 137 | if args.src.is_dir(): 138 | create_scene_from_single_camera_image_collection( 139 | raw_images_dir=args.src, 140 | scene_root_dir=args.root_dir, 141 | opts=args.scene_opts, 142 | ) 143 | else: 144 | assert args.fps is not None, "must specify extracted frames per second via --fps for video source" 145 | create_scene_from_video( 146 | video_path=args.src, 147 | scene_root_dir=args.root_dir, 148 | fps=args.fps, 149 | opts=args.scene_opts, 150 | ) 151 | 152 | else: 153 | assert_never("tyro already ensures subcommand passed here are valid, this line should never be executed") 154 | 155 | 156 | if __name__ == "__main__": 157 | args = tyro.cli(Args) 158 | main(args) 159 | -------------------------------------------------------------------------------- /deps/spherical-harmonics-encoding-jax/src/shjax/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import chex 4 | import jax 5 | from jax.abstract_arrays import ShapedArray 6 | from jax.interpreters import batching, mlir, xla 7 | from jax.interpreters.mlir import ir 8 | from jax.lib import xla_client 9 | import jax.numpy as jnp 10 | 11 | from . import cudaops 12 | 13 | try: 14 | from jaxlib.mhlo_helpers import custom_call 15 | except ModuleNotFoundError: 16 | # A more recent jaxlib would have `hlo_helpers` instead of `mhlo_helpers` 17 | # 18 | from jaxlib.hlo_helpers import custom_call 19 | 20 | 21 | # register GPU XLA custom calls 22 | for name, value in cudaops.get_registrations().items(): 23 | xla_client.register_custom_call_target(name, value, platform="gpu") 24 | 25 | # jit rules, infer returned shape according to input 26 | def _spherical_harmonics_encoding_abstract(coord: jax.Array, hint: jax.Array): 27 | """ 28 | Inputs: 29 | coord [..., 3] float: input coordinates 30 | hint [degree]: an array with shape [L], this is used to hint the function with the desired 31 | spherical harmonics degrees 32 | """ 33 | (*n, _), dtype = coord.shape, coord.dtype 34 | n = functools.reduce(lambda x, y: x * y, n) 35 | degree, = hint.shape 36 | dtype = jax.dtypes.canonicalize_dtype(coord.dtype) 37 | return ShapedArray(shape=(n, degree * degree), dtype=dtype) 38 | 39 | # register the primitive 40 | sh_enc_p = jax.core.Primitive("spherical_harmonics_encoding🌐") 41 | sh_enc_p.multiple_results = False 42 | sh_enc_p.def_impl(functools.partial(xla.apply_primitive, sh_enc_p)) 43 | sh_enc_p.def_abstract_eval(_spherical_harmonics_encoding_abstract) 44 | 45 | # helper function for mapping given shapes to their default mlir layouts 46 | def default_layouts(*shapes): 47 | return [range(len(shape) - 1, -1, -1) for shape in shapes] 48 | 49 | # mlir lowering rule 50 | def _spherical_harmonics_encoding_lowering_cuda( 51 | ctx: mlir.LoweringRuleContext, 52 | coord: ir.Value, 53 | hint: ir.Value, 54 | ): 55 | coord_type = ir.RankedTensorType(coord.type) 56 | coord_shape = coord_type.shape 57 | 58 | n, _ = coord_shape 59 | degree, = ir.RankedTensorType(hint.type).shape 60 | 61 | result_shape = (n, degree * degree) 62 | 63 | opaque = cudaops.make_spherical_harmonics_encoding_descriptor(n, degree) 64 | 65 | # Documentation says directly return the `custom_call` would suffice, but directly returning 66 | # here results in error "Output of translation rule must be iterable: ...", so let's make it 67 | # iterable. 68 | # NOTE: 69 | # A newer jaxlib (current 0.3.22) may require this to be a single custom_call(...), instead of 70 | # an iterable, as documentation suggests. 71 | # REF: 72 | # documentation: 73 | # tutorial: 74 | return [custom_call( 75 | "spherical_harmonics_encoding_cuda_f32", # the name of the registered XLA custom call at the top of this script 76 | out_types=[ 77 | ir.RankedTensorType.get(result_shape, coord_type.element_type), 78 | ], 79 | operands=[coord], 80 | backend_config=opaque, 81 | operand_layouts=default_layouts(coord_shape), 82 | result_layouts=default_layouts(result_shape), 83 | )] 84 | 85 | mlir.register_lowering( 86 | prim=sh_enc_p, 87 | rule=_spherical_harmonics_encoding_lowering_cuda, 88 | platform="gpu", 89 | ) 90 | 91 | # vmap support. REF: 92 | def spherical_harmonics_encoding_batch(args, axes): 93 | """ 94 | The primitive is already able to handle arbitrary shape (except for the last axis, which must 95 | have a dimension of 3), directly binding to the primitive impl should suffice. 96 | 97 | Inputs: 98 | args: Passed to def_impl, contains two tensors: `coord` and `hint`, where only `coord` is 99 | batched. args is (coord, hint) 100 | axes: The axes that are being batched, one value for each arg, value is an integer if the 101 | arg is batched, value is None if the arg is not batched. In this case, 102 | coord.shape[axes[0]] = B, and axes[1] = None. 103 | """ 104 | coord, hint = args 105 | assert coord.shape[-1] == 3, "spatial coordinates must be the last dimension" 106 | 107 | enc = sh_enc_p.bind(coord, hint) 108 | # or: 109 | # enc = spherical_harmonics_encoding( 110 | # coord=coord, 111 | # degree=hint.shape[0], 112 | # ) 113 | 114 | # return the result, and the result axis that was batched 115 | return enc, axes[0] 116 | 117 | batching.primitive_batchers[sh_enc_p] = spherical_harmonics_encoding_batch 118 | 119 | # the only exposed function 120 | def spherical_harmonics_encoding(coord: jax.Array, degree: int) -> jax.Array: 121 | """ 122 | Spherical harmonics encoding with GPU acceleration, expects unit vectors as input. 123 | 124 | Inputs: 125 | coord [..., 3] float: input 3D coordinates 126 | degree int: highest degree used in spherical harmonics 127 | 128 | Returns: 129 | outputs [..., degree**2] float: encoded coordinates 130 | """ 131 | chex.assert_rank(coord, 2) 132 | chex.assert_axis_dimension(coord, -1, 3) 133 | chex.assert_scalar_non_negative(degree) 134 | return sh_enc_p.bind(coord, jnp.empty((degree,))) 135 | -------------------------------------------------------------------------------- /app/imagefit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | from PIL import Image 7 | from flax.training.train_state import TrainState 8 | import jax 9 | import jax.numpy as jnp 10 | import jax.random as jran 11 | import numpy as np 12 | import optax 13 | from tqdm import tqdm 14 | import tyro 15 | 16 | from models.imagefit import ImageFitter 17 | from utils import common, data 18 | from utils.args import ImageFitArgs 19 | 20 | 21 | logger = common.setup_logging("imagefit") 22 | 23 | 24 | @jax.jit 25 | def train_step(state: TrainState, uvs, rgbs, perm): 26 | def loss(params, x, y): 27 | preds = state.apply_fn({"params": params}, x) 28 | loss = jnp.square(preds - y).mean() 29 | return loss 30 | 31 | loss_grad_fn = jax.value_and_grad(loss) 32 | 33 | loss, grads = loss_grad_fn(state.params, uvs[perm], rgbs[perm]) 34 | state = state.apply_gradients(grads=grads) 35 | metrics = { 36 | "loss": loss * perm.shape[0], 37 | } 38 | return state, metrics 39 | 40 | 41 | def train_epoch( 42 | image_metadata: data.ImageMetadata, 43 | permutation: data.Dataset, 44 | total_batches: int, 45 | state: TrainState, 46 | ep_log: int, 47 | ): 48 | loss = 0 49 | for perm in tqdm(permutation, total=total_batches, desc="ep#{:03d}".format(ep_log), bar_format=common.tqdm_format): 50 | state, metrics = train_step(state, image_metadata.uvs, image_metadata.rgbs, perm) 51 | loss += metrics["loss"] 52 | return loss, state 53 | 54 | 55 | @jax.jit 56 | def eval_step(state, uvs, perm): 57 | preds = state.apply_fn({"params": state.params}, uvs[perm]) 58 | return preds 59 | 60 | 61 | def eval( 62 | image_array, 63 | image_metadata: data.ImageMetadata, 64 | state: TrainState, 65 | ): 66 | H, W = image_array.shape[:2] 67 | 68 | @common.jit_jaxfn_with(static_argnames=["chunk_size"]) 69 | def get_perms(chunk_size: int) -> list[jax.Array]: 70 | all_perms = jnp.arange(H*W) 71 | if chunk_size >= H*W: 72 | n_chunks = 1 73 | else: 74 | n_chunks = H*W // chunk_size 75 | perms = jnp.array_split(all_perms, n_chunks) 76 | return perms 77 | 78 | for perm in tqdm(get_perms(chunk_size=2**15), desc="evaluating", bar_format=common.tqdm_format): 79 | # preds = state.apply_fn({"params": state.params}, uv) 80 | preds = eval_step(state, image_metadata.uvs, perm) 81 | image_array = data.set_pixels(image_array, image_metadata.xys, perm, preds) 82 | 83 | return image_array 84 | 85 | 86 | def main( 87 | args: ImageFitArgs, 88 | in_image: Path, 89 | out_path: Path, 90 | encoding: Literal["hashgrid", "frequency"], 91 | # Enable this to suppress prompt if out_path exists and directly overwrite the file. 92 | overwrite: bool = False, 93 | encoding_prec: int = 32, 94 | model_summary: bool = False, 95 | ): 96 | logger.setLevel(args.common.logging.upper()) 97 | 98 | if not out_path.parent.is_dir(): 99 | logger.err("Output path's parent '{}' does not exist or is not a directory!".format(out_path.parent)) 100 | exit(1) 101 | 102 | if out_path.exists() and not overwrite: 103 | logger.warn("Output path '{}' exists and will be overwritten!".format(out_path)) 104 | try: 105 | r = input("Continue? [y/N] ") 106 | if (r.strip() + "n").lower()[0] != "y": 107 | exit(0) 108 | except EOFError: 109 | print() 110 | exit(0) 111 | except KeyboardInterrupt: 112 | print() 113 | exit(0) 114 | 115 | encoding_dtype = getattr(jnp, "float{}".format(encoding_prec)) 116 | dtype = getattr(jnp, "float{}".format(args.common.prec)) 117 | 118 | # deterministic 119 | K = common.set_deterministic(args.common.seed) 120 | 121 | # model parameters 122 | K, key = jran.split(K, 2) 123 | model, init_input = ( 124 | ImageFitter(encoding=encoding, encoding_dtype=encoding_dtype), 125 | jnp.zeros((1, 2), dtype=dtype), 126 | ) 127 | variables = model.init(key, init_input) 128 | if model_summary: 129 | print(model.tabulate(key, init_input)) 130 | 131 | # training state 132 | state = TrainState.create( 133 | apply_fn=model.apply, 134 | params=variables["params"], 135 | tx=optax.adam( 136 | learning_rate=args.train.lr, 137 | b1=0.9, 138 | b2=0.99, 139 | # paper: 140 | # the small value of 𝜖 = 10^{−15} can significantly accelerate the convergence of the 141 | # hash table entries when their gradients are sparse and weak. 142 | eps=1e-15, 143 | ), 144 | ) 145 | 146 | # data 147 | in_image = np.asarray(Image.open(in_image)) 148 | image_metadata = data.make_image_metadata( 149 | image=in_image, 150 | bg=(1.0, 1.0, 1.0), 151 | ) 152 | 153 | for ep in range(args.train.n_epochs): 154 | ep_log = ep + 1 155 | K, key = jran.split(K, 2) 156 | permutation = data.make_permutation_dataset( 157 | key, 158 | size=image_metadata.W * image_metadata.H, 159 | shuffle=True 160 | )\ 161 | .batch(args.train.bs, drop_remainder=True)\ 162 | .repeat(args.data.loop) 163 | loss, state = train_epoch( 164 | image_metadata=image_metadata, 165 | permutation=permutation.take(args.train.n_batches).as_numpy_iterator(), 166 | total_batches=args.train.n_batches, 167 | state=state, 168 | ep_log=ep_log, 169 | ) 170 | 171 | image = np.asarray(Image.new("RGB", in_image.shape[:2][::-1])) 172 | image = eval(image, image_metadata, state) 173 | logger.debug("saving image of shape {} to {}".format(image.shape, out_path)) 174 | Image.fromarray(np.asarray(image)).save(out_path) 175 | 176 | logger.info( 177 | "epoch#{:03d}: per-pixel loss={:.2e}, psnr={}".format( 178 | ep_log, 179 | loss / (image_metadata.H * image_metadata.W), 180 | data.psnr(in_image, image), 181 | ) 182 | ) 183 | 184 | 185 | if __name__ == "__main__": 186 | tyro.cli(main) 187 | -------------------------------------------------------------------------------- /deps/pycolmap/expose-bundle-adjustment-function.patch: -------------------------------------------------------------------------------- 1 | diff --git a/pipeline/sfm.cc b/pipeline/sfm.cc 2 | index d4cda9a..81eeb79 100644 3 | --- a/pipeline/sfm.cc 4 | +++ b/pipeline/sfm.cc 5 | @@ -3,6 +3,7 @@ 6 | #include "colmap/exe/sfm.h" 7 | #include "colmap/base/camera_models.h" 8 | #include "colmap/base/reconstruction.h" 9 | +#include "colmap/controllers/bundle_adjustment.h" 10 | #include "colmap/controllers/incremental_mapper.h" 11 | #include "colmap/util/misc.h" 12 | 13 | @@ -104,6 +105,33 @@ std::map incremental_mapping( 14 | return reconstructions; 15 | } 16 | 17 | +// Copied from colmap/exe/sfm.cc 18 | +Reconstruction bundle_adjustment( 19 | + const py::object input_path_, 20 | + const py::object output_path_, 21 | + const BundleAdjustmentOptions& ba_options 22 | +) { 23 | + std::string input_path = py::str(input_path_).cast(); 24 | + std::string output_path = py::str(output_path_).cast(); 25 | + 26 | + THROW_CHECK_DIR_EXISTS(input_path); 27 | + THROW_CHECK_DIR_EXISTS(output_path); 28 | + 29 | + Reconstruction reconstruction; 30 | + reconstruction.Read(input_path); 31 | + 32 | + OptionManager options; 33 | + *options.bundle_adjustment = ba_options; 34 | + 35 | + BundleAdjustmentController ba_controller(options, &reconstruction); 36 | + ba_controller.Start(); 37 | + ba_controller.Wait(); 38 | + 39 | + reconstruction.Write(output_path); 40 | + 41 | + return reconstruction; 42 | +} 43 | + 44 | std::map incremental_mapping( 45 | const py::object database_path_, 46 | const py::object image_path_, 47 | @@ -183,6 +211,71 @@ void init_sfm(py::module& m) { 48 | make_dataclass(PyIncrementalMapperOptions); 49 | auto mapper_options = PyIncrementalMapperOptions().cast(); 50 | 51 | + using BAOpts = BundleAdjustmentOptions; 52 | + auto PyBALossFunctionType = 53 | + py::enum_(m, "LossFunctionType") 54 | + .value("TRIVIAL", BAOpts::LossFunctionType::TRIVIAL) 55 | + .value("SOFT_L1", BAOpts::LossFunctionType::SOFT_L1) 56 | + .value("CAUCHY", BAOpts::LossFunctionType::CAUCHY); 57 | + AddStringToEnumConstructor(PyBALossFunctionType); 58 | + using CSOpts = ceres::Solver::Options; 59 | + auto PyCeresSolverOptions = 60 | + py::class_(m, 61 | + "CeresSolverOptions", 62 | + // If ceres::Solver::Options is registered by pycolmap AND a downstream 63 | + // library, importing the downstream library results in error: 64 | + // ImportError: generic_type: type "CeresSolverOptions" is already registered! 65 | + // Adding a `py::module_local()` fixes this. 66 | + // https://github.com/pybind/pybind11/issues/439#issuecomment-1338251822 67 | + py::module_local()) 68 | + .def(py::init<>()) 69 | + .def_readwrite("function_tolerance", &CSOpts::function_tolerance) 70 | + .def_readwrite("gradient_tolerance", &CSOpts::gradient_tolerance) 71 | + .def_readwrite("parameter_tolerance", &CSOpts::parameter_tolerance) 72 | + .def_readwrite("minimizer_progress_to_stdout", &CSOpts::minimizer_progress_to_stdout) 73 | + .def_readwrite("minimizer_progress_to_stdout", &CSOpts::minimizer_progress_to_stdout) 74 | + .def_readwrite("max_num_iterations", &CSOpts::max_num_iterations) 75 | + .def_readwrite("max_linear_solver_iterations", &CSOpts::max_linear_solver_iterations) 76 | + .def_readwrite("max_num_consecutive_invalid_steps", &CSOpts::max_num_consecutive_invalid_steps) 77 | + .def_readwrite("max_consecutive_nonmonotonic_steps", &CSOpts::max_consecutive_nonmonotonic_steps) 78 | + .def_readwrite("num_threads", &CSOpts::num_threads); 79 | + make_dataclass(PyCeresSolverOptions); 80 | + auto PyBundleAdjustmentOptions = 81 | + py::class_(m, "BundleAdjustmentOptions") 82 | + .def(py::init<>()) 83 | + .def_readwrite("loss_function_type", 84 | + &BAOpts::loss_function_type, 85 | + "Loss function types: Trivial (non-robust) and Cauchy (robust) loss.") 86 | + .def_readwrite("loss_function_scale", 87 | + &BAOpts::loss_function_scale, 88 | + "Scaling factor determines residual at which robustification takes place.") 89 | + .def_readwrite("refine_focal_length", 90 | + &BAOpts::refine_focal_length, 91 | + "Whether to refine the focal length parameter group.") 92 | + .def_readwrite("refine_principal_point", 93 | + &BAOpts::refine_principal_point, 94 | + "Whether to refine the principal point parameter group.") 95 | + .def_readwrite("refine_extra_params", 96 | + &BAOpts::refine_extra_params, 97 | + "Whether to refine the extra parameter group.") 98 | + .def_readwrite("refine_extrinsics", 99 | + &BAOpts::refine_extrinsics, 100 | + "Whether to refine the extrinsic parameter group.") 101 | + .def_readwrite("print_summary", 102 | + &BAOpts::print_summary, 103 | + "Whether to print a final summary.") 104 | + .def_readwrite("min_num_residuals_for_multi_threading", 105 | + &BAOpts::min_num_residuals_for_multi_threading, 106 | + "Minimum number of residuals to enable multi-threading. Note that " 107 | + "single-threaded is typically better for small bundle adjustment problems " 108 | + "due to the overhead of threading. " 109 | + ) 110 | + .def_readwrite("solver_options", 111 | + &BAOpts::solver_options, 112 | + "Ceres-Solver options."); 113 | + make_dataclass(PyBundleAdjustmentOptions); 114 | + auto ba_options = PyBundleAdjustmentOptions().cast(); 115 | + 116 | m.def("triangulate_points", 117 | &triangulate_points, 118 | py::arg("reconstruction"), 119 | @@ -207,6 +300,12 @@ void init_sfm(py::module& m) { 120 | py::arg("input_path") = py::str(""), 121 | "Triangulate 3D points from known poses"); 122 | 123 | + m.def("bundle_adjustment", 124 | + &bundle_adjustment, 125 | + py::arg("input_path"), 126 | + py::arg("output_path"), 127 | + py::arg("options") = ba_options); 128 | + 129 | m.def("incremental_mapping", 130 | static_cast (*)(const py::object, 131 | const py::object, 132 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/marching/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from . import impl 7 | 8 | def march_rays( 9 | # static 10 | total_samples: int, 11 | diagonal_n_steps: int, 12 | K: int, 13 | G: int, 14 | bound: float, 15 | stepsize_portion: float, 16 | 17 | # inputs 18 | rays_o: jax.Array, 19 | rays_d: jax.Array, 20 | t_starts: jax.Array, 21 | t_ends: jax.Array, 22 | noises: jax.Array, 23 | occupancy_bitfield: jax.Array, 24 | ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: 25 | """ 26 | Given a pack of rays (`rays_o`, `rays_d`), their intersection time with the scene bounding box 27 | (`t_starts`, `t_ends`), and an occupancy grid (`occupancy_bitfield`), generate samples along 28 | each ray. 29 | 30 | Inputs: 31 | total_samples `int`: , 32 | diagonal_n_steps `int`: the length of a minimal ray marching step is calculated internally 33 | as: 34 | Δ𝑡 := √3 / diagonal_n_steps; 35 | the NGP paper uses diagonal_n_steps=1024 (as described in appendix 36 | E.1). 37 | K `int`: total number of cascades of `occupancy_bitfield` 38 | G `int`: occupancy grid resolution, the paper uses 128 for every cascade 39 | bound `float`: the half length of the longest axis of the scene’s bounding box, 40 | e.g. the `bound` of the bounding box [-1, 1]^3 is 1 41 | stepsize_portion: next step size is calculated as t * stepsize_portion, the paper uses 1/256 42 | 43 | rays_o `[n_rays, 3]`: ray origins 44 | rays_d `[n_rays, 3]`: **unit** vectors representing ray directions 45 | t_starts `[n_rays]`: time of the ray entering the scene bounding box 46 | t_ends `[n_rays]`: time of the ray leaving the scene bounding box 47 | noises `broadcastable to [n_rays]`: noises to perturb the starting point of ray marching 48 | occupancy_bitfield `[K*(G**3)//8]`: the occupancy grid represented as a bit array, grid 49 | cells are laid out in Morton (z-curve) order, as 50 | described in appendix E.2 of the NGP paper 51 | 52 | Returns: 53 | measured_batch_size_before_compaction `int`: total number of generated samples of all rays 54 | ray_is_valid `bool` `[n_rays]`: a mask, where a true value denotes the ray's gradients 55 | should flow, even if there are no samples generated for it 56 | idcs `[total_samples]`: indices indicating which ray the i-th sample comes from. 57 | rays_n_samples `[n_rays]`: number of samples of each ray, its sum is `total_samples` 58 | referenced below 59 | rays_sample_startidx `[n_rays]`: indices of each ray's first sample 60 | xyzs `[total_samples, 3]`: spatial coordinates of the generated samples, invalid array 61 | locations are masked out with zeros 62 | dirs `[total_samples, 3]`: spatial coordinates of the generated samples, invalid array 63 | locations are masked out with zeros. 64 | dss `[total_samples]`: `ds`s of each sample, for a more detailed explanation of this 65 | notation, see documentation of function `volrendjax.integrate_rays`, 66 | invalid array locations are masked out with zeros. 67 | z_vals `[total_samples]`: samples' distances to their origins, invalid array 68 | locations are masked out with zeros. 69 | """ 70 | n_rays, _ = rays_o.shape 71 | noises = jnp.broadcast_to(noises, (n_rays,)) 72 | 73 | next_sample_write_location, number_of_exceeded_samples, ray_is_valid, rays_n_samples, rays_sample_startidx, idcs, xyzs, dirs, dss, z_vals = impl.march_rays_p.bind( 74 | # arrays 75 | rays_o, 76 | rays_d, 77 | t_starts, 78 | t_ends, 79 | noises, 80 | occupancy_bitfield, 81 | 82 | # static args 83 | total_samples=total_samples, 84 | diagonal_n_steps=diagonal_n_steps, 85 | K=K, 86 | G=G, 87 | bound=bound, 88 | stepsize_portion=stepsize_portion, 89 | ) 90 | 91 | measured_batch_size_before_compaction = next_sample_write_location[0] - number_of_exceeded_samples[0] 92 | 93 | return measured_batch_size_before_compaction, ray_is_valid, rays_n_samples, rays_sample_startidx, idcs, xyzs, dirs, dss, z_vals 94 | 95 | 96 | def march_rays_inference( 97 | # static 98 | diagonal_n_steps: int, 99 | K: int, 100 | G: int, 101 | march_steps_cap: int, 102 | bound: float, 103 | stepsize_portion: float, 104 | 105 | # inputs 106 | rays_o: jax.Array, 107 | rays_d: jax.Array, 108 | t_starts: jax.Array, 109 | t_ends: jax.Array, 110 | occupancy_bitfield: jax.Array, 111 | next_ray_index_in: jax.Array, 112 | terminated: jax.Array, 113 | indices: jax.Array, 114 | ): 115 | """ 116 | Inputs: 117 | diagonal_n_steps, K, G, bound, stepsize_portion: see explanations in function `march_rays` 118 | march_steps_cap `int`: maximum steps to march for each ray in this iteration 119 | 120 | rays_o `float` `[n_total_rays, 3]`: ray origins 121 | rays_d `float` `[n_total_rays, 3]`: ray directions 122 | t_starts `float` `n_total_rays`: distance of each ray's starting point to its origin 123 | t_ends `float` `n_total_rays`: distance of each ray's ending point to its origin 124 | occupancy_bitfield `uint8` `[K*(G**3)//8]`: the occupancy grid represented as a bit array 125 | next_ray_index_in `uint32`: helper variable to keep record of the latest ray that got rendered 126 | terminated `bool` `[n_rays]`: output of `integrate_rays_inference`, a binary mask indicating 127 | each ray's termination status 128 | indices `[n_rays]`: each ray's location in the global arrays 129 | 130 | Returns: 131 | next_ray_index `uint32` `[1]`: for use in next iteration 132 | indices `uint32` `[n_rays]`: for use in the integrate_rays_inference immediately after 133 | n_samples `uint32` `[n_rays]`: number of generated samples of each ray in question 134 | t_starts `float` `[n_rays]`: advanced values of `t` for use in next iteration 135 | xyzs `float` `[n_rays, march_steps_cap, 3]`: each sample's XYZ coordinate 136 | dss `float` `[n_rays, march_steps_cap]`: `ds` of each sample 137 | z_vals `float` `[n_rays, march_steps_cap]`: distance of each sample to their ray origins 138 | """ 139 | next_ray_index, indices, n_samples, t_starts_out, xyzs, dss, z_vals = impl.march_rays_inference_p.bind( 140 | rays_o, 141 | rays_d, 142 | t_starts, 143 | t_ends, 144 | occupancy_bitfield, 145 | next_ray_index_in, 146 | terminated, 147 | indices, 148 | 149 | diagonal_n_steps=diagonal_n_steps, 150 | K=K, 151 | G=G, 152 | march_steps_cap=march_steps_cap, 153 | bound=bound, 154 | stepsize_portion=stepsize_portion, 155 | ) 156 | t_starts = t_starts.at[indices].set(t_starts_out) 157 | return next_ray_index, indices, n_samples, t_starts, xyzs, dss, z_vals 158 | -------------------------------------------------------------------------------- /app/nerf/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jran 6 | import optax 7 | 8 | from models.renderers import render_rays_train 9 | from utils import common, data 10 | from utils.types import NeRFState, SceneData 11 | 12 | 13 | __all__ = [ 14 | "make_optimizer", 15 | "train_step", 16 | ] 17 | 18 | 19 | def make_optimizer(lr: float) -> optax.GradientTransformation: 20 | lr_sch = optax.exponential_decay( 21 | init_value=lr, 22 | transition_steps=10_000, 23 | decay_rate=1/3, # decay to `1/3 * init_lr` after `transition_steps` steps 24 | staircase=True, # use integer division to determine lr drop step 25 | transition_begin=10_000, # hold the initial lr value for the initial 10k steps (but first lr drop happens at 20k steps because `staircase` is specified) 26 | end_value=lr / 100, # stop decaying at `1/100 * init_lr` 27 | ) 28 | optimizer_network = optax.adam( 29 | learning_rate=lr_sch, 30 | b1=0.9, 31 | b2=0.99, 32 | # paper: 33 | # the small value of 𝜖 = 10^{−15} can significantly accelerate the convergence of the 34 | # hash table entries when their gradients are sparse and weak. 35 | eps=1e-15, 36 | eps_root=1e-15, 37 | ) 38 | optimizer_ae = optax.adam( 39 | learning_rate=1e-4, 40 | b1=.9, 41 | b2=.99, 42 | eps=1e-8, 43 | eps_root=0, 44 | ) 45 | return optax.chain( 46 | optax.multi_transform( 47 | transforms={ 48 | "network": optimizer_network, 49 | "ae": optimizer_ae, 50 | }, 51 | param_labels={ 52 | "nerf": "network", 53 | "bg": "network", 54 | "appearance_embeddings": "ae", 55 | }, 56 | ), 57 | optax.add_decayed_weights( 58 | # In NeRF experiments, the network can converge to a reasonably low loss during the 59 | # first ~50k training steps (with 1024 rays per batch and 1024 samples per ray), but the 60 | # loss becomes NaN after about 50~150k training steps. 61 | # paper: 62 | # To prevent divergence after long training periods, we apply a weak L2 regularization 63 | # (factor 10^{−6}) to the neural network weights, ... 64 | weight_decay=1e-6, 65 | # paper: 66 | # ... to the neural network weights, but not to the hash table entries. 67 | mask={ 68 | "nerf": { 69 | "density_mlp": True, 70 | "rgb_mlp": True, 71 | "position_encoder": False, 72 | }, 73 | "bg": True, 74 | "appearance_embeddings": False, 75 | }, 76 | ), 77 | ) 78 | 79 | 80 | @common.jit_jaxfn_with( 81 | static_argnames=["total_samples"], 82 | donate_argnums=(0,), # NOTE: this only works for positional arguments, see 83 | ) 84 | def train_step( 85 | state: NeRFState, 86 | /, 87 | KEY: jran.KeyArray, 88 | total_samples: int, 89 | scene: SceneData, 90 | perm: jax.Array, 91 | ) -> Tuple[NeRFState, Dict[str, jax.Array | float]]: 92 | # indices of views and pixels 93 | view_idcs, pixel_idcs = scene.get_view_indices(perm), scene.get_pixel_indices(perm) 94 | 95 | # TODO: 96 | # merge this and `models.renderers.make_rays_worldspace` as a single function 97 | def make_rays_worldspace() -> Tuple[jax.Array, jax.Array]: 98 | # [N], [N] 99 | x, y = ( 100 | jnp.mod(pixel_idcs, scene.meta.camera.width), 101 | jnp.floor_divide(pixel_idcs, scene.meta.camera.width), 102 | ) 103 | # [N, 3] 104 | d_cam = scene.meta.camera.make_ray_directions_from_pixel_coordinates(x, y, use_pixel_center=True) 105 | 106 | # [N, 3] 107 | o_world = scene.transforms[view_idcs, -3:] 108 | 109 | # [N, 3, 3] 110 | R_cws = scene.transforms[view_idcs, :9].reshape(-1, 3, 3) 111 | # [N, 3] 112 | # equavalent to performing `d_cam[i] @ R_cws[i].T` for each i in [0, N) 113 | d_world = (d_cam[:, None, :] * R_cws).sum(-1) 114 | 115 | return o_world, d_world 116 | 117 | # CAVEAT: gradient is only calculate w.r.t. the first parameter of this function 118 | # (`params_to_optimize here`), any parameters that need to be optimized should be taken from 119 | # this parameter, instead from the outer-scope `state.params`. 120 | def loss_fn(params_to_optimize, gt_rgba_f32, KEY): 121 | o_world, d_world = make_rays_worldspace() 122 | appearance_embeddings = ( 123 | params_to_optimize["appearance_embeddings"][view_idcs] 124 | if "appearance_embeddings" in params_to_optimize 125 | else jnp.empty(0) 126 | ) 127 | if state.use_background_model: 128 | bg = state.bg_fn( 129 | {"params": params_to_optimize["bg"]}, 130 | o_world, 131 | d_world, 132 | appearance_embeddings, 133 | ) 134 | elif state.render.random_bg: 135 | KEY, key = jran.split(KEY, 2) 136 | bg = jran.uniform(key, shape=(o_world.shape[0], 3), dtype=jnp.float32, minval=0, maxval=1) 137 | else: 138 | bg = jnp.asarray(state.render.bg) 139 | KEY, key = jran.split(KEY, 2) 140 | batch_metrics, pred_rgbds, tv = render_rays_train( 141 | KEY=key, 142 | o_world=o_world, 143 | d_world=d_world, 144 | appearance_embeddings=appearance_embeddings, 145 | bg=bg, 146 | total_samples=total_samples, 147 | state=state.replace(params=params_to_optimize), 148 | ) 149 | pred_rgbs, pred_depths = jnp.array_split(pred_rgbds, [3], axis=-1) 150 | gt_rgbs = data.blend_rgba_image_array(imgarr=gt_rgba_f32, bg=bg) 151 | batch_metrics["loss"] = { 152 | "rgb": jnp.where( 153 | batch_metrics["ray_is_valid"], 154 | optax.huber_loss(pred_rgbs, gt_rgbs, delta=0.1).mean(axis=-1), 155 | 0., 156 | ).sum() / batch_metrics["n_valid_rays"], 157 | "total_variation": tv, 158 | } 159 | loss = jax.tree_util.tree_reduce(lambda x, y: x + y, batch_metrics["loss"]) 160 | return loss, batch_metrics 161 | 162 | loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 163 | KEY, key = jran.split(KEY, 2) 164 | (_, batch_metrics), grads = loss_grad_fn( 165 | state.params, 166 | scene.rgbas_u8[perm].astype(jnp.float32) / 255, 167 | key, 168 | ) 169 | state = state.apply_gradients(grads=grads) 170 | return state, batch_metrics 171 | 172 | 173 | def format_metrics(metrics: Dict[str, jax.Array | float]) -> str: 174 | loss = metrics["loss"] 175 | return "batch_size={}/{} samp./ray={:.1f}/{:.1f} n_rays={} loss:{{rgb={:.2e}({:.2f}dB),tv={:.2e}}}".format( 176 | metrics["measured_batch_size"], 177 | metrics["measured_batch_size_before_compaction"], 178 | metrics["measured_batch_size"] / metrics["n_valid_rays"], 179 | metrics["measured_batch_size_before_compaction"] / metrics["n_valid_rays"], 180 | metrics["n_valid_rays"], 181 | loss["rgb"], 182 | data.linear_to_db(loss["rgb"], maxval=1.), 183 | loss["total_variation"], 184 | ) 185 | -------------------------------------------------------------------------------- /app/nerf/test.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import List 3 | from typing_extensions import assert_never 4 | 5 | from PIL import Image 6 | from flax.training import checkpoints 7 | import jax 8 | import jax.numpy as jnp 9 | import jax.random as jran 10 | import numpy as np 11 | 12 | from models.nerfs import make_nerf_ngp, make_skysphere_background_model_ngp 13 | from models.renderers import render_image_inference 14 | from utils import common, data 15 | from utils.args import NeRFTestingArgs 16 | from utils.types import NeRFState, RenderedImage, RigidTransformation 17 | 18 | 19 | def test(KEY: jran.KeyArray, args: NeRFTestingArgs, logger: common.Logger) -> int: 20 | args.logs_dir.mkdir(parents=True, exist_ok=True) 21 | logger = common.setup_logging( 22 | "nerf.test", 23 | file=args.logs_dir.joinpath("test.log"), 24 | level=args.common.logging.upper(), 25 | file_level="DEBUG", 26 | ) 27 | if not args.ckpt.exists(): 28 | logger.error("specified checkpoint '{}' does not exist".format(args.ckpt)) 29 | return 1 30 | 31 | scene_data = data.load_scene( 32 | srcs=args.frames, 33 | scene_options=args.scene, 34 | sort_frames=args.sort_frames, 35 | ) 36 | 37 | scene_meta = scene_data.meta 38 | 39 | if args.report_metrics: 40 | logger.warn("will not load gt images because either the intrinsics or the extrinsics of the camera have been changed") 41 | if args.trajectory == "orbit": 42 | scene_meta = scene_meta.make_frames_with_orbiting_trajectory(args.orbit) 43 | logger.info("generated {} camera transforms for testing".format(len(scene_meta.frames))) 44 | else: 45 | logger.debug("loading testing frames from {}".format(args.frames)) 46 | logger.info("loaded {} camera transforms for testing".format(len(scene_meta.frames))) 47 | 48 | if args.camera_override.enabled: 49 | scene_meta = scene_meta.replace(camera=args.camera_override.update_camera(scene_meta.camera)) 50 | 51 | # load parameters 52 | logger.debug("loading checkpoint from '{}'".format(args.ckpt)) 53 | state: NeRFState = checkpoints.restore_checkpoint( 54 | args.ckpt, 55 | target=NeRFState.empty( 56 | raymarch=args.raymarch, 57 | render=args.render, 58 | scene_options=args.scene, 59 | scene_meta=scene_meta, 60 | nerf_fn=make_nerf_ngp(bound=scene_meta.bound, inference=True).apply, 61 | bg_fn=make_skysphere_background_model_ngp(bound=scene_meta.bound).apply if scene_meta.bg else None, 62 | ), 63 | ) 64 | # WARN: 65 | # flax.checkpoints.restore_checkpoint() returns a pytree with all arrays of numpy's array type, 66 | # which slows down inference. use jax.device_put() to move them to jax's default device. 67 | # REF: 68 | state = jax.device_put(state) 69 | if state.step == 0: 70 | logger.error("an empty checkpoint was loaded from '{}'".format(args.ckpt)) 71 | return 2 72 | logger.info("checkpoint loaded from '{}' (step={})".format(args.ckpt, int(state.step))) 73 | 74 | rendered_images: List[RenderedImage] = [] 75 | try: 76 | n_frames = len(scene_meta.frames) 77 | logger.info("starting testing (totally {} transform(s) to test)".format(n_frames)) 78 | for test_i in common.tqdm(range(n_frames), desc="testing (resolultion: {}x{})".format(scene_meta.camera.width, scene_meta.camera.height)): 79 | logger.debug("testing on frame {}".format(scene_meta.frames[test_i])) 80 | transform = RigidTransformation( 81 | rotation=scene_meta.frames[test_i].transform_matrix_jax_array[:3, :3], 82 | translation=scene_meta.frames[test_i].transform_matrix_jax_array[:3, 3], 83 | ) 84 | KEY, key = jran.split(KEY, 2) 85 | bg, rgb, depth, _ = data.to_cpu(render_image_inference( 86 | KEY=key, 87 | transform_cw=transform, 88 | state=state, 89 | )) 90 | rendered_images.append(RenderedImage( 91 | bg=bg, 92 | rgb=rgb, 93 | depth=depth, # call to data.mono_to_rgb is deferred below so as to minimize impact on rendering speed 94 | )) 95 | except KeyboardInterrupt: 96 | logger.warn("keyboard interrupt, tested {} images".format(len(rendered_images))) 97 | 98 | if args.trajectory == "loaded": 99 | if len(rendered_images) == 0: 100 | logger.warn("tested 0 image, not calculating psnr") 101 | else: 102 | gt_rgbs_f32 = map( 103 | lambda test_view, rendered_image: data.blend_rgba_image_array( 104 | test_view.image_rgba_u8.astype(jnp.float32) / 255, 105 | rendered_image.bg, 106 | ), 107 | scene_data.all_views, 108 | rendered_images, 109 | ) 110 | logger.debug("calculating psnr") 111 | mean_psnr = sum(map( 112 | data.psnr, 113 | map(data.f32_to_u8, gt_rgbs_f32), 114 | map(lambda ri: ri.rgb, rendered_images), 115 | )) / len(rendered_images) 116 | logger.info("tested {} images, mean psnr={}".format(len(rendered_images), mean_psnr)) 117 | 118 | elif args.trajectory == "orbit": 119 | logger.debug("using generated orbiting trajectory, not calculating psnr") 120 | 121 | else: 122 | assert_never("") 123 | 124 | save_dest = args.logs_dir.joinpath("test") 125 | save_dest.mkdir(parents=True, exist_ok=True) 126 | 127 | if "video" in args.save_as: 128 | dest_rgb_video = save_dest.joinpath("rgb.mp4") 129 | dest_depth_video = save_dest.joinpath("depth.mp4") 130 | 131 | logger.debug("saving predicted color images as a video at '{}'".format(dest_rgb_video)) 132 | data.write_video( 133 | save_dest.joinpath("rgb.mp4"), 134 | map(lambda img: img.rgb, rendered_images), 135 | fps=args.fps, 136 | loop=args.loop, 137 | ) 138 | 139 | logger.debug("saving predicted disparities as a video at '{}'".format(dest_depth_video)) 140 | data.write_video( 141 | save_dest.joinpath("depth.mp4"), 142 | map(lambda img: common.compose(data.mono_to_rgb, data.f32_to_u8)(img.depth), rendered_images), 143 | fps=args.fps, 144 | loop=args.loop, 145 | ) 146 | 147 | if "image" in args.save_as: 148 | dest_rgb = save_dest.joinpath("rgb") 149 | dest_depth = save_dest.joinpath("depth") 150 | 151 | dest_rgb.mkdir(parents=True, exist_ok=True) 152 | dest_depth.mkdir(parents=True, exist_ok=True) 153 | 154 | logger.debug("saving as images") 155 | def save_rgb_and_depth(save_i: int, img: RenderedImage): 156 | common.compose( 157 | np.asarray, 158 | Image.fromarray 159 | )(img.rgb).save(dest_rgb.joinpath("{:04d}.png".format(save_i))) 160 | common.compose( 161 | data.mono_to_rgb, 162 | data.f32_to_u8, 163 | np.asarray, 164 | Image.fromarray 165 | )(img.depth).save(dest_depth.joinpath("{:04d}.png".format(save_i))) 166 | for _ in common.tqdm( 167 | ThreadPoolExecutor().map( 168 | save_rgb_and_depth, 169 | range(len(rendered_images)), 170 | rendered_images, 171 | ), 172 | total=len(rendered_images), 173 | desc="| saving images", 174 | ): 175 | pass 176 | return 0 177 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/marching/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | 4 | from .. import volrendutils_cuda 5 | 6 | try: 7 | from jaxlib.mhlo_helpers import custom_call 8 | except ModuleNotFoundError: 9 | # A more recent jaxlib would have `hlo_helpers` instead of `mhlo_helpers` 10 | # 11 | from jaxlib.hlo_helpers import custom_call 12 | 13 | 14 | # helper function for mapping given shapes to their default mlir layouts 15 | def default_layouts(*shapes): 16 | return [range(len(shape) - 1, -1, -1) for shape in shapes] 17 | 18 | 19 | def march_rays_lowering_rule( 20 | ctx: mlir.LoweringRule, 21 | 22 | # arrays 23 | rays_o: ir.Value, 24 | rays_d: ir.Value, 25 | t_starts: ir.Value, 26 | t_ends: ir.Value, 27 | noises: ir.Value, 28 | occupancy_bitfield: ir.Value, 29 | 30 | # static args 31 | total_samples: int, # int 32 | diagonal_n_steps: int, # int 33 | K: int, # int 34 | G: int, # int 35 | bound: float, # float 36 | stepsize_portion: float, # float 37 | ): 38 | n_rays, _ = ir.RankedTensorType(rays_o.type).shape 39 | 40 | opaque = volrendutils_cuda.make_marching_descriptor( 41 | n_rays, 42 | total_samples, 43 | diagonal_n_steps, 44 | K, 45 | G, 46 | bound, 47 | stepsize_portion, 48 | ) 49 | 50 | shapes = { 51 | "in.rays_o": (n_rays, 3), 52 | "in.rays_d": (n_rays, 3), 53 | "in.t_starts": (n_rays,), 54 | "in.t_ends": (n_rays,), 55 | "in.noises": (n_rays,), 56 | "in.occupancy_bitfield": (K*G*G*G//8,), 57 | 58 | "helper.next_sample_write_location": (1,), 59 | "helper.number_of_exceeded_samples": (1,), 60 | "helper.ray_is_valid": (n_rays,), 61 | 62 | "out.rays_n_samples": (n_rays,), 63 | "out.rays_sample_startidx": (n_rays,), 64 | "out.idcs": (total_samples,), 65 | "out.xyzs": (total_samples, 3), 66 | "out.dirs": (total_samples, 3), 67 | "out.dss": (total_samples,), 68 | "out.z_vals": (total_samples,), 69 | } 70 | 71 | return custom_call( 72 | call_target_name="march_rays", 73 | out_types=[ 74 | ir.RankedTensorType.get(shapes["helper.next_sample_write_location"], ir.IntegerType.get_unsigned(32)), 75 | ir.RankedTensorType.get(shapes["helper.number_of_exceeded_samples"], ir.IntegerType.get_unsigned(32)), 76 | ir.RankedTensorType.get(shapes["helper.ray_is_valid"], ir.IntegerType.get_signless(1)), 77 | ir.RankedTensorType.get(shapes["out.rays_n_samples"], ir.IntegerType.get_unsigned(32)), 78 | ir.RankedTensorType.get(shapes["out.rays_sample_startidx"], ir.IntegerType.get_unsigned(32)), 79 | ir.RankedTensorType.get(shapes["out.idcs"], ir.IntegerType.get_unsigned(32)), 80 | ir.RankedTensorType.get(shapes["out.xyzs"], ir.F32Type.get()), 81 | ir.RankedTensorType.get(shapes["out.dirs"], ir.F32Type.get()), 82 | ir.RankedTensorType.get(shapes["out.dss"], ir.F32Type.get()), 83 | ir.RankedTensorType.get(shapes["out.z_vals"], ir.F32Type.get()), 84 | ], 85 | operands=[ 86 | rays_o, 87 | rays_d, 88 | t_starts, 89 | t_ends, 90 | noises, 91 | occupancy_bitfield, 92 | ], 93 | backend_config=opaque, 94 | operand_layouts=default_layouts( 95 | shapes["in.rays_o"], 96 | shapes["in.rays_d"], 97 | shapes["in.t_starts"], 98 | shapes["in.t_ends"], 99 | shapes["in.noises"], 100 | shapes["in.occupancy_bitfield"], 101 | ), 102 | result_layouts=default_layouts( 103 | shapes["helper.next_sample_write_location"], 104 | shapes["helper.number_of_exceeded_samples"], 105 | shapes["helper.ray_is_valid"], 106 | shapes["out.rays_n_samples"], 107 | shapes["out.rays_sample_startidx"], 108 | shapes["out.idcs"], 109 | shapes["out.xyzs"], 110 | shapes["out.dirs"], 111 | shapes["out.dss"], 112 | shapes["out.z_vals"], 113 | ), 114 | ) 115 | 116 | 117 | def march_rays_inference_lowering_rule( 118 | ctx: mlir.LoweringRule, 119 | 120 | # arrays 121 | rays_o: ir.BlockArgument, 122 | rays_d: ir.BlockArgument, 123 | t_starts: ir.BlockArgument, 124 | t_ends: ir.BlockArgument, 125 | occupancy_bitfield: ir.BlockArgument, 126 | next_ray_index_in: ir.BlockArgument, 127 | terminated: ir.BlockArgument, 128 | indices_in: ir.BlockArgument, 129 | 130 | # static args 131 | diagonal_n_steps: int, 132 | K: int, 133 | G: int, 134 | march_steps_cap: int, 135 | bound: float, 136 | stepsize_portion: float, 137 | ): 138 | (n_total_rays, _), (n_rays,) = ir.RankedTensorType(rays_o.type).shape, ir.RankedTensorType(terminated.type).shape 139 | 140 | opaque = volrendutils_cuda.make_marching_inference_descriptor( 141 | n_total_rays, 142 | n_rays, 143 | diagonal_n_steps, 144 | K, 145 | G, 146 | march_steps_cap, 147 | bound, 148 | stepsize_portion, 149 | ) 150 | 151 | shapes = { 152 | "in.rays_o": (n_total_rays, 3), 153 | "in.rays_d": (n_total_rays, 3), 154 | "in.t_starts": (n_total_rays,), 155 | "in.t_ends": (n_total_rays,), 156 | "in.occupancy_bitfield": (K*G*G*G//8,), 157 | "in.next_ray_index_in": (1,), 158 | "in.terminated": (n_rays,), 159 | "in.indices_in": (n_rays,), 160 | 161 | "out.next_ray_index": (1,), 162 | "out.indices_out": (n_rays,), 163 | "out.n_samples": (n_rays,), 164 | "out.t_starts": (n_rays,), 165 | "out.xyzs": (n_rays, march_steps_cap, 3), 166 | "out.dss": (n_rays, march_steps_cap), 167 | "out.z_vals": (n_rays, march_steps_cap), 168 | } 169 | 170 | return custom_call( 171 | call_target_name="march_rays_inference", 172 | out_types=[ 173 | ir.RankedTensorType.get(shapes["out.next_ray_index"], ir.IntegerType.get_unsigned(32)), 174 | ir.RankedTensorType.get(shapes["out.indices_out"], ir.IntegerType.get_unsigned(32)), 175 | ir.RankedTensorType.get(shapes["out.n_samples"], ir.IntegerType.get_unsigned(32)), 176 | ir.RankedTensorType.get(shapes["out.t_starts"], ir.F32Type.get()), 177 | ir.RankedTensorType.get(shapes["out.xyzs"], ir.F32Type.get()), 178 | ir.RankedTensorType.get(shapes["out.dss"], ir.F32Type.get()), 179 | ir.RankedTensorType.get(shapes["out.z_vals"], ir.F32Type.get()), 180 | ], 181 | operands=[ 182 | rays_o, 183 | rays_d, 184 | t_starts, 185 | t_ends, 186 | occupancy_bitfield, 187 | next_ray_index_in, 188 | terminated, 189 | indices_in, 190 | ], 191 | backend_config=opaque, 192 | operand_layouts=default_layouts( 193 | shapes["in.rays_o"], 194 | shapes["in.rays_d"], 195 | shapes["in.t_starts"], 196 | shapes["in.t_ends"], 197 | shapes["in.occupancy_bitfield"], 198 | shapes["in.next_ray_index_in"], 199 | shapes["in.terminated"], 200 | shapes["in.indices_in"], 201 | ), 202 | result_layouts=default_layouts( 203 | shapes["out.next_ray_index"], 204 | shapes["out.indices_out"], 205 | shapes["out.n_samples"], 206 | shapes["out.t_starts"], 207 | shapes["out.xyzs"], 208 | shapes["out.dss"], 209 | shapes["out.z_vals"], 210 | ), 211 | ) 212 | -------------------------------------------------------------------------------- /deps/jax-tcnn/lib/impl/hashgrid.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "fmt/core.h" 13 | #include "tcnnutils.h" 14 | 15 | 16 | namespace jaxtcnn { 17 | 18 | namespace { 19 | 20 | void hashgrid_forward_launcher(cudaStream_t stream, void **buffers, char const *opaque, std::size_t opaque_len) { 21 | std::uint32_t static constexpr DIM = 3u; 22 | 23 | // buffer indexing helper 24 | std::uint32_t __buffer_idx = 0; 25 | auto const next_buffer = [&]() { return buffers[__buffer_idx++]; }; 26 | 27 | HashGridDescriptor const &desc = 28 | *deserialize(opaque, opaque_len); 29 | std::uint32_t const n_coords = desc.n_coords; 30 | std::uint32_t const L = desc.L; 31 | std::uint32_t const F = desc.F; 32 | std::uint32_t const N_min = desc.N_min; 33 | float const per_level_scale = desc.per_level_scale; 34 | 35 | // inputs 36 | std::uint32_t * const offset_table_data = static_cast(next_buffer()); // [L+1] 37 | float const * const __restrict__ coords_rm = static_cast(next_buffer()); // [dim, n_coords] 38 | float const * const __restrict__ params = static_cast(next_buffer()); // [n_params, F] 39 | 40 | // outputs 41 | float * const __restrict__ encoded_positions_rm = static_cast(next_buffer()); // [L*F, n_coords] 42 | float * const __restrict__ dy_dcoords_rm = static_cast(next_buffer()); // [dim*L*F, n_coords] 43 | 44 | // prepare input data for tcnn:kernel_grid 45 | tcnn::GridOffsetTable offset_table{{}, L + 1}; 46 | CUDA_CHECK_THROW(cudaMemcpyAsync(offset_table.data, offset_table_data, (L + 1) * sizeof(std::uint32_t), 47 | cudaMemcpyKind::cudaMemcpyDeviceToHost, stream)); 48 | CUDA_CHECK_THROW(cudaMemsetAsync(dy_dcoords_rm, 0x00, DIM * L * F * n_coords * sizeof(float), stream)); 49 | // tcnn::GPUMatrixDynamic positions_in(coords, n_points, DIM); 50 | tcnn::MatrixView positions_in(coords_rm, n_coords, 1); // row major 51 | 52 | // kernel launch 53 | std::uint32_t static constexpr n_threads = 512; 54 | dim3 const blocks = { tcnn::div_round_up(n_coords, n_threads), L, 1 }; 55 | #define PARAMS \ 56 | n_coords \ 57 | , L * F \ 58 | , offset_table \ 59 | , N_min \ 60 | , log2f(per_level_scale) \ 61 | , 0.f \ 62 | , 1e3f \ 63 | , nullptr \ 64 | , tcnn::InterpolationType::Linear \ 65 | , tcnn::GridType::Hash \ 66 | , params \ 67 | , positions_in \ 68 | , encoded_positions_rm \ 69 | , dy_dcoords_rm 70 | 71 | if (F == 2) { 72 | std::uint32_t static constexpr N_FEATURES_PER_LEVEL = 2; 73 | tcnn::kernel_grid<<>>( 74 | PARAMS 75 | ); 76 | } else if (F == 4) { 77 | std::uint32_t static constexpr N_FEATURES_PER_LEVEL = 4; 78 | tcnn::kernel_grid<<>>( 79 | PARAMS 80 | ); 81 | } else { 82 | throw std::runtime_error{ 83 | fmt::format("supported values of F (n_features_per_level) are [2, 4], got {}", F) 84 | }; 85 | } 86 | 87 | CUDA_CHECK_THROW(cudaGetLastError()); 88 | } 89 | 90 | void hashgrid_backward_launcher(cudaStream_t stream, void **buffers, char const *opaque, std::size_t opaque_len) { 91 | // buffer indexing helper 92 | std::uint32_t __buffer_idx = 0; 93 | auto const next_buffer = [&]() { return buffers[__buffer_idx++]; }; 94 | 95 | HashGridDescriptor const &desc = 96 | *deserialize(opaque, opaque_len); 97 | std::uint32_t const n_coords = desc.n_coords; 98 | std::uint32_t const L = desc.L; 99 | std::uint32_t const F = desc.F; 100 | std::uint32_t const N_min = desc.N_min; 101 | float const per_level_scale = desc.per_level_scale; 102 | 103 | // input 104 | std::uint32_t * const offset_table_data = static_cast(next_buffer()); // [L+1] 105 | float const * const __restrict__ coords_rm = static_cast(next_buffer()); // [dim, n_coords] 106 | float const * const __restrict__ dL_dy_rm = static_cast(next_buffer()); // [L*F, n_coords] 107 | float const * const __restrict__ dy_dcoords_rm = static_cast(next_buffer()); // [dim*L*F, n_coords] 108 | 109 | // output 110 | float * const __restrict__ dL_dparams = static_cast(next_buffer()); // [n_params, F] 111 | float * const __restrict__ dL_dcoords_rm = static_cast(next_buffer()); // [dim, n_coords] 112 | 113 | // prepare input data for tcnn::kernel_grid_backward 114 | tcnn::GridOffsetTable offset_table{{}, L + 1}; 115 | CUDA_CHECK_THROW(cudaMemcpyAsync(offset_table.data, offset_table_data, (L + 1) * sizeof(std::uint32_t), 116 | cudaMemcpyKind::cudaMemcpyDeviceToHost, stream)); 117 | CUDA_CHECK_THROW(cudaMemsetAsync(dL_dparams, 0x00, offset_table.data[L] * F * sizeof(float), stream)); 118 | tcnn::MatrixView positions_in(coords_rm, n_coords, 1); // row major 119 | 120 | // kernel launch 121 | std::uint32_t static constexpr n_threads = 256; 122 | 123 | #define PARAMS \ 124 | n_coords \ 125 | , L * F \ 126 | , offset_table \ 127 | , N_min \ 128 | , log2f(per_level_scale) \ 129 | , 1e3f \ 130 | , nullptr \ 131 | , false \ 132 | , tcnn::InterpolationType::Linear \ 133 | , tcnn::GridType::Hash \ 134 | , dL_dparams \ 135 | , positions_in \ 136 | , dL_dy_rm 137 | 138 | std::uint32_t static constexpr DIM = 3u; 139 | if (F == 2u) { 140 | std::uint32_t static constexpr N_FEATURES_PER_LEVEL = 2u; 141 | std::uint32_t static constexpr n_features_per_thread = std::min(2u, N_FEATURES_PER_LEVEL); 142 | const dim3 blocks = { tcnn::div_round_up(n_coords * F / n_features_per_thread, n_threads), L, 1 }; 143 | tcnn::kernel_grid_backward<<>>( 144 | PARAMS 145 | ); 146 | } else if (F == 4) { 147 | std::uint32_t static constexpr N_FEATURES_PER_LEVEL = 4u; 148 | std::uint32_t static constexpr n_features_per_thread = std::min(2u, N_FEATURES_PER_LEVEL); 149 | const dim3 blocks = { tcnn::div_round_up(n_coords * F / n_features_per_thread, n_threads), L, 1 }; 150 | tcnn::kernel_grid_backward<<>>( 151 | PARAMS 152 | ); 153 | } else { 154 | throw std::runtime_error{ 155 | fmt::format("supported values of F (n_features_per_level) are 2, 4, got {}", F) 156 | }; 157 | } 158 | 159 | CUDA_CHECK_THROW(cudaGetLastError()); 160 | 161 | // gradients w.r.t. input coordinates 162 | // prepare input data for tcnn::kernel_grid_backward_input 163 | tcnn::MatrixView dL_dinput_view(dL_dcoords_rm, n_coords, 1); 164 | 165 | tcnn::linear_kernel(tcnn::kernel_grid_backward_input, 0, stream, 166 | n_coords, 167 | L * F, 168 | dL_dy_rm, 169 | dy_dcoords_rm, 170 | dL_dinput_view 171 | ); 172 | 173 | CUDA_CHECK_THROW(cudaGetLastError()); 174 | } 175 | 176 | } 177 | 178 | void hashgrid_encode( 179 | cudaStream_t stream, 180 | void **buffers, 181 | char const *opaque, 182 | std::size_t opaque_len 183 | ) { 184 | hashgrid_forward_launcher(stream, buffers, opaque, opaque_len); 185 | } 186 | 187 | void hashgrid_encode_backward( 188 | cudaStream_t stream, 189 | void **buffers, 190 | char const *opaque, 191 | std::size_t opaque_len 192 | ) { 193 | hashgrid_backward_launcher(stream, buffers, opaque, opaque_len); 194 | } 195 | 196 | } // namespace jaxtcnn 197 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/src/volrendjax/integrating/lowering.py: -------------------------------------------------------------------------------- 1 | from jax.interpreters import mlir 2 | from jax.interpreters.mlir import ir 3 | 4 | from .. import volrendutils_cuda 5 | 6 | try: 7 | from jaxlib.mhlo_helpers import custom_call 8 | except ModuleNotFoundError: 9 | # A more recent jaxlib would have `hlo_helpers` instead of `mhlo_helpers` 10 | # 11 | from jaxlib.hlo_helpers import custom_call 12 | 13 | 14 | # helper function for mapping given shapes to their default mlir layouts 15 | def default_layouts(*shapes): 16 | return [range(len(shape) - 1, -1, -1) for shape in shapes] 17 | 18 | 19 | def integrate_rays_lowering_rule( 20 | ctx: mlir.LoweringRuleContext, 21 | 22 | rays_sample_startidx: ir.Value, 23 | rays_n_samples: ir.Value, 24 | 25 | bgs: ir.Value, 26 | dss: ir.Value, 27 | z_vals: ir.Value, 28 | drgbs: ir.Value, 29 | ): 30 | n_rays, = ir.RankedTensorType(rays_sample_startidx.type).shape 31 | total_samples, = ir.RankedTensorType(z_vals.type).shape 32 | 33 | opaque = volrendutils_cuda.make_integrating_descriptor(n_rays, total_samples) 34 | 35 | shapes = { 36 | "in.rays_sample_startidx": (n_rays,), 37 | "in.rays_n_samples": (n_rays,), 38 | 39 | "in.bgs": (n_rays, 3), 40 | "in.dss": (total_samples,), 41 | "in.z_vals": (total_samples,), 42 | "in.drgbs": (total_samples, 4), 43 | 44 | "helper.measured_batch_size": (1,), 45 | 46 | "out.final_rgbds": (n_rays, 4), 47 | "out.final_opacities": (n_rays,), 48 | } 49 | 50 | return custom_call( 51 | call_target_name="integrate_rays", 52 | out_types=[ 53 | ir.RankedTensorType.get(shapes["helper.measured_batch_size"], ir.IntegerType.get_unsigned(32)), 54 | ir.RankedTensorType.get(shapes["out.final_rgbds"], ir.F32Type.get()), 55 | ir.RankedTensorType.get(shapes["out.final_opacities"], ir.F32Type.get()), 56 | ], 57 | operands=[ 58 | rays_sample_startidx, 59 | rays_n_samples, 60 | bgs, 61 | dss, 62 | z_vals, 63 | drgbs, 64 | ], 65 | backend_config=opaque, 66 | operand_layouts=default_layouts( 67 | shapes["in.rays_sample_startidx"], 68 | shapes["in.rays_n_samples"], 69 | shapes["in.bgs"], 70 | shapes["in.dss"], 71 | shapes["in.z_vals"], 72 | shapes["in.drgbs"], 73 | ), 74 | result_layouts=default_layouts( 75 | shapes["helper.measured_batch_size"], 76 | shapes["out.final_rgbds"], 77 | shapes["out.final_opacities"], 78 | ), 79 | ) 80 | 81 | 82 | def integrate_rays_backward_lowring_rule( 83 | ctx: mlir.LoweringRuleContext, 84 | 85 | rays_sample_startidx: ir.Value, 86 | rays_n_samples: ir.Value, 87 | 88 | # original inputs 89 | bgs: ir.Value, 90 | dss: ir.Value, 91 | z_vals: ir.Value, 92 | drgbs: ir.Value, 93 | 94 | # original outputs 95 | final_rgbds: ir.Value, 96 | final_opacities: ir.Value, 97 | 98 | # gradient inputs 99 | dL_dfinal_rgbds: ir.Value, 100 | 101 | # static argument 102 | near_distance: float, 103 | ): 104 | n_rays, = ir.RankedTensorType(rays_sample_startidx.type).shape 105 | total_samples, = ir.RankedTensorType(z_vals.type).shape 106 | 107 | opaque = volrendutils_cuda.make_integrating_backward_descriptor(n_rays, total_samples, near_distance) 108 | 109 | shapes = { 110 | "in.rays_sample_startidx": (n_rays,), 111 | "in.rays_n_samples": (n_rays,), 112 | 113 | "in.bgs": (n_rays, 3), 114 | "in.dss": (total_samples,), 115 | "in.z_vals": (total_samples,), 116 | "in.drgbs": (total_samples, 4), 117 | 118 | "in.final_rgbds": (n_rays, 4), 119 | "in.final_opacities": (n_rays,), 120 | 121 | "in.dL_dfinal_rgbds": (n_rays, 4), 122 | 123 | "out.dL_dbgs": (n_rays, 3), 124 | "out.dL_dz_vals": (total_samples,), 125 | "out.dL_ddrgbs": (total_samples, 4), 126 | } 127 | 128 | return custom_call( 129 | call_target_name="integrate_rays_backward", 130 | out_types=[ 131 | ir.RankedTensorType.get(shapes["out.dL_dbgs"], ir.F32Type.get()), 132 | ir.RankedTensorType.get(shapes["out.dL_dz_vals"], ir.F32Type.get()), 133 | ir.RankedTensorType.get(shapes["out.dL_ddrgbs"], ir.F32Type.get()), 134 | ], 135 | operands=[ 136 | rays_sample_startidx, 137 | rays_n_samples, 138 | 139 | bgs, 140 | dss, 141 | z_vals, 142 | drgbs, 143 | 144 | final_rgbds, 145 | final_opacities, 146 | 147 | dL_dfinal_rgbds, 148 | ], 149 | backend_config=opaque, 150 | operand_layouts=default_layouts( 151 | shapes["in.rays_sample_startidx"], 152 | shapes["in.rays_n_samples"], 153 | shapes["in.bgs"], 154 | shapes["in.dss"], 155 | shapes["in.z_vals"], 156 | shapes["in.drgbs"], 157 | 158 | shapes["in.final_rgbds"], 159 | shapes["in.final_opacities"], 160 | 161 | shapes["in.dL_dfinal_rgbds"], 162 | ), 163 | result_layouts=default_layouts( 164 | shapes["out.dL_dbgs"], 165 | shapes["out.dL_dz_vals"], 166 | shapes["out.dL_ddrgbs"], 167 | ), 168 | ) 169 | 170 | 171 | def integrate_rays_inference_lowering_rule( 172 | ctx: mlir.LoweringRuleContext, 173 | 174 | rays_bg: ir.Value, 175 | rays_rgbd: ir.Value, 176 | rays_T: ir.Value, 177 | 178 | n_samples: ir.Value, 179 | indices: ir.Value, 180 | dss: ir.Value, 181 | z_vals: ir.Value, 182 | drgbs: ir.Value, 183 | ): 184 | (n_total_rays, _) = ir.RankedTensorType(rays_rgbd.type).shape 185 | (n_rays, march_steps_cap) = ir.RankedTensorType(dss.type).shape 186 | 187 | opaque = volrendutils_cuda.make_integrating_inference_descriptor(n_total_rays, n_rays, march_steps_cap) 188 | 189 | shapes = { 190 | "in.rays_bg": (n_total_rays, 3), 191 | "in.rays_rgbd": (n_total_rays, 4), 192 | "in.rays_T": (n_total_rays,), 193 | 194 | "in.n_samples": (n_rays,), 195 | "in.indices": (n_rays,), 196 | "in.dss": (n_rays, march_steps_cap), 197 | "in.z_vals": (n_rays, march_steps_cap), 198 | "in.drgbs": (n_rays, march_steps_cap, 4), 199 | 200 | "out.terminate_cnt": (1,), 201 | "out.terminated": (n_rays,), 202 | "out.rays_rgbd": (n_rays, 4), 203 | "out.rays_T": (n_rays,), 204 | } 205 | 206 | return custom_call( 207 | call_target_name="integrate_rays_inference", 208 | out_types=[ 209 | ir.RankedTensorType.get(shapes["out.terminate_cnt"], ir.IntegerType.get_unsigned(32)), 210 | ir.RankedTensorType.get(shapes["out.terminated"], ir.IntegerType.get_signless(1)), 211 | ir.RankedTensorType.get(shapes["out.rays_rgbd"], ir.F32Type.get()), 212 | ir.RankedTensorType.get(shapes["out.rays_T"], ir.F32Type.get()), 213 | ], 214 | operands=[ 215 | rays_bg, 216 | rays_rgbd, 217 | rays_T, 218 | 219 | n_samples, 220 | indices, 221 | dss, 222 | z_vals, 223 | drgbs, 224 | ], 225 | backend_config=opaque, 226 | operand_layouts=default_layouts( 227 | shapes["in.rays_bg"], 228 | shapes["in.rays_rgbd"], 229 | shapes["in.rays_T"], 230 | 231 | shapes["in.n_samples"], 232 | shapes["in.indices"], 233 | shapes["in.dss"], 234 | shapes["in.z_vals"], 235 | shapes["in.drgbs"], 236 | ), 237 | result_layouts=default_layouts( 238 | shapes["out.terminate_cnt"], 239 | shapes["out.terminated"], 240 | shapes["out.rays_rgbd"], 241 | shapes["out.rays_T"], 242 | ), 243 | ) 244 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | import tyro 6 | 7 | from utils.types import ( 8 | CameraOverrideOptions, 9 | LogLevel, 10 | OrbitTrajectoryOptions, 11 | RayMarchingOptions, 12 | RenderingOptions, 13 | SceneOptions, 14 | TransformsProvider, 15 | ) 16 | 17 | 18 | @dataclass(frozen=True, kw_only=True) 19 | class CommonArgs: 20 | # log level 21 | logging: LogLevel = "INFO" 22 | 23 | # random seed 24 | seed: int = 1_000_000_007 25 | 26 | # display model information after model init 27 | summary: bool=False 28 | 29 | 30 | @dataclass(frozen=True, kw_only=True) 31 | class TrainingArgs: 32 | # learning rate 33 | lr: float 34 | 35 | # scalar multiplied to total variation loss, set this to a positive value to enable calculation 36 | # of TV loss 37 | tv_scale: float 38 | 39 | # batch size 40 | bs: int 41 | 42 | # number of epochs to train 43 | epochs: int 44 | 45 | # batches per epoch 46 | iters: int 47 | 48 | # loop within training data for this number of iterations, this helps reduce the effective 49 | # dataloader overhead. 50 | data_loop: int 51 | 52 | # will validate every `validate_every` epochs, set this to a large value to disable validation 53 | validate_every: int 54 | 55 | # number of latest checkpoints to keep 56 | keep: int=1 57 | 58 | # how many epochs should a new checkpoint to be kept (in addition to keeping the last `keep` 59 | # checkpoints) 60 | keep_every: int | None=8 61 | 62 | @property 63 | def keep_every_n_steps(self) -> int | None: 64 | if self.keep_every is None: 65 | return None 66 | else: 67 | return self.keep_every * self.iters 68 | 69 | @dataclass(frozen=True, kw_only=True) 70 | class ImageFitArgs: 71 | common: tyro.conf.OmitArgPrefixes[CommonArgs]=CommonArgs() 72 | train: tyro.conf.OmitArgPrefixes[TrainingArgs]=TrainingArgs( 73 | # paper: 74 | # We observed fastest convergence with a learning rate of 10^{-4} for signed distance 75 | # functions and 10^{-2} otherwise 76 | # 77 | # We use a smaller learning rate since our batch size is much smaller the paper (see below). 78 | lr=1e-3, 79 | tv_scale=0., 80 | # paper: 81 | # ...as well a a batch size of 2^{14} for neural radiance caching and 2^{18} otherwise. 82 | # 83 | # In our case, setting the batch size to a larger number hinders data loading performance, 84 | # and thus causes the GPU not being fully occupied. On the other hand, setting the batch 85 | # size to a smaller one utilizes the GPU fully, but the iterations per second capped at 86 | # some rate which results in lower throughput. setting bs to 2^{10} achieves a satisfying 87 | # tradeoff here. 88 | bs=2**10, 89 | epochs=32, 90 | iters=2**30, 91 | data_loop=1, 92 | validate_every=1, 93 | ) 94 | 95 | 96 | @dataclass(frozen=True, kw_only=True) 97 | class NeRFArgsBase: 98 | raymarch: RayMarchingOptions 99 | render: RenderingOptions 100 | scene: SceneOptions 101 | 102 | common: tyro.conf.OmitArgPrefixes[CommonArgs]=CommonArgs() 103 | 104 | 105 | @dataclass(frozen=True, kw_only=True) 106 | class _SharedNeRFTrainingArgs(NeRFArgsBase): 107 | # each experiment run will save its own code, config, training logs, and checkpoints into a 108 | # separate subdirectory of this path 109 | exp_dir: Path 110 | 111 | # directories or transform.json files containing data for training 112 | frames_train: tyro.conf.Positional[Tuple[Path, ...]] 113 | 114 | # an optional description of this run 115 | note: str | None=None 116 | 117 | # directories or transform.json files containing data for validation 118 | frames_val: Tuple[Path, ...]=() 119 | 120 | # if specified, continue training from this checkpoint 121 | ckpt: Path | None=None 122 | 123 | # training hyper parameters 124 | train: tyro.conf.OmitArgPrefixes[TrainingArgs]=TrainingArgs( 125 | # This is a relatively large learning rate, should be used jointly with 126 | # `threasholded_exponential` as density activation, and random color as supervision for 127 | # transparent pixels. 128 | lr=1e-2, 129 | tv_scale=0., 130 | bs=1024 * (1<<10), 131 | epochs=50, 132 | iters=2**10, 133 | data_loop=1, 134 | validate_every=10, 135 | ) 136 | 137 | # raymarching/rendering options during training 138 | raymarch: RayMarchingOptions=RayMarchingOptions( 139 | diagonal_n_steps=1<<10, 140 | perturb=True, 141 | density_grid_res=128, 142 | ) 143 | render: RenderingOptions=RenderingOptions( 144 | bg=(1.0, 1.0, 1.0), # white, but ignored by default due to random_bg=True 145 | random_bg=True, 146 | ) 147 | scene: SceneOptions=SceneOptions( 148 | sharpness_threshold=-1., 149 | resolution_scale=1.0, 150 | camera_near=0.3, 151 | max_mem_mbytes=2500, # ~300 1920x1080 8bit RGBA images 152 | ) 153 | 154 | # raymarching/rendering options for validating during training 155 | raymarch_eval: RayMarchingOptions=RayMarchingOptions( 156 | diagonal_n_steps=1<<10, 157 | perturb=False, 158 | density_grid_res=128, 159 | ) 160 | render_eval: RenderingOptions=RenderingOptions( 161 | bg=(0.0, 0.0, 0.0), # black 162 | random_bg=False, 163 | ) 164 | 165 | 166 | @dataclass(frozen=True, kw_only=True) 167 | class NeRFTrainingArgs(_SharedNeRFTrainingArgs): ... 168 | 169 | 170 | @dataclass(frozen=True, kw_only=True) 171 | class NeRFTestingArgs(NeRFArgsBase): 172 | # testing logs and results are saved to this directory, overwriting its content if any 173 | logs_dir: Path 174 | 175 | frames: tyro.conf.Positional[Tuple[Path, ...]] 176 | 177 | camera_override: CameraOverrideOptions=CameraOverrideOptions() 178 | 179 | # use checkpoint from this path (can be a directory) for testing 180 | ckpt: Path 181 | 182 | # if specified, render with a generated orbiting trajectory instead of the loaded frame 183 | # transformations 184 | trajectory: TransformsProvider="loaded" 185 | 186 | orbit: OrbitTrajectoryOptions 187 | 188 | # naturally sort frames according to their file names before testing 189 | sort_frames: bool=False 190 | 191 | # if specified value contains "video", a video will be saved; if specified value contains 192 | # "image", rendered images will be saved. Value can contain both "video" and "image", e.g., 193 | # `--save-as "video-image"` will save both video and images. 194 | save_as: str="image and video" 195 | 196 | # specifies frames per second for saved video 197 | fps: int=24 198 | 199 | # loop rendered images this many times in saved video 200 | loop: int=3 201 | 202 | # raymarching/rendering options during testing 203 | raymarch: RayMarchingOptions=RayMarchingOptions( 204 | diagonal_n_steps=1<<10, 205 | perturb=False, 206 | density_grid_res=128, 207 | ) 208 | render: RenderingOptions=RenderingOptions( 209 | bg=(0.0, 0.0, 0.0), # black 210 | random_bg=False, 211 | ) 212 | scene: SceneOptions=SceneOptions( 213 | sharpness_threshold=-1., 214 | resolution_scale=1.0, 215 | camera_near=0.3, 216 | max_mem_mbytes=0, # this is testing, no images is loaded to GPU 217 | ) 218 | 219 | @property 220 | def report_metrics(self) -> bool: 221 | return self.camera_override.enabled or self.trajectory != "loaded" 222 | 223 | 224 | @dataclass(frozen=True, kw_only=True) 225 | class NeRFGUIArgs(_SharedNeRFTrainingArgs): 226 | 227 | @dataclass(frozen=True, kw_only=True) 228 | class ViewportOptions: 229 | W: int=1024 230 | H: int=768 231 | 232 | resolution_scale: float=0.3 233 | 234 | control_window_width: int=300 235 | 236 | #max number of loss steps shown on gui 237 | max_show_loss_step: int=200 238 | 239 | viewport: ViewportOptions=ViewportOptions() 240 | 241 | train: TrainingArgs=TrainingArgs( 242 | lr=1e-2, 243 | tv_scale=0., 244 | bs=1<<18, 245 | iters=5, # render a frame every 5 steps 246 | epochs=50, # ignored 247 | data_loop=1, # ignored 248 | validate_every=10, # ignored 249 | ) 250 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/lib/impl/volrend.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define STRINGIFY(x) #x 7 | #define STR(x) STRINGIFY(x) 8 | #define FILE_LINE __FILE__ ":" STR(__LINE__) 9 | #define CUDA_CHECK_THROW(x) \ 10 | do { \ 11 | cudaError_t result = x; \ 12 | if (result != cudaSuccess) \ 13 | throw std::runtime_error( \ 14 | std::string(FILE_LINE " " #x " failed with error ") \ 15 | + cudaGetErrorString(result)); \ 16 | } while(0) 17 | 18 | static constexpr float SQRT3 = 1.732050807568877293527446341505872367f; 19 | 20 | namespace volrendjax { 21 | 22 | // Static parameters passed to `integrate_rays` 23 | struct IntegratingDescriptor { 24 | // number of input rays 25 | std::uint32_t n_rays; 26 | 27 | // sum of number of samples of each ray 28 | std::uint32_t total_samples; 29 | }; 30 | 31 | // Static parameters passed to `integrate_rays_backward` 32 | struct IntegratingBackwardDescriptor { 33 | // number of input rays 34 | std::uint32_t n_rays; 35 | 36 | // sum of number of samples of each ray 37 | std::uint32_t total_samples; 38 | 39 | // camera's near distance, samples behind the camera's near plane with non-negligible 40 | // (> ~exp(-10)) densities will be penalized 41 | float near_distance; 42 | }; 43 | 44 | // Static parameters passed to `integrate_rays_inference` 45 | struct IntegratingInferenceDescriptor { 46 | // total number of rays to march 47 | std::uint32_t n_total_rays; 48 | 49 | // number of input rays 50 | std::uint32_t n_rays; 51 | 52 | // see MarchingInferenceDescriptor 53 | std::uint32_t march_steps_cap; 54 | }; 55 | 56 | // Static parameters passed to `march_rays` 57 | struct MarchingDescriptor { 58 | // number of input rays 59 | std::uint32_t n_rays; 60 | 61 | // number of available slots to write generated samples to, i.e. the length of output samples 62 | // array 63 | std::uint32_t total_samples; 64 | 65 | // the length of a minimal ray marching step is calculated as Δ𝑡 := 2*√3/`diagonal_n_steps` 66 | // (Appendix E.1 of the NGP paper) 67 | std::uint32_t diagonal_n_steps; 68 | 69 | // paper: we maintain a cascade of 𝐾 multiscale occupancy grids, where 𝐾 = 1 for all synthetic 70 | // NeRF scenes (single grid) and 𝐾 ∈ [1, 5] for larger real-world scenes (up to 5 grids, 71 | // depending on scene size) 72 | std::uint32_t K; 73 | 74 | // density grid resolution, the paper uses 128 for every cascade 75 | std::uint32_t G; 76 | 77 | // the half-length of the longest axis of the scene’s bounding box. E.g. the `bound` of the 78 | // bounding box [-1, 1]^3 is 1. 79 | float bound; 80 | 81 | // next step size is calculated as: 82 | // clamp(z_val[i] * stepsize_portion, sqrt3/1024.f, 2 * bound * sqrt3/1024.f) 83 | // where bound is the half-length of the largest axis of the scene’s bounding box, as mentioned 84 | // in Appendix E.1 of the NGP paper (the intercept theorem) 85 | float stepsize_portion; 86 | }; 87 | 88 | // Static parameters passed to `march_rays_inference` 89 | struct MarchingInferenceDescriptor { 90 | // total number of rays to march 91 | std::uint32_t n_total_rays; 92 | 93 | // number of rays to march in this iteration 94 | std::uint32_t n_rays; 95 | 96 | // same things as `diagonal_n_steps`, `K`, and `G` in `MarchingDescriptor` 97 | std::uint32_t diagonal_n_steps, K, G; 98 | 99 | // max steps to march, this is only used for early stopping marching, the minimal ray marching 100 | // step is still determined by `diagonal_n_steps` 101 | std::uint32_t march_steps_cap; 102 | 103 | // same thing as `bound` in `MarchingDescriptor` 104 | float bound; 105 | 106 | // same thing as `stepsize_portion` in `MarchingDescriptor` 107 | float stepsize_portion; 108 | }; 109 | 110 | struct Morton3DDescriptor { 111 | // number of entries to process 112 | std::uint32_t length; 113 | }; 114 | 115 | // Static parameters passed to `pack_density_into_bits` 116 | struct PackbitsDescriptor { 117 | std::uint32_t n_bytes; 118 | }; 119 | 120 | // functions to register 121 | void pack_density_into_bits( 122 | cudaStream_t stream, 123 | void **buffers, 124 | const char *opaque, 125 | std::size_t opaque_len 126 | ); 127 | 128 | void march_rays( 129 | cudaStream_t stream, 130 | void **buffers, 131 | const char *opaque, 132 | std::size_t opaque_len 133 | ); 134 | 135 | void march_rays_inference( 136 | cudaStream_t stream, 137 | void **buffers, 138 | const char *opaque, 139 | std::size_t opaque_len 140 | ); 141 | 142 | /// morton3d 143 | void morton3d( 144 | cudaStream_t stream, 145 | void **buffers, 146 | const char *opaque, 147 | std::size_t opaque_len 148 | ); 149 | void morton3d_invert( 150 | cudaStream_t stream, 151 | void **buffers, 152 | const char *opaque, 153 | std::size_t opaque_len 154 | ); 155 | 156 | void integrate_rays( 157 | cudaStream_t stream, 158 | void **buffers, 159 | const char *opaque, 160 | std::size_t opaque_len 161 | ); 162 | 163 | void integrate_rays_backward( 164 | cudaStream_t stream, 165 | void **buffers, 166 | const char *opaque, 167 | std::size_t opaque_len 168 | ); 169 | 170 | void integrate_rays_inference( 171 | cudaStream_t stream, 172 | void **buffers, 173 | const char *opaque, 174 | std::size_t opaque_len 175 | ); 176 | 177 | 178 | #ifdef __CUDACC__ 179 | inline __device__ float clampf(float val, float lo, float hi) { 180 | return fminf(fmaxf(val, lo), hi); 181 | } 182 | inline __device__ int clampi(int val, int lo, int hi) { 183 | return min(max(val, lo), hi); 184 | } 185 | inline __host__ __device__ float signf(const float x) { 186 | return copysignf(1.0, x); 187 | } 188 | 189 | template 190 | struct vec3 { 191 | T x, y, z; 192 | T __alignment; 193 | inline __device__ T L_inf() const { 194 | return fmaxf(fabsf(this->x), fmaxf(fabsf(this->y), fabsf(this->z))); 195 | } 196 | inline __device__ vec3 operator+() const { 197 | return *this; 198 | } 199 | inline __device__ vec3 operator+(vec3 const & rhs) const { 200 | return { 201 | this->x + rhs.x, 202 | this->y + rhs.y, 203 | this->z + rhs.z, 204 | }; 205 | } 206 | inline __device__ vec3 operator+(T const & rhs) const { 207 | return (*this) + vec3 {rhs, rhs, rhs}; 208 | } 209 | inline __device__ vec3 operator-() const { 210 | return { 211 | -this->x, 212 | -this->y, 213 | -this->z, 214 | }; 215 | } 216 | inline __device__ vec3 operator-(vec3 const & rhs) const { 217 | return (*this) + (-rhs); 218 | } 219 | inline __device__ vec3 operator-(T const & rhs) const { 220 | return (*this) - vec3 {rhs, rhs, rhs}; 221 | } 222 | inline __device__ vec3 operator*(vec3 const & rhs) const { 223 | return { 224 | this->x * rhs.x, 225 | this->y * rhs.y, 226 | this->z * rhs.z, 227 | }; 228 | } 229 | inline __device__ vec3 operator*(T const & rhs) const { 230 | return (*this) * vec3 {rhs, rhs, rhs}; 231 | } 232 | inline __device__ vec3 operator/(vec3 const & rhs) const { 233 | return { 234 | this->x / rhs.x, 235 | this->y / rhs.y, 236 | this->z / rhs.z, 237 | }; 238 | } 239 | inline __device__ vec3 operator/(T const & rhs) const { 240 | return (*this) / vec3 {rhs, rhs, rhs}; 241 | } 242 | }; 243 | inline __device__ vec3 sign_vec3f(vec3 const & v) { 244 | return { 245 | signf(v.x), 246 | signf(v.y), 247 | signf(v.z), 248 | }; 249 | } 250 | inline __device__ vec3 floor_vec3f(vec3 const & v) { 251 | return { 252 | floorf(v.x), 253 | floorf(v.y), 254 | floorf(v.z), 255 | }; 256 | } 257 | template 258 | inline __device__ vec3 operator+(T const & lhs, vec3 const & rhs) { 259 | return rhs + lhs; 260 | } 261 | template 262 | inline __device__ vec3 operator-(T const & lhs, vec3 const & rhs) { 263 | return -rhs + lhs; 264 | } 265 | template 266 | inline __device__ vec3 operator*(T const & lhs, vec3 const & rhs) { 267 | return rhs * lhs; 268 | } 269 | template 270 | inline __device__ vec3 operator/(T const & lhs, vec3 const & rhs) { 271 | return vec3 {lhs, lhs, lhs} / rhs; 272 | } 273 | using vec3f = vec3; 274 | using vec3u = vec3; 275 | #endif 276 | 277 | } // namespace volrendjax 278 | -------------------------------------------------------------------------------- /deps/volume-rendering-jax/lib/ffi.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "impl/volrend.h" 9 | 10 | namespace volrendjax { 11 | 12 | template 13 | pybind11::bytes to_pybind11_bytes(T const &descriptor) { 14 | return pybind11::bytes(serialize(descriptor)); 15 | } 16 | 17 | template 18 | pybind11::capsule encapsulate_function(T *fn) { 19 | return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); 20 | } 21 | 22 | // expose gpu function 23 | namespace { 24 | 25 | pybind11::dict get_packbits_registrations() { 26 | pybind11::dict dict; 27 | dict["pack_density_into_bits"] = encapsulate_function(pack_density_into_bits); 28 | return dict; 29 | } 30 | 31 | pybind11::dict get_marching_registrations() { 32 | pybind11::dict dict; 33 | dict["march_rays"] = encapsulate_function(march_rays); 34 | dict["march_rays_inference"] = encapsulate_function(march_rays_inference); 35 | return dict; 36 | } 37 | 38 | pybind11::dict get_morton3d_registrations() { 39 | pybind11::dict dict; 40 | dict["morton3d"] = encapsulate_function(morton3d); 41 | dict["morton3d_invert"] = encapsulate_function(morton3d_invert); 42 | return dict; 43 | } 44 | 45 | pybind11::dict get_integrating_registrations() { 46 | pybind11::dict dict; 47 | dict["integrate_rays"] = encapsulate_function(integrate_rays); 48 | dict["integrate_rays_backward"] = encapsulate_function(integrate_rays_backward); 49 | dict["integrate_rays_inference"] = encapsulate_function(integrate_rays_inference); 50 | return dict; 51 | } 52 | 53 | PYBIND11_MODULE(volrendutils_cuda, m) { 54 | m.def("get_packbits_registrations", &get_packbits_registrations); 55 | m.def("make_packbits_descriptor", 56 | [](std::uint32_t const n_bytes) { 57 | if (n_bytes == 0) { 58 | throw std::runtime_error("expected n_bytes to be a positive integer, got 0"); 59 | } 60 | return to_pybind11_bytes(PackbitsDescriptor{ 61 | .n_bytes = n_bytes, 62 | }); 63 | }, 64 | "Static arguments passed to the `pack_density_into_bits` function.\n\n" 65 | "Args:\n" 66 | " n_bytes: sum of number of byetes of all cascades of occupancy bitfields\n" 67 | ); 68 | 69 | m.def("get_marching_registrations", &get_marching_registrations); 70 | m.def("make_marching_descriptor", 71 | [](std::uint32_t const n_rays 72 | , std::uint32_t const total_samples 73 | , std::uint32_t const diagonal_n_steps 74 | , std::uint32_t const K 75 | , std::uint32_t const G 76 | , float const bound 77 | , float const stepsize_portion) { 78 | if (K == 0) { 79 | throw std::runtime_error("expected K to be a positive integer, got 0"); 80 | } 81 | return to_pybind11_bytes(MarchingDescriptor{ 82 | .n_rays = n_rays, 83 | .total_samples = total_samples, 84 | .diagonal_n_steps = diagonal_n_steps, 85 | .K = K, 86 | .G = G, 87 | .bound = bound, 88 | .stepsize_portion = stepsize_portion, 89 | }); 90 | }, 91 | "Static arguments passed to the `march_rays` function.\n\n" 92 | "Args:\n" 93 | " n_rays: number of input rays\n" 94 | " total_samples: number of available slots to write generated samples to, i.e. the\n" 95 | " length of output samples array\n" 96 | " diagonal_n_steps: used to calculate the length of a minimal ray marching step\n" 97 | " K: total number of cascades of the occupancy bitfield\n" 98 | " G: occupancy grid resolution, the paper uses 128 for every cascade\n" 99 | " bound: the half length of the longest axis of the scene’s bounding box,\n" 100 | " e.g. the `bound` of the bounding box [-1, 1]^3 is 1\n" 101 | " stepsize_portion: next step size is calculated as t * stepsize_portion,\n" 102 | " the paper uses 1/256\n" 103 | ); 104 | m.def("make_marching_inference_descriptor", 105 | [](std::uint32_t const n_total_rays 106 | , std::uint32_t const n_rays 107 | , std::uint32_t const diagonal_n_steps 108 | , std::uint32_t const K 109 | , std::uint32_t const G 110 | , std::uint32_t const march_steps_cap 111 | , float const bound 112 | , float const stepsize_portion) { 113 | if (K == 0) { 114 | throw std::runtime_error("expected K to be a positive integer, got 0"); 115 | } 116 | return to_pybind11_bytes(MarchingInferenceDescriptor{ 117 | .n_total_rays = n_total_rays, 118 | .n_rays = n_rays, 119 | .diagonal_n_steps = diagonal_n_steps, 120 | .K = K, 121 | .G = G, 122 | .march_steps_cap = march_steps_cap, 123 | .bound = bound, 124 | .stepsize_portion = stepsize_portion, 125 | }); 126 | }, 127 | "Static arguments passed to the `march_rays_inference` function.\n\n" 128 | "Args:\n" 129 | " n_total_rays: total number of rays to march\n" 130 | " n_rays: number of rays to march during this iteration\n" 131 | " diagonal_n_steps: used to calculate the length of a minimal ray marching step\n" 132 | " K: total number of cascades of the occupancy bitfield\n" 133 | " G: occupancy grid resolution, the paper uses 128 for every cascade\n" 134 | " march_steps_cap: max number of samples to generate for each ray\n" 135 | " bound: the half length of the longest axis of the scene’s bounding box,\n" 136 | " e.g. the `bound` of the bounding box [-1, 1]^3 is 1\n" 137 | " stepsize_portion: next step size is calculated as t * stepsize_portion,\n" 138 | " the paper uses 1/256\n" 139 | ); 140 | 141 | m.def("get_morton3d_registrations", &get_morton3d_registrations); 142 | m.def( 143 | "make_morton3d_descriptor", 144 | [](std::uint32_t const length) { 145 | return to_pybind11_bytes(Morton3DDescriptor { .length = length }); 146 | }, 147 | "Static arguments passed to the `morton3d` or `morton3d_invert` functions.\n\n" 148 | "Args:\n" 149 | " length: number of entries to process\n" 150 | "\n" 151 | "Returns:\n" 152 | " Serialized bytes that can be passed as the opaque parameter to `morton3d` or\n" 153 | " `morton3d_invert` functions" 154 | ); 155 | 156 | m.def("get_integrating_registrations", &get_integrating_registrations); 157 | m.def("make_integrating_descriptor", 158 | [](std::uint32_t const n_rays, std::uint32_t const total_samples) { 159 | return to_pybind11_bytes(IntegratingDescriptor{ 160 | .n_rays = n_rays, 161 | .total_samples = total_samples, 162 | }); 163 | }, 164 | "Static arguments passed to the `integrate_rays` function.\n\n" 165 | "Args:\n" 166 | " n_rays: number of rays\n" 167 | " total_samples: sum of number of samples on each ray\n" 168 | "\n" 169 | "Returns:\n" 170 | " Serialized bytes that can be passed as the opaque parameter to `integrate_rays`\n" 171 | " or `integrate_rays_backward`" 172 | ); 173 | m.def("make_integrating_backward_descriptor", 174 | [](std::uint32_t const n_rays, std::uint32_t const total_samples, float const near_distance) { 175 | return to_pybind11_bytes(IntegratingBackwardDescriptor{ 176 | .n_rays = n_rays, 177 | .total_samples = total_samples, 178 | .near_distance = near_distance, 179 | }); 180 | }, 181 | "Static arguments passed to the `integrate_rays_backward` function.\n\n" 182 | "Args:\n" 183 | " n_rays: number of rays\n" 184 | " total_samples: sum of number of samples on each ray\n" 185 | " near_distance: camera's near distance, samples behind the camera's near plane with\n" 186 | " non-negligible introduce a penalty on their densities\n" 187 | "\n" 188 | "Returns:\n" 189 | " Serialized bytes that can be passed as the opaque parameter to `integrate_rays`\n" 190 | " or `integrate_rays_backward`" 191 | ); 192 | m.def("make_integrating_inference_descriptor", 193 | [](std::uint32_t const n_total_rays 194 | , std::uint32_t const n_rays 195 | , std::uint32_t const march_steps_cap) { 196 | return to_pybind11_bytes(IntegratingInferenceDescriptor{ 197 | .n_total_rays = n_total_rays, 198 | .n_rays = n_rays, 199 | .march_steps_cap = march_steps_cap, 200 | }); 201 | }, 202 | "Static arguments passed to the `integrate_rays_inference`\n\n" 203 | "Args:\n" 204 | " n_total_rays: total number of rays to march\n" 205 | " n_rays: number of rays to integrate during this iteration\n" 206 | " march_steps_cap: see MarchingInferenceDescriptor\n" 207 | ); 208 | }; 209 | 210 | } // namespace 211 | 212 | } // namespace volrendjax 213 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023, Gaoyang Zhang, Yingxi Chen, and the jaxngp contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------