├── .gitmodules ├── gridencoder ├── __init__.py ├── src │ ├── bindings.cpp │ ├── gridencoder.h │ └── gridencoder.cu ├── backend.py ├── setup.py └── grid.py ├── raymarching ├── __init__.py ├── src │ ├── bindings.cpp │ ├── raymarching.h │ ├── pcg32.h │ └── raymarching.cu ├── backend.py ├── setup.py └── raymarching.py ├── shencoder ├── __init__.py ├── src │ ├── bindings.cpp │ ├── shencoder.h │ └── shencoder.cu ├── backend.py ├── setup.py └── sphere_harmonics.py ├── .gitignore ├── scripts ├── install_ext.sh ├── run.sh └── run_gui.sh ├── requirements.txt ├── loss.py ├── activation.py ├── LICENSE ├── nerf ├── network.py ├── bg_utils.py ├── provider.py ├── network_cc.py ├── gui.py └── renderer.py ├── encoding.py ├── readme.md └── main_nerf.py /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | 6 | tmp* 7 | data/ 8 | trial*/ 9 | .vs/ -------------------------------------------------------------------------------- /scripts/install_ext.sh: -------------------------------------------------------------------------------- 1 | cd raymarching 2 | pip install . 3 | cd .. 4 | 5 | cd shencoder 6 | pip install . 7 | cd .. 8 | 9 | # cd gridencoder 10 | # pip install . 11 | # cd .. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch-ema 2 | ninja 3 | trimesh 4 | opencv-python 5 | tensorboardX 6 | torch 7 | numpy 8 | pandas 9 | tqdm 10 | matplotlib 11 | PyMCubes 12 | rich 13 | pysdf 14 | dearpygui 15 | scipy 16 | git+https://github.com/openai/CLIP.git 17 | -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def mape_loss(pred, target): 6 | # pred, target: [B, 1], torch tenspr 7 | difference = (pred - target).abs() 8 | scale = 1 / (target.abs() + 1e-2) 9 | loss = difference * scale 10 | 11 | return loss.mean() -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | class _trunc_exp(Function): 6 | @staticmethod 7 | @custom_fwd(cast_inputs=torch.float) 8 | def forward(ctx, x): 9 | ctx.save_for_backward(x) 10 | return torch.exp(x) 11 | 12 | @staticmethod 13 | @custom_bwd 14 | def backward(ctx, g): 15 | x = ctx.saved_tensors[0] 16 | return g * torch.exp(x.clamp(-15, 15)) 17 | 18 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | // encode_forward(inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx) 10 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx); 11 | 12 | // sh_encode_backward(grad, inputs, B, input_dim, degree, ctx.calc_grad_inputs, dy_dx, grad_inputs) 13 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 8 | // train 9 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 10 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 11 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 12 | // infer 13 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 14 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 15 | m.def("compact_rays", &compact_rays, "compact rays (CUDA)"); 16 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype); 14 | 15 | #endif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 hawkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "bouquet of flowers sitting in a clear glass vase" --workspace trial_seed42_flowers --cuda_ray --fp16 # --gui 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a small green vase displays some small yellow blooms" --workspace trial_seed42_vase --cuda_ray --fp16 # --gui 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a slug crawling on the ground around flower petals" --workspace trial_seed42_slug --cuda_ray --fp16 # --gui 6 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a man" --workspace trial_seed42_man --cuda_ray --fp16 # --gui 7 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "armchair in the sahpe of an avocado" --workspace trial_seed42_avocado_chair --cuda_ray --fp16 # --gui 8 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "teapot in the sahpe of an avocado" --workspace trial_seed42_avocado_teapot --cuda_ray --fp16 # --gui 9 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a bird that has many colors on it" --workspace trial_seed42_bird --cuda_ray --fp16 # --gui 10 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a blue jug in a garden filled with mud" --workspace trial_seed42_jug --cuda_ray --fp16 # --gui 11 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "cthulhu" --workspace trial_seed42_cthulhu --cuda_ray --fp16 # --gui -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_sh_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'shencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_raymarching', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'raymarching.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /scripts/run_gui.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # "a bird that has many colors on it" 4 | # "an illustration of a pumpkin on the vine" 5 | # "an armchair in the shape of an avocado" 6 | # "a bouquet of roses in a vase" 7 | # "a sculpture of a rooster." 8 | # "a tray that has meat and carrots on a table." 9 | # "a small green vase displays some small yellow blooms." 10 | # "a high-quality 3d render of a jenga tower" 11 | # "a slug crawling on the ground around flower petals." 12 | # "bouquet of flowers sitting in a clear glass vase." 13 | 14 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py "a fox" --workspace trial --cuda_ray --gui 15 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py "a rainbow cat" --workspace trial --fp16 --ff --gui --seed 10 16 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py "a watermelon with a knife on it" --workspace trial --fp16 --ff --cuda_ray --gui --seed 42 17 | 18 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a red fox." --workspace trial --cuda_ray --gui --fp16 --seed 2 19 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --image ./data/redfox.jpg --workspace trial --cuda_ray --gui --fp16 --seed 2 #--dir_text 20 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "a red fox in a black sofa." --workspace trial --cuda_ray --gui --fp16 --seed 43 #--dir_text 21 | 22 | 23 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_nerf.py --text "an illustration of a pumpkin on the vine" --workspace trial --cuda_ray --gui --seed 42 -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='shencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_shencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'shencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | 9 | void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb); 10 | void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 11 | void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 12 | 13 | void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor density_grid, const float mean_density, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb); 14 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 15 | void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter); -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | ''' 33 | Usage: 34 | 35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 36 | 37 | python setup.py install # build extensions and install (copy) to PATH. 38 | pip install . # ditto but better (e.g., dependency & metadata handling) 39 | 40 | python setup.py develop # build extensions and install (symbolic) to PATH. 41 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 42 | 43 | ''' 44 | setup( 45 | name='raymarching', # package name, import this to use python API 46 | ext_modules=[ 47 | CUDAExtension( 48 | name='_raymarching', # extension name, import this to use CUDA API 49 | sources=[os.path.join(_src_path, 'src', f) for f in [ 50 | 'raymarching.cu', 51 | 'bindings.cpp', 52 | ]], 53 | extra_compile_args={ 54 | 'cxx': c_flags, 55 | 'nvcc': nvcc_flags, 56 | } 57 | ), 58 | ], 59 | cmdclass={ 60 | 'build_ext': BuildExtension, 61 | } 62 | ) -------------------------------------------------------------------------------- /nerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from .renderer import NeRFRenderer 7 | 8 | 9 | class NeRFNetwork(NeRFRenderer): 10 | def __init__(self, 11 | encoding="frequency", 12 | num_layers=6, 13 | hidden_dim=128, 14 | bound=1, 15 | **kwargs, 16 | ): 17 | super().__init__(bound, **kwargs) 18 | 19 | # sigma network 20 | self.num_layers = num_layers 21 | self.hidden_dim = hidden_dim 22 | self.encoder, self.in_dim = get_encoder(encoding) 23 | 24 | sigma_net = [] 25 | for l in range(num_layers): 26 | if l == 0: 27 | in_dim = self.in_dim 28 | else: 29 | in_dim = hidden_dim 30 | 31 | if l == num_layers - 1: 32 | out_dim = 4 33 | else: 34 | out_dim = hidden_dim 35 | 36 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 37 | 38 | self.sigma_net = nn.ModuleList(sigma_net) 39 | 40 | 41 | def forward(self, x, d): 42 | # x: [B, N, 3], in [-bound, bound] 43 | # d: [B, N, 3], nomalized in [-1, 1] 44 | 45 | # shift origin 46 | x = x + self.origin 47 | 48 | # sigma 49 | x = self.encoder(x, size=self.bound) 50 | 51 | h = x 52 | for l in range(self.num_layers): 53 | h = self.sigma_net[l](h) 54 | if l != self.num_layers - 1: 55 | h = F.relu(h, inplace=True) 56 | 57 | sigma = F.softplus(h[..., 0]) 58 | color = torch.sigmoid(h[..., 1:]) 59 | 60 | return sigma, color 61 | 62 | def density(self, x): 63 | # x: [B, N, 3], in [-bound, bound] 64 | 65 | # shift origin 66 | x = x + self.origin 67 | 68 | x = self.encoder(x, size=self.bound) 69 | 70 | h = x 71 | for l in range(self.num_layers): 72 | h = self.sigma_net[l](h) 73 | if l != self.num_layers - 1: 74 | h = F.relu(h, inplace=True) 75 | 76 | sigma = F.softplus(h[..., 0]) 77 | 78 | return { 79 | 'sigma': sigma, 80 | } 81 | 82 | 83 | # optimizer utils 84 | def get_params(self, lr1): 85 | return [ 86 | {'params': self.sigma_net.parameters(), 'lr': lr1}, 87 | ] -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | 16 | self.output_dim = 0 17 | if self.include_input: 18 | self.output_dim += self.input_dim 19 | 20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 21 | 22 | if log_sampling: 23 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 24 | else: 25 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 26 | 27 | self.freq_bands = self.freq_bands.numpy().tolist() 28 | 29 | def forward(self, input, **kwargs): 30 | 31 | out = [] 32 | if self.include_input: 33 | out.append(input) 34 | 35 | for i in range(len(self.freq_bands)): 36 | freq = self.freq_bands[i] 37 | for p_fn in self.periodic_fns: 38 | out.append(p_fn(input * freq)) 39 | 40 | out = torch.cat(out, dim=-1) 41 | 42 | 43 | return out 44 | 45 | def get_encoder(encoding, input_dim=3, 46 | multires=6, 47 | degree=4, 48 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, 49 | **kwargs): 50 | 51 | if encoding == 'None': 52 | return lambda x, **kwargs: None, 0 53 | 54 | elif encoding == 'identity': 55 | return lambda x, **kwargs: x, input_dim 56 | 57 | elif encoding == 'frequency': 58 | encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 59 | 60 | elif encoding == 'sphere_harmonics': 61 | from shencoder import SHEncoder 62 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 63 | 64 | elif encoding == 'hashgrid': 65 | from gridencoder import GridEncoder 66 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash') 67 | 68 | elif encoding == 'tiledgrid': 69 | from gridencoder import GridEncoder 70 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled') 71 | 72 | elif encoding == 'ash': 73 | from ashencoder import AshEncoder 74 | encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) 75 | 76 | else: 77 | raise NotImplementedError() 78 | 79 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.half) 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = torch.empty(1, dtype=inputs.dtype, device=inputs.device) 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | ctx.calc_grad_inputs = calc_grad_inputs 37 | 38 | return outputs 39 | 40 | @staticmethod 41 | @once_differentiable 42 | @custom_bwd 43 | def backward(ctx, grad): 44 | # grad: [B, C * C] 45 | 46 | if ctx.calc_grad_inputs: 47 | grad = grad.contiguous() 48 | inputs, dy_dx = ctx.saved_tensors 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # dreamfields-torch (WIP) 2 | 3 | A pytorch implementation of [dreamfields](https://github.com/google-research/google-research/tree/master/dreamfields) as described in [Zero-Shot Text-Guided Object Generation with Dream Fields](https://arxiv.org/abs/2112.01455). 4 | 5 | An example of a generated neural field by prompt "cthulhu" viewed in real-time: 6 | 7 | https://user-images.githubusercontent.com/25863658/158593558-a52fe215-4276-41eb-a588-cf60c9461cf3.mp4 8 | 9 | # Install 10 | 11 | The code framework is based on [torch-ngp](https://github.com/ashawkey/torch-ngp). 12 | 13 | ```bash 14 | git clone https://github.com/ashawkey/dreamfields-torch.git 15 | cd dreamfields-torch 16 | ``` 17 | 18 | ### Install with pip 19 | ```bash 20 | pip install -r requirements.txt 21 | 22 | # (optional) install the tcnn backbone 23 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 24 | ``` 25 | 26 | 27 | ### Build extension (optional) 28 | By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. 29 | However, this may be inconvenient sometimes. 30 | Therefore, we also provide the `setup.py` to build each extension: 31 | ```bash 32 | # install all extension modules 33 | bash scripts/install_ext.sh 34 | 35 | # if you want to install manually, here is an example: 36 | cd raymarching 37 | python setup.py build_ext --inplace # build ext only, do not install (only can be used in the parent directory) 38 | pip install . # install to python path (you still need the raymarching/ folder, since this only install the built extension.) 39 | ``` 40 | 41 | ### Tested environments 42 | * Ubuntu 20 with torch 1.10 & CUDA 11.3 on a TITAN RTX. 43 | * Windows 10 with torch 1.11 & CUDA 11.3 on a RTX 3070. 44 | 45 | Currently, `--ff` only supports GPUs with CUDA architecture `>= 70`. 46 | For GPUs with lower architecture, `--tcnn` can still be used, but the speed will be slower compared to more recent GPUs. 47 | 48 | # Usage 49 | 50 | First time running will take some time to compile the CUDA extensions. 51 | 52 | ```bash 53 | # text-guided generation 54 | python main_nerf.py --text "cthulhu" --workspace trial --cuda_ray --fp16 55 | 56 | # use the GUI 57 | python main_nerf.py --text "cthulhu" --workspace trial --cuda_ray --fp16 --gui 58 | 59 | # [experimental] image-guided generation (also use the CLIP loss) 60 | python main_nerf.py --image /path/to/image --workspace trial --cuda_ray --fp16 61 | 62 | ``` 63 | 64 | check the `scripts` directory for more provided examples. 65 | 66 | 67 | # Difference from the original implementation 68 | 69 | * Mip-nerf is not implemented, currently only the original nerf is supported. 70 | * Sampling poses with an elevation range in [-30, 30] degrees, instead of fixed at 30 degree. 71 | * Use the origin loss. 72 | 73 | 74 | # Update Logs 75 | * 5.18: major update. 76 | * 3.16: basic reproduction. 77 | 78 | 79 | # Acknowledgement 80 | 81 | * The great paper and official JAX implementation of [dreamfields](https://ajayj.com/dreamfields): 82 | ``` 83 | @article{jain2021dreamfields, 84 | author = {Jain, Ajay and Mildenhall, Ben and Barron, Jonathan T. and Abbeel, Pieter and Poole, Ben}, 85 | title = {Zero-Shot Text-Guided Object Generation with Dream Fields}, 86 | journal = {arXiv}, 87 | month = {December}, 88 | year = {2021}, 89 | } 90 | ``` 91 | 92 | * The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui). 93 | -------------------------------------------------------------------------------- /nerf/bg_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Direct JAX port of some helpers in lucid.optvis. 17 | 18 | Ported from https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/ 19 | """ 20 | 21 | import torch 22 | import numpy as np 23 | 24 | 25 | # Constants from lucid/optvis/param/color.py 26 | color_correlation_svd_sqrt = np.asarray( 27 | [[0.26, 0.09, 0.02], 28 | [0.27, 0.00, -0.05], 29 | [0.27, -0.09, 0.03]]).astype("float32") 30 | max_norm_svd_sqrt = np.max(np.linalg.norm(color_correlation_svd_sqrt, axis=0)) 31 | 32 | color_mean = [0.48, 0.46, 0.41] 33 | 34 | 35 | # this is a pain for channel-first image. I cannot afford to transpose again and again, since speed matters.. 36 | # this function is just too complicated compared to the other two random bg algorithms, and will cause imbalance in speed. 37 | # TODO: 38 | def _linear_correlate_color(t): 39 | """Multiply input by sqrt of empirical (ImageNet) color correlation matrix. 40 | 41 | If you interpret t's innermost dimension as describing colors in a 42 | decorrelated version of the color space (which is a very natural way to 43 | describe colors -- see discussion in Feature Visualization article) the way 44 | to map back to normal colors is multiply the square root of your color 45 | correlations. 46 | 47 | Args: 48 | t: input whitened color array, with trailing dimension 3. 49 | 50 | Returns: 51 | t_correlated: RGB color array. 52 | """ 53 | assert t.shape[-1] == 3 54 | t_flat = np.reshape(t, [-1, 3]) 55 | color_correlation_normalized = (color_correlation_svd_sqrt / max_norm_svd_sqrt) 56 | t_flat = np.matmul(t_flat, color_correlation_normalized.T) 57 | t_correlated = np.reshape(t_flat, t.shape) 58 | return t_correlated 59 | 60 | 61 | def constrain_l_inf(x): 62 | # NOTE(jainajay): does not use custom grad unlike Lucid 63 | return x / torch.maximum(1.0, torch.abs(x)) 64 | 65 | 66 | def to_valid_rgb(t, decorrelated=False, sigmoid=True): 67 | """Transform inner dimension of t to valid rgb colors. 68 | 69 | In practice this consists of two parts: 70 | (1) If requested, transform the colors from a decorrelated color space to RGB. 71 | (2) Constrain the color channels to be in [0,1], either using a sigmoid 72 | function or clipping. 73 | 74 | Args: 75 | t: Input tensor, trailing dimension will be interpreted as colors and 76 | transformed/constrained. 77 | decorrelated: If True, the input tensor's colors are interpreted as coming 78 | from a whitened space. 79 | sigmoid: If True, the colors are constrained elementwise using sigmoid. If 80 | False, colors are constrained by clipping infinity norm. 81 | 82 | Returns: 83 | t with the innermost dimension transformed. 84 | """ 85 | if decorrelated: 86 | t = _linear_correlate_color(t) 87 | if decorrelated and not sigmoid: 88 | t += color_mean 89 | 90 | if sigmoid: 91 | return torch.sigmoid(t) 92 | 93 | return constrain_l_inf(2 * t - 1) / 2 + 0.5 94 | 95 | 96 | def rfft2d_freqs(h, w): 97 | """Computes 2D spectrum frequencies.""" 98 | fy = torch.fft.fftfreq(h)[:, None] 99 | # when we have an odd input dimension we need to keep one additional 100 | # frequency and later cut off 1 pixel 101 | fx = torch.fft.fftfreq(w)[:w // 2 + 1 + w % 2] 102 | return torch.sqrt(fx * fx + fy * fy) 103 | 104 | 105 | def rand_fft_image(shape, sd=None, decay_power=1): 106 | """Generate a random background.""" 107 | b, ch, h, w = shape 108 | sd = 0.01 if sd is None else sd 109 | 110 | imgs = [] 111 | for _ in range(b): 112 | freqs = rfft2d_freqs(h, w) 113 | fh, fw = freqs.shape 114 | spectrum_var = sd * torch.randn(2, ch, fh, fw, dtype=torch.float32) 115 | spectrum = torch.complex(spectrum_var[0], spectrum_var[1]) 116 | spectrum_scale = 1.0 / torch.maximum(freqs, 1.0 / max(h, w))**decay_power 117 | # Scale the spectrum by the square-root of the number of pixels 118 | # to get a unitary transformation. This allows to use similar 119 | # learning rates to pixel-wise optimisation. 120 | spectrum_scale *= torch.sqrt(w * h) 121 | scaled_spectrum = spectrum * spectrum_scale 122 | # img = tf.signal.irfft2d(scaled_spectrum) 123 | img = torch.fft.irfft2(scaled_spectrum) 124 | # in case of odd input dimension we cut off the additional pixel 125 | # we get from irfft2d length computation 126 | img = img[:ch, :h, :w] 127 | imgs.append(img) 128 | 129 | return torch.stack(imgs) / 4.0 130 | 131 | 132 | def image_sample(key, shape, decorrelated=True, sd=None, decay_power=1): 133 | raw_spatial = rand_fft_image(key, shape, sd=sd, decay_power=decay_power) 134 | return to_valid_rgb(raw_spatial, decorrelated=decorrelated) 135 | -------------------------------------------------------------------------------- /main_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from nerf.provider import NeRFDataset 5 | from nerf.utils import * 6 | 7 | #torch.autograd.set_detect_anomaly(True) 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--text', default=None, help="text prompt") 13 | parser.add_argument('--image', default=None, help="ref image prompt") 14 | parser.add_argument('--test', action='store_true', help="test mode") 15 | parser.add_argument('--workspace', type=str, default='workspace') 16 | parser.add_argument('--seed', type=int, default=0) 17 | ### training options 18 | parser.add_argument('--iters', type=int, default=30000, help="training iters") 19 | parser.add_argument('--lr', type=float, default=5e-4, help="initial learning rate") 20 | parser.add_argument('--ckpt', type=str, default='latest') 21 | parser.add_argument('--num_rays', type=int, default=4096) 22 | parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") 23 | parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when not using --cuda_ray)") 24 | parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when not using --cuda_ray)") 25 | parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") 26 | 27 | ### network backbone options 28 | parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") 29 | parser.add_argument('--cc', action='store_true', help="use TensoRF") 30 | ### dataset options 31 | parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)") 32 | parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 33 | parser.add_argument('--w', type=int, default=128, help="render width for CLIP training (<=224)") 34 | parser.add_argument('--h', type=int, default=128, help="render height for CLIP training (<=224)") 35 | ### GUI options 36 | parser.add_argument('--gui', action='store_true', help="start a GUI") 37 | parser.add_argument('--W', type=int, default=800, help="GUI width") 38 | parser.add_argument('--H', type=int, default=800, help="GUI height") 39 | parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center") 40 | parser.add_argument('--fovy', type=float, default=90, help="default GUI camera fovy") 41 | parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel") 42 | ### other options 43 | parser.add_argument('--tau_0', type=float, default=0.5, help="target mean transparency 0") 44 | parser.add_argument('--tau_1', type=float, default=0.8, help="target mean transparency 1") 45 | parser.add_argument('--tau_step', type=float, default=500, help="steps to anneal from tau_0 to tau_1") 46 | parser.add_argument('--aug_copy', type=int, default=8, help="augmentation copy for each renderred image before feeding into CLIP") 47 | parser.add_argument('--dir_text', action='store_true', help="direction encoded text prompt") 48 | 49 | opt = parser.parse_args() 50 | 51 | assert not (opt.text is None and opt.image is None) 52 | 53 | 54 | if opt.cc: 55 | from nerf.network_cc import NeRFNetwork 56 | else: 57 | from nerf.network import NeRFNetwork 58 | 59 | print(opt) 60 | 61 | seed_everything(opt.seed) 62 | 63 | model = NeRFNetwork( 64 | bound=opt.bound, 65 | cuda_ray=opt.cuda_ray, 66 | density_scale=1, 67 | ) 68 | 69 | print(model) 70 | 71 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 72 | if opt.test: 73 | 74 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint='latest') 75 | 76 | if opt.gui: 77 | from nerf.gui import NeRFGUI 78 | gui = NeRFGUI(opt, trainer) 79 | gui.render() 80 | 81 | else: 82 | test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, radius=opt.radius, fovy=opt.fovy, size=10).dataloader() 83 | 84 | trainer.test(test_loader) 85 | 86 | else: 87 | 88 | optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) 89 | 90 | train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, radius=opt.radius, fovy=opt.fovy, size=100).dataloader() 91 | 92 | # decay to 0.1 * init_lr at last iter step 93 | scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) 94 | 95 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=20) 96 | 97 | if opt.gui: 98 | from nerf.gui import NeRFGUI 99 | trainer.train_loader = train_loader # attach dataloader to trainer 100 | 101 | gui = NeRFGUI(opt, trainer) 102 | gui.render() 103 | 104 | else: 105 | valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, radius=opt.radius, fovy=opt.fovy, size=10).dataloader() 106 | 107 | max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) 108 | trainer.train(train_loader, valid_loader, max_epoch) 109 | 110 | # also test 111 | test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, radius=opt.radius, fovy=opt.fovy, size=10).dataloader() 112 | trainer.test(test_loader) -------------------------------------------------------------------------------- /nerf/provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import json 5 | import tqdm 6 | import numpy as np 7 | from scipy.spatial.transform import Slerp, Rotation 8 | 9 | import trimesh 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from .utils import get_rays 15 | 16 | 17 | # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 18 | def nerf_matrix_to_ngp(pose, scale=0.33): 19 | # for the fox dataset, 0.33 scales camera radius to ~ 2 20 | new_pose = np.array([ 21 | [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale], 22 | [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale], 23 | [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale], 24 | [0, 0, 0, 1], 25 | ], dtype=np.float32) 26 | return new_pose 27 | 28 | 29 | def visualize_poses(poses, size=0.1): 30 | # poses: [B, 4, 4] 31 | 32 | axes = trimesh.creation.axis(axis_length=4) 33 | sphere = trimesh.creation.icosphere(radius=1) 34 | objects = [axes, sphere] 35 | 36 | for pose in poses: 37 | # a camera is visualized with 8 line segments. 38 | pos = pose[:3, 3] 39 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 40 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 41 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 42 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 43 | 44 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]]) 45 | segs = trimesh.load_path(segs) 46 | objects.append(segs) 47 | 48 | trimesh.Scene(objects).show() 49 | 50 | def get_view_direction(thetas, phis): 51 | # phis [B,]; thetas: [B,] 52 | # front = 0 0-90 53 | # side (left) = 1 90-180 54 | # back = 2 180-270 55 | # side (right) = 3 270-360 56 | # top = 4 0-45 57 | # bottom = 5 135-180 58 | res = np.zeros(phis.shape[0], dtype=np.int64) 59 | # first determine by phis 60 | res[phis < (np.pi / 2)] = 0 61 | res[(phis >= (np.pi / 2)) & (phis < np.pi)] = 1 62 | res[(phis >= np.pi) & (phis < (3 * np.pi / 2))] = 2 63 | res[(phis >= (3 * np.pi / 2)) & (phis < (2 * np.pi))] = 3 64 | # override by thetas 65 | res[thetas < (np.pi / 4)] = 4 66 | res[thetas > (3 * np.pi / 4)] = 5 67 | return res 68 | 69 | 70 | def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2 * np.pi/3], phi_range=[0, 2*np.pi]): 71 | ''' generate random poses from an orbit camera 72 | Args: 73 | size: batch size of generated poses. 74 | device: where to allocate the output. 75 | radius: camera radius 76 | theta_range: [min, max], should be in [0, \pi] 77 | phi_range: [min, max], should be in [0, 2\pi] 78 | Return: 79 | poses: [size, 4, 4] 80 | ''' 81 | 82 | def normalize(vectors): 83 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) 84 | 85 | thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] 86 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 87 | 88 | centers = torch.stack([ 89 | radius * torch.sin(thetas) * torch.sin(phis), 90 | radius * torch.cos(thetas), 91 | radius * torch.sin(thetas) * torch.cos(phis), 92 | ], dim=-1) # [B, 3] 93 | 94 | # lookat 95 | forward_vector = - normalize(centers) 96 | up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system... 97 | right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) 98 | up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) 99 | 100 | poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) 101 | poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) 102 | poses[:, :3, 3] = centers 103 | 104 | return poses 105 | 106 | 107 | class NeRFDataset: 108 | def __init__(self, opt, device, type='train', H=128, W=128, radius=3, fovy=90, size=100): 109 | super().__init__() 110 | 111 | self.opt = opt 112 | self.device = device 113 | self.type = type # train, val, test 114 | 115 | self.H = H 116 | self.W = W 117 | self.radius = radius 118 | self.fovy = fovy 119 | self.size = size 120 | 121 | self.training = self.type in ['train', 'all'] 122 | self.num_rays = self.opt.num_rays if self.training else -1 123 | 124 | fl_y = self.H / (2 * np.tan(np.radians(self.fovy) / 2)) 125 | fl_x = fl_y 126 | cx = self.H / 2 127 | cy = self.W / 2 128 | self.intrinsics = np.array([fl_x, fl_y, cx, cy]) 129 | 130 | # [debug] visualize poses 131 | #poses = rand_poses(100, 'cpu', radius=self.radius).detach().numpy() 132 | #visualize_poses(poses) 133 | 134 | 135 | def collate(self, index): 136 | 137 | B = len(index) # always 1 138 | 139 | # random pose 140 | poses = rand_poses(B, self.device, radius=self.radius) 141 | 142 | # sample a low-resolution but full image for CLIP 143 | rays = get_rays(poses, self.intrinsics, self.H, self.W, -1) 144 | 145 | return { 146 | 'H': self.H, 147 | 'W': self.W, 148 | 'rays_o': rays['rays_o'], 149 | 'rays_d': rays['rays_d'], 150 | } 151 | 152 | def dataloader(self): 153 | loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) 154 | loader._data = self # an ugly fix... we need to access error_map & poses in trainer. 155 | return loader -------------------------------------------------------------------------------- /gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | class _grid_encode(Function): 20 | @staticmethod 21 | @custom_fwd(cast_inputs=torch.half) 22 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0): 23 | # inputs: [B, D], float in [0, 1] 24 | # embeddings: [sO, C], float 25 | # offsets: [L + 1], int 26 | # RETURN: [B, F], float 27 | 28 | inputs = inputs.contiguous() 29 | embeddings = embeddings.contiguous() 30 | offsets = offsets.contiguous() 31 | 32 | B, D = inputs.shape # batch size, coord dim 33 | L = offsets.shape[0] - 1 # level 34 | C = embeddings.shape[1] # embedding dim for each level 35 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 36 | H = base_resolution # base resolution 37 | 38 | # L first, optimize cache for cuda kernel, but needs an extra permute later 39 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=inputs.dtype) 40 | 41 | if calc_grad_inputs: 42 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=inputs.dtype) 43 | else: 44 | dy_dx = torch.empty(1, device=inputs.device, dtype=inputs.dtype) 45 | 46 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype) 47 | 48 | # permute back to [B, L * C] 49 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 50 | 51 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 52 | ctx.dims = [B, D, C, L, S, H, gridtype] 53 | ctx.calc_grad_inputs = calc_grad_inputs 54 | 55 | return outputs 56 | 57 | @staticmethod 58 | #@once_differentiable 59 | @custom_bwd 60 | def backward(ctx, grad): 61 | 62 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 63 | B, D, C, L, S, H, gridtype = ctx.dims 64 | calc_grad_inputs = ctx.calc_grad_inputs 65 | 66 | # grad: [B, L * C] --> [L, B, C] 67 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 68 | 69 | grad_embeddings = torch.zeros_like(embeddings) 70 | 71 | if calc_grad_inputs: 72 | grad_inputs = torch.zeros_like(inputs) 73 | else: 74 | grad_inputs = torch.zeros(1, device=inputs.device, dtype=inputs.dtype) 75 | 76 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype) 77 | 78 | if calc_grad_inputs: 79 | return grad_inputs, grad_embeddings, None, None, None, None, None 80 | else: 81 | return None, grad_embeddings, None, None, None, None, None 82 | 83 | 84 | grid_encode = _grid_encode.apply 85 | 86 | 87 | class GridEncoder(nn.Module): 88 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash'): 89 | super().__init__() 90 | 91 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 92 | if desired_resolution is not None: 93 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 94 | 95 | self.input_dim = input_dim # coord dims, 2 or 3 96 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 97 | self.level_dim = level_dim # encode channels per level 98 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 99 | self.log2_hashmap_size = log2_hashmap_size 100 | self.base_resolution = base_resolution 101 | self.output_dim = num_levels * level_dim 102 | self.gridtype = gridtype 103 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 104 | 105 | if level_dim % 2 != 0: 106 | print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)') 107 | 108 | # allocate parameters 109 | offsets = [] 110 | offset = 0 111 | self.max_params = 2 ** log2_hashmap_size 112 | for i in range(num_levels): 113 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 114 | params_in_level = min(self.max_params, (resolution + 1) ** input_dim) # limit max number 115 | #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible 116 | offsets.append(offset) 117 | offset += params_in_level 118 | offsets.append(offset) 119 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 120 | self.register_buffer('offsets', offsets) 121 | 122 | self.n_params = offsets[-1] * level_dim 123 | 124 | # parameters 125 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 126 | 127 | self.reset_parameters() 128 | 129 | def reset_parameters(self): 130 | std = 1e-4 131 | self.embeddings.data.uniform_(-std, std) 132 | 133 | def __repr__(self): 134 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)} gridtype={self.gridtype}" 135 | 136 | def forward(self, inputs, bound=1): 137 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 138 | # return: [..., num_levels * level_dim] 139 | 140 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 141 | 142 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 143 | 144 | prefix_shape = list(inputs.shape[:-1]) 145 | inputs = inputs.view(-1, self.input_dim) 146 | 147 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id) 148 | outputs = outputs.view(prefix_shape + [self.output_dim]) 149 | 150 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 151 | 152 | return outputs -------------------------------------------------------------------------------- /nerf/network_cc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | from encoding import get_encoder 8 | from activation import trunc_exp 9 | from nerf.renderer import NeRFRenderer 10 | 11 | 12 | class NeRFNetwork(NeRFRenderer): 13 | def __init__(self, 14 | resolution=[256] * 3, 15 | rank_line=32, 16 | rank_plane=32, 17 | bound=1, 18 | **kwargs 19 | ): 20 | super().__init__(bound, **kwargs) 21 | 22 | self.resolution = np.asarray(resolution) 23 | 24 | self.out_dim = 4 25 | 26 | self.vec_ids = [0, 1, 2] 27 | self.mat_ids = [[0, 1], [0, 2], [1, 2]] 28 | 29 | self.rank_line = rank_line 30 | self.rank_plane = rank_plane 31 | 32 | # line 33 | self.U = nn.ParameterList() 34 | for i in range(len(self.vec_ids)): 35 | vec_id = self.vec_ids[i] 36 | self.U.append(nn.Parameter(torch.randn(1, rank_line, resolution[vec_id], 1) * 0.1)) # [1, R, H, 1] 37 | 38 | # plane 39 | self.V = nn.ParameterList() 40 | for i in range(len(self.mat_ids)): 41 | mat_id_0, mat_id_1 = self.mat_ids[i] 42 | self.V.append(nn.Parameter(torch.randn(1, rank_plane, resolution[mat_id_1], resolution[mat_id_0]) * 0.1)) # [1, R, H, W] 43 | 44 | # singular values (for line and plane, separately) 45 | self.S = nn.ParameterList() 46 | self.S.append(nn.Parameter(torch.ones(self.out_dim, rank_line))) 47 | self.S.append(nn.Parameter(torch.ones(self.out_dim, rank_plane))) 48 | 49 | torch.nn.init.kaiming_normal_(self.S[0].data) 50 | torch.nn.init.kaiming_normal_(self.S[1].data) 51 | 52 | 53 | def transform(self, x): 54 | # x: [N, 3], in [-bound, bound] 55 | # y: transformed x in oid's coordinate system, and normalized into [-1, 1] 56 | 57 | #x = x + self.origin 58 | 59 | aabb = self.aabb_train if self.training else self.aabb_infer 60 | y = 2 * (x - aabb[:3]) / (aabb[3:] - aabb[:3]) - 1 # in [-1, 1] (may have outliers, but no matter since grid_sample use zero padding.) 61 | 62 | return y 63 | 64 | 65 | def get_feat(self, x): 66 | # x: [N, 3], in [-1, 1] 67 | 68 | N = x.shape[0] 69 | 70 | vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) 71 | vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).detach().view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord 72 | 73 | mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).detach().view(3, -1, 1, 2) # [3, N, 1, 2] 74 | 75 | 76 | vec_feat = F.grid_sample(self.U[0], vec_coord[[0]], align_corners=True).view(-1, N) * \ 77 | F.grid_sample(self.U[1], vec_coord[[1]], align_corners=True).view(-1, N) * \ 78 | F.grid_sample(self.U[2], vec_coord[[2]], align_corners=True).view(-1, N) # [R1, N] 79 | 80 | mat_feat = F.grid_sample(self.V[0], mat_coord[[0]], align_corners=True).view(-1, N) * \ 81 | F.grid_sample(self.V[1], mat_coord[[1]], align_corners=True).view(-1, N) * \ 82 | F.grid_sample(self.V[2], mat_coord[[2]], align_corners=True).view(-1, N) # [R2, N] 83 | 84 | S_vec = self.S[0] 85 | S_mat = self.S[1] 86 | 87 | vec_feat = S_vec @ vec_feat # [out_dim, N] 88 | mat_feat = S_mat @ mat_feat # [out_dim, N] 89 | 90 | hybrid_feat = (vec_feat + mat_feat).T.contiguous() # [out_dim, N] --> [N, out_dim] 91 | 92 | return hybrid_feat 93 | 94 | def density(self, x): 95 | # x: [N, 3], in [-bound, bound] 96 | 97 | # normalize to [-1, 1] 98 | x_model = self.transform(x) 99 | 100 | feat = self.get_feat(x_model) # [N, out_dim] 101 | sigma = trunc_exp(feat[..., 0]) 102 | 103 | return { 104 | 'sigma': sigma, 105 | 'feat': feat, 106 | } 107 | 108 | # allow masked inference 109 | def color(self, x, d, mask=None, feat=None, **kwargs): 110 | # x: [N, 3] in [-bound, bound] 111 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 112 | # feat: [N, out_dim] 113 | N = x.shape[0] 114 | 115 | h = feat[..., 1:] # [N, 3] 116 | rgbs = torch.sigmoid(h) 117 | 118 | return rgbs 119 | 120 | # L1 penalty for loss 121 | def density_loss(self): 122 | loss = 0 123 | for i in range(len(self.U)): 124 | loss += torch.mean(torch.abs(self.U[i])) + torch.mean(torch.abs(self.V[i])) 125 | return loss 126 | 127 | 128 | @torch.no_grad() 129 | def upsample_model(self, resolution): 130 | for i in range(len(self.U)): 131 | vec_id = self.vec_ids[i % 3] 132 | self.U[i] = torch.nn.Parameter(F.interpolate(self.U[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=True)) 133 | 134 | for i in range(len(self.V)): 135 | mat_id_0, mat_id_1 = self.mat_ids[i % 3] 136 | self.V[i] = torch.nn.Parameter(F.interpolate(self.V[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode='bilinear', align_corners=True)) 137 | 138 | self.resolution = resolution 139 | 140 | 141 | @torch.no_grad() 142 | def shrink_model(self): 143 | # shrink aabb_train and the model so it only represents the space inside aabb_train. 144 | 145 | H = self.density_grid.shape[1] 146 | half_grid_size = self.bound / H 147 | thresh = min(0.01, self.mean_density) 148 | 149 | # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) 150 | valid_grid = self.density_grid[self.cascade - 1] > thresh # [H, W, D] 151 | valid_pos = torch.nonzero(valid_grid) # [Nz, 3], in [0, H - 1] 152 | #plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... 153 | valid_pos = (2 * valid_pos / (H - 1) - 1) * (self.bound - half_grid_size) # [Nz, 3], in [-b+hgs, b-hgs] 154 | min_pos = valid_pos.amin(0) - half_grid_size # [3] 155 | max_pos = valid_pos.amax(0) + half_grid_size # [3] 156 | 157 | # shrink model 158 | reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) 159 | units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso 160 | tl = (min_pos - self.aabb_train[:3]) / units 161 | br = (max_pos - self.aabb_train[:3]) / units 162 | tl = torch.round(tl).long().clamp(min=0) 163 | br = torch.minimum(torch.round(br).long(), reso) 164 | 165 | for i in range(len(self.U)): 166 | vec_id = self.vec_ids[i % 3] 167 | self.U[i] = nn.Parameter(self.U[i].data[..., tl[vec_id]:br[vec_id], :]) 168 | 169 | for i in range(len(self.V)): 170 | mat_id_0, mat_id_1 = self.mat_ids[i % 3] 171 | self.V[i] = nn.Parameter(self.V[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) 172 | 173 | self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] 174 | 175 | print(f'[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}') 176 | print(f'[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}') 177 | 178 | 179 | # optimizer utils 180 | def get_params(self, lr1, lr2=None): 181 | if lr2 is None: 182 | lr2 = lr1 183 | return [ 184 | {'params': self.U, 'lr': lr1}, 185 | {'params': self.V, 'lr': lr1}, 186 | {'params': self.S, 'lr': lr2}, 187 | ] -------------------------------------------------------------------------------- /raymarching/src/pcg32.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tiny self-contained version of the PCG Random Number Generation for C++ 3 | * put together from pieces of the much larger C/C++ codebase. 4 | * Wenzel Jakob, February 2015 5 | * 6 | * The PCG random number generator was developed by Melissa O'Neill 7 | * 8 | * 9 | * Licensed under the Apache License, Version 2.0 (the "License"); 10 | * you may not use this file except in compliance with the License. 11 | * You may obtain a copy of the License at 12 | * 13 | * http://www.apache.org/licenses/LICENSE-2.0 14 | * 15 | * Unless required by applicable law or agreed to in writing, software 16 | * distributed under the License is distributed on an "AS IS" BASIS, 17 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | * See the License for the specific language governing permissions and 19 | * limitations under the License. 20 | * 21 | * For additional information about the PCG random number generation scheme, 22 | * including its license and other licensing options, visit 23 | * 24 | * http://www.pcg-random.org 25 | * 26 | * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors. 27 | */ 28 | 29 | #pragma once 30 | 31 | #define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL 32 | #define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL 33 | #define PCG32_MULT 0x5851f42d4c957f2dULL 34 | 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | #include 41 | #include 42 | 43 | /// PCG32 Pseudorandom number generator 44 | struct pcg32 { 45 | /// Initialize the pseudorandom number generator with default seed 46 | __host__ __device__ pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {} 47 | 48 | /// Initialize the pseudorandom number generator with the \ref seed() function 49 | __host__ __device__ pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); } 50 | 51 | /** 52 | * \brief Seed the pseudorandom number generator 53 | * 54 | * Specified in two parts: a state initializer and a sequence selection 55 | * constant (a.k.a. stream id) 56 | */ 57 | __host__ __device__ void seed(uint64_t initstate, uint64_t initseq = 1) { 58 | state = 0U; 59 | inc = (initseq << 1u) | 1u; 60 | next_uint(); 61 | state += initstate; 62 | next_uint(); 63 | } 64 | 65 | /// Generate a uniformly distributed unsigned 32-bit random number 66 | __host__ __device__ uint32_t next_uint() { 67 | uint64_t oldstate = state; 68 | state = oldstate * PCG32_MULT + inc; 69 | uint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u); 70 | uint32_t rot = (uint32_t) (oldstate >> 59u); 71 | return (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31)); 72 | } 73 | 74 | /// Generate a uniformly distributed number, r, where 0 <= r < bound 75 | __host__ __device__ uint32_t next_uint(uint32_t bound) { 76 | // To avoid bias, we need to make the range of the RNG a multiple of 77 | // bound, which we do by dropping output less than a threshold. 78 | // A naive scheme to calculate the threshold would be to do 79 | // 80 | // uint32_t threshold = 0x100000000ull % bound; 81 | // 82 | // but 64-bit div/mod is slower than 32-bit div/mod (especially on 83 | // 32-bit platforms). In essence, we do 84 | // 85 | // uint32_t threshold = (0x100000000ull-bound) % bound; 86 | // 87 | // because this version will calculate the same modulus, but the LHS 88 | // value is less than 2^32. 89 | 90 | uint32_t threshold = (~bound+1u) % bound; 91 | 92 | // Uniformity guarantees that this loop will terminate. In practice, it 93 | // should usually terminate quickly; on average (assuming all bounds are 94 | // equally likely), 82.25% of the time, we can expect it to require just 95 | // one iteration. In the worst case, someone passes a bound of 2^31 + 1 96 | // (i.e., 2147483649), which invalidates almost 50% of the range. In 97 | // practice, bounds are typically small and only a tiny amount of the range 98 | // is eliminated. 99 | for (;;) { 100 | uint32_t r = next_uint(); 101 | if (r >= threshold) 102 | return r % bound; 103 | } 104 | } 105 | 106 | /// Generate a single precision floating point value on the interval [0, 1) 107 | __host__ __device__ float next_float() { 108 | /* Trick from MTGP: generate an uniformly distributed 109 | single precision number in [1,2) and subtract 1. */ 110 | union { 111 | uint32_t u; 112 | float f; 113 | } x; 114 | x.u = (next_uint() >> 9) | 0x3f800000u; 115 | return x.f - 1.0f; 116 | } 117 | 118 | /** 119 | * \brief Generate a double precision floating point value on the interval [0, 1) 120 | * 121 | * \remark Since the underlying random number generator produces 32 bit output, 122 | * only the first 32 mantissa bits will be filled (however, the resolution is still 123 | * finer than in \ref next_float(), which only uses 23 mantissa bits) 124 | */ 125 | __host__ __device__ double next_double() { 126 | /* Trick from MTGP: generate an uniformly distributed 127 | double precision number in [1,2) and subtract 1. */ 128 | union { 129 | uint64_t u; 130 | double d; 131 | } x; 132 | x.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL; 133 | return x.d - 1.0; 134 | } 135 | 136 | /** 137 | * \brief Multi-step advance function (jump-ahead, jump-back) 138 | * 139 | * The method used here is based on Brown, "Random Number Generation 140 | * with Arbitrary Stride", Transactions of the American Nuclear 141 | * Society (Nov. 1994). The algorithm is very similar to fast 142 | * exponentiation. 143 | * 144 | * The default value of 2^32 ensures that the PRNG is advanced 145 | * sufficiently far that there is (likely) no overlap with 146 | * previously drawn random numbers, even if small advancements. 147 | * are made inbetween. 148 | */ 149 | __host__ __device__ void advance(int64_t delta_ = (1ll<<32)) { 150 | uint64_t 151 | cur_mult = PCG32_MULT, 152 | cur_plus = inc, 153 | acc_mult = 1u, 154 | acc_plus = 0u; 155 | 156 | /* Even though delta is an unsigned integer, we can pass a signed 157 | integer to go backwards, it just goes "the long way round". */ 158 | uint64_t delta = (uint64_t) delta_; 159 | 160 | while (delta > 0) { 161 | if (delta & 1) { 162 | acc_mult *= cur_mult; 163 | acc_plus = acc_plus * cur_mult + cur_plus; 164 | } 165 | cur_plus = (cur_mult + 1) * cur_plus; 166 | cur_mult *= cur_mult; 167 | delta /= 2; 168 | } 169 | state = acc_mult * state + acc_plus; 170 | } 171 | 172 | /// Compute the distance between two PCG32 pseudorandom number generators 173 | __host__ __device__ int64_t operator-(const pcg32 &other) const { 174 | assert(inc == other.inc); 175 | 176 | uint64_t 177 | cur_mult = PCG32_MULT, 178 | cur_plus = inc, 179 | cur_state = other.state, 180 | the_bit = 1u, 181 | distance = 0u; 182 | 183 | while (state != cur_state) { 184 | if ((state & the_bit) != (cur_state & the_bit)) { 185 | cur_state = cur_state * cur_mult + cur_plus; 186 | distance |= the_bit; 187 | } 188 | assert((state & the_bit) == (cur_state & the_bit)); 189 | the_bit <<= 1; 190 | cur_plus = (cur_mult + 1ULL) * cur_plus; 191 | cur_mult *= cur_mult; 192 | } 193 | 194 | return (int64_t) distance; 195 | } 196 | 197 | /// Equality operator 198 | __host__ __device__ bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; } 199 | 200 | /// Inequality operator 201 | __host__ __device__ bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; } 202 | 203 | uint64_t state; // RNG state. All values are possible. 204 | uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd. 205 | }; -------------------------------------------------------------------------------- /raymarching/raymarching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _raymarching as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | # ---------------------------------------- 16 | # utils 17 | # ---------------------------------------- 18 | 19 | class _near_far_from_aabb(Function): 20 | @staticmethod 21 | @custom_fwd(cast_inputs=torch.float32) 22 | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): 23 | ''' near_far_from_aabb, CUDA implementation 24 | Calculate rays' intersection time (near and far) with aabb 25 | Args: 26 | rays_o: float, [N, 3] 27 | rays_d: float, [N, 3] 28 | aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) 29 | min_near: float, scalar 30 | Returns: 31 | nears: float, [N] 32 | fars: float, [N] 33 | ''' 34 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 35 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 36 | 37 | rays_o = rays_o.contiguous().view(-1, 3) 38 | rays_d = rays_d.contiguous().view(-1, 3) 39 | 40 | N = rays_o.shape[0] # num rays 41 | 42 | nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 43 | fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 44 | 45 | _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) 46 | 47 | return nears, fars 48 | 49 | near_far_from_aabb = _near_far_from_aabb.apply 50 | 51 | # ---------------------------------------- 52 | # train functions 53 | # ---------------------------------------- 54 | 55 | class _march_rays_train(Function): 56 | @staticmethod 57 | @custom_fwd(cast_inputs=torch.float32) 58 | def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): 59 | ''' march rays to generate points (forward only) 60 | Args: 61 | rays_o/d: float, [N, 3] 62 | bound: float, scalar 63 | density_grid: float, [C, H, H, H] 64 | mean_density: float, scalar 65 | nears/fars: float, [N] 66 | step_counter: int32, (2), used to count the actual number of generated points. 67 | mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) 68 | perturb: bool 69 | align: int, pad output so its size is dividable by align, set to -1 to disable. 70 | force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. 71 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 72 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 73 | Returns: 74 | xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) 75 | dirs: float, [M, 3], all generated points' view dirs. 76 | deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) 77 | rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] 78 | ''' 79 | 80 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 81 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 82 | if not density_grid.is_cuda: density_grid = density_grid.cuda() 83 | 84 | rays_o = rays_o.contiguous().view(-1, 3) 85 | rays_d = rays_d.contiguous().view(-1, 3) 86 | density_grid = density_grid.contiguous() 87 | 88 | N = rays_o.shape[0] # num rays 89 | 90 | C = density_grid.shape[0] # grid cascade 91 | H = density_grid.shape[1] # grid resolution 92 | 93 | M = N * max_steps # init max points number in total 94 | 95 | # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) 96 | # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. 97 | if not force_all_rays and mean_count > 0: 98 | if align > 0: 99 | mean_count += align - mean_count % align 100 | M = mean_count 101 | 102 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 103 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 104 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) 105 | rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps 106 | 107 | if step_counter is None: 108 | step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter 109 | 110 | _backend.march_rays_train(rays_o, rays_d, density_grid, mean_density, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, perturb) # m is the actually used points number 111 | 112 | #print(step_counter, M) 113 | 114 | # only used at the first (few) epochs. 115 | if force_all_rays or mean_count <= 0: 116 | m = step_counter[0].item() # D2H copy 117 | if align > 0: 118 | m += align - m % align 119 | xyzs = xyzs[:m] 120 | dirs = dirs[:m] 121 | deltas = deltas[:m] 122 | 123 | torch.cuda.empty_cache() 124 | 125 | return xyzs, dirs, deltas, rays 126 | 127 | march_rays_train = _march_rays_train.apply 128 | 129 | 130 | class _composite_rays_train(Function): 131 | @staticmethod 132 | @custom_fwd(cast_inputs=torch.float32) 133 | def forward(ctx, sigmas, rgbs, deltas, rays): 134 | ''' composite rays' rgbs, according to the ray marching formula. 135 | Args: 136 | rgbs: float, [M, 3] 137 | sigmas: float, [M,] 138 | deltas: float, [M, 2] 139 | rays: int32, [N, 3] 140 | Returns: 141 | weights_sum: float, [N,], the alpha channel 142 | depth: float, [N, ], the Depth 143 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 144 | ''' 145 | 146 | sigmas = sigmas.contiguous() 147 | rgbs = rgbs.contiguous() 148 | 149 | M = sigmas.shape[0] 150 | N = rays.shape[0] 151 | 152 | weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 153 | depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 154 | image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) 155 | 156 | _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, weights_sum, depth, image) 157 | 158 | ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) 159 | ctx.dims = [M, N] 160 | 161 | return weights_sum, depth, image 162 | 163 | @staticmethod 164 | @custom_bwd 165 | def backward(ctx, grad_weights_sum, grad_depth, grad_image): 166 | 167 | # NOTE: grad_depth is not used now! It won't be propagated to sigmas. 168 | 169 | grad_weights_sum = grad_weights_sum.contiguous() 170 | grad_image = grad_image.contiguous() 171 | 172 | sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors 173 | M, N = ctx.dims 174 | 175 | grad_sigmas = torch.zeros_like(sigmas) 176 | grad_rgbs = torch.zeros_like(rgbs) 177 | 178 | _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, grad_sigmas, grad_rgbs) 179 | 180 | return grad_sigmas, grad_rgbs, None, None 181 | 182 | 183 | composite_rays_train = _composite_rays_train.apply 184 | 185 | # ---------------------------------------- 186 | # infer functions 187 | # ---------------------------------------- 188 | 189 | class _march_rays(Function): 190 | @staticmethod 191 | @custom_fwd(cast_inputs=torch.float32) 192 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_grid, mean_density, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): 193 | ''' march rays to generate points (forward only, for inference) 194 | Args: 195 | n_alive: int, number of alive rays 196 | n_step: int, how many steps we march 197 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) 198 | rays_t: float, [N], the alive rays' time, we only use the first n_alive. 199 | rays_o/d: float, [N, 3] 200 | bound: float, scalar 201 | density_grid: float, [C, H, H, H] 202 | mean_density: float, scalar 203 | nears/fars: float, [N] 204 | align: int, pad output so its size is dividable by align, set to -1 to disable. 205 | perturb: bool/int, int > 0 is used as the random seed. 206 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 207 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 208 | Returns: 209 | xyzs: float, [n_alive * n_step, 3], all generated points' coords 210 | dirs: float, [n_alive * n_step, 3], all generated points' view dirs. 211 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). 212 | ''' 213 | 214 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 215 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 216 | 217 | rays_o = rays_o.contiguous().view(-1, 3) 218 | rays_d = rays_d.contiguous().view(-1, 3) 219 | 220 | C = density_grid.shape[0] # grid cascade 221 | H = density_grid.shape[1] # grid resolution 222 | M = n_alive * n_step 223 | 224 | if align > 0: 225 | M += align - (M % align) 226 | 227 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 228 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 229 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth 230 | 231 | _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_grid, mean_density, near, far, xyzs, dirs, deltas, perturb) 232 | 233 | return xyzs, dirs, deltas 234 | 235 | march_rays = _march_rays.apply 236 | 237 | 238 | class _composite_rays(Function): 239 | @staticmethod 240 | @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float 241 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image): 242 | ''' composite rays' rgbs, according to the ray marching formula. (for inference) 243 | Args: 244 | n_alive: int, number of alive rays 245 | n_step: int, how many steps we march 246 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) 247 | rays_t: float, [N], the alive rays' time, we only use the first n_alive. 248 | sigmas: float, [n_alive * n_step,] 249 | rgbs: float, [n_alive * n_step, 3] 250 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). 251 | In-place Outputs: 252 | weights_sum: float, [N,], the alpha channel 253 | depth: float, [N,], the depth value 254 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 255 | ''' 256 | _backend.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) 257 | return tuple() 258 | 259 | 260 | composite_rays = _composite_rays.apply 261 | 262 | 263 | class _compact_rays(Function): 264 | @staticmethod 265 | @custom_fwd(cast_inputs=torch.float32) 266 | def forward(ctx, n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter): 267 | ''' compact rays, remove dead rays and reallocate alive rays, to accelerate next ray marching. 268 | Args: 269 | n_alive: int, number of alive rays 270 | rays_alive_old: int, [N] 271 | rays_t_old: float, [N], dead rays are marked by rays_t < 0 272 | alive_counter: int, [1], used to count remained alive rays. 273 | In-place Outputs: 274 | rays_alive: int, [N] 275 | rays_t: float, [N] 276 | ''' 277 | _backend.compact_rays(n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter) 278 | return tuple() 279 | 280 | compact_rays = _compact_rays.apply -------------------------------------------------------------------------------- /nerf/gui.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import dearpygui.dearpygui as dpg 5 | from scipy.spatial.transform import Rotation as R 6 | 7 | from nerf.utils import * 8 | 9 | 10 | class OrbitCamera: 11 | def __init__(self, W, H, r=2, fovy=60): 12 | self.W = W 13 | self.H = H 14 | self.radius = r # camera distance from center 15 | self.fovy = fovy # in degree 16 | self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point 17 | self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention) 18 | self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! 19 | 20 | # pose 21 | @property 22 | def pose(self): 23 | # first move camera to radius 24 | res = np.eye(4, dtype=np.float32) 25 | res[2, 3] -= self.radius 26 | # rotate 27 | rot = np.eye(4, dtype=np.float32) 28 | rot[:3, :3] = self.rot.as_matrix() 29 | res = rot @ res 30 | # translate 31 | res[:3, 3] -= self.center 32 | return res 33 | 34 | # intrinsics 35 | @property 36 | def intrinsics(self): 37 | focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2)) 38 | return np.array([focal, focal, self.W // 2, self.H // 2]) 39 | 40 | def orbit(self, dx, dy): 41 | # rotate along camera up/side axis! 42 | side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. 43 | rotvec_x = self.up * np.radians(-0.1 * dx) 44 | rotvec_y = side * np.radians(-0.1 * dy) 45 | self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot 46 | 47 | # wrong: rotate along global x/y axis 48 | #self.rot = R.from_euler('xy', [-dy * 0.1, -dx * 0.1], degrees=True) * self.rot 49 | 50 | def scale(self, delta): 51 | self.radius *= 1.1 ** (-delta) 52 | 53 | def pan(self, dx, dy, dz=0): 54 | # pan in camera coordinate system (careful on the sensitivity!) 55 | self.center += 0.001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) 56 | 57 | # wrong: pan in global coordinate system 58 | #self.center += 0.001 * np.array([-dx, -dy, dz]) 59 | 60 | 61 | 62 | class NeRFGUI: 63 | def __init__(self, opt, trainer, debug=True): 64 | self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. 65 | self.W = opt.W 66 | self.H = opt.H 67 | self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) 68 | self.trainer = trainer 69 | self.debug = debug 70 | self.bg_color = None 71 | self.training = False 72 | self.step = 0 # training step 73 | 74 | self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) 75 | self.need_update = True # camera moved, should reset accumulation 76 | self.spp = 1 # sample per pixel 77 | 78 | self.dynamic_resolution = True 79 | self.downscale = 1 80 | self.train_steps = 16 81 | 82 | dpg.create_context() 83 | self.register_dpg() 84 | self.test_step() 85 | 86 | 87 | def __del__(self): 88 | dpg.destroy_context() 89 | 90 | 91 | def train_step(self): 92 | 93 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 94 | starter.record() 95 | 96 | outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps) 97 | 98 | ender.record() 99 | torch.cuda.synchronize() 100 | t = starter.elapsed_time(ender) 101 | 102 | self.step += self.train_steps 103 | self.need_update = True 104 | 105 | dpg.set_value("_log_train_time", f'{t:.4f}ms') 106 | dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') 107 | 108 | # dynamic train steps 109 | # max allowed train time per-frame is 500 ms 110 | full_t = t / self.train_steps * 16 111 | train_steps = min(16, max(4, int(16 * 500 / full_t))) 112 | if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: 113 | self.train_steps = train_steps 114 | 115 | 116 | def test_step(self): 117 | # TODO: seems we have to move data from GPU --> CPU --> GPU? 118 | 119 | if self.need_update or self.spp < self.opt.max_spp: 120 | 121 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 122 | starter.record() 123 | 124 | outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale) 125 | 126 | ender.record() 127 | torch.cuda.synchronize() 128 | t = starter.elapsed_time(ender) 129 | 130 | # update dynamic resolution 131 | if self.dynamic_resolution: 132 | # max allowed infer time per-frame is 200 ms 133 | full_t = t / (self.downscale ** 2) 134 | downscale = min(1, max(1/4, math.sqrt(200 / full_t))) 135 | if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: 136 | self.downscale = downscale 137 | 138 | if self.need_update: 139 | self.render_buffer = outputs['image'] 140 | self.spp = 1 141 | self.need_update = False 142 | else: 143 | self.render_buffer = (self.render_buffer * self.spp + outputs['image']) / (self.spp + 1) 144 | self.spp += 1 145 | 146 | dpg.set_value("_log_infer_time", f'{t:.4f}ms') 147 | dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') 148 | dpg.set_value("_log_spp", self.spp) 149 | dpg.set_value("_texture", self.render_buffer) 150 | 151 | 152 | def register_dpg(self): 153 | 154 | ### register texture 155 | 156 | with dpg.texture_registry(show=False): 157 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") 158 | 159 | ### register window 160 | 161 | # the rendered image, as the primary window 162 | with dpg.window(tag="_primary_window", width=self.W, height=self.H): 163 | 164 | # add the texture 165 | dpg.add_image("_texture") 166 | 167 | dpg.set_primary_window("_primary_window", True) 168 | 169 | # control window 170 | with dpg.window(label="Control", tag="_control_window", width=400, height=300): 171 | 172 | # text prompt 173 | if self.opt.text is not None: 174 | dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text") 175 | if self.opt.image is not None: 176 | dpg.add_text("ref image: " + self.opt.image, tag="_log_prompt_image") 177 | 178 | # button theme 179 | with dpg.theme() as theme_button: 180 | with dpg.theme_component(dpg.mvButton): 181 | dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) 182 | dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) 183 | dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) 184 | dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) 185 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) 186 | 187 | # time 188 | if not self.opt.test: 189 | with dpg.group(horizontal=True): 190 | dpg.add_text("Train time: ") 191 | dpg.add_text("no data", tag="_log_train_time") 192 | 193 | with dpg.group(horizontal=True): 194 | dpg.add_text("Infer time: ") 195 | dpg.add_text("no data", tag="_log_infer_time") 196 | 197 | with dpg.group(horizontal=True): 198 | dpg.add_text("SPP: ") 199 | dpg.add_text("1", tag="_log_spp") 200 | 201 | # train button 202 | if not self.opt.test: 203 | with dpg.collapsing_header(label="Train", default_open=True): 204 | with dpg.group(horizontal=True): 205 | dpg.add_text("Train: ") 206 | 207 | def callback_train(sender, app_data): 208 | if self.training: 209 | self.training = False 210 | dpg.configure_item("_button_train", label="start") 211 | else: 212 | self.training = True 213 | dpg.configure_item("_button_train", label="stop") 214 | 215 | dpg.add_button(label="start", tag="_button_train", callback=callback_train) 216 | dpg.bind_item_theme("_button_train", theme_button) 217 | 218 | def callback_reset(sender, app_data): 219 | @torch.no_grad() 220 | def weight_reset(m: nn.Module): 221 | reset_parameters = getattr(m, "reset_parameters", None) 222 | if callable(reset_parameters): 223 | m.reset_parameters() 224 | self.trainer.model.apply(fn=weight_reset) 225 | self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter 226 | self.need_update = True 227 | 228 | dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) 229 | dpg.bind_item_theme("_button_reset", theme_button) 230 | 231 | 232 | with dpg.group(horizontal=True): 233 | dpg.add_text("Checkpoint: ") 234 | 235 | def callback_save(sender, app_data): 236 | self.trainer.save_checkpoint(full=True, best=False) 237 | dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) 238 | self.trainer.epoch += 1 # use epoch to indicate different calls. 239 | 240 | dpg.add_button(label="save", tag="_button_save", callback=callback_save) 241 | dpg.bind_item_theme("_button_save", theme_button) 242 | 243 | dpg.add_text("", tag="_log_ckpt") 244 | 245 | # save mesh 246 | with dpg.group(horizontal=True): 247 | dpg.add_text("Marching Cubes: ") 248 | 249 | def callback_mesh(sender, app_data): 250 | self.trainer.save_mesh(resolution=256, threshold=10) 251 | dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') 252 | self.trainer.epoch += 1 # use epoch to indicate different calls. 253 | 254 | dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) 255 | dpg.bind_item_theme("_button_mesh", theme_button) 256 | 257 | dpg.add_text("", tag="_log_mesh") 258 | 259 | with dpg.group(horizontal=True): 260 | dpg.add_text("", tag="_log_train_log") 261 | 262 | 263 | # rendering options 264 | with dpg.collapsing_header(label="Options", default_open=True): 265 | 266 | # dynamic rendering resolution 267 | with dpg.group(horizontal=True): 268 | 269 | def callback_set_dynamic_resolution(sender, app_data): 270 | if self.dynamic_resolution: 271 | self.dynamic_resolution = False 272 | self.downscale = 1 273 | else: 274 | self.dynamic_resolution = True 275 | self.need_update = True 276 | 277 | dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) 278 | dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") 279 | 280 | # bg_color picker 281 | def callback_change_bg(sender, app_data): 282 | self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] 283 | self.need_update = True 284 | 285 | dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) 286 | 287 | # fov slider 288 | def callback_set_fovy(sender, app_data): 289 | self.cam.fovy = app_data 290 | self.need_update = True 291 | 292 | dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) 293 | 294 | # dt_gamma slider 295 | def callback_set_dt_gamma(sender, app_data): 296 | self.opt.dt_gamma = app_data 297 | self.need_update = True 298 | 299 | dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) 300 | 301 | # aabb slider 302 | def callback_set_aabb(sender, app_data, user_data): 303 | # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) 304 | self.trainer.model.aabb_infer[user_data] = app_data 305 | 306 | # also change train aabb ? [better not...] 307 | #self.trainer.model.aabb_train[user_data] = app_data 308 | 309 | self.need_update = True 310 | 311 | dpg.add_separator() 312 | dpg.add_text("Axis-aligned bounding box:") 313 | 314 | with dpg.group(horizontal=True): 315 | dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) 316 | dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) 317 | 318 | with dpg.group(horizontal=True): 319 | dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) 320 | dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) 321 | 322 | with dpg.group(horizontal=True): 323 | dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) 324 | dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) 325 | 326 | 327 | # debug info 328 | if self.debug: 329 | with dpg.collapsing_header(label="Debug"): 330 | # pose 331 | dpg.add_separator() 332 | dpg.add_text("Camera Pose:") 333 | dpg.add_text(str(self.cam.pose), tag="_log_pose") 334 | 335 | 336 | ### register camera handler 337 | 338 | def callback_camera_drag_rotate(sender, app_data): 339 | 340 | if not dpg.is_item_focused("_primary_window"): 341 | return 342 | 343 | dx = app_data[1] 344 | dy = app_data[2] 345 | 346 | self.cam.orbit(dx, dy) 347 | self.need_update = True 348 | 349 | if self.debug: 350 | dpg.set_value("_log_pose", str(self.cam.pose)) 351 | 352 | 353 | def callback_camera_wheel_scale(sender, app_data): 354 | 355 | if not dpg.is_item_focused("_primary_window"): 356 | return 357 | 358 | delta = app_data 359 | 360 | self.cam.scale(delta) 361 | self.need_update = True 362 | 363 | if self.debug: 364 | dpg.set_value("_log_pose", str(self.cam.pose)) 365 | 366 | 367 | def callback_camera_drag_pan(sender, app_data): 368 | 369 | if not dpg.is_item_focused("_primary_window"): 370 | return 371 | 372 | dx = app_data[1] 373 | dy = app_data[2] 374 | 375 | self.cam.pan(dx, dy) 376 | self.need_update = True 377 | 378 | if self.debug: 379 | dpg.set_value("_log_pose", str(self.cam.pose)) 380 | 381 | 382 | with dpg.handler_registry(): 383 | dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) 384 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) 385 | dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) 386 | 387 | 388 | dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) 389 | 390 | # TODO: seems dearpygui doesn't support resizing texture... 391 | # def callback_resize(sender, app_data): 392 | # self.W = app_data[0] 393 | # self.H = app_data[1] 394 | # # how to reload texture ??? 395 | 396 | # dpg.set_viewport_resize_callback(callback_resize) 397 | 398 | ### global theme 399 | with dpg.theme() as theme_no_padding: 400 | with dpg.theme_component(dpg.mvAll): 401 | # set all padding to 0 to avoid scroll bar 402 | dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) 403 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) 404 | dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) 405 | 406 | dpg.bind_item_theme("_primary_window", theme_no_padding) 407 | 408 | dpg.setup_dearpygui() 409 | 410 | #dpg.show_metrics() 411 | 412 | dpg.show_viewport() 413 | 414 | 415 | def render(self): 416 | 417 | while dpg.is_dearpygui_running(): 418 | # update texture every frame 419 | if self.training: 420 | self.train_step() 421 | self.test_step() 422 | dpg.render_dearpygui_frame() -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | 15 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 17 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 18 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 19 | 20 | 21 | // requires CUDA >= 10 and ARCH >= 70 22 | // this is very slow compared to float or __half2, do not use! 23 | static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { 24 | return atomicAdd(reinterpret_cast<__half*>(address), val); 25 | } 26 | 27 | 28 | template 29 | static inline __host__ __device__ T div_round_up(T val, T divisor) { 30 | return (val + divisor - 1) / divisor; 31 | } 32 | 33 | 34 | template 35 | __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { 36 | static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); 37 | 38 | // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence 39 | // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional 40 | // coordinates. 41 | constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 }; 42 | 43 | uint32_t result = 0; 44 | #pragma unroll 45 | for (uint32_t i = 0; i < D; ++i) { 46 | result ^= pos_grid[i] * primes[i]; 47 | } 48 | 49 | return result; 50 | } 51 | 52 | 53 | template 54 | __device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { 55 | uint32_t stride = 1; 56 | uint32_t index = 0; 57 | 58 | #pragma unroll 59 | for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { 60 | index += pos_grid[d] * stride; 61 | stride *= (resolution + 1); 62 | } 63 | 64 | // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. 65 | // gridtype: 0 == hash, 1 == tiled 66 | if (gridtype == 0 && stride > hashmap_size) { 67 | index = fast_hash(pos_grid); 68 | } 69 | 70 | return (index % hashmap_size) * C + ch; 71 | } 72 | 73 | 74 | template 75 | __global__ void kernel_grid( 76 | const scalar_t * __restrict__ inputs, 77 | const scalar_t * __restrict__ grid, 78 | const int * __restrict__ offsets, 79 | scalar_t * __restrict__ outputs, 80 | const uint32_t B, const uint32_t L, const float S, const uint32_t H, 81 | const bool calc_grad_inputs, 82 | scalar_t * __restrict__ dy_dx, 83 | const uint32_t gridtype 84 | ) { 85 | const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; 86 | 87 | if (b >= B) return; 88 | 89 | const uint32_t level = blockIdx.y; 90 | 91 | // locate 92 | grid += (uint32_t)offsets[level] * C; 93 | inputs += b * D; 94 | outputs += level * B * C + b * C; 95 | 96 | // check input range (should be in [0, 1]) 97 | bool flag_oob = false; 98 | #pragma unroll 99 | for (uint32_t d = 0; d < D; d++) { 100 | if (inputs[d] < 0 || inputs[d] > 1) { 101 | flag_oob = true; 102 | } 103 | } 104 | // if input out of bound, just set output to 0 105 | if (flag_oob) { 106 | #pragma unroll 107 | for (uint32_t ch = 0; ch < C; ch++) { 108 | outputs[ch] = 0; 109 | } 110 | if (calc_grad_inputs) { 111 | dy_dx += b * D * L * C + level * D * C; // B L D C 112 | #pragma unroll 113 | for (uint32_t d = 0; d < D; d++) { 114 | #pragma unroll 115 | for (uint32_t ch = 0; ch < C; ch++) { 116 | dy_dx[d * C + ch] = 0; 117 | } 118 | } 119 | } 120 | return; 121 | } 122 | 123 | const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; 124 | const float scale = exp2f(level * S) * H - 1.0f; 125 | const uint32_t resolution = (uint32_t)ceil(scale) + 1; 126 | 127 | // calculate coordinate 128 | float pos[D]; 129 | uint32_t pos_grid[D]; 130 | 131 | #pragma unroll 132 | for (uint32_t d = 0; d < D; d++) { 133 | pos[d] = (float)inputs[d] * scale + 0.5f; 134 | pos_grid[d] = floorf(pos[d]); 135 | pos[d] -= (float)pos_grid[d]; 136 | } 137 | 138 | //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); 139 | 140 | // interpolate 141 | scalar_t results[C] = {0}; // temp results in register 142 | 143 | #pragma unroll 144 | for (uint32_t idx = 0; idx < (1 << D); idx++) { 145 | float w = 1; 146 | uint32_t pos_grid_local[D]; 147 | 148 | #pragma unroll 149 | for (uint32_t d = 0; d < D; d++) { 150 | if ((idx & (1 << d)) == 0) { 151 | w *= 1 - pos[d]; 152 | pos_grid_local[d] = pos_grid[d]; 153 | } else { 154 | w *= pos[d]; 155 | pos_grid_local[d] = pos_grid[d] + 1; 156 | } 157 | } 158 | 159 | uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local); 160 | 161 | // writing to register (fast) 162 | #pragma unroll 163 | for (uint32_t ch = 0; ch < C; ch++) { 164 | results[ch] += w * grid[index + ch]; 165 | } 166 | 167 | //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); 168 | } 169 | 170 | // writing to global memory (slow) 171 | #pragma unroll 172 | for (uint32_t ch = 0; ch < C; ch++) { 173 | outputs[ch] = results[ch]; 174 | } 175 | 176 | // prepare dy_dx for calc_grad_inputs 177 | // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 178 | if (calc_grad_inputs) { 179 | 180 | dy_dx += b * D * L * C + level * D * C; // B L D C 181 | 182 | #pragma unroll 183 | for (uint32_t gd = 0; gd < D; gd++) { 184 | 185 | scalar_t results_grad[C] = {0}; 186 | 187 | #pragma unroll 188 | for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { 189 | float w = scale; 190 | uint32_t pos_grid_local[D]; 191 | 192 | #pragma unroll 193 | for (uint32_t nd = 0; nd < D - 1; nd++) { 194 | const uint32_t d = (nd >= gd) ? (nd + 1) : nd; 195 | 196 | if ((idx & (1 << nd)) == 0) { 197 | w *= 1 - pos[d]; 198 | pos_grid_local[d] = pos_grid[d]; 199 | } else { 200 | w *= pos[d]; 201 | pos_grid_local[d] = pos_grid[d] + 1; 202 | } 203 | } 204 | 205 | pos_grid_local[gd] = pos_grid[gd]; 206 | uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local); 207 | pos_grid_local[gd] = pos_grid[gd] + 1; 208 | uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local); 209 | 210 | #pragma unroll 211 | for (uint32_t ch = 0; ch < C; ch++) { 212 | results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]); 213 | } 214 | } 215 | 216 | #pragma unroll 217 | for (uint32_t ch = 0; ch < C; ch++) { 218 | dy_dx[gd * C + ch] = results_grad[ch]; 219 | } 220 | } 221 | } 222 | } 223 | 224 | 225 | template 226 | __global__ void kernel_grid_backward( 227 | const scalar_t * __restrict__ grad, 228 | const scalar_t * __restrict__ inputs, 229 | const scalar_t * __restrict__ grid, 230 | const int * __restrict__ offsets, 231 | scalar_t * __restrict__ grad_grid, 232 | const uint32_t B, const uint32_t L, const float S, const uint32_t H, 233 | const uint32_t gridtype 234 | ) { 235 | const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; 236 | if (b >= B) return; 237 | 238 | const uint32_t level = blockIdx.y; 239 | const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; 240 | 241 | // locate 242 | grad_grid += offsets[level] * C; 243 | inputs += b * D; 244 | grad += level * B * C + b * C + ch; // L, B, C 245 | 246 | const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; 247 | const float scale = exp2f(level * S) * H - 1.0f; 248 | const uint32_t resolution = (uint32_t)ceil(scale) + 1; 249 | 250 | // check input range (should be in [0, 1]) 251 | #pragma unroll 252 | for (uint32_t d = 0; d < D; d++) { 253 | if (inputs[d] < 0 || inputs[d] > 1) { 254 | return; // grad is init as 0, so we simply return. 255 | } 256 | } 257 | 258 | // calculate coordinate 259 | float pos[D]; 260 | uint32_t pos_grid[D]; 261 | 262 | #pragma unroll 263 | for (uint32_t d = 0; d < D; d++) { 264 | pos[d] = (float)inputs[d] * scale + 0.5f; 265 | pos_grid[d] = floorf(pos[d]); 266 | pos[d] -= (float)pos_grid[d]; 267 | } 268 | 269 | scalar_t grad_cur[N_C] = {0}; // fetch to register 270 | #pragma unroll 271 | for (uint32_t c = 0; c < N_C; c++) { 272 | grad_cur[c] = grad[c]; 273 | } 274 | 275 | // interpolate 276 | #pragma unroll 277 | for (uint32_t idx = 0; idx < (1 << D); idx++) { 278 | float w = 1; 279 | uint32_t pos_grid_local[D]; 280 | 281 | #pragma unroll 282 | for (uint32_t d = 0; d < D; d++) { 283 | if ((idx & (1 << d)) == 0) { 284 | w *= 1 - pos[d]; 285 | pos_grid_local[d] = pos_grid[d]; 286 | } else { 287 | w *= pos[d]; 288 | pos_grid_local[d] = pos_grid[d] + 1; 289 | } 290 | } 291 | 292 | uint32_t index = get_grid_index(gridtype, ch, hashmap_size, resolution, pos_grid_local); 293 | 294 | // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 295 | // TODO: use float which is better than __half, if N_C % 2 != 0 296 | if (std::is_same::value && N_C % 2 == 0) { 297 | #pragma unroll 298 | for (uint32_t c = 0; c < N_C; c += 2) { 299 | // process two __half at once (by interpreting as a __half2) 300 | __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; 301 | atomicAdd((__half2*)&grad_grid[index + c], v); 302 | } 303 | // float, or __half when N_C % 2 != 0 (which means C == 1) 304 | } else { 305 | #pragma unroll 306 | for (uint32_t c = 0; c < N_C; c++) { 307 | atomicAdd(&grad_grid[index + c], w * grad_cur[c]); 308 | } 309 | } 310 | } 311 | } 312 | 313 | 314 | template 315 | __global__ void kernel_input_backward( 316 | const scalar_t * __restrict__ grad, 317 | const scalar_t * __restrict__ dy_dx, 318 | scalar_t * __restrict__ grad_inputs, 319 | uint32_t B, uint32_t L 320 | ) { 321 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 322 | if (t >= B * D) return; 323 | 324 | const uint32_t b = t / D; 325 | const uint32_t d = t - b * D; 326 | 327 | dy_dx += b * L * D * C; 328 | 329 | scalar_t result = 0; 330 | 331 | # pragma unroll 332 | for (int l = 0; l < L; l++) { 333 | # pragma unroll 334 | for (int ch = 0; ch < C; ch++) { 335 | result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; 336 | } 337 | } 338 | 339 | grad_inputs[t] = result; 340 | } 341 | 342 | 343 | template 344 | void kernel_grid_wrapper(const scalar_t *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype) { 345 | static constexpr uint32_t N_THREAD = 512; 346 | const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; 347 | switch (C) { 348 | case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype); break; 349 | case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype); break; 350 | case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype); break; 351 | case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype); break; 352 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 353 | } 354 | } 355 | 356 | // inputs: [B, D], float, in [0, 1] 357 | // embeddings: [sO, C], float 358 | // offsets: [L + 1], uint32_t 359 | // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) 360 | // H: base resolution 361 | // dy_dx: [B, L * D * C] 362 | template 363 | void grid_encode_forward_cuda(const scalar_t *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype) { 364 | switch (D) { 365 | case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype); break; 366 | case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype); break; 367 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 368 | } 369 | 370 | } 371 | 372 | template 373 | void kernel_grid_backward_wrapper(const scalar_t *grad, const scalar_t *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype) { 374 | static constexpr uint32_t N_THREAD = 256; 375 | const uint32_t N_C = std::min(2u, C); // n_features_per_thread 376 | const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; 377 | switch (C) { 378 | case 1: 379 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype); 380 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 381 | break; 382 | case 2: 383 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype); 384 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 385 | break; 386 | case 4: 387 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype); 388 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 389 | break; 390 | case 8: 391 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype); 392 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 393 | break; 394 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 395 | } 396 | } 397 | 398 | 399 | // grad: [L, B, C], float 400 | // inputs: [B, D], float, in [0, 1] 401 | // embeddings: [sO, C], float 402 | // offsets: [L + 1], uint32_t 403 | // grad_embeddings: [sO, C] 404 | // H: base resolution 405 | template 406 | void grid_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype) { 407 | switch (D) { 408 | case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype); break; 409 | case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype); break; 410 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 411 | } 412 | } 413 | 414 | 415 | 416 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype) { 417 | CHECK_CUDA(inputs); 418 | CHECK_CUDA(embeddings); 419 | CHECK_CUDA(offsets); 420 | CHECK_CUDA(outputs); 421 | CHECK_CUDA(dy_dx); 422 | 423 | CHECK_CONTIGUOUS(inputs); 424 | CHECK_CONTIGUOUS(embeddings); 425 | CHECK_CONTIGUOUS(offsets); 426 | CHECK_CONTIGUOUS(outputs); 427 | CHECK_CONTIGUOUS(dy_dx); 428 | 429 | CHECK_IS_FLOATING(inputs); 430 | CHECK_IS_FLOATING(embeddings); 431 | CHECK_IS_INT(offsets); 432 | CHECK_IS_FLOATING(outputs); 433 | CHECK_IS_FLOATING(dy_dx); 434 | 435 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 436 | inputs.scalar_type(), "grid_encode_forward", ([&] { 437 | grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), gridtype); 438 | })); 439 | } 440 | 441 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype) { 442 | CHECK_CUDA(grad); 443 | CHECK_CUDA(inputs); 444 | CHECK_CUDA(embeddings); 445 | CHECK_CUDA(offsets); 446 | CHECK_CUDA(grad_embeddings); 447 | CHECK_CUDA(dy_dx); 448 | CHECK_CUDA(grad_inputs); 449 | 450 | CHECK_CONTIGUOUS(grad); 451 | CHECK_CONTIGUOUS(inputs); 452 | CHECK_CONTIGUOUS(embeddings); 453 | CHECK_CONTIGUOUS(offsets); 454 | CHECK_CONTIGUOUS(grad_embeddings); 455 | CHECK_CONTIGUOUS(dy_dx); 456 | CHECK_CONTIGUOUS(grad_inputs); 457 | 458 | CHECK_IS_FLOATING(grad); 459 | CHECK_IS_FLOATING(inputs); 460 | CHECK_IS_FLOATING(embeddings); 461 | CHECK_IS_INT(offsets); 462 | CHECK_IS_FLOATING(grad_embeddings); 463 | CHECK_IS_FLOATING(dy_dx); 464 | CHECK_IS_FLOATING(grad_inputs); 465 | 466 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 467 | grad.scalar_type(), "grid_encode_backward", ([&] { 468 | grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), grad_inputs.data_ptr(), gridtype); 469 | })); 470 | 471 | } 472 | -------------------------------------------------------------------------------- /nerf/renderer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import trimesh 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import raymarching 10 | from .utils import custom_meshgrid 11 | 12 | def sample_pdf(bins, weights, n_samples, det=False): 13 | # This implementation is from NeRF 14 | # bins: [B, T], old_z_vals 15 | # weights: [B, T - 1], bin weights. 16 | # return: [B, n_samples], new_z_vals 17 | 18 | # Get pdf 19 | weights = weights + 1e-5 # prevent nans 20 | pdf = weights / torch.sum(weights, -1, keepdim=True) 21 | cdf = torch.cumsum(pdf, -1) 22 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 23 | # Take uniform samples 24 | if det: 25 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) 26 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 27 | else: 28 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) 29 | 30 | # Invert CDF 31 | u = u.contiguous() 32 | inds = torch.searchsorted(cdf, u, right=True) 33 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 34 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 35 | inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) 36 | 37 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 38 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 39 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 40 | 41 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 42 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 43 | t = (u - cdf_g[..., 0]) / denom 44 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 45 | 46 | return samples 47 | 48 | 49 | @torch.cuda.amp.autocast(enabled=False) 50 | def near_far_from_bound(rays_o, rays_d, bound, type='cube'): 51 | # rays: [B, N, 3], [B, N, 3] 52 | # bound: int, radius for ball or half-edge-length for cube 53 | # return near [B, N, 1], far [B, N, 1] 54 | 55 | radius = rays_o.norm(dim=-1, keepdim=True) 56 | 57 | if type == 'sphere': 58 | near = radius - bound # [B, N, 1] 59 | far = radius + bound 60 | 61 | elif type == 'cube': 62 | tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] 63 | tmax = (bound - rays_o) / (rays_d + 1e-15) 64 | near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] 65 | far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] 66 | # if far < near, means no intersection, set both near and far to inf (1e9 here) 67 | mask = far < near 68 | near[mask] = 1e9 69 | far[mask] = 1e9 70 | # restrict near to a minimal value 71 | near = torch.clamp(near, min=0.05) 72 | 73 | return near, far 74 | 75 | 76 | def plot_pointcloud(pc, color=None): 77 | # pc: [N, 3] 78 | # color: [N, 3/4] 79 | print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) 80 | pc = trimesh.PointCloud(pc, color) 81 | # axis 82 | axes = trimesh.creation.axis(axis_length=4) 83 | # sphere 84 | sphere = trimesh.creation.icosphere(radius=1) 85 | trimesh.Scene([pc, axes, sphere]).show() 86 | 87 | 88 | class NeRFRenderer(nn.Module): 89 | def __init__(self, 90 | bound=1, 91 | cuda_ray=False, 92 | density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. 93 | ): 94 | super().__init__() 95 | 96 | self.bound = bound 97 | self.cascade = 1 + math.ceil(math.log2(bound)) 98 | self.density_scale = density_scale 99 | 100 | # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) 101 | # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. 102 | aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) 103 | aabb_infer = aabb_train.clone() 104 | self.register_buffer('aabb_train', aabb_train) 105 | self.register_buffer('aabb_infer', aabb_infer) 106 | 107 | # extra state for cuda raymarching 108 | self.cuda_ray = cuda_ray 109 | if cuda_ray: 110 | # density grid 111 | density_grid = torch.zeros([self.cascade] + [128] * 3) # [CAS, H, H, H] 112 | self.register_buffer('density_grid', density_grid) 113 | self.mean_density = 0 114 | self.iter_density = 0 115 | # step counter 116 | step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... 117 | self.register_buffer('step_counter', step_counter) 118 | self.mean_count = 0 119 | self.local_step = 0 120 | 121 | # ema origin 122 | origin = torch.zeros(3, dtype=torch.float32) 123 | self.register_buffer('origin', origin) 124 | 125 | def forward(self, x, d): 126 | raise NotImplementedError() 127 | 128 | def density(self, x): 129 | raise NotImplementedError() 130 | 131 | def color(self, x, d, mask=None, **kwargs): 132 | raise NotImplementedError() 133 | 134 | def reset_extra_state(self): 135 | if not self.cuda_ray: 136 | return 137 | # density grid 138 | self.density_grid.zero_() 139 | self.mean_density = 0 140 | self.iter_density = 0 141 | # step counter 142 | self.step_counter.zero_() 143 | self.mean_count = 0 144 | self.local_step = 0 145 | 146 | def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, perturb=False, **kwargs): 147 | # rays_o, rays_d: [B, N, 3], assumes B == 1 148 | # bg_color: [3] in range [0, 1] 149 | # return: image: [B, N, 3], depth: [B, N] 150 | 151 | prefix = rays_o.shape[:-1] 152 | rays_o = rays_o.contiguous().view(-1, 3) 153 | rays_d = rays_d.contiguous().view(-1, 3) 154 | 155 | N = rays_o.shape[0] # N = B * N, in fact 156 | device = rays_o.device 157 | 158 | # choose aabb 159 | aabb = self.aabb_train if self.training else self.aabb_infer 160 | 161 | # sample steps 162 | nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb) 163 | nears.unsqueeze_(-1) 164 | fars.unsqueeze_(-1) 165 | 166 | #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') 167 | 168 | z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] 169 | z_vals = z_vals.expand((N, num_steps)) # [N, T] 170 | z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] 171 | 172 | # perturb z_vals 173 | sample_dist = (fars - nears) / num_steps 174 | if perturb: 175 | z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist 176 | #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. 177 | 178 | # generate xyzs 179 | xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] 180 | xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. 181 | 182 | #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) 183 | 184 | # query SDF and RGB 185 | density_outputs = self.density(xyzs.reshape(-1, 3)) 186 | 187 | #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] 188 | for k, v in density_outputs.items(): 189 | density_outputs[k] = v.view(N, num_steps, -1) 190 | 191 | # upsample z_vals (nerf-like) 192 | if upsample_steps > 0: 193 | with torch.no_grad(): 194 | 195 | deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] 196 | deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) 197 | 198 | alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] 199 | alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] 200 | weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] 201 | 202 | # sample new z_vals 203 | z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] 204 | new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] 205 | 206 | new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] 207 | new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. 208 | 209 | # only forward new points to save computation 210 | new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) 211 | #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] 212 | for k, v in new_density_outputs.items(): 213 | new_density_outputs[k] = v.view(N, upsample_steps, -1) 214 | 215 | # re-order 216 | z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] 217 | z_vals, z_index = torch.sort(z_vals, dim=1) 218 | 219 | xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] 220 | xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) 221 | 222 | for k in density_outputs: 223 | tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) 224 | density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) 225 | 226 | deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] 227 | deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) 228 | alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] 229 | alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] 230 | weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] 231 | 232 | mask = weights > 1e-4 # hard coded 233 | 234 | dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) 235 | for k, v in density_outputs.items(): 236 | density_outputs[k] = v.view(-1, v.shape[-1]) 237 | 238 | rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) 239 | rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] 240 | 241 | #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) 242 | 243 | # calculate weight_sum (mask) 244 | weights_sum = weights.sum(dim=-1) # [N] 245 | 246 | # calculate depth 247 | ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) 248 | depth = torch.sum(weights * ori_z_vals, dim=-1) 249 | 250 | # calculate color 251 | image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] 252 | 253 | # mix background color 254 | # if bg_color is None: 255 | # bg_color = 1 256 | 257 | #image = image + (1 - weights_sum).unsqueeze(-1) * bg_color 258 | 259 | image = image.view(*prefix, 3) 260 | depth = depth.view(*prefix) 261 | 262 | # tmp: reg loss in mip-nerf 360 263 | # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) 264 | # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] 265 | # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() 266 | 267 | return { 268 | 'depth': depth, 269 | 'image': image, 270 | } 271 | 272 | 273 | def run_cuda(self, rays_o, rays_d, dt_gamma=0, perturb=False, force_all_rays=False, **kwargs): 274 | # rays_o, rays_d: [B, N, 3], assumes B == 1 275 | # return: image: [B, N, 3], depth: [B, N] 276 | 277 | prefix = rays_o.shape[:-1] 278 | rays_o = rays_o.contiguous().view(-1, 3) 279 | rays_d = rays_d.contiguous().view(-1, 3) 280 | 281 | N = rays_o.shape[0] # N = B * N, in fact 282 | device = rays_o.device 283 | 284 | # pre-calculate near far 285 | nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer) 286 | 287 | 288 | 289 | if self.training: 290 | # setup counter 291 | counter = self.step_counter[self.local_step % 16] 292 | counter.zero_() # set to 0 293 | self.local_step += 1 294 | 295 | xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_grid, self.mean_density, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, 512) 296 | #print(xyzs.shape, rays.shape) 297 | 298 | #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) 299 | 300 | sigmas, rgbs = self(xyzs, dirs) 301 | # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. 302 | # sigmas = density_outputs['sigma'] 303 | sigmas = self.density_scale * sigmas 304 | # rgbs = self.color(xyzs, dirs, **density_outputs) 305 | 306 | #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') 307 | 308 | weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays) 309 | 310 | # tracking origin (assert B == 1) 311 | sigmas = sigmas.unsqueeze(-1) # [M, 1] 312 | total_sigma = sigmas.sum(dim=0, keepdim=True) # [1, 1] 313 | cur_origin = (xyzs * sigmas / total_sigma).sum(dim=0) # [M, 3] --> [3] 314 | self.origin = 0.999 * self.origin + 0.001 * cur_origin.detach() 315 | 316 | else: 317 | 318 | # allocate outputs 319 | # if use autocast, must init as half so it won't be autocasted and lose reference. 320 | #dtype = torch.half if torch.is_autocast_enabled() else torch.float32 321 | # output should always be float32! only network inference uses half. 322 | dtype = torch.float32 323 | 324 | weights_sum = torch.zeros(N, dtype=dtype, device=device) 325 | depth = torch.zeros(N, dtype=dtype, device=device) 326 | image = torch.zeros(N, 3, dtype=dtype, device=device) 327 | 328 | n_alive = N 329 | alive_counter = torch.zeros([1], dtype=torch.int32, device=device) 330 | 331 | rays_alive = torch.zeros(2, n_alive, dtype=torch.int32, device=device) # 2 is used to loop old/new 332 | rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device) 333 | 334 | 335 | step = 0 336 | i = 0 337 | while step < 1024: # hard coded max step 338 | 339 | # count alive rays 340 | if step == 0: 341 | # init rays at first step. 342 | torch.arange(n_alive, out=rays_alive[0]) 343 | rays_t[0] = nears 344 | else: 345 | alive_counter.zero_() 346 | raymarching.compact_rays(n_alive, rays_alive[i % 2], rays_alive[(i + 1) % 2], rays_t[i % 2], rays_t[(i + 1) % 2], alive_counter) 347 | n_alive = alive_counter.item() # must invoke D2H copy here 348 | 349 | # exit loop 350 | if n_alive <= 0: 351 | break 352 | 353 | # decide compact_steps 354 | n_step = max(min(N // n_alive, 8), 1) 355 | 356 | xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], rays_o, rays_d, self.bound, self.density_grid, self.mean_density, nears, fars, 128, perturb, dt_gamma) 357 | 358 | sigmas, rgbs = self(xyzs, dirs) 359 | # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. 360 | # sigmas = density_outputs['sigma'] 361 | # rgbs = self.color(xyzs, dirs, **density_outputs) 362 | sigmas = self.density_scale * sigmas 363 | 364 | raymarching.composite_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image) 365 | 366 | #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') 367 | 368 | step += n_step 369 | i += 1 370 | 371 | #image = image + (1 - weights_sum).unsqueeze(-1) * bg_color 372 | image = image.view(*prefix, 3) 373 | 374 | depth = torch.clamp(depth - nears, min=0) / (fars - nears) 375 | depth = depth.view(*prefix) 376 | 377 | weights_sum = weights_sum.reshape(*prefix) 378 | 379 | mask = (nears < fars).reshape(*prefix) 380 | 381 | results = { 382 | 'image': image, 383 | 'depth': depth, 384 | 'weights_sum': weights_sum, 385 | 'mask': mask, 386 | } 387 | 388 | if self.training: 389 | results['origin'] = cur_origin 390 | 391 | return results 392 | 393 | @torch.no_grad() 394 | def mark_untrained_grid(self, poses, intrinsic, S=64): 395 | # poses: [B, 4, 4] 396 | # intrinsic: [3, 3] 397 | 398 | if not self.cuda_ray: 399 | return 400 | 401 | if isinstance(poses, np.ndarray): 402 | poses = torch.from_numpy(poses) 403 | 404 | B = poses.shape[0] 405 | 406 | fx, fy, cx, cy = intrinsic 407 | 408 | resolution = self.density_grid.shape[1] 409 | 410 | X = torch.linspace(-1, 1, resolution).split(S) 411 | Y = torch.linspace(-1, 1, resolution).split(S) 412 | Z = torch.linspace(-1, 1, resolution).split(S) 413 | 414 | count = torch.zeros_like(self.density_grid) 415 | poses = poses.to(count.device) 416 | 417 | # 5-level loop, forgive me... 418 | for xi, xs in enumerate(X): 419 | for yi, ys in enumerate(Y): 420 | for zi, zs in enumerate(Z): 421 | lx, ly, lz = len(xs), len(ys), len(zs) 422 | # construct points 423 | xx, yy, zz = custom_meshgrid(xs, ys, zs) 424 | world_xyzs = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).unsqueeze(0).to(count.device) # [1, N, 3] 425 | 426 | # cascading 427 | for cas in range(self.cascade): 428 | bound = min(2 ** cas, self.bound) 429 | half_grid_size = bound / resolution 430 | # scale to current cascade's resolution 431 | cas_world_xyzs = world_xyzs * (bound - half_grid_size) 432 | 433 | # split batch to avoid OOM 434 | head = 0 435 | while head < B: 436 | tail = min(head + S, B) 437 | 438 | # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) 439 | cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) 440 | cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] 441 | 442 | # query if point is covered by any camera 443 | mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] 444 | mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 445 | mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 446 | mask = (mask_z & mask_x & mask_y).sum(0).reshape(lx, ly, lz) # [N] --> [lx, ly, lz] 447 | 448 | # update count 449 | count[cas, xi * S: xi * S + lx, yi * S: yi * S + ly, zi * S: zi * S + lz] += mask 450 | head += S 451 | 452 | # mark untrained grid as -1 453 | self.density_grid[count == 0] = -1 454 | 455 | #print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}') 456 | 457 | @torch.no_grad() 458 | def update_extra_state(self, decay=0.95, S=128): 459 | # call before each epoch to update extra states. 460 | 461 | if not self.cuda_ray: 462 | return 463 | 464 | ### update density grid 465 | resolution = self.density_grid.shape[1] 466 | 467 | # TODO: random sample coordinates after a warm up, instead of always uniformly query all cascades! 468 | # TODO: cast to bit mask to accelerate. 469 | # TODO: check `bitfield_max_pool`, should apply max pool across consequent cascades! (this is a must if using random sampling...) 470 | # too difficult in pytorch... 471 | 472 | tmp_grid = torch.zeros_like(self.density_grid) 473 | 474 | X = torch.linspace(-1, 1, resolution).split(S) 475 | Y = torch.linspace(-1, 1, resolution).split(S) 476 | Z = torch.linspace(-1, 1, resolution).split(S) 477 | for xi, xs in enumerate(X): 478 | for yi, ys in enumerate(Y): 479 | for zi, zs in enumerate(Z): 480 | lx, ly, lz = len(xs), len(ys), len(zs) 481 | # construct points 482 | xx, yy, zz = custom_meshgrid(xs, ys, zs) 483 | xyzs = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [-1, 1] 484 | 485 | # cascading 486 | for cas in range(self.cascade): 487 | bound = min(2 ** cas, self.bound) 488 | half_grid_size = bound / resolution 489 | # scale to current cascade's resolution 490 | cas_xyzs = xyzs * (bound - half_grid_size) 491 | # add noise in [-hgs, hgs] 492 | cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size 493 | # query density 494 | sigmas = self.density(cas_xyzs.to(tmp_grid.device))['sigma'].reshape(lx, ly, lz).detach() 495 | # from `scalbnf(MIN_CONE_STEPSIZE(), 0)`, check `splat_grid_samples_nerf_max_nearest_neighbor` 496 | # scale == 2 * sqrt(3) / 1024 497 | sigmas *= self.density_scale * 0.003383 498 | # assign 499 | tmp_grid[cas, xi * S: xi * S + lx, yi * S: yi * S + ly, zi * S: zi * S + lz] = sigmas 500 | 501 | 502 | # ema update 503 | valid_mask = self.density_grid >= 0 504 | self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) 505 | self.mean_density = torch.mean(self.density_grid[valid_mask]).item() 506 | self.iter_density += 1 507 | 508 | ### update step counter 509 | total_step = min(16, self.local_step) 510 | if total_step > 0: 511 | self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) 512 | self.local_step = 0 513 | 514 | print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') 515 | 516 | 517 | def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, bg_color=None, perturb=False, **kwargs): 518 | # rays_o, rays_d: [B, N, 3], assumes B == 1 519 | # return: pred_rgb: [B, N, 3] 520 | 521 | if self.cuda_ray: 522 | _run = self.run_cuda 523 | else: 524 | _run = self.run 525 | 526 | B, N = rays_o.shape[:2] 527 | device = rays_o.device 528 | 529 | # never stage when cuda_ray 530 | if staged and not self.cuda_ray: 531 | depth = torch.empty((B, N), device=device) 532 | image = torch.empty((B, N, 3), device=device) 533 | 534 | for b in range(B): 535 | head = 0 536 | while head < N: 537 | tail = min(head + max_ray_batch, N) 538 | results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], bg_color=bg_color, perturb=perturb, **kwargs) 539 | depth[b:b+1, head:tail] = results_['depth'] 540 | image[b:b+1, head:tail] = results_['image'] 541 | head += max_ray_batch 542 | 543 | results = {} 544 | results['depth'] = depth 545 | results['image'] = image 546 | 547 | else: 548 | results = _run(rays_o, rays_d, bg_color=bg_color, perturb=perturb, **kwargs) 549 | 550 | return results -------------------------------------------------------------------------------- /raymarching/src/raymarching.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "pcg32.h" 14 | 15 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 17 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 18 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 19 | 20 | 21 | inline constexpr __device__ float DENSITY_THRESH() { return 0.01f; } 22 | inline constexpr __device__ float SQRT3() { return 1.73205080757f; } 23 | inline constexpr __device__ float MIN_NEAR() { return 0.05f; } 24 | 25 | 26 | template 27 | __host__ __device__ T div_round_up(T val, T divisor) { 28 | return (val + divisor - 1) / divisor; 29 | } 30 | 31 | inline __host__ __device__ float signf(const float x) { 32 | return copysignf(1.0, x); 33 | } 34 | 35 | inline __host__ __device__ float clamp(const float x, const float min, const float max) { 36 | return fminf(max, fmaxf(min, x)); 37 | } 38 | 39 | inline __host__ __device__ void swapf(float& a, float& b) { 40 | float c = a; a = b; b = c; 41 | } 42 | 43 | inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { 44 | const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); 45 | int exponent; 46 | frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... 47 | return fminf(max_cascade - 1, fmaxf(0, exponent)); 48 | } 49 | 50 | 51 | //////////////////////////////////////////////////// 52 | ///////////// utils ///////////// 53 | //////////////////////////////////////////////////// 54 | 55 | // rays_o/d: [N, 3] 56 | // nears/fars: [N] 57 | // scalar_t should always be float in use. 58 | template 59 | __global__ void kernel_near_far_from_aabb( 60 | const scalar_t * __restrict__ rays_o, 61 | const scalar_t * __restrict__ rays_d, 62 | const scalar_t * __restrict__ aabb, 63 | const uint32_t N, 64 | const float min_near, 65 | scalar_t * nears, scalar_t * fars 66 | ) { 67 | // parallel per ray 68 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 69 | if (n >= N) return; 70 | 71 | // locate 72 | rays_o += n * 3; 73 | rays_d += n * 3; 74 | 75 | const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; 76 | const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; 77 | const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; 78 | 79 | // get near far (assume cube scene) 80 | float near = (aabb[0] - ox) * rdx; 81 | float far = (aabb[3] - ox) * rdx; 82 | if (near > far) swapf(near, far); 83 | 84 | float near_y = (aabb[1] - oy) * rdy; 85 | float far_y = (aabb[4] - oy) * rdy; 86 | if (near_y > far_y) swapf(near_y, far_y); 87 | 88 | if (near > far_y || near_y > far) { 89 | nears[n] = fars[n] = std::numeric_limits::max(); 90 | return; 91 | } 92 | 93 | if (near_y > near) near = near_y; 94 | if (far_y < far) far = far_y; 95 | 96 | float near_z = (aabb[2] - oz) * rdz; 97 | float far_z = (aabb[5] - oz) * rdz; 98 | if (near_z > far_z) swapf(near_z, far_z); 99 | 100 | if (near > far_z || near_z > far) { 101 | nears[n] = fars[n] = std::numeric_limits::max(); 102 | return; 103 | } 104 | 105 | if (near_z > near) near = near_z; 106 | if (far_z < far) far = far_z; 107 | 108 | if (near < min_near) near = min_near; 109 | 110 | nears[n] = near; 111 | fars[n] = far; 112 | } 113 | 114 | 115 | void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { 116 | 117 | static constexpr uint32_t N_THREAD = 256; 118 | 119 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 120 | rays_o.scalar_type(), "near_far_from_aabb", ([&] { 121 | kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); 122 | })); 123 | } 124 | 125 | //////////////////////////////////////////////////// 126 | ///////////// training ///////////// 127 | //////////////////////////////////////////////////// 128 | 129 | // rays_o/d: [N, 3] 130 | // grid: [C, H, H, H] 131 | // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] 132 | // dirs: [M, 3] 133 | // rays: [N, 3], idx, offset, num_steps 134 | template 135 | __global__ void kernel_march_rays_train( 136 | const scalar_t * __restrict__ rays_o, 137 | const scalar_t * __restrict__ rays_d, 138 | const scalar_t * __restrict__ grid, 139 | const float mean_density, 140 | const float bound, 141 | const float dt_gamma, const uint32_t max_steps, 142 | const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, 143 | const scalar_t* __restrict__ nears, 144 | const scalar_t* __restrict__ fars, 145 | scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, 146 | int * rays, 147 | int * counter, 148 | const uint32_t perturb 149 | ) { 150 | // parallel per ray 151 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 152 | if (n >= N) return; 153 | 154 | const float density_thresh = fminf(DENSITY_THRESH(), mean_density); 155 | 156 | // locate 157 | rays_o += n * 3; 158 | rays_d += n * 3; 159 | 160 | // ray marching 161 | const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; 162 | const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; 163 | const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; 164 | 165 | const float near = nears[n]; 166 | const float far = fars[n]; 167 | 168 | const float dt_min = 2 * SQRT3() / max_steps; 169 | 170 | float t0 = near; 171 | 172 | if (perturb) { 173 | pcg32 rng((uint64_t)n); 174 | t0 += dt_min * rng.next_float(); 175 | } 176 | 177 | // first pass: estimation of num_steps 178 | float t = t0; 179 | uint32_t num_steps = 0; 180 | 181 | //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); 182 | 183 | while (t < far && num_steps < max_steps) { 184 | // current point 185 | const float x = clamp(ox + t * dx, -bound, bound); 186 | const float y = clamp(oy + t * dy, -bound, bound); 187 | const float z = clamp(oz + t * dz, -bound, bound); 188 | 189 | // get mip level 190 | // TODO: check why using mip_from_dt... 191 | const int level = mip_from_pos(x, y, z, C); // range in [0, C - 1] 192 | const float mip_bound = fminf(exp2f((float)level), bound); 193 | const float mip_rbound = 1 / mip_bound; 194 | const float dt_max = 2 * mip_bound / H; // dt_max is dependent on the current mip level 195 | 196 | // convert to nearest grid position 197 | const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 198 | const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 199 | const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 200 | 201 | const uint32_t index = level * H * H * H + nx * H * H + ny * H + nz; 202 | const float density = grid[index]; 203 | 204 | // if occpuied, advance a small step, and write to output 205 | //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); 206 | 207 | if (density > density_thresh) { 208 | num_steps++; 209 | const float dt = clamp(t * dt_gamma, dt_min, dt_max); 210 | t += dt; 211 | // else, skip a large step (basically skip a voxel grid) 212 | } else { 213 | // calc distance to next voxel 214 | const float tx = (((nx + 0.5f + 0.5f * signf(dx)) / (H - 1) * 2 - 1) * mip_bound - x) * rdx; 215 | const float ty = (((ny + 0.5f + 0.5f * signf(dy)) / (H - 1) * 2 - 1) * mip_bound - y) * rdy; 216 | const float tz = (((nz + 0.5f + 0.5f * signf(dz)) / (H - 1) * 2 - 1) * mip_bound - z) * rdz; 217 | const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); 218 | // step until next voxel 219 | do { 220 | const float dt = clamp(t * dt_gamma, dt_min, dt_max); 221 | t += dt; 222 | } while (t < tt); 223 | } 224 | } 225 | 226 | //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); 227 | 228 | // second pass: really locate and write points & dirs 229 | uint32_t point_index = atomicAdd(counter, num_steps); 230 | uint32_t ray_index = atomicAdd(counter + 1, 1); 231 | 232 | //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); 233 | 234 | // write rays 235 | rays[ray_index * 3] = n; 236 | rays[ray_index * 3 + 1] = point_index; 237 | rays[ray_index * 3 + 2] = num_steps; 238 | 239 | if (num_steps == 0) return; 240 | if (point_index + num_steps >= M) return; 241 | 242 | xyzs += point_index * 3; 243 | dirs += point_index * 3; 244 | deltas += point_index * 2; 245 | 246 | t = t0; 247 | uint32_t step = 0; 248 | 249 | float last_t = t; 250 | 251 | while (t < far && step < num_steps) { 252 | // current point 253 | const float x = clamp(ox + t * dx, -bound, bound); 254 | const float y = clamp(oy + t * dy, -bound, bound); 255 | const float z = clamp(oz + t * dz, -bound, bound); 256 | 257 | // get mip level 258 | // TODO: check why using mip_from_dt... 259 | const int level = mip_from_pos(x, y, z, C); // range in [0, C - 1] 260 | const float mip_bound = fminf(exp2f((float)level), bound); 261 | const float mip_rbound = 1 / mip_bound; 262 | const float dt_max = 2 * mip_bound / H; // dt_max is dependent on the current mip level 263 | 264 | // convert to nearest grid position 265 | const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 266 | const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 267 | const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 268 | 269 | // query grid 270 | const uint32_t index = level * H * H * H + nx * H * H + ny * H + nz; 271 | const float density = grid[index]; 272 | 273 | // if occpuied, advance a small step, and write to output 274 | if (density > density_thresh) { 275 | // write step 276 | xyzs[0] = x; 277 | xyzs[1] = y; 278 | xyzs[2] = z; 279 | dirs[0] = dx; 280 | dirs[1] = dy; 281 | dirs[2] = dz; 282 | const float dt = clamp(t * dt_gamma, dt_min, dt_max); 283 | t += dt; 284 | deltas[0] = dt; 285 | deltas[1] = t - last_t; // used to calc depth 286 | last_t = t; 287 | xyzs += 3; 288 | dirs += 3; 289 | deltas += 2; 290 | step++; 291 | // else, skip a large step (basically skip a voxel grid) 292 | } else { 293 | // calc distance to next voxel 294 | const float tx = (((nx + 0.5f + 0.5f * signf(dx)) / (H - 1) * 2 - 1) * mip_bound - x) * rdx; 295 | const float ty = (((ny + 0.5f + 0.5f * signf(dy)) / (H - 1) * 2 - 1) * mip_bound - y) * rdy; 296 | const float tz = (((nz + 0.5f + 0.5f * signf(dz)) / (H - 1) * 2 - 1) * mip_bound - z) * rdz; 297 | const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); 298 | // step until next voxel 299 | do { 300 | const float dt = clamp(t * dt_gamma, dt_min, dt_max); 301 | t += dt; 302 | } while (t < tt); 303 | } 304 | } 305 | } 306 | 307 | 308 | // sigmas: [M] 309 | // rgbs: [M, 3] 310 | // deltas: [M, 2] 311 | // rays: [N, 3], idx, offset, num_steps 312 | // weights_sum: [N], final pixel alpha 313 | // depth: [N,] 314 | // image: [N, 3] 315 | template 316 | __global__ void kernel_composite_rays_train_forward( 317 | const scalar_t * __restrict__ sigmas, 318 | const scalar_t * __restrict__ rgbs, 319 | const scalar_t * __restrict__ deltas, 320 | const int * __restrict__ rays, 321 | const uint32_t M, const uint32_t N, 322 | scalar_t * weights_sum, 323 | scalar_t * depth, 324 | scalar_t * image 325 | ) { 326 | // parallel per ray 327 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 328 | if (n >= N) return; 329 | 330 | // locate 331 | uint32_t index = rays[n * 3]; 332 | uint32_t offset = rays[n * 3 + 1]; 333 | uint32_t num_steps = rays[n * 3 + 2]; 334 | 335 | // empty ray, or ray that exceed max step count. 336 | if (num_steps == 0 || offset + num_steps >= M) { 337 | weights_sum[index] = 0; 338 | depth[index] = 0; 339 | image[index * 3] = 0; 340 | image[index * 3 + 1] = 0; 341 | image[index * 3 + 2] = 0; 342 | return; 343 | } 344 | 345 | sigmas += offset; 346 | rgbs += offset * 3; 347 | deltas += offset * 2; 348 | 349 | // accumulate 350 | uint32_t step = 0; 351 | 352 | scalar_t T = 1.0f; 353 | scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; 354 | 355 | while (step < num_steps) { 356 | 357 | const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); 358 | const scalar_t weight = alpha * T; 359 | 360 | // minimal remained transmittence 361 | //if (weight < 1e-4f) break; 362 | 363 | r += weight * rgbs[0]; 364 | g += weight * rgbs[1]; 365 | b += weight * rgbs[2]; 366 | 367 | t += deltas[1]; // real delta 368 | d += weight * t; 369 | 370 | ws += weight; 371 | 372 | T *= 1.0f - alpha; 373 | 374 | //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); 375 | 376 | // locate 377 | sigmas++; 378 | rgbs += 3; 379 | deltas += 2; 380 | 381 | step++; 382 | } 383 | 384 | //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); 385 | 386 | // write 387 | weights_sum[index] = ws; // weights_sum 388 | depth[index] = d; 389 | image[index * 3] = r; 390 | image[index * 3 + 1] = g; 391 | image[index * 3 + 2] = b; 392 | } 393 | 394 | 395 | // grad_weights_sum: [N,] 396 | // grad: [N, 3] 397 | // sigmas: [M] 398 | // rgbs: [M, 3] 399 | // deltas: [M, 2] 400 | // rays: [N, 3], idx, offset, num_steps 401 | // weights_sum: [N,], weights_sum here 402 | // image: [N, 3] 403 | // grad_sigmas: [M] 404 | // grad_rgbs: [M, 3] 405 | template 406 | __global__ void kernel_composite_rays_train_backward( 407 | const scalar_t * __restrict__ grad_weights_sum, 408 | const scalar_t * __restrict__ grad_image, 409 | const scalar_t * __restrict__ sigmas, 410 | const scalar_t * __restrict__ rgbs, 411 | const scalar_t * __restrict__ deltas, 412 | const int * __restrict__ rays, 413 | const scalar_t * __restrict__ weights_sum, 414 | const scalar_t * __restrict__ image, 415 | const uint32_t M, const uint32_t N, 416 | scalar_t * grad_sigmas, 417 | scalar_t * grad_rgbs 418 | ) { 419 | // parallel per ray 420 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 421 | if (n >= N) return; 422 | 423 | // locate 424 | uint32_t index = rays[n * 3]; 425 | uint32_t offset = rays[n * 3 + 1]; 426 | uint32_t num_steps = rays[n * 3 + 2]; 427 | 428 | if (num_steps == 0 || offset + num_steps >= M) return; 429 | 430 | grad_weights_sum += index; 431 | grad_image += index * 3; 432 | weights_sum += index; 433 | image += index * 3; 434 | sigmas += offset; 435 | rgbs += offset * 3; 436 | deltas += offset * 2; 437 | grad_sigmas += offset; 438 | grad_rgbs += offset * 3; 439 | 440 | // accumulate 441 | uint32_t step = 0; 442 | 443 | scalar_t T = 1.0f; 444 | const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; 445 | scalar_t r = 0, g = 0, b = 0, ws = 0; 446 | 447 | while (step < num_steps) { 448 | 449 | const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); 450 | const scalar_t weight = alpha * T; 451 | 452 | //if (weight < 1e-4f) break; 453 | 454 | r += weight * rgbs[0]; 455 | g += weight * rgbs[1]; 456 | b += weight * rgbs[2]; 457 | ws += weight; 458 | 459 | T *= 1.0f - alpha; 460 | 461 | // write grad_image 462 | grad_rgbs[0] = grad_image[0] * weight; 463 | grad_rgbs[1] = grad_image[1] * weight; 464 | grad_rgbs[2] = grad_image[2] * weight; 465 | 466 | // not grad_sigmas! just a trick to pre-calculate grad_sigmas. 467 | grad_sigmas[0] = deltas[0] * ( 468 | grad_image[0] * (T * rgbs[0] - (r_final - r)) + 469 | grad_image[1] * (T * rgbs[1] - (g_final - g)) + 470 | grad_image[2] * (T * rgbs[2] - (b_final - b)) + 471 | grad_weights_sum[0] * (T - (ws_final - ws)) 472 | ); 473 | 474 | //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); 475 | 476 | // locate 477 | sigmas++; 478 | rgbs += 3; 479 | deltas += 2; 480 | grad_sigmas++; 481 | grad_rgbs += 3; 482 | 483 | step++; 484 | } 485 | } 486 | 487 | 488 | void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb) { 489 | 490 | static constexpr uint32_t N_THREAD = 256; 491 | 492 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 493 | rays_o.scalar_type(), "march_rays_train", ([&] { 494 | kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), mean_density, bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), perturb); 495 | })); 496 | } 497 | 498 | 499 | void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { 500 | 501 | static constexpr uint32_t N_THREAD = 256; 502 | 503 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 504 | sigmas.scalar_type(), "composite_rays_train_forward", ([&] { 505 | kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); 506 | })); 507 | } 508 | 509 | 510 | void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { 511 | 512 | static constexpr uint32_t N_THREAD = 256; 513 | 514 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 515 | grad_image.scalar_type(), "composite_rays_train_backward", ([&] { 516 | kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), image.data_ptr(), M, N, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); 517 | })); 518 | } 519 | 520 | 521 | //////////////////////////////////////////////////// 522 | ///////////// infernce ///////////// 523 | //////////////////////////////////////////////////// 524 | 525 | template 526 | __global__ void kernel_march_rays( 527 | const uint32_t n_alive, 528 | const uint32_t n_step, 529 | const int* __restrict__ rays_alive, 530 | const scalar_t* __restrict__ rays_t, 531 | const scalar_t* __restrict__ rays_o, 532 | const scalar_t* __restrict__ rays_d, 533 | const float bound, 534 | const float dt_gamma, const uint32_t max_steps, 535 | const uint32_t C, const uint32_t H, 536 | const scalar_t * __restrict__ grid, 537 | const float mean_density, 538 | const scalar_t* __restrict__ nears, 539 | const scalar_t* __restrict__ fars, 540 | scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, 541 | const uint32_t perturb 542 | ) { 543 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 544 | if (n >= n_alive) return; 545 | 546 | const int index = rays_alive[n]; // ray id 547 | float t = rays_t[n]; // current ray's t 548 | 549 | const float density_thresh = fminf(DENSITY_THRESH(), mean_density); 550 | 551 | // locate 552 | rays_o += index * 3; 553 | rays_d += index * 3; 554 | xyzs += n * n_step * 3; 555 | dirs += n * n_step * 3; 556 | deltas += n * n_step * 2; 557 | 558 | const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; 559 | const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; 560 | const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; 561 | const float near = nears[index], far = fars[index]; 562 | 563 | const float dt_min = 2 * SQRT3() / max_steps; 564 | 565 | // march for n_step steps, record points 566 | uint32_t step = 0; 567 | 568 | // introduce some randomness (pass in spp as perturb here) 569 | if (perturb) { 570 | pcg32 rng((uint64_t)n, (uint64_t)perturb); 571 | t += dt_min * rng.next_float(); 572 | } 573 | 574 | float last_t = t; 575 | 576 | while (t < far && step < n_step) { 577 | // current point 578 | const float x = clamp(ox + t * dx, -bound, bound); 579 | const float y = clamp(oy + t * dy, -bound, bound); 580 | const float z = clamp(oz + t * dz, -bound, bound); 581 | 582 | // get mip level 583 | // TODO: check why using mip_from_dt... 584 | const int level = mip_from_pos(x, y, z, C); // range in [0, C - 1] 585 | const float mip_bound = fminf(exp2f((float)level), bound); 586 | const float mip_rbound = 1 / mip_bound; 587 | const float dt_max = 2 * mip_bound / H; 588 | 589 | // convert to nearest grid position 590 | const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 591 | const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 592 | const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); 593 | 594 | const uint32_t index = level * H * H * H + nx * H * H + ny * H + nz; 595 | const float density = grid[index]; 596 | 597 | // if occpuied, advance a small step, and write to output 598 | if (density > density_thresh) { 599 | // write step 600 | xyzs[0] = x; 601 | xyzs[1] = y; 602 | xyzs[2] = z; 603 | dirs[0] = dx; 604 | dirs[1] = dy; 605 | dirs[2] = dz; 606 | // calc dt 607 | const float dt = clamp(t * dt_gamma, dt_min, dt_max); 608 | t += dt; 609 | deltas[0] = dt; 610 | deltas[1] = t - last_t; // used to calc depth 611 | last_t = t; 612 | // step 613 | xyzs += 3; 614 | dirs += 3; 615 | deltas += 2; 616 | step++; 617 | 618 | // else, skip a large step (basically skip a voxel grid) 619 | } else { 620 | // calc distance to next voxel 621 | const float tx = (((nx + 0.5f + 0.5f * signf(dx)) / (H - 1) * 2 - 1) * mip_bound - x) * rdx; 622 | const float ty = (((ny + 0.5f + 0.5f * signf(dy)) / (H - 1) * 2 - 1) * mip_bound - y) * rdy; 623 | const float tz = (((nz + 0.5f + 0.5f * signf(dz)) / (H - 1) * 2 - 1) * mip_bound - z) * rdz; 624 | const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); 625 | // step until next voxel 626 | do { 627 | const float dt = clamp(t * dt_gamma, dt_min, dt_max); 628 | t += dt; 629 | } while (t < tt); 630 | } 631 | } 632 | } 633 | 634 | 635 | void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor density_grid, const float mean_density, at::Tensor near, at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb) { 636 | static constexpr uint32_t N_THREAD = 256; 637 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 638 | rays_o.scalar_type(), "march_rays", ([&] { 639 | kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, density_grid.data_ptr(), mean_density, near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), perturb); 640 | })); 641 | } 642 | 643 | 644 | template 645 | __global__ void kernel_composite_rays( 646 | const uint32_t n_alive, 647 | const uint32_t n_step, 648 | const int* __restrict__ rays_alive, 649 | scalar_t* rays_t, 650 | const scalar_t* __restrict__ sigmas, 651 | const scalar_t* __restrict__ rgbs, 652 | const scalar_t* __restrict__ deltas, 653 | scalar_t* weights_sum, scalar_t* depth, scalar_t* image 654 | ) { 655 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 656 | if (n >= n_alive) return; 657 | 658 | const int index = rays_alive[n]; // ray id 659 | scalar_t t = rays_t[n]; // current ray's t 660 | 661 | // locate 662 | sigmas += n * n_step; 663 | rgbs += n * n_step * 3; 664 | deltas += n * n_step * 2; 665 | 666 | weights_sum += index; 667 | depth += index; 668 | image += index * 3; 669 | 670 | scalar_t weight_sum = weights_sum[0]; 671 | scalar_t d = depth[0]; 672 | scalar_t r = image[0]; 673 | scalar_t g = image[1]; 674 | scalar_t b = image[2]; 675 | 676 | // accumulate 677 | uint32_t step = 0; 678 | while (step < n_step) { 679 | 680 | // ray is terminated if delta == 0 681 | if (deltas[0] == 0) break; 682 | 683 | const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); 684 | 685 | /* 686 | T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) 687 | w_i = alpha_i * T_i 688 | --> 689 | T_i = 1 - \sum_{j=0}^{i-1} w_j 690 | */ 691 | const scalar_t T = 1 - weight_sum; 692 | const scalar_t weight = alpha * T; 693 | weight_sum += weight; 694 | 695 | t += deltas[1]; // real delta 696 | d += weight * t; 697 | r += weight * rgbs[0]; 698 | g += weight * rgbs[1]; 699 | b += weight * rgbs[2]; 700 | 701 | //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); 702 | 703 | // ray is terminated if T is too small 704 | if (T < 1e-4) break; 705 | 706 | // locate 707 | sigmas++; 708 | rgbs += 3; 709 | deltas += 2; 710 | step++; 711 | } 712 | 713 | //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); 714 | 715 | // rays_t = -1 means ray is terminated early. 716 | if (step < n_step) { 717 | rays_t[n] = -1; 718 | } else { 719 | rays_t[n] = t; 720 | } 721 | 722 | weights_sum[0] = weight_sum; // this is the thing I needed! 723 | depth[0] = d; 724 | image[0] = r; 725 | image[1] = g; 726 | image[2] = b; 727 | } 728 | 729 | 730 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { 731 | static constexpr uint32_t N_THREAD = 256; 732 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 733 | image.scalar_type(), "composite_rays", ([&] { 734 | kernel_composite_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); 735 | })); 736 | } 737 | 738 | 739 | template 740 | __global__ void kernel_compact_rays( 741 | const uint32_t n_alive, 742 | int* rays_alive, 743 | const int* __restrict__ rays_alive_old, 744 | scalar_t* rays_t, 745 | const scalar_t* __restrict__ rays_t_old, 746 | int* alive_counter 747 | ) { 748 | const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; 749 | if (n >= n_alive) return; 750 | 751 | // rays_t_old[n] < 0 means ray died in last composite kernel. 752 | if (rays_t_old[n] >= 0) { 753 | const int index = atomicAdd(alive_counter, 1); 754 | rays_alive[index] = rays_alive_old[n]; 755 | rays_t[index] = rays_t_old[n]; 756 | } 757 | } 758 | 759 | 760 | void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter) { 761 | static constexpr uint32_t N_THREAD = 256; 762 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 763 | rays_t.scalar_type(), "compact_rays", ([&] { 764 | kernel_compact_rays<<>>(n_alive, rays_alive.data_ptr(), rays_alive_old.data_ptr(), rays_t.data_ptr(), rays_t_old.data_ptr(), alive_counter.data_ptr()); 765 | })); 766 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 20 | 21 | 22 | template 23 | __host__ __device__ T div_round_up(T val, T divisor) { 24 | return (val + divisor - 1) / divisor; 25 | } 26 | 27 | template 28 | __global__ void kernel_sh( 29 | const scalar_t * __restrict__ inputs, 30 | scalar_t * outputs, 31 | uint32_t B, uint32_t D, uint32_t C, 32 | const bool calc_grad_inputs, 33 | scalar_t * dy_dx 34 | ) { 35 | const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; 36 | if (b >= B) return; 37 | 38 | const uint32_t C2 = C * C; 39 | 40 | // locate 41 | inputs += b * D; 42 | outputs += b * C2; 43 | 44 | scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; 45 | 46 | scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; 47 | scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; 48 | scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; 49 | 50 | auto write_sh = [&]() { 51 | outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) 52 | if (C <= 1) { return; } 53 | outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) 54 | outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) 55 | outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) 56 | if (C <= 2) { return; } 57 | outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) 58 | outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) 59 | outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) 60 | outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) 61 | outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) 62 | if (C <= 3) { return; } 63 | outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) 64 | outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) 65 | outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) 66 | outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) 67 | outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) 68 | outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) 69 | outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) 70 | if (C <= 4) { return; } 71 | outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) 72 | outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) 73 | outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) 74 | outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) 75 | outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) 76 | outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) 77 | outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) 78 | outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) 79 | outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) 80 | if (C <= 5) { return; } 81 | outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) 82 | outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) 83 | outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) 84 | outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) 85 | outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) 86 | outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) 87 | outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) 88 | outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) 89 | outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) 90 | outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) 91 | outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) 92 | if (C <= 6) { return; } 93 | outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) 94 | outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) 95 | outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) 96 | outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) 97 | outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) 98 | outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) 99 | outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) 100 | outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) 101 | outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) 102 | outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) 103 | outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) 104 | outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) 105 | outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) 106 | if (C <= 7) { return; } 107 | outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) 108 | outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) 109 | outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) 110 | outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) 111 | outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) 112 | outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) 113 | outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) 114 | outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) 115 | outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) 116 | outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) 117 | outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) 118 | outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) 119 | outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) 120 | outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) 121 | outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) 122 | }; 123 | 124 | write_sh(); 125 | 126 | if (calc_grad_inputs) { 127 | scalar_t *dx = dy_dx + b * D * C2; 128 | scalar_t *dy = dx + C2; 129 | scalar_t *dz = dy + C2; 130 | 131 | auto write_sh_dx = [&]() { 132 | dx[0] = 0.0f ; // 0 133 | if (C <= 1) { return; } 134 | dx[1] = 0.0f ; // 0 135 | dx[2] = 0.0f ; // 0 136 | dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) 137 | if (C <= 2) { return; } 138 | dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) 139 | dx[5] = 0.0f ; // 0 140 | dx[6] = 0.0f ; // 0 141 | dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) 142 | dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) 143 | if (C <= 3) { return; } 144 | dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) 145 | dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) 146 | dx[11] = 0.0f ; // 0 147 | dx[12] = 0.0f ; // 0 148 | dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) 149 | dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) 150 | dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) 151 | if (C <= 4) { return; } 152 | dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) 153 | dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) 154 | dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) 155 | dx[19] = 0.0f ; // 0 156 | dx[20] = 0.0f ; // 0 157 | dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) 158 | dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) 159 | dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) 160 | dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) 161 | if (C <= 5) { return; } 162 | dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) 163 | dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) 164 | dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) 165 | dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) 166 | dx[29] = 0.0f ; // 0 167 | dx[30] = 0.0f ; // 0 168 | dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) 169 | dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) 170 | dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) 171 | dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) 172 | dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) 173 | if (C <= 6) { return; } 174 | dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) 175 | dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) 176 | dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) 177 | dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) 178 | dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) 179 | dx[41] = 0.0f ; // 0 180 | dx[42] = 0.0f ; // 0 181 | dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) 182 | dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) 183 | dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) 184 | dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) 185 | dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) 186 | dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) 187 | if (C <= 7) { return; } 188 | dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) 189 | dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) 190 | dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) 191 | dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) 192 | dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) 193 | dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) 194 | dx[55] = 0.0f ; // 0 195 | dx[56] = 0.0f ; // 0 196 | dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) 197 | dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) 198 | dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) 199 | dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) 200 | dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) 201 | dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) 202 | dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) 203 | }; 204 | 205 | auto write_sh_dy = [&]() { 206 | dy[0] = 0.0f ; // 0 207 | if (C <= 1) { return; } 208 | dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) 209 | dy[2] = 0.0f ; // 0 210 | dy[3] = 0.0f ; // 0 211 | if (C <= 2) { return; } 212 | dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) 213 | dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) 214 | dy[6] = 0.0f ; // 0 215 | dy[7] = 0.0f ; // 0 216 | dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) 217 | if (C <= 3) { return; } 218 | dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) 219 | dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) 220 | dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) 221 | dy[12] = 0.0f ; // 0 222 | dy[13] = 0.0f ; // 0 223 | dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) 224 | dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) 225 | if (C <= 4) { return; } 226 | dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) 227 | dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) 228 | dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) 229 | dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) 230 | dy[20] = 0.0f ; // 0 231 | dy[21] = 0.0f ; // 0 232 | dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) 233 | dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) 234 | dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) 235 | if (C <= 5) { return; } 236 | dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) 237 | dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) 238 | dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) 239 | dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) 240 | dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) 241 | dy[30] = 0.0f ; // 0 242 | dy[31] = 0.0f ; // 0 243 | dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) 244 | dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) 245 | dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) 246 | dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) 247 | if (C <= 6) { return; } 248 | dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) 249 | dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) 250 | dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) 251 | dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) 252 | dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) 253 | dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) 254 | dy[42] = 0.0f ; // 0 255 | dy[43] = 0.0f ; // 0 256 | dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) 257 | dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) 258 | dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) 259 | dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) 260 | dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) 261 | if (C <= 7) { return; } 262 | dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) 263 | dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) 264 | dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) 265 | dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) 266 | dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) 267 | dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) 268 | dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) 269 | dy[56] = 0.0f ; // 0 270 | dy[57] = 0.0f ; // 0 271 | dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) 272 | dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) 273 | dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) 274 | dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) 275 | dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) 276 | dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) 277 | }; 278 | 279 | auto write_sh_dz = [&]() { 280 | dz[0] = 0.0f ; // 0 281 | if (C <= 1) { return; } 282 | dz[1] = 0.0f ; // 0 283 | dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) 284 | dz[3] = 0.0f ; // 0 285 | if (C <= 2) { return; } 286 | dz[4] = 0.0f ; // 0 287 | dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) 288 | dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) 289 | dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) 290 | dz[8] = 0.0f ; // 0 291 | if (C <= 3) { return; } 292 | dz[9] = 0.0f ; // 0 293 | dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) 294 | dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) 295 | dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) 296 | dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) 297 | dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) 298 | dz[15] = 0.0f ; // 0 299 | if (C <= 4) { return; } 300 | dz[16] = 0.0f ; // 0 301 | dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) 302 | dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) 303 | dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) 304 | dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) 305 | dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) 306 | dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) 307 | dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) 308 | dz[24] = 0.0f ; // 0 309 | if (C <= 5) { return; } 310 | dz[25] = 0.0f ; // 0 311 | dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) 312 | dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) 313 | dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) 314 | dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) 315 | dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) 316 | dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) 317 | dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) 318 | dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) 319 | dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) 320 | dz[35] = 0.0f ; // 0 321 | if (C <= 6) { return; } 322 | dz[36] = 0.0f ; // 0 323 | dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) 324 | dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) 325 | dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) 326 | dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) 327 | dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) 328 | dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) 329 | dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) 330 | dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) 331 | dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) 332 | dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) 333 | dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) 334 | dz[48] = 0.0f ; // 0 335 | if (C <= 7) { return; } 336 | dz[49] = 0.0f ; // 0 337 | dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) 338 | dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) 339 | dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) 340 | dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) 341 | dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) 342 | dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) 343 | dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) 344 | dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) 345 | dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) 346 | dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) 347 | dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) 348 | dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) 349 | dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) 350 | dz[63] = 0.0f ; // 0 351 | }; 352 | write_sh_dx(); 353 | write_sh_dy(); 354 | write_sh_dz(); 355 | } 356 | } 357 | 358 | 359 | template 360 | __global__ void kernel_sh_backward( 361 | const scalar_t * __restrict__ grad, 362 | const scalar_t * __restrict__ inputs, 363 | uint32_t B, uint32_t D, uint32_t C, 364 | const scalar_t * __restrict__ dy_dx, 365 | scalar_t * grad_inputs 366 | ) { 367 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 368 | const uint32_t b = t / D; 369 | if (b >= B) return; 370 | 371 | const uint32_t d = t - b * D; 372 | const uint32_t C2 = C * C; 373 | 374 | // locate 375 | grad += b * C2; 376 | dy_dx += b * D * C2 + d * C2; 377 | 378 | for (int ch = 0; ch < C2; ch++) { 379 | grad_inputs[t] += grad[ch] * dy_dx[ch]; 380 | //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); 381 | } 382 | 383 | } 384 | 385 | // inputs: [B, D], float, in [0, 1] 386 | // outputs: [B, L * C], float 387 | template 388 | void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, scalar_t *dy_dx) { 389 | static constexpr uint32_t N_THREADS = 256; 390 | kernel_sh<<>>(inputs, outputs, B, D, C, calc_grad_inputs, dy_dx); 391 | } 392 | 393 | 394 | template 395 | void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { 396 | static constexpr uint32_t N_THREADS = 256; 397 | kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); 398 | } 399 | 400 | 401 | 402 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx) { 403 | CHECK_CUDA(inputs); 404 | CHECK_CUDA(outputs); 405 | CHECK_CUDA(dy_dx); 406 | 407 | CHECK_CONTIGUOUS(inputs); 408 | CHECK_CONTIGUOUS(outputs); 409 | CHECK_CONTIGUOUS(dy_dx); 410 | 411 | CHECK_IS_FLOATING(inputs); 412 | CHECK_IS_FLOATING(outputs); 413 | CHECK_IS_FLOATING(dy_dx); 414 | 415 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 416 | inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { 417 | sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, calc_grad_inputs, dy_dx.data_ptr()); 418 | })); 419 | } 420 | 421 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { 422 | CHECK_CUDA(grad); 423 | CHECK_CUDA(inputs); 424 | CHECK_CUDA(dy_dx); 425 | CHECK_CUDA(grad_inputs); 426 | 427 | CHECK_CONTIGUOUS(grad); 428 | CHECK_CONTIGUOUS(inputs); 429 | CHECK_CONTIGUOUS(dy_dx); 430 | CHECK_CONTIGUOUS(grad_inputs); 431 | 432 | CHECK_IS_FLOATING(grad); 433 | CHECK_IS_FLOATING(inputs); 434 | CHECK_IS_FLOATING(dy_dx); 435 | CHECK_IS_FLOATING(grad_inputs); 436 | 437 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 438 | grad.scalar_type(), "sh_encode_backward_cuda", ([&] { 439 | sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); 440 | })); 441 | } --------------------------------------------------------------------------------