├── .gitignore ├── LICENSE ├── activation.py ├── assets ├── teaser.jpg └── teaser2.jpg ├── depth_tools ├── download_models.sh ├── dpt.py └── extract_depth.py ├── encoding.py ├── freqencoder ├── __init__.py ├── backend.py ├── freq.py ├── setup.py └── src │ ├── bindings.cpp │ ├── freqencoder.cu │ └── freqencoder.h ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── loss.py ├── main.py ├── meshutils.py ├── nerf ├── colmap_provider.py ├── colmap_utils.py ├── dtu_provider.py ├── gui.py ├── network.py ├── provider.py ├── renderer.py └── utils.py ├── raymarching ├── __init__.py ├── backend.py ├── raymarching.py ├── setup.py └── src │ ├── bindings.cpp │ ├── raymarching.cu │ └── raymarching.h ├── readme.md ├── renderer.html ├── requirements.txt ├── scripts ├── colmap2nerf.py ├── downscale.py ├── install_ext.sh ├── remove_bg.py ├── runall_360_indoor.sh ├── runall_360_indoor_sdf.sh ├── runall_360_outdoor.sh ├── runall_llff.sh ├── runall_outdoor_sdf.sh ├── runall_syn.sh └── runall_syn_sdf.sh └── shencoder ├── __init__.py ├── backend.py ├── setup.py ├── sphere_harmonics.py └── src ├── bindings.cpp ├── shencoder.cu └── shencoder.h /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | 6 | tmp* 7 | data/ 8 | data 9 | trial*/ 10 | .vs/ 11 | 12 | *.ckpt 13 | *wos* 14 | ablation.sh 15 | runall_firekeeper* 16 | runall_dtu* 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 hawkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | class _trunc_exp(Function): 6 | @staticmethod 7 | @custom_fwd(cast_inputs=torch.float32) # cast to float32 8 | def forward(ctx, x): 9 | ctx.save_for_backward(x) 10 | return torch.exp(x) 11 | 12 | @staticmethod 13 | @custom_bwd 14 | def backward(ctx, g): 15 | x = ctx.saved_tensors[0] 16 | return g * torch.exp(x.clamp(-15, 15)) 17 | 18 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/nerf2mesh/ec7f930ccf768ba4d6e602360b4a6ff0300fe9c8/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/teaser2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/nerf2mesh/ec7f930ccf768ba4d6e602360b4a6ff0300fe9c8/assets/teaser2.jpg -------------------------------------------------------------------------------- /depth_tools/download_models.sh: -------------------------------------------------------------------------------- 1 | gdown '1Jrh-bRnJEjyMCS7f-WsaFlccfPjJPPHI&confirm=t' -O ./omnidata_dpt_depth_v2.ckpt # omnidata depth (v2) -------------------------------------------------------------------------------- /depth_tools/extract_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | from PIL import Image 5 | import argparse 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | 14 | from dpt import DPTDepthModel 15 | 16 | IMAGE_SIZE = 384 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('path', type=str) 20 | parser.add_argument('--out_dir', type=str) 21 | parser.add_argument('--ckpt', type=str, default='./depth_tools/omnidata_dpt_depth_v2.ckpt') 22 | 23 | opt = parser.parse_args() 24 | 25 | if opt.path[-1] == '/': 26 | opt.path = opt.path[:-1] 27 | 28 | out_dir = os.path.join(os.path.dirname(opt.path), f'depths') 29 | 30 | os.makedirs(out_dir, exist_ok=True) 31 | 32 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu') 33 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 34 | 35 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid 36 | 37 | print(f'[INFO] loading checkpoint from {opt.ckpt}') 38 | checkpoint = torch.load(opt.ckpt, map_location=map_location) 39 | 40 | if 'state_dict' in checkpoint: 41 | state_dict = {} 42 | for k, v in checkpoint['state_dict'].items(): 43 | state_dict[k[6:]] = v 44 | else: 45 | state_dict = checkpoint 46 | 47 | model.load_state_dict(state_dict) 48 | model.to(device) 49 | model.eval() 50 | 51 | trans_totensor = transforms.Compose([ 52 | transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=0.5, std=0.5) 55 | ]) 56 | 57 | 58 | @torch.no_grad() 59 | def run_image(img_path): 60 | # img: filepath 61 | img = Image.open(img_path) 62 | W, H = img.size 63 | img_input = trans_totensor(img).unsqueeze(0).to(device) 64 | 65 | depth = model(img_input) 66 | 67 | depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False) 68 | depth = depth.squeeze().cpu().numpy() 69 | 70 | out_path = os.path.join(out_dir, os.path.splitext(os.path.basename(img_path))[0]) + '.npy' 71 | 72 | # plt.matshow(depth) 73 | # plt.show() 74 | 75 | # plt.matshow(img_input.detach().cpu()[0].permute(1,2,0).numpy()) 76 | # plt.show() 77 | # print(f'[INFO] {out_path} {depth.min()} {depth.max()} {depth.shape}') 78 | 79 | np.save(out_path, depth) 80 | 81 | 82 | img_paths = glob.glob(os.path.join(opt.path, '*')) 83 | for img_path in tqdm.tqdm(img_paths): 84 | run_image(img_path) -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | class FreqEncoder_torch(nn.Module): 9 | def __init__(self, input_dim, max_freq_log2, N_freqs, 10 | log_sampling=True, include_input=True, 11 | periodic_fns=(torch.sin, torch.cos)): 12 | 13 | super().__init__() 14 | 15 | self.input_dim = input_dim 16 | self.include_input = include_input 17 | self.periodic_fns = periodic_fns 18 | 19 | self.output_dim = 0 20 | if self.include_input: 21 | self.output_dim += self.input_dim 22 | 23 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 24 | 25 | if log_sampling: 26 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 27 | else: 28 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 29 | 30 | self.freq_bands = self.freq_bands.numpy().tolist() 31 | 32 | def forward(self, input, **kwargs): 33 | 34 | out = [] 35 | if self.include_input: 36 | out.append(input) 37 | 38 | for i in range(len(self.freq_bands)): 39 | freq = self.freq_bands[i] 40 | for p_fn in self.periodic_fns: 41 | out.append(p_fn(input * freq)) 42 | 43 | out = torch.cat(out, dim=-1) 44 | 45 | 46 | return out 47 | 48 | class TCNN_hashgrid(nn.Module): 49 | def __init__(self, num_levels, level_dim, log2_hashmap_size, base_resolution, desired_resolution, interpolation, **kwargs): 50 | super().__init__() 51 | import tinycudann as tcnn 52 | self.encoder = tcnn.Encoding( 53 | n_input_dims=3, 54 | encoding_config={ 55 | "otype": "HashGrid", 56 | "n_levels": num_levels, 57 | "n_features_per_level": level_dim, 58 | "log2_hashmap_size": log2_hashmap_size, 59 | "base_resolution": base_resolution, 60 | "per_level_scale": np.exp2(np.log2(desired_resolution / num_levels) / (num_levels - 1)), 61 | "interpolation": "Smoothstep" if interpolation == 'smoothstep' else "Linear", 62 | }, 63 | dtype=torch.float32, 64 | ) 65 | self.output_dim = self.encoder.n_output_dims # patch 66 | 67 | def forward(self, x, bound=1, **kwargs): 68 | return self.encoder((x + bound) / (2 * bound)) 69 | 70 | 71 | def get_encoder(encoding, input_dim=3, 72 | output_dim=1, resolution=300, mode='bilinear', # dense grid 73 | multires=6, # freq 74 | degree=4, # SH 75 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, # hash/tiled grid 76 | align_corners=False, interpolation='linear', # grid 77 | **kwargs): 78 | 79 | if encoding == 'None': 80 | return lambda x, **kwargs: x, input_dim 81 | 82 | elif encoding == 'frequency_torch': 83 | encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 84 | 85 | elif encoding == 'frequency': 86 | from freqencoder import FreqEncoder 87 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 88 | 89 | elif encoding == 'sh': 90 | from shencoder import SHEncoder 91 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 92 | 93 | elif encoding == 'hashgrid': 94 | from gridencoder import GridEncoder 95 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation) 96 | 97 | elif encoding == 'hashgrid_tcnn': 98 | encoder = TCNN_hashgrid(num_levels=num_levels, level_dim=level_dim, log2_hashmap_size=log2_hashmap_size, base_resolution=base_resolution, desired_resolution=desired_resolution, interpolation=interpolation) 99 | 100 | elif encoding == 'tiledgrid': 101 | from gridencoder import GridEncoder 102 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation) 103 | 104 | else: 105 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sh, hashgrid, tiledgrid]') 106 | 107 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /freqencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .freq import FreqEncoder -------------------------------------------------------------------------------- /freqencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | from packaging import version 4 | import torch 5 | 6 | torch_version = torch.__version__ 7 | 8 | _src_path = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | nvcc_flags = [ 11 | '-O3', '-std=c++14', 12 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 13 | '-use_fast_math' 14 | ] 15 | 16 | if os.name == "posix": 17 | c_flags = ['-O3', '-std=c++14'] 18 | if version.parse(torch_version) >= version.parse("2.1"): 19 | nvcc_flags = [ 20 | '-O3', '-std=c++17', 21 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 22 | '-use_fast_math' 23 | ] 24 | c_flags = ['-O3', '-std=c++17'] 25 | 26 | 27 | 28 | elif os.name == "nt": 29 | c_flags = ['/O2', '/std:c++17'] 30 | 31 | # find cl.exe 32 | def find_cl_path(): 33 | import glob 34 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 35 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 36 | if paths: 37 | return paths[0] 38 | 39 | # If cl.exe is not on path, try to find it. 40 | if os.system("where cl.exe >nul 2>nul") != 0: 41 | cl_path = find_cl_path() 42 | if cl_path is None: 43 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 44 | os.environ["PATH"] += ";" + cl_path 45 | 46 | _backend = load(name='_freqencoder', 47 | extra_cflags=c_flags, 48 | extra_cuda_cflags=nvcc_flags, 49 | sources=[os.path.join(_src_path, 'src', f) for f in [ 50 | 'freqencoder.cu', 51 | 'bindings.cpp', 52 | ]], 53 | ) 54 | 55 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /freqencoder/freq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _freqencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | class _freq_encoder(Function): 16 | @staticmethod 17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 18 | def forward(ctx, inputs, degree, output_dim): 19 | # inputs: [B, input_dim], float 20 | # RETURN: [B, F], float 21 | 22 | if not inputs.is_cuda: inputs = inputs.cuda() 23 | inputs = inputs.contiguous() 24 | 25 | B, input_dim = inputs.shape # batch size, coord dim 26 | 27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 28 | 29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 30 | 31 | ctx.save_for_backward(inputs, outputs) 32 | ctx.dims = [B, input_dim, degree, output_dim] 33 | 34 | return outputs 35 | 36 | @staticmethod 37 | #@once_differentiable 38 | @custom_bwd 39 | def backward(ctx, grad): 40 | # grad: [B, C * C] 41 | 42 | grad = grad.contiguous() 43 | inputs, outputs = ctx.saved_tensors 44 | B, input_dim, degree, output_dim = ctx.dims 45 | 46 | grad_inputs = torch.zeros_like(inputs) 47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 48 | 49 | return grad_inputs, None, None 50 | 51 | 52 | freq_encode = _freq_encoder.apply 53 | 54 | 55 | class FreqEncoder(nn.Module): 56 | def __init__(self, input_dim=3, degree=4): 57 | super().__init__() 58 | 59 | self.input_dim = input_dim 60 | self.degree = degree 61 | self.output_dim = input_dim + input_dim * 2 * degree 62 | 63 | def __repr__(self): 64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" 65 | 66 | def forward(self, inputs, **kwargs): 67 | # inputs: [..., input_dim] 68 | # return: [..., ] 69 | 70 | prefix_shape = list(inputs.shape[:-1]) 71 | inputs = inputs.reshape(-1, self.input_dim) 72 | 73 | outputs = freq_encode(inputs, self.degree, self.output_dim) 74 | 75 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 76 | 77 | return outputs -------------------------------------------------------------------------------- /freqencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | from packaging import version 5 | import torch 6 | 7 | torch_version = torch.__version__ 8 | 9 | _src_path = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | nvcc_flags = [ 12 | '-O3', '-std=c++14', 13 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 14 | '-use_fast_math' 15 | ] 16 | 17 | if os.name == "posix": 18 | c_flags = ['-O3', '-std=c++14'] 19 | if version.parse(torch_version) >= version.parse("2.1"): 20 | nvcc_flags = [ 21 | '-O3', '-std=c++17', 22 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 23 | '-use_fast_math' 24 | ] 25 | c_flags = ['-O3', '-std=c++17'] 26 | 27 | elif os.name == "nt": 28 | c_flags = ['/O2', '/std:c++17'] 29 | 30 | # find cl.exe 31 | def find_cl_path(): 32 | import glob 33 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 34 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 35 | if paths: 36 | return paths[0] 37 | 38 | # If cl.exe is not on path, try to find it. 39 | if os.system("where cl.exe >nul 2>nul") != 0: 40 | cl_path = find_cl_path() 41 | if cl_path is None: 42 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 43 | os.environ["PATH"] += ";" + cl_path 44 | 45 | setup( 46 | name='freqencoder', # package name, import this to use python API 47 | ext_modules=[ 48 | CUDAExtension( 49 | name='_freqencoder', # extension name, import this to use CUDA API 50 | sources=[os.path.join(_src_path, 'src', f) for f in [ 51 | 'freqencoder.cu', 52 | 'bindings.cpp', 53 | ]], 54 | extra_compile_args={ 55 | 'cxx': c_flags, 56 | 'nvcc': nvcc_flags, 57 | } 58 | ), 59 | ], 60 | cmdclass={ 61 | 'build_ext': BuildExtension, 62 | } 63 | ) -------------------------------------------------------------------------------- /freqencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "freqencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); 7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 20 | 21 | inline constexpr __device__ float PI() { return 3.141592653589793f; } 22 | 23 | template 24 | __host__ __device__ T div_round_up(T val, T divisor) { 25 | return (val + divisor - 1) / divisor; 26 | } 27 | 28 | // inputs: [B, D] 29 | // outputs: [B, C], C = D + D * deg * 2 30 | __global__ void kernel_freq( 31 | const float * __restrict__ inputs, 32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 33 | float * outputs 34 | ) { 35 | // parallel on per-element 36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 37 | if (t >= B * C) return; 38 | 39 | // get index 40 | const uint32_t b = t / C; 41 | const uint32_t c = t - b * C; // t % C; 42 | 43 | // locate 44 | inputs += b * D; 45 | outputs += t; 46 | 47 | // write self 48 | if (c < D) { 49 | outputs[0] = inputs[c]; 50 | // write freq 51 | } else { 52 | const uint32_t col = c / D - 1; 53 | const uint32_t d = c % D; 54 | const uint32_t freq = col / 2; 55 | const float phase_shift = (col % 2) * (PI() / 2); 56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); 57 | } 58 | } 59 | 60 | // grad: [B, C], C = D + D * deg * 2 61 | // outputs: [B, C] 62 | // grad_inputs: [B, D] 63 | __global__ void kernel_freq_backward( 64 | const float * __restrict__ grad, 65 | const float * __restrict__ outputs, 66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 67 | float * grad_inputs 68 | ) { 69 | // parallel on per-element 70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 71 | if (t >= B * D) return; 72 | 73 | const uint32_t b = t / D; 74 | const uint32_t d = t - b * D; // t % D; 75 | 76 | // locate 77 | grad += b * C; 78 | outputs += b * C; 79 | grad_inputs += t; 80 | 81 | // register 82 | float result = grad[d]; 83 | grad += D; 84 | outputs += D; 85 | 86 | for (uint32_t f = 0; f < deg; f++) { 87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); 88 | grad += 2 * D; 89 | outputs += 2 * D; 90 | } 91 | 92 | // write 93 | grad_inputs[0] = result; 94 | } 95 | 96 | 97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { 98 | CHECK_CUDA(inputs); 99 | CHECK_CUDA(outputs); 100 | 101 | CHECK_CONTIGUOUS(inputs); 102 | CHECK_CONTIGUOUS(outputs); 103 | 104 | CHECK_IS_FLOATING(inputs); 105 | CHECK_IS_FLOATING(outputs); 106 | 107 | static constexpr uint32_t N_THREADS = 128; 108 | 109 | kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); 110 | } 111 | 112 | 113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { 114 | CHECK_CUDA(grad); 115 | CHECK_CUDA(outputs); 116 | CHECK_CUDA(grad_inputs); 117 | 118 | CHECK_CONTIGUOUS(grad); 119 | CHECK_CONTIGUOUS(outputs); 120 | CHECK_CONTIGUOUS(grad_inputs); 121 | 122 | CHECK_IS_FLOATING(grad); 123 | CHECK_IS_FLOATING(outputs); 124 | CHECK_IS_FLOATING(grad_inputs); 125 | 126 | static constexpr uint32_t N_THREADS = 128; 127 | 128 | kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); 129 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); 8 | 9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | from packaging import version 4 | import torch 5 | 6 | torch_version = torch.__version__ 7 | 8 | _src_path = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | nvcc_flags = [ 11 | '-O3', '-std=c++14', 12 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 13 | ] 14 | 15 | if os.name == "posix": 16 | c_flags = ['-O3', '-std=c++14'] 17 | if version.parse(torch_version) >= version.parse("2.1"): 18 | nvcc_flags = [ 19 | '-O3', '-std=c++17', 20 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 21 | '-use_fast_math' 22 | ] 23 | c_flags = ['-O3', '-std=c++17'] 24 | elif os.name == "nt": 25 | c_flags = ['/O2', '/std:c++17'] 26 | 27 | # find cl.exe 28 | def find_cl_path(): 29 | import glob 30 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 31 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 32 | if paths: 33 | return paths[0] 34 | 35 | # If cl.exe is not on path, try to find it. 36 | if os.system("where cl.exe >nul 2>nul") != 0: 37 | cl_path = find_cl_path() 38 | if cl_path is None: 39 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 40 | os.environ["PATH"] += ";" + cl_path 41 | 42 | _backend = load(name='_grid_encoder', 43 | extra_cflags=c_flags, 44 | extra_cuda_cflags=nvcc_flags, 45 | sources=[os.path.join(_src_path, 'src', f) for f in [ 46 | 'gridencoder.cu', 47 | 'bindings.cpp', 48 | ]], 49 | ) 50 | 51 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | _interp_to_id = { 20 | 'linear': 0, 21 | 'smoothstep': 1, 22 | } 23 | 24 | class _grid_encode(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None): 28 | # inputs: [B, D], float in [0, 1] 29 | # embeddings: [sO, C], float 30 | # offsets: [L + 1], int 31 | # RETURN: [B, F], float 32 | 33 | inputs = inputs.contiguous() 34 | 35 | B, D = inputs.shape # batch size, coord dim 36 | L = offsets.shape[0] - 1 # level 37 | C = embeddings.shape[1] # embedding dim for each level 38 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 39 | H = base_resolution # base resolution 40 | 41 | max_level = L if max_level is None else min(max_level, L) 42 | 43 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 44 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 45 | if torch.is_autocast_enabled() and C % 2 == 0: 46 | embeddings = embeddings.to(torch.half) 47 | 48 | # L first, optimize cache for cuda kernel, but needs an extra permute later 49 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 50 | 51 | # zero init if we only calculate partial levels 52 | if max_level < L: outputs.zero_() 53 | 54 | if calc_grad_inputs: 55 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) 56 | if max_level < L: dy_dx.zero_() 57 | else: 58 | dy_dx = None 59 | 60 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation) 61 | 62 | # permute back to [B, L * C] 63 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 64 | 65 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 66 | ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level] 67 | ctx.align_corners = align_corners 68 | 69 | return outputs 70 | 71 | @staticmethod 72 | #@once_differentiable 73 | @custom_bwd 74 | def backward(ctx, grad): 75 | 76 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 77 | B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims 78 | align_corners = ctx.align_corners 79 | 80 | # grad: [B, L * C] --> [L, B, C] 81 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 82 | 83 | grad_embeddings = torch.zeros_like(embeddings) 84 | 85 | if dy_dx is not None: 86 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 87 | else: 88 | grad_inputs = None 89 | 90 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) 91 | 92 | if dy_dx is not None: 93 | grad_inputs = grad_inputs.to(inputs.dtype) 94 | 95 | return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None 96 | 97 | 98 | 99 | grid_encode = _grid_encode.apply 100 | 101 | 102 | class GridEncoder(nn.Module): 103 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): 104 | super().__init__() 105 | 106 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 107 | if desired_resolution is not None: 108 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 109 | 110 | self.input_dim = input_dim # coord dims, 2 or 3 111 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 112 | self.level_dim = level_dim # encode channels per level 113 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 114 | self.log2_hashmap_size = log2_hashmap_size 115 | self.base_resolution = base_resolution 116 | self.output_dim = num_levels * level_dim 117 | self.gridtype = gridtype 118 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 119 | self.interpolation = interpolation 120 | self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" 121 | self.align_corners = align_corners 122 | 123 | # allocate parameters 124 | offsets = [] 125 | offset = 0 126 | self.max_params = 2 ** log2_hashmap_size 127 | for i in range(num_levels): 128 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 129 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 130 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 131 | offsets.append(offset) 132 | offset += params_in_level 133 | offsets.append(offset) 134 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 135 | self.register_buffer('offsets', offsets) 136 | 137 | self.n_params = offsets[-1] * level_dim 138 | 139 | # parameters 140 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 141 | 142 | self.reset_parameters() 143 | 144 | def reset_parameters(self): 145 | std = 1e-4 146 | self.embeddings.data.uniform_(-std, std) 147 | 148 | def __repr__(self): 149 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" 150 | 151 | def forward(self, inputs, bound=1, max_level=None): 152 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 153 | # max_level: only calculate first max_level levels (None will use all levels) 154 | # return: [..., num_levels * level_dim] 155 | 156 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 157 | 158 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 159 | 160 | prefix_shape = list(inputs.shape[:-1]) 161 | inputs = inputs.view(-1, self.input_dim) 162 | 163 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level) 164 | outputs = outputs.view(prefix_shape + [self.output_dim]) 165 | 166 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 167 | 168 | return outputs 169 | 170 | # always run in float precision! 171 | @torch.cuda.amp.autocast(enabled=False) 172 | def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): 173 | # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. 174 | 175 | D = self.input_dim 176 | C = self.embeddings.shape[1] # embedding dim for each level 177 | L = self.offsets.shape[0] - 1 # level 178 | S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 179 | H = self.base_resolution # base resolution 180 | 181 | if inputs is None or inputs.size(0) == 0: 182 | # randomized in [0, 1] 183 | inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) 184 | else: 185 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 186 | inputs = inputs.view(-1, self.input_dim) 187 | B = inputs.shape[0] 188 | 189 | if self.embeddings.grad is None: 190 | raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') 191 | 192 | _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) 193 | -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | from packaging import version 5 | import torch 6 | 7 | torch_version = torch.__version__ 8 | 9 | _src_path = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | nvcc_flags = [ 12 | '-O3', '-std=c++14', 13 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | 16 | if os.name == "posix": 17 | c_flags = ['-O3', '-std=c++14'] 18 | if version.parse(torch_version) >= version.parse("2.1"): 19 | nvcc_flags = [ 20 | '-O3', '-std=c++17', 21 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 22 | '-use_fast_math' 23 | ] 24 | c_flags = ['-O3', '-std=c++17'] 25 | elif os.name == "nt": 26 | c_flags = ['/O2', '/std:c++17'] 27 | 28 | # find cl.exe 29 | def find_cl_path(): 30 | import glob 31 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 32 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 33 | if paths: 34 | return paths[0] 35 | 36 | # If cl.exe is not on path, try to find it. 37 | if os.system("where cl.exe >nul 2>nul") != 0: 38 | cl_path = find_cl_path() 39 | if cl_path is None: 40 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 41 | os.environ["PATH"] += ";" + cl_path 42 | 43 | setup( 44 | name='gridencoder', # package name, import this to use python API 45 | ext_modules=[ 46 | CUDAExtension( 47 | name='_gridencoder', # extension name, import this to use CUDA API 48 | sources=[os.path.join(_src_path, 'src', f) for f in [ 49 | 'gridencoder.cu', 50 | 'bindings.cpp', 51 | ]], 52 | extra_compile_args={ 53 | 'cxx': c_flags, 54 | 'nvcc': nvcc_flags, 55 | } 56 | ), 57 | ], 58 | cmdclass={ 59 | 'build_ext': BuildExtension, 60 | } 61 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | 17 | #endif -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | def mape_loss(pred, target, reduction='mean'): 8 | # pred, target: [B, 1], torch tenspr 9 | difference = (pred - target).abs() 10 | scale = 1 / (target.abs() + 1e-2) 11 | loss = difference * scale 12 | 13 | if reduction == 'mean': 14 | loss = loss.mean() 15 | 16 | return loss 17 | 18 | def huber_loss(pred, target, delta=0.1, reduction='mean'): 19 | rel = (pred - target).abs() 20 | sqr = 0.5 / delta * rel * rel 21 | loss = torch.where(rel > delta, rel - 0.5 * delta, sqr) 22 | 23 | if reduction == 'mean': 24 | loss = loss.mean() 25 | 26 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from nerf.gui import NeRFGUI 5 | from nerf.network import NeRFNetwork 6 | from nerf.utils import * 7 | 8 | # torch.autograd.set_detect_anomaly(True) 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('path', type=str) 14 | parser.add_argument('-O', action='store_true', help="recommended settings") 15 | parser.add_argument('--workspace', type=str, default='workspace') 16 | parser.add_argument('--seed', type=int, default=0) 17 | parser.add_argument('--stage', type=int, default=0, help="training stage") 18 | parser.add_argument('--ckpt', type=str, default='latest') 19 | parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") 20 | parser.add_argument('--sdf', action='store_true', help="use sdf instead of density for nerf") 21 | parser.add_argument('--tcnn', action='store_true', help="use tcnn's gridencoder") 22 | parser.add_argument('--progressive_level', action='store_true', help="progressively increase max_level") 23 | 24 | ### testing options 25 | parser.add_argument('--test', action='store_true', help="test mode") 26 | parser.add_argument('--test_no_video', action='store_true', help="test mode: do not save video") 27 | parser.add_argument('--test_no_mesh', action='store_true', help="test mode: do not save mesh") 28 | parser.add_argument('--camera_traj', type=str, default='', help="nerfstudio compatible json file for camera trajactory") 29 | 30 | ### dataset options 31 | parser.add_argument('--data_format', type=str, default='nerf', choices=['nerf', 'colmap', 'dtu']) 32 | parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'all']) 33 | parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") 34 | parser.add_argument('--random_image_batch', action='store_true', help="randomly sample rays from all images per step in training stage 0, incompatible with enable_sparse_depth") 35 | parser.add_argument('--downscale', type=int, default=1, help="downscale training images") 36 | parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") 37 | parser.add_argument('--scale', type=float, default=-1, help="scale camera location into box[-bound, bound]^3, -1 means automatically determine based on camera poses..") 38 | parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") 39 | parser.add_argument('--mesh', type=str, default='', help="template mesh for phase 2") 40 | parser.add_argument('--enable_cam_near_far', action='store_true', help="colmap mode: use the sparse points to estimate camera near far per view.") 41 | parser.add_argument('--enable_cam_center', action='store_true', help="use camera center instead of sparse point center (colmap dataset only)") 42 | parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") 43 | parser.add_argument('--enable_sparse_depth', action='store_true', help="use sparse depth from colmap pts3d, only valid if using --data_formt colmap") 44 | parser.add_argument('--enable_dense_depth', action='store_true', help="use dense depth from omnidepth calibrated to colmap pts3d, only valid if using --data_formt colmap") 45 | 46 | ### training options 47 | parser.add_argument('--iters', type=int, default=30000, help="training iters") 48 | parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate") 49 | parser.add_argument('--lr_vert', type=float, default=1e-4, help="initial learning rate for vert optimization") 50 | parser.add_argument('--pos_gradient_boost', type=float, default=1, help="nvdiffrast option") 51 | parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") 52 | parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") 53 | parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") 54 | parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") 55 | parser.add_argument('--grid_size', type=int, default=128, help="density grid resolution") 56 | parser.add_argument('--mark_untrained', action='store_true', help="mark_untrained grid") 57 | parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 58 | parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") 59 | parser.add_argument('--diffuse_step', type=int, default=1000, help="training iters that only trains diffuse color for better initialization") 60 | parser.add_argument('--diffuse_only', action='store_true', help="only train diffuse color by overriding --diffuse_step") 61 | parser.add_argument('--background', type=str, default='random', choices=['white', 'random'], help="training background mode") 62 | parser.add_argument('--enable_offset_nerf_grad', action='store_true', help="allow grad to pass through nerf to train vertices offsets in stage 1, only work for small meshes (e.g., synthetic dataset)") 63 | parser.add_argument('--n_eval', type=int, default=5, help="eval $ times during training") 64 | parser.add_argument('--n_ckpt', type=int, default=50, help="save $ times during training") 65 | 66 | # batch size related 67 | parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step") 68 | parser.add_argument('--adaptive_num_rays', action='store_true', help="adaptive num rays for more efficient training") 69 | parser.add_argument('--num_points', type=int, default=2 ** 18, help="target num points for each training step, only work with adaptive num_rays") 70 | 71 | # stage 0 regularizations 72 | parser.add_argument('--lambda_density', type=float, default=0, help="loss scale") 73 | parser.add_argument('--lambda_entropy', type=float, default=0, help="loss scale") 74 | parser.add_argument('--lambda_tv', type=float, default=1e-8, help="loss scale") 75 | parser.add_argument('--lambda_depth', type=float, default=0.1, help="loss scale") 76 | parser.add_argument('--lambda_specular', type=float, default=1e-5, help="loss scale") 77 | parser.add_argument('--lambda_eikonal', type=float, default=0.1, help="loss scale") 78 | parser.add_argument('--lambda_rgb', type=float, default=1, help="loss scale") 79 | parser.add_argument('--lambda_mask', type=float, default=0.1, help="loss scale") 80 | 81 | # stage 1 regularizations 82 | parser.add_argument('--wo_smooth', action='store_true', help="disable all smoothness regularizations") 83 | parser.add_argument('--lambda_lpips', type=float, default=0, help="loss scale") 84 | parser.add_argument('--lambda_offsets', type=float, default=0.1, help="loss scale") 85 | parser.add_argument('--lambda_lap', type=float, default=0.001, help="loss scale") 86 | parser.add_argument('--lambda_normal', type=float, default=0, help="loss scale") 87 | parser.add_argument('--lambda_edgelen', type=float, default=0, help="loss scale") 88 | 89 | # unused 90 | parser.add_argument('--contract', action='store_true', help="apply L-INF ray contraction as in mip-nerf, only work for bound > 1, will override bound to 2.") 91 | parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") 92 | parser.add_argument('--trainable_density_grid', action='store_true', help="update density_grid through loss functions, instead of directly update.") 93 | parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") 94 | parser.add_argument('--ind_dim', type=int, default=0, help="individual code dim, 0 to turn off") 95 | parser.add_argument('--ind_num', type=int, default=500, help="number of individual codes, should be larger than training dataset size") 96 | 97 | ### mesh options 98 | # stage 0 99 | parser.add_argument('--mcubes_reso', type=int, default=512, help="resolution for marching cubes") 100 | parser.add_argument('--env_reso', type=int, default=256, help="max layers (resolution) for env mesh") 101 | parser.add_argument('--decimate_target', type=float, default=3e5, help="decimate target for number of triangles, <=0 to disable") 102 | parser.add_argument('--mesh_visibility_culling', action='store_true', help="cull mesh faces based on visibility in training dataset") 103 | parser.add_argument('--visibility_mask_dilation', type=int, default=5, help="visibility dilation") 104 | parser.add_argument('--clean_min_f', type=int, default=8, help="mesh clean: min face count for isolated mesh") 105 | parser.add_argument('--clean_min_d', type=int, default=5, help="mesh clean: min diameter for isolated mesh") 106 | 107 | # stage 1 108 | parser.add_argument('--ssaa', type=int, default=2, help="super sampling anti-aliasing ratio") 109 | parser.add_argument('--texture_size', type=int, default=4096, help="exported texture resolution") 110 | parser.add_argument('--refine', action='store_true', help="track face error and do subdivision") 111 | parser.add_argument("--refine_steps_ratio", type=float, action="append", default=[0.1, 0.2, 0.3, 0.4, 0.5, 0.7]) 112 | parser.add_argument('--refine_size', type=float, default=0.01, help="refine trig length") 113 | parser.add_argument('--refine_decimate_ratio', type=float, default=0.1, help="refine decimate ratio") 114 | parser.add_argument('--refine_remesh_size', type=float, default=0.02, help="remesh trig length") 115 | 116 | ### GUI options 117 | parser.add_argument('--vis_pose', action='store_true', help="visualize the poses") 118 | parser.add_argument('--gui', action='store_true', help="start a GUI") 119 | parser.add_argument('--W', type=int, default=1000, help="GUI width") 120 | parser.add_argument('--H', type=int, default=1000, help="GUI height") 121 | parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") 122 | parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") 123 | parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") 124 | 125 | opt = parser.parse_args() 126 | 127 | opt.cuda_ray = True 128 | 129 | if opt.O: 130 | opt.fp16 = True 131 | opt.preload = True 132 | opt.mark_untrained = True 133 | opt.random_image_batch = True 134 | opt.mesh_visibility_culling = True 135 | opt.adaptive_num_rays = True 136 | opt.refine = True 137 | 138 | if opt.sdf: 139 | # opt.tcnn = True # tcnn supports 2nd order gradient, which is faster than finite difference. 140 | # opt.lambda_tv = 0 # tcnn does not support inplace TV 141 | opt.density_thresh = 0.001 # use smaller thresh to suit density scale from sdf 142 | if opt.stage == 0: 143 | opt.progressive_level = True 144 | 145 | # contract background 146 | if opt.bound > 1: 147 | opt.contract = True 148 | 149 | opt.enable_offset_nerf_grad = True # lead to more sharp texture 150 | 151 | # just perform remesh periodically 152 | opt.refine_decimate_ratio = 0 # disable decimation 153 | opt.refine_size = 0 # disable subdivision 154 | 155 | if opt.contract: 156 | # mark untrained is not very correct in contraction mode... 157 | opt.mark_untrained = False 158 | 159 | # best rendering quality at the sacrifice of mesh quality 160 | if opt.wo_smooth: 161 | opt.lambda_offsets = 0 162 | opt.lambda_lap = 0 163 | opt.lambda_normal = 0 164 | 165 | if opt.enable_sparse_depth: 166 | print(f'[WARN] disable random image batch when depth supervision is used!') 167 | opt.random_image_batch = False 168 | 169 | if opt.patch_size > 1: 170 | # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." 171 | assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." 172 | 173 | if opt.data_format == 'colmap': 174 | from nerf.colmap_provider import ColmapDataset as NeRFDataset 175 | elif opt.data_format == 'dtu': 176 | from nerf.dtu_provider import NeRFDataset 177 | else: # 'nerf 178 | from nerf.provider import NeRFDataset 179 | 180 | # convert ratio to steps 181 | opt.refine_steps = [int(round(x * opt.iters)) for x in opt.refine_steps_ratio] 182 | 183 | seed_everything(opt.seed) 184 | 185 | model = NeRFNetwork(opt) 186 | 187 | criterion = torch.nn.MSELoss(reduction='none') 188 | # criterion = torch.nn.SmoothL1Loss(reduction='none') 189 | 190 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 191 | 192 | if opt.test: 193 | 194 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, use_checkpoint=opt.ckpt) 195 | 196 | if opt.gui: 197 | gui = NeRFGUI(opt, trainer) 198 | gui.render() 199 | 200 | else: 201 | if not opt.test_no_video: 202 | test_loader = NeRFDataset(opt, device=device, type='test').dataloader() 203 | 204 | if test_loader.has_gt: 205 | trainer.metrics = [PSNRMeter(), LPIPSMeter(device=device)] # set up metrics 206 | trainer.evaluate(test_loader) # blender has gt, so evaluate it. 207 | 208 | trainer.test(test_loader, write_video=True) # test and save video 209 | 210 | if not opt.test_no_mesh: 211 | if opt.stage == 1: 212 | trainer.export_stage1(resolution=opt.texture_size) 213 | else: 214 | # need train loader to get camera poses for visibility test 215 | if opt.mesh_visibility_culling: 216 | train_loader = NeRFDataset(opt, device=device, type=opt.train_split).dataloader() 217 | trainer.save_mesh(resolution=opt.mcubes_reso, decimate_target=opt.decimate_target, dataset=train_loader._data if opt.mesh_visibility_culling else None) 218 | 219 | else: 220 | 221 | optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), eps=1e-15) 222 | 223 | train_loader = NeRFDataset(opt, device=device, type=opt.train_split).dataloader() 224 | 225 | max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) 226 | save_interval = max(1, max_epoch // max(opt.n_ckpt, 1)) 227 | eval_interval = max(1, max_epoch // max(opt.n_eval, 1)) 228 | print(f'[INFO] max_epoch {max_epoch}, eval every {eval_interval}, save every {save_interval}.') 229 | 230 | if opt.ind_dim > 0: 231 | assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to at least this number!" 232 | 233 | # colmap can estimate a more compact AABB 234 | if opt.data_format == 'colmap': 235 | model.update_aabb(train_loader._data.pts_aabb) 236 | 237 | # scheduler = lambda optimizer: optim.lr_scheduler.MultiStepLR(optimizer, milestones=[opt.iters // 2, opt.iters * 3 // 4, opt.iters * 9 // 10], gamma=0.33) 238 | # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) 239 | scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.01 + 0.99 * (iter / 500) if iter <= 500 else 0.1 ** ((iter - 500) / (opt.iters - 500))) 240 | 241 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95 if opt.stage == 0 else None, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, use_checkpoint=opt.ckpt, eval_interval=eval_interval, save_interval=save_interval) 242 | 243 | if opt.gui: 244 | gui = NeRFGUI(opt, trainer, train_loader) 245 | gui.render() 246 | 247 | else: 248 | valid_loader = NeRFDataset(opt, device=device, type='val').dataloader() 249 | 250 | trainer.metrics = [PSNRMeter(),] 251 | trainer.train(train_loader, valid_loader, max_epoch) 252 | 253 | # last validation 254 | trainer.metrics = [PSNRMeter(), LPIPSMeter(device=device)] 255 | trainer.evaluate(valid_loader) 256 | 257 | # also test 258 | test_loader = NeRFDataset(opt, device=device, type='test').dataloader() 259 | 260 | if test_loader.has_gt: 261 | trainer.evaluate(test_loader) # blender has gt, so evaluate it. 262 | 263 | trainer.test(test_loader, write_video=True) # test and save video 264 | 265 | if opt.stage == 1: 266 | trainer.export_stage1(resolution=opt.texture_size) 267 | else: 268 | trainer.save_mesh(resolution=opt.mcubes_reso, decimate_target=opt.decimate_target, dataset=train_loader._data if opt.mesh_visibility_culling else None) -------------------------------------------------------------------------------- /meshutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pymeshlab as pml 3 | 4 | def isotropic_explicit_remeshing(verts, faces): 5 | 6 | _ori_vert_shape = verts.shape 7 | _ori_face_shape = faces.shape 8 | 9 | m = pml.Mesh(verts, faces) 10 | ms = pml.MeshSet() 11 | ms.add_mesh(m, 'mesh') # will copy! 12 | 13 | # filters 14 | # ms.apply_coord_taubin_smoothing() 15 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.PercentageValue(1)) 16 | 17 | # extract mesh 18 | m = ms.current_mesh() 19 | verts = m.vertex_matrix() 20 | faces = m.face_matrix() 21 | 22 | print(f'[INFO] isotropic explicit remesh: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 23 | 24 | return verts, faces 25 | 26 | 27 | def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True): 28 | # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. 29 | 30 | _ori_vert_shape = verts.shape 31 | _ori_face_shape = faces.shape 32 | 33 | if backend == 'pyfqmr': 34 | import pyfqmr 35 | solver = pyfqmr.Simplify() 36 | solver.setMesh(verts, faces) 37 | solver.simplify_mesh(target_count=int(target), preserve_border=False, verbose=False) 38 | verts, faces, normals = solver.getMesh() 39 | else: 40 | 41 | m = pml.Mesh(verts, faces) 42 | ms = pml.MeshSet() 43 | ms.add_mesh(m, 'mesh') # will copy! 44 | 45 | # filters 46 | # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1)) 47 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement) 48 | 49 | if remesh: 50 | ms.apply_coord_taubin_smoothing() 51 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.PercentageValue(1)) 52 | 53 | # extract mesh 54 | m = ms.current_mesh() 55 | verts = m.vertex_matrix() 56 | faces = m.face_matrix() 57 | 58 | print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 59 | 60 | return verts, faces 61 | 62 | 63 | def remove_masked_trigs(verts, faces, mask, dilation=5): 64 | # mask: 0 == keep, 1 == remove 65 | 66 | _ori_vert_shape = verts.shape 67 | _ori_face_shape = faces.shape 68 | 69 | m = pml.Mesh(verts, faces, f_scalar_array=mask) # mask as the quality 70 | ms = pml.MeshSet() 71 | ms.add_mesh(m, 'mesh') # will copy! 72 | 73 | # select faces 74 | ms.compute_selection_by_condition_per_face(condselect='fq == 0') # select kept faces 75 | # dilate to aviod holes... 76 | for _ in range(dilation): 77 | ms.apply_selection_dilatation() 78 | ms.apply_selection_inverse(invfaces=True) # invert 79 | 80 | # delete faces 81 | ms.meshing_remove_selected_faces() 82 | 83 | # clean unref verts 84 | ms.meshing_remove_unreferenced_vertices() 85 | 86 | # extract 87 | m = ms.current_mesh() 88 | verts = m.vertex_matrix() 89 | faces = m.face_matrix() 90 | 91 | print(f'[INFO] mesh mask trigs: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 92 | 93 | return verts, faces 94 | 95 | 96 | def remove_masked_verts(verts, faces, mask): 97 | # mask: 0 == keep, 1 == remove 98 | 99 | _ori_vert_shape = verts.shape 100 | _ori_face_shape = faces.shape 101 | 102 | m = pml.Mesh(verts, faces, v_scalar_array=mask) # mask as the quality 103 | ms = pml.MeshSet() 104 | ms.add_mesh(m, 'mesh') # will copy! 105 | 106 | # select verts 107 | ms.compute_selection_by_condition_per_vertex(condselect='q == 1') 108 | 109 | # delete verts and connected faces 110 | ms.meshing_remove_selected_vertices() 111 | 112 | # extract 113 | m = ms.current_mesh() 114 | verts = m.vertex_matrix() 115 | faces = m.face_matrix() 116 | 117 | print(f'[INFO] mesh mask verts: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 118 | 119 | return verts, faces 120 | 121 | 122 | def remove_selected_verts(verts, faces, query='(x < 1) && (x > -1) && (y < 1) && (y > -1) && (z < 1 ) && (z > -1)'): 123 | 124 | _ori_vert_shape = verts.shape 125 | _ori_face_shape = faces.shape 126 | 127 | m = pml.Mesh(verts, faces) 128 | ms = pml.MeshSet() 129 | ms.add_mesh(m, 'mesh') # will copy! 130 | 131 | # select verts 132 | ms.compute_selection_by_condition_per_vertex(condselect=query) 133 | 134 | # delete verts and connected faces 135 | ms.meshing_remove_selected_vertices() 136 | 137 | # extract 138 | m = ms.current_mesh() 139 | verts = m.vertex_matrix() 140 | faces = m.face_matrix() 141 | 142 | print(f'[INFO] mesh remove verts: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 143 | 144 | return verts, faces 145 | 146 | def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True): 147 | # verts: [N, 3] 148 | # faces: [N, 3] 149 | 150 | _ori_vert_shape = verts.shape 151 | _ori_face_shape = faces.shape 152 | 153 | m = pml.Mesh(verts, faces) 154 | ms = pml.MeshSet() 155 | ms.add_mesh(m, 'mesh') # will copy! 156 | 157 | # filters 158 | ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces 159 | 160 | if v_pct > 0: 161 | ms.meshing_merge_close_vertices(threshold=pml.PercentageValue(v_pct)) # 1/10000 of bounding box diagonal 162 | 163 | ms.meshing_remove_duplicate_faces() # faces defined by the same verts 164 | ms.meshing_remove_null_faces() # faces with area == 0 165 | 166 | if min_d > 0: 167 | ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.PercentageValue(min_d)) 168 | 169 | if min_f > 0: 170 | ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) 171 | 172 | if repair: 173 | # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) 174 | ms.meshing_repair_non_manifold_edges(method=0) 175 | ms.meshing_repair_non_manifold_vertices(vertdispratio=0) 176 | 177 | if remesh: 178 | # ms.apply_coord_taubin_smoothing() 179 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.PercentageValue(1)) 180 | 181 | # extract mesh 182 | m = ms.current_mesh() 183 | verts = m.vertex_matrix() 184 | faces = m.face_matrix() 185 | 186 | print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 187 | 188 | return verts, faces 189 | 190 | 191 | def decimate_and_refine_mesh(verts, faces, mask, decimate_ratio=0.1, refine_size=0.01, refine_remesh_size=0.02): 192 | # verts: [N, 3] 193 | # faces: [M, 3] 194 | # mask: [M], 0 denotes do nothing, 1 denotes decimation, 2 denotes subdivision 195 | 196 | _ori_vert_shape = verts.shape 197 | _ori_face_shape = faces.shape 198 | 199 | m = pml.Mesh(verts, faces, f_scalar_array=mask) 200 | ms = pml.MeshSet() 201 | ms.add_mesh(m, 'mesh') # will copy! 202 | 203 | # decimate and remesh 204 | ms.compute_selection_by_condition_per_face(condselect='fq == 1') 205 | if decimate_ratio > 0: 206 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int((1 - decimate_ratio) * (mask == 1).sum()), selected=True) 207 | 208 | if refine_remesh_size > 0: 209 | ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.PureValue(refine_remesh_size), selectedonly=True) 210 | 211 | # repair 212 | ms.set_selection_none(allfaces=True) 213 | ms.meshing_repair_non_manifold_edges(method=0) 214 | ms.meshing_repair_non_manifold_vertices(vertdispratio=0) 215 | 216 | # refine 217 | if refine_size > 0: 218 | ms.compute_selection_by_condition_per_face(condselect='fq == 2') 219 | ms.meshing_surface_subdivision_midpoint(threshold=pml.PureValue(refine_size), selected=True) 220 | 221 | # ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.PureValue(refine_size), selectedonly=True) 222 | 223 | # extract mesh 224 | m = ms.current_mesh() 225 | verts = m.vertex_matrix() 226 | faces = m.face_matrix() 227 | 228 | print(f'[INFO] mesh decimating & subdividing: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') 229 | 230 | return verts, faces 231 | 232 | 233 | # in meshutils.py 234 | def select_bad_and_flat_faces_by_normal(verts, faces, usear=False, aratio=0.02, nfratio_bad=120, nfratio_flat=5): 235 | m = pml.Mesh(verts, faces) 236 | ms = pml.MeshSet() 237 | ms.add_mesh(m, 'mesh') 238 | 239 | ms.compute_selection_bad_faces(usear=usear, aratio=aratio, usenf=True, nfratio=nfratio_bad) 240 | m = ms.current_mesh() 241 | bad_faces_mask = m.face_selection_array() 242 | # print('bad_faces_mask cnt: ', sum(bad_faces_mask * 1.0)) 243 | ms.set_selection_none(allfaces=True) 244 | 245 | ms.compute_selection_bad_faces(usear=usear, aratio=aratio, usenf=True, nfratio=nfratio_flat) 246 | m = ms.current_mesh() 247 | flat_faces_mask = m.face_selection_array() == False # reverse 248 | # print('flat_faces_mask cnt: ', sum(flat_faces_mask * 1.0)) 249 | ms.set_selection_none(allfaces=True) 250 | 251 | return bad_faces_mask, flat_faces_mask 252 | -------------------------------------------------------------------------------- /nerf/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return -------------------------------------------------------------------------------- /nerf/dtu_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import json 5 | import tqdm 6 | import numpy as np 7 | from scipy.spatial.transform import Slerp, Rotation 8 | 9 | import trimesh 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from .utils import get_rays, create_dodecahedron_cameras 15 | 16 | 17 | # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 18 | def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): 19 | pose[:3, 3] = pose[:3, 3] * scale + np.array(offset) 20 | pose = pose.astype(np.float32) 21 | return pose 22 | 23 | def visualize_poses(poses, size=0.1, bound=1): 24 | # poses: [B, 4, 4] 25 | 26 | axes = trimesh.creation.axis(axis_length=4) 27 | box = trimesh.primitives.Box(extents=[2*bound]*3).as_outline() 28 | box.colors = np.array([[128, 128, 128]] * len(box.entities)) 29 | objects = [axes, box] 30 | 31 | for pose in poses: 32 | # a camera is visualized with 8 line segments. 33 | pos = pose[:3, 3] 34 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] 35 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] 36 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] 37 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] 38 | 39 | dir = (a + b + c + d) / 4 - pos 40 | dir = dir / (np.linalg.norm(dir) + 1e-8) 41 | o = pos + dir * 3 42 | 43 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) 44 | segs = trimesh.load_path(segs) 45 | objects.append(segs) 46 | 47 | trimesh.Scene(objects).show() 48 | 49 | def load_K_Rt_from_P(P): 50 | 51 | out = cv2.decomposeProjectionMatrix(P) 52 | K = out[0] 53 | R = out[1] 54 | t = out[2] 55 | 56 | K = K / K[2, 2] 57 | intrinsic = np.array([K[0, 0], K[1, 1], K[0, 2], K[1, 2]]) 58 | 59 | pose = np.eye(4, dtype=np.float32) 60 | pose[:3, :3] = R.transpose() 61 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 62 | 63 | return intrinsic, pose 64 | 65 | class NeRFDataset: 66 | def __init__(self, opt, device, type='train', n_test=10): 67 | super().__init__() 68 | 69 | self.opt = opt 70 | self.device = device 71 | self.type = type # train, val, test 72 | self.downscale = opt.downscale 73 | self.root_path = opt.path 74 | self.preload = opt.preload # preload data into GPU 75 | self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. 76 | self.offset = opt.offset # camera offset 77 | self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. 78 | self.fp16 = opt.fp16 # if preload, load into fp16. 79 | 80 | if self.scale == -1: 81 | print(f'[WARN] --data_format dtu cannot auto-choose --scale, use 1 as default.') 82 | self.scale = 1 83 | 84 | self.training = self.type in ['train', 'all', 'trainval'] 85 | 86 | camera_dict = np.load(os.path.join(self.root_path, 'cameras_sphere.npz')) 87 | image_paths = sorted(glob.glob(os.path.join(self.root_path, 'image', '*.png'))) 88 | mask_paths = sorted(glob.glob(os.path.join(self.root_path, 'mask', '*.png'))) 89 | 90 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(len(image_paths))] 91 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(len(image_paths))] 92 | 93 | intrinsics = [] 94 | poses = [] 95 | 96 | for scale_mat, world_mat in zip(scale_mats, world_mats): 97 | P = world_mat @ scale_mat 98 | P = P[:3, :4] 99 | intrinsic, pose = load_K_Rt_from_P(P) 100 | intrinsics.append(intrinsic) 101 | pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) 102 | poses.append(pose) 103 | 104 | self.intrinsics = torch.from_numpy(np.stack(intrinsics)).float() # [N, 4] 105 | self.poses = np.stack(poses) # [N, 4, 4] 106 | 107 | self.poses[:, :3, 1:3] *= -1 108 | self.poses = self.poses[:, [1, 0, 2, 3], :] 109 | self.poses[:, 2] *= -1 110 | 111 | # we have to actually read an image to get H and W later. 112 | self.H = self.W = None 113 | # make split 114 | if self.type == 'test': 115 | 116 | poses = [] 117 | 118 | if self.opt.camera_traj == 'circle': 119 | 120 | print(f'[INFO] use circular camera traj for testing.') 121 | 122 | # circle 360 pose 123 | # radius = np.linalg.norm(self.poses[:, :3, 3], axis=-1).mean(0) 124 | radius = 0.1 125 | theta = np.deg2rad(80) 126 | for i in range(100): 127 | phi = np.deg2rad(i / 100 * 360) 128 | center = np.array([ 129 | radius * np.sin(theta) * np.sin(phi), 130 | radius * np.sin(theta) * np.cos(phi), 131 | radius * np.cos(theta), 132 | ]) 133 | # look at 134 | def normalize(v): 135 | return v / (np.linalg.norm(v) + 1e-10) 136 | forward_v = normalize(center) 137 | up_v = np.array([0, 0, 1]) 138 | right_v = normalize(np.cross(forward_v, up_v)) 139 | up_v = normalize(np.cross(right_v, forward_v)) 140 | # make pose 141 | pose = np.eye(4) 142 | pose[:3, :3] = np.stack((right_v, up_v, forward_v), axis=-1) 143 | pose[:3, 3] = center 144 | poses.append(pose) 145 | 146 | self.poses = np.stack(poses, axis=0) 147 | 148 | # choose some random poses, and interpolate between. 149 | else: 150 | 151 | fs = np.random.choice(len(self.poses), 5, replace=False) 152 | pose0 = self.poses[fs[0]] 153 | for i in range(1, len(fs)): 154 | pose1 = self.poses[fs[i]] 155 | rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]])) 156 | slerp = Slerp([0, 1], rots) 157 | for i in range(n_test + 1): 158 | ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5 159 | pose = np.eye(4, dtype=np.float32) 160 | pose[:3, :3] = slerp(ratio).as_matrix() 161 | pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3] 162 | poses.append(pose) 163 | pose0 = pose1 164 | 165 | self.poses = np.stack(poses, axis=0) 166 | 167 | # fix intrinsics for test case 168 | self.intrinsics = self.intrinsics[[0]].repeat(self.poses.shape[0], 1) 169 | 170 | self.images = None 171 | self.H = self.W = 512 172 | 173 | # manually split a valid set (the first frame). 174 | else: 175 | if type == 'train': 176 | image_paths = image_paths[1:] 177 | mask_paths = mask_paths[1:] 178 | self.poses = self.poses[1:] 179 | self.intrinsics = self.intrinsics[1:] 180 | elif type == 'val': 181 | image_paths = image_paths[:1] 182 | mask_paths = mask_paths[:1] 183 | self.poses = self.poses[:1] 184 | self.intrinsics = self.intrinsics[:1] 185 | # else 'all' or 'trainval' : use all frames 186 | 187 | # read images 188 | self.images = [] 189 | for i in tqdm.tqdm(range(len(image_paths)), desc=f'Loading {type} data'): 190 | 191 | f_path = image_paths[i] 192 | m_path = mask_paths[i] 193 | 194 | image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] 195 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 196 | 197 | # if use mask, add as an alpha channel 198 | mask = cv2.imread(m_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] 199 | image = np.concatenate([image, mask[..., :1]], axis=-1) 200 | 201 | if self.H is None or self.W is None: 202 | self.H = image.shape[0] // self.downscale 203 | self.W = image.shape[1] // self.downscale 204 | 205 | if image.shape[0] != self.H or image.shape[1] != self.W: 206 | image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA) 207 | 208 | self.images.append(image) 209 | 210 | if self.images is not None: 211 | self.images = torch.from_numpy(np.stack(self.images, axis=0).astype(np.uint8)) # [N, H, W, C] 212 | 213 | # [debug] uncomment to view all training poses. 214 | if self.opt.vis_pose: 215 | visualize_poses(self.poses, bound=self.opt.bound) 216 | 217 | self.poses = torch.from_numpy(self.poses.astype(np.float32)) # [N, 4, 4] 218 | 219 | # perspective projection matrix 220 | self.near = self.opt.min_near 221 | self.far = 1000 # infinite 222 | aspect = self.W / self.H 223 | 224 | projections = [] 225 | for intrinsic in self.intrinsics: 226 | y = self.H / (2.0 * intrinsic[1].item()) # fl_y 227 | projections.append(np.array([[1/(y*aspect), 0, 0, 0], 228 | [0, -1/y, 0, 0], 229 | [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], 230 | [0, 0, -1, 0]], dtype=np.float32)) 231 | self.projections = torch.from_numpy(np.stack(projections)) # [N, 4, 4] 232 | self.mvps = self.projections @ torch.inverse(self.poses) 233 | 234 | # tmp: dodecahedron_cameras for mesh visibility test 235 | dodecahedron_poses = create_dodecahedron_cameras() 236 | # visualize_poses(dodecahedron_poses, bound=self.opt.bound, points=self.pts3d) 237 | self.dodecahedron_poses = torch.from_numpy(dodecahedron_poses.astype(np.float32)) # [N, 4, 4] 238 | self.dodecahedron_mvps = self.projections[[0]] @ torch.inverse(self.dodecahedron_poses) # assume the same intrinsic 239 | 240 | if self.preload: 241 | self.intrinsics = self.intrinsics.to(self.device) 242 | self.poses = self.poses.to(self.device) 243 | if self.images is not None: 244 | self.images = self.images.to(self.device) 245 | self.mvps = self.mvps.to(self.device) 246 | 247 | 248 | def collate(self, index): 249 | 250 | B = len(index) # a list of length 1 251 | 252 | results = {'H': self.H, 'W': self.W} 253 | 254 | if self.training and self.opt.stage == 0: 255 | # randomly sample over images too 256 | num_rays = self.opt.num_rays 257 | 258 | if self.opt.random_image_batch: 259 | index = torch.randint(0, len(self.poses), size=(num_rays,), device=self.device) 260 | 261 | else: 262 | num_rays = -1 263 | 264 | poses = self.poses[index].to(self.device) # [N, 4, 4] 265 | intrinsics = self.intrinsics[index].to(self.device) # [1/N, 4] 266 | 267 | rays = get_rays(poses, intrinsics, self.H, self.W, num_rays, self.opt.patch_size) 268 | 269 | results['rays_o'] = rays['rays_o'] 270 | results['rays_d'] = rays['rays_d'] 271 | results['index'] = index 272 | 273 | if self.opt.stage > 0: 274 | mvp = self.mvps[index].to(self.device) 275 | results['mvp'] = mvp 276 | 277 | if self.images is not None: 278 | 279 | if self.training and self.opt.stage == 0: 280 | images = self.images[index, rays['j'], rays['i']].float().to(self.device) / 255 # [N, 3/4] 281 | else: 282 | images = self.images[index].squeeze(0).float().to(self.device) / 255 # [H, W, 3/4] 283 | 284 | if self.training: 285 | C = self.images.shape[-1] 286 | images = images.view(-1, C) 287 | 288 | results['images'] = images 289 | 290 | return results 291 | 292 | def dataloader(self): 293 | size = len(self.poses) 294 | loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) 295 | loader._data = self # an ugly fix... we need to access error_map & poses in trainer. 296 | loader.has_gt = self.images is not None 297 | return loader -------------------------------------------------------------------------------- /nerf/gui.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import dearpygui.dearpygui as dpg 5 | from scipy.spatial.transform import Rotation as R 6 | 7 | from .utils import * 8 | 9 | 10 | class OrbitCamera: 11 | def __init__(self, W, H, r=2, fovy=60, near=0.1, far=1000): 12 | self.W = W 13 | self.H = H 14 | self.radius = r # camera distance from center 15 | self.fovy = fovy # in degree 16 | self.near = near 17 | self.far = far 18 | self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point 19 | self.rot = R.from_matrix(np.eye(3)) 20 | self.up = np.array([0, 0, 1], dtype=np.float32) # need to be normalized! 21 | 22 | # pose 23 | @property 24 | def pose(self): 25 | # first move camera to radius 26 | res = np.eye(4, dtype=np.float32) 27 | res[2, 3] = self.radius # opengl convention... 28 | # rotate 29 | rot = np.eye(4, dtype=np.float32) 30 | rot[:3, :3] = self.rot.as_matrix() 31 | res = rot @ res 32 | # translate 33 | res[:3, 3] -= self.center 34 | return res 35 | 36 | # view 37 | @property 38 | def view(self): 39 | return np.linalg.inv(self.pose) 40 | 41 | # intrinsics 42 | @property 43 | def intrinsics(self): 44 | focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2)) 45 | return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) 46 | 47 | # projection (perspective) 48 | @property 49 | def perspective(self): 50 | y = np.tan(np.radians(self.fovy) / 2) 51 | aspect = self.W / self.H 52 | return np.array([[1/(y*aspect), 0, 0, 0], 53 | [ 0, -1/y, 0, 0], 54 | [ 0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], 55 | [ 0, 0, -1, 0]], dtype=np.float32) 56 | 57 | 58 | def orbit(self, dx, dy): 59 | # rotate along camera up/side axis! 60 | side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. 61 | rotvec_x = self.up * np.radians(-0.05 * dx) 62 | rotvec_y = side * np.radians(-0.05 * dy) 63 | self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot 64 | 65 | def scale(self, delta): 66 | self.radius *= 1.1 ** (-delta) 67 | 68 | def pan(self, dx, dy, dz=0): 69 | # pan in camera coordinate system (careful on the sensitivity!) 70 | self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz]) 71 | 72 | 73 | class NeRFGUI: 74 | def __init__(self, opt, trainer, train_loader=None, debug=True): 75 | self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. 76 | self.W = opt.W 77 | self.H = opt.H 78 | self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) 79 | self.debug = debug 80 | self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg 81 | self.training = False 82 | self.step = 0 # training step 83 | 84 | self.trainer = trainer 85 | self.train_loader = train_loader 86 | 87 | self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) 88 | self.need_update = True # camera moved, should reset accumulation 89 | self.spp = 1 # sample per pixel 90 | self.mode = 'image' # choose from ['image', 'depth'] 91 | self.shading = 'full' 92 | 93 | self.dynamic_resolution = True if self.opt.stage == 0 else False 94 | self.downscale = 1 95 | self.train_steps = 16 96 | 97 | dpg.create_context() 98 | self.register_dpg() 99 | self.test_step() 100 | 101 | 102 | def __del__(self): 103 | dpg.destroy_context() 104 | 105 | 106 | def train_step(self): 107 | 108 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 109 | starter.record() 110 | 111 | outputs = self.trainer.train_gui(self.train_loader, step=self.train_steps) 112 | 113 | ender.record() 114 | torch.cuda.synchronize() 115 | t = starter.elapsed_time(ender) 116 | 117 | self.step += self.train_steps 118 | self.need_update = True 119 | 120 | dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') 121 | dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') 122 | 123 | # dynamic train steps 124 | # max allowed train time per-frame is 500 ms 125 | full_t = t / self.train_steps * 16 126 | train_steps = min(16, max(4, int(16 * 500 / full_t))) 127 | if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: 128 | self.train_steps = train_steps 129 | 130 | def prepare_buffer(self, outputs): 131 | if self.mode == 'image': 132 | return outputs['image'].astype(np.float32) 133 | else: 134 | depth = outputs['depth'].astype(np.float32) 135 | depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) 136 | return np.expand_dims(depth, -1).repeat(3, -1) 137 | 138 | 139 | def test_step(self): 140 | 141 | if self.need_update or self.spp < self.opt.max_spp: 142 | 143 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 144 | starter.record() 145 | 146 | # mvp 147 | mv = torch.from_numpy(self.cam.view).cuda() # [4, 4] 148 | proj = torch.from_numpy(self.cam.perspective).cuda() # [4, 4] 149 | mvp = proj @ mv 150 | 151 | outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.shading) 152 | 153 | ender.record() 154 | torch.cuda.synchronize() 155 | t = starter.elapsed_time(ender) 156 | 157 | # update dynamic resolution 158 | if self.dynamic_resolution: 159 | # max allowed infer time per-frame is 200 ms 160 | full_t = t / (self.downscale ** 2) 161 | downscale = min(1, max(1/4, math.sqrt(200 / full_t))) 162 | if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: 163 | self.downscale = downscale 164 | 165 | if self.need_update: 166 | self.render_buffer = self.prepare_buffer(outputs) 167 | self.spp = 1 168 | self.need_update = False 169 | else: 170 | self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) 171 | self.spp += 1 172 | 173 | dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') 174 | dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') 175 | dpg.set_value("_log_spp", self.spp) 176 | dpg.set_value("_texture", self.render_buffer) 177 | 178 | 179 | def register_dpg(self): 180 | 181 | ### register texture 182 | 183 | with dpg.texture_registry(show=False): 184 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") 185 | 186 | ### register window 187 | 188 | # the rendered image, as the primary window 189 | with dpg.window(tag="_primary_window", width=self.W, height=self.H): 190 | 191 | # add the texture 192 | dpg.add_image("_texture") 193 | 194 | dpg.set_primary_window("_primary_window", True) 195 | 196 | # control window 197 | with dpg.window(label="Control", tag="_control_window", width=400, height=300): 198 | 199 | # button theme 200 | with dpg.theme() as theme_button: 201 | with dpg.theme_component(dpg.mvButton): 202 | dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) 203 | dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) 204 | dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) 205 | dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) 206 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) 207 | 208 | # time 209 | if not self.opt.test: 210 | with dpg.group(horizontal=True): 211 | dpg.add_text("Train time: ") 212 | dpg.add_text("no data", tag="_log_train_time") 213 | 214 | with dpg.group(horizontal=True): 215 | dpg.add_text("Infer time: ") 216 | dpg.add_text("no data", tag="_log_infer_time") 217 | 218 | with dpg.group(horizontal=True): 219 | dpg.add_text("SPP: ") 220 | dpg.add_text("1", tag="_log_spp") 221 | 222 | # train button 223 | if not self.opt.test: 224 | with dpg.collapsing_header(label="Train", default_open=True): 225 | 226 | # train / stop 227 | with dpg.group(horizontal=True): 228 | dpg.add_text("Train: ") 229 | 230 | def callback_train(sender, app_data): 231 | if self.training: 232 | self.training = False 233 | dpg.configure_item("_button_train", label="start") 234 | else: 235 | self.training = True 236 | dpg.configure_item("_button_train", label="stop") 237 | 238 | dpg.add_button(label="start", tag="_button_train", callback=callback_train) 239 | dpg.bind_item_theme("_button_train", theme_button) 240 | 241 | def callback_reset(sender, app_data): 242 | @torch.no_grad() 243 | def weight_reset(m: nn.Module): 244 | reset_parameters = getattr(m, "reset_parameters", None) 245 | if callable(reset_parameters): 246 | m.reset_parameters() 247 | self.trainer.model.apply(fn=weight_reset) 248 | self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter 249 | self.need_update = True 250 | 251 | dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) 252 | dpg.bind_item_theme("_button_reset", theme_button) 253 | 254 | # save ckpt 255 | with dpg.group(horizontal=True): 256 | dpg.add_text("Checkpoint: ") 257 | 258 | def callback_save(sender, app_data): 259 | self.trainer.save_checkpoint(full=True, best=False) 260 | dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) 261 | self.trainer.epoch += 1 # use epoch to indicate different calls. 262 | 263 | dpg.add_button(label="save", tag="_button_save", callback=callback_save) 264 | dpg.bind_item_theme("_button_save", theme_button) 265 | 266 | dpg.add_text("", tag="_log_ckpt") 267 | 268 | with dpg.group(horizontal=True): 269 | dpg.add_text("", tag="_log_train_log") 270 | 271 | 272 | # rendering options 273 | with dpg.collapsing_header(label="Options", default_open=True): 274 | 275 | # # binary 276 | # def callback_set_binary(sender, app_data): 277 | # if self.opt.binary: 278 | # self.opt.binary = False 279 | # else: 280 | # self.opt.binary = True 281 | # self.need_update = True 282 | 283 | # dpg.add_checkbox(label="binary", default_value=self.opt.binary, callback=callback_set_binary) 284 | 285 | 286 | # dynamic rendering resolution 287 | with dpg.group(horizontal=True): 288 | 289 | def callback_set_dynamic_resolution(sender, app_data): 290 | if self.dynamic_resolution: 291 | self.dynamic_resolution = False 292 | self.downscale = 1 293 | else: 294 | self.dynamic_resolution = True 295 | self.need_update = True 296 | 297 | dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) 298 | dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") 299 | 300 | # mode combo 301 | def callback_change_mode(sender, app_data): 302 | self.mode = app_data 303 | self.need_update = True 304 | 305 | dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) 306 | 307 | # shading combo 308 | def callback_change_shading(sender, app_data): 309 | self.shading = app_data 310 | self.need_update = True 311 | 312 | dpg.add_combo(('full', 'diffuse', 'specular'), label='shading', default_value=self.shading, callback=callback_change_shading) 313 | 314 | # bg_color picker 315 | def callback_change_bg(sender, app_data): 316 | self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] 317 | self.need_update = True 318 | 319 | dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) 320 | 321 | # fov slider 322 | def callback_set_fovy(sender, app_data): 323 | self.cam.fovy = app_data 324 | self.need_update = True 325 | 326 | dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) 327 | 328 | # dt_gamma slider 329 | def callback_set_dt_gamma(sender, app_data): 330 | self.opt.dt_gamma = app_data 331 | self.need_update = True 332 | 333 | dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) 334 | 335 | # max_steps slider 336 | def callback_set_max_steps(sender, app_data): 337 | self.opt.max_steps = app_data 338 | self.need_update = True 339 | 340 | dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) 341 | 342 | # aabb slider 343 | def callback_set_aabb(sender, app_data, user_data): 344 | # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) 345 | self.trainer.model.aabb_infer[user_data] = app_data 346 | 347 | # also change train aabb ? [better not...] 348 | #self.trainer.model.aabb_train[user_data] = app_data 349 | 350 | self.need_update = True 351 | 352 | dpg.add_separator() 353 | dpg.add_text("Axis-aligned bounding box:") 354 | 355 | with dpg.group(horizontal=True): 356 | dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) 357 | dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) 358 | 359 | with dpg.group(horizontal=True): 360 | dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) 361 | dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) 362 | 363 | with dpg.group(horizontal=True): 364 | dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) 365 | dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) 366 | 367 | 368 | # debug info 369 | if self.debug: 370 | with dpg.collapsing_header(label="Debug"): 371 | # pose 372 | dpg.add_separator() 373 | dpg.add_text("Camera Pose:") 374 | dpg.add_text(str(self.cam.pose), tag="_log_pose") 375 | 376 | 377 | ### register camera handler 378 | 379 | def callback_camera_drag_rotate(sender, app_data): 380 | 381 | if not dpg.is_item_focused("_primary_window"): 382 | return 383 | 384 | dx = app_data[1] 385 | dy = app_data[2] 386 | 387 | self.cam.orbit(dx, dy) 388 | self.need_update = True 389 | 390 | if self.debug: 391 | dpg.set_value("_log_pose", str(self.cam.pose)) 392 | 393 | 394 | def callback_camera_wheel_scale(sender, app_data): 395 | 396 | if not dpg.is_item_focused("_primary_window"): 397 | return 398 | 399 | delta = app_data 400 | 401 | self.cam.scale(delta) 402 | self.need_update = True 403 | 404 | if self.debug: 405 | dpg.set_value("_log_pose", str(self.cam.pose)) 406 | 407 | 408 | def callback_camera_drag_pan(sender, app_data): 409 | 410 | if not dpg.is_item_focused("_primary_window"): 411 | return 412 | 413 | dx = app_data[1] 414 | dy = app_data[2] 415 | 416 | self.cam.pan(dx, dy) 417 | self.need_update = True 418 | 419 | if self.debug: 420 | dpg.set_value("_log_pose", str(self.cam.pose)) 421 | 422 | 423 | with dpg.handler_registry(): 424 | dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) 425 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) 426 | dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan) 427 | 428 | 429 | dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) 430 | 431 | ### global theme 432 | with dpg.theme() as theme_no_padding: 433 | with dpg.theme_component(dpg.mvAll): 434 | # set all padding to 0 to avoid scroll bar 435 | dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) 436 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) 437 | dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) 438 | 439 | dpg.bind_item_theme("_primary_window", theme_no_padding) 440 | 441 | dpg.setup_dearpygui() 442 | 443 | #dpg.show_metrics() 444 | 445 | dpg.show_viewport() 446 | 447 | 448 | def render(self): 449 | 450 | while dpg.is_dearpygui_running(): 451 | # update texture every frame 452 | if self.training: 453 | self.train_step() 454 | self.test_step() 455 | dpg.render_dearpygui_frame() -------------------------------------------------------------------------------- /nerf/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from encoding import get_encoder 7 | from activation import trunc_exp 8 | from .renderer import NeRFRenderer 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, geom_init=False, weight_norm=False): 12 | super().__init__() 13 | self.dim_in = dim_in 14 | self.dim_out = dim_out 15 | self.dim_hidden = dim_hidden 16 | self.num_layers = num_layers 17 | self.geom_init = geom_init 18 | 19 | net = [] 20 | for l in range(num_layers): 21 | 22 | in_dim = self.dim_in if l == 0 else self.dim_hidden 23 | out_dim = self.dim_out if l == num_layers - 1 else self.dim_hidden 24 | 25 | net.append(nn.Linear(in_dim, out_dim, bias=bias)) 26 | 27 | if geom_init: 28 | if l == num_layers - 1: 29 | torch.nn.init.normal_(net[l].weight, mean=math.sqrt(math.pi) / math.sqrt(in_dim), std=1e-4) 30 | if bias: torch.nn.init.constant_(net[l].bias, -0.5) # sphere init (very important for hashgrid encoding!) 31 | 32 | elif l == 0: 33 | torch.nn.init.normal_(net[l].weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(out_dim)) 34 | torch.nn.init.constant_(net[l].weight[:, 3:], 0.0) 35 | if bias: torch.nn.init.constant_(net[l].bias, 0.0) 36 | 37 | else: 38 | torch.nn.init.normal_(net[l].weight, 0.0, math.sqrt(2) / math.sqrt(out_dim)) 39 | if bias: torch.nn.init.constant_(net[l].bias, 0.0) 40 | 41 | if weight_norm: 42 | net[l] = nn.utils.weight_norm(net[l]) 43 | 44 | self.net = nn.ModuleList(net) 45 | 46 | def forward(self, x): 47 | for l in range(self.num_layers): 48 | x = self.net[l](x) 49 | if l != self.num_layers - 1: 50 | if self.geom_init: 51 | x = F.softplus(x, beta=100) 52 | else: 53 | x = F.relu(x, inplace=True) 54 | return x 55 | 56 | 57 | class NeRFNetwork(NeRFRenderer): 58 | def __init__(self, 59 | opt, 60 | specular_dim=3, 61 | ): 62 | 63 | super().__init__(opt) 64 | 65 | # density network 66 | self.encoder, self.in_dim_density = get_encoder("hashgrid_tcnn" if self.opt.tcnn else "hashgrid", level_dim=1, desired_resolution=2048 * self.bound, interpolation='linear') 67 | # self.sigma_net = MLP(3 + self.in_dim_density, 1, 32, 2, bias=self.opt.sdf, geom_init=self.opt.sdf, weight_norm=self.opt.sdf) 68 | self.sigma_net = MLP(3 + self.in_dim_density, 1, 32, 2, bias=False) 69 | 70 | # color network 71 | self.encoder_color, self.in_dim_color = get_encoder("hashgrid_tcnn" if self.opt.tcnn else "hashgrid", level_dim=2, desired_resolution=2048 * self.bound, interpolation='linear') 72 | self.color_net = MLP(3 + self.in_dim_color + self.individual_dim, 3 + specular_dim, 64, 3, bias=False) 73 | 74 | self.encoder_dir, self.in_dim_dir = get_encoder("None") 75 | self.specular_net = MLP(specular_dim + self.in_dim_dir, 3, 32, 2, bias=False) 76 | 77 | # sdf 78 | if self.opt.sdf: 79 | self.register_parameter('variance', nn.Parameter(torch.tensor(0.3, dtype=torch.float32))) 80 | 81 | def forward(self, x, d, c=None, shading='full'): 82 | # x: [N, 3], in [-bound, bound] 83 | # d: [N, 3], nomalized in [-1, 1] 84 | # c: [1/N, individual_dim] 85 | 86 | sigma = self.density(x)['sigma'] 87 | color, specular = self.rgb(x, d, c, shading) 88 | 89 | return sigma, color, specular 90 | 91 | 92 | def density(self, x): 93 | 94 | # sigma 95 | h = self.encoder(x, bound=self.bound, max_level=self.max_level) 96 | h = torch.cat([x, h], dim=-1) 97 | h = self.sigma_net(h) 98 | 99 | results = {} 100 | 101 | if self.opt.sdf: 102 | sigma = h[..., 0].float() # sdf 103 | else: 104 | sigma = trunc_exp(h[..., 0]) 105 | 106 | results['sigma'] = sigma 107 | 108 | return results 109 | 110 | # init the sdf to two spheres by pretraining, assume view cameras fall between the spheres 111 | def init_double_sphere(self, r1=0.5, r2=1.5, iters=8192, batch_size=8192): 112 | # sphere init is only for sdf mode! 113 | if not self.opt.sdf: 114 | return 115 | # import kiui 116 | import tqdm 117 | loss_fn = torch.nn.MSELoss() 118 | optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-3) 119 | pbar = tqdm.trange(iters, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') 120 | for _ in range(iters): 121 | # random points inside [-b, b]^3 122 | xyzs = torch.rand(batch_size, 3, device='cuda') * 2 * self.bound - self.bound 123 | d = torch.norm(xyzs, p=2, dim=-1) 124 | gt_sdf = torch.where(d < (r1 + r2) / 2, d - r1, r2 - d) 125 | # kiui.lo(xyzs, gt_sdf) 126 | pred_sdf = self.density(xyzs)['sigma'] 127 | loss = loss_fn(pred_sdf, gt_sdf) 128 | optimizer.zero_grad() 129 | loss.backward() 130 | optimizer.step() 131 | pbar.set_description(f'pretrain sdf loss={loss.item():.8f}') 132 | pbar.update(1) 133 | 134 | # finite difference 135 | def normal(self, x, epsilon=1e-4): 136 | 137 | if self.opt.tcnn: 138 | with torch.enable_grad(): 139 | x.requires_grad_(True) 140 | sigma = self.density(x)['sigma'] 141 | normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] 142 | else: 143 | dx_pos = self.density((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))['sigma'] 144 | dx_neg = self.density((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))['sigma'] 145 | dy_pos = self.density((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))['sigma'] 146 | dy_neg = self.density((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))['sigma'] 147 | dz_pos = self.density((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))['sigma'] 148 | dz_neg = self.density((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))['sigma'] 149 | 150 | normal = torch.stack([ 151 | 0.5 * (dx_pos - dx_neg) / epsilon, 152 | 0.5 * (dy_pos - dy_neg) / epsilon, 153 | 0.5 * (dz_pos - dz_neg) / epsilon 154 | ], dim=-1) 155 | 156 | return normal 157 | 158 | 159 | def geo_feat(self, x, c=None): 160 | 161 | h = self.encoder_color(x, bound=self.bound, max_level=self.max_level) 162 | h = torch.cat([x, h], dim=-1) 163 | if c is not None: 164 | h = torch.cat([h, c.repeat(x.shape[0], 1) if c.shape[0] == 1 else c], dim=-1) 165 | h = self.color_net(h) 166 | geo_feat = torch.sigmoid(h) 167 | 168 | return geo_feat 169 | 170 | 171 | def rgb(self, x, d, c=None, shading='full'): 172 | 173 | # color 174 | geo_feat = self.geo_feat(x, c) 175 | diffuse = geo_feat[..., :3] 176 | 177 | if shading == 'diffuse': 178 | color = diffuse 179 | specular = None 180 | else: 181 | d = self.encoder_dir(d) 182 | specular = self.specular_net(torch.cat([d, geo_feat[..., 3:]], dim=-1)) 183 | specular = torch.sigmoid(specular) 184 | if shading == 'specular': 185 | color = specular 186 | else: # full 187 | color = (specular + diffuse).clamp(0, 1) # specular + albedo 188 | 189 | return color, specular 190 | 191 | 192 | # optimizer utils 193 | def get_params(self, lr): 194 | 195 | params = super().get_params(lr) 196 | 197 | params.extend([ 198 | {'params': self.encoder.parameters(), 'lr': lr}, 199 | {'params': self.encoder_color.parameters(), 'lr': lr}, 200 | {'params': self.sigma_net.parameters(), 'lr': lr}, 201 | {'params': self.color_net.parameters(), 'lr': lr}, 202 | {'params': self.specular_net.parameters(), 'lr': lr}, 203 | ]) 204 | 205 | if self.opt.sdf: 206 | params.append({'params': self.variance, 'lr': lr * 0.1}) 207 | 208 | return params -------------------------------------------------------------------------------- /nerf/provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import json 5 | import tqdm 6 | import numpy as np 7 | from scipy.spatial.transform import Slerp, Rotation 8 | 9 | import trimesh 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from .utils import get_rays, create_dodecahedron_cameras 15 | 16 | def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): 17 | pose[:3, 3] = pose[:3, 3] * scale + np.array(offset) 18 | pose = pose.astype(np.float32) 19 | return pose 20 | 21 | def visualize_poses(poses, size=0.1, bound=1): 22 | # poses: [B, 4, 4] 23 | 24 | axes = trimesh.creation.axis(axis_length=4) 25 | box = trimesh.primitives.Box(extents=[2*bound]*3).as_outline() 26 | box.colors = np.array([[128, 128, 128]] * len(box.entities)) 27 | objects = [axes, box] 28 | 29 | if bound > 1: 30 | unit_box = trimesh.primitives.Box(extents=[2]*3).as_outline() 31 | unit_box.colors = np.array([[128, 128, 128]] * len(unit_box.entities)) 32 | objects.append(unit_box) 33 | 34 | for pose in poses: 35 | # a camera is visualized with 8 line segments. 36 | pos = pose[:3, 3] 37 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] 38 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] 39 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] 40 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] 41 | 42 | dir = (a + b + c + d) / 4 - pos 43 | dir = dir / (np.linalg.norm(dir) + 1e-8) 44 | o = pos + dir * 3 45 | 46 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) 47 | segs = trimesh.load_path(segs) 48 | objects.append(segs) 49 | 50 | trimesh.Scene(objects).show() 51 | 52 | 53 | def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2*np.pi/3], phi_range=[0, 2*np.pi]): 54 | ''' generate random poses from an orbit camera 55 | Args: 56 | size: batch size of generated poses. 57 | device: where to allocate the output. 58 | radius: camera radius 59 | theta_range: [min, max], should be in [0, \pi] 60 | phi_range: [min, max], should be in [0, 2\pi] 61 | Return: 62 | poses: [size, 4, 4] 63 | ''' 64 | 65 | def normalize(vectors): 66 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) 67 | 68 | thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] 69 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 70 | 71 | centers = torch.stack([ 72 | radius * torch.sin(thetas) * torch.sin(phis), 73 | radius * torch.cos(thetas), 74 | radius * torch.sin(thetas) * torch.cos(phis), 75 | ], dim=-1) # [B, 3] 76 | 77 | # lookat 78 | forward_vector = - normalize(centers) 79 | up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system... 80 | right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) 81 | up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) 82 | 83 | poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) 84 | poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) 85 | poses[:, :3, 3] = centers 86 | 87 | return poses 88 | 89 | 90 | class NeRFDataset: 91 | def __init__(self, opt, device, type='train', n_test=10): 92 | super().__init__() 93 | 94 | self.opt = opt 95 | self.device = device 96 | self.type = type # train, val, test 97 | self.downscale = opt.downscale 98 | self.root_path = opt.path 99 | self.preload = opt.preload # preload data into GPU 100 | self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. 101 | self.offset = opt.offset # camera offset 102 | self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. 103 | self.fp16 = opt.fp16 # if preload, load into fp16. 104 | 105 | if self.scale == -1: 106 | print(f'[WARN] --data_format nerf cannot auto-choose --scale, use 1 as default.') 107 | self.scale = 1 108 | 109 | self.training = self.type in ['train', 'all', 'trainval'] 110 | 111 | # auto-detect transforms.json and split mode. 112 | if os.path.exists(os.path.join(self.root_path, 'transforms.json')): 113 | self.mode = 'colmap' # manually split, use view-interpolation for test. 114 | elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')): 115 | self.mode = 'blender' # provided split 116 | else: 117 | raise NotImplementedError(f'[NeRFDataset] Cannot find transforms*.json under {self.root_path}') 118 | 119 | # load nerf-compatible format data. 120 | if self.mode == 'colmap': 121 | with open(os.path.join(self.root_path, 'transforms.json'), 'r') as f: 122 | transform = json.load(f) 123 | elif self.mode == 'blender': 124 | # load all splits (train/valid/test), this is what instant-ngp in fact does... 125 | if type == 'all': 126 | transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) 127 | transform = None 128 | for transform_path in transform_paths: 129 | with open(transform_path, 'r') as f: 130 | tmp_transform = json.load(f) 131 | if transform is None: 132 | transform = tmp_transform 133 | else: 134 | transform['frames'].extend(tmp_transform['frames']) 135 | # load train and val split 136 | elif type == 'trainval': 137 | with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: 138 | transform = json.load(f) 139 | with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: 140 | transform_val = json.load(f) 141 | transform['frames'].extend(transform_val['frames']) 142 | # only load one specified split 143 | else: 144 | with open(os.path.join(self.root_path, f'transforms_{type}.json'), 'r') as f: 145 | transform = json.load(f) 146 | 147 | else: 148 | raise NotImplementedError(f'unknown dataset mode: {self.mode}') 149 | 150 | # load image size 151 | if 'h' in transform and 'w' in transform: 152 | self.H = int(transform['h']) // self.downscale 153 | self.W = int(transform['w']) // self.downscale 154 | else: 155 | # we have to actually read an image to get H and W later. 156 | self.H = self.W = None 157 | 158 | # read images 159 | frames = np.array(transform["frames"]) 160 | 161 | # tmp: if time in frames (dynamic scene), only load time == 0 162 | if 'time' in frames[0]: 163 | frames = np.array([f for f in frames if f['time'] == 0]) 164 | print(f'[INFO] selecting time == 0 frames: {len(transform["frames"])} --> {len(frames)}') 165 | 166 | 167 | # for colmap, manually interpolate a test set. 168 | if self.mode == 'colmap' and type == 'test': 169 | 170 | # choose two random poses, and interpolate between. 171 | f0, f1 = np.random.choice(frames, 2, replace=False) 172 | pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] 173 | pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] 174 | rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]])) 175 | slerp = Slerp([0, 1], rots) 176 | 177 | self.poses = [] 178 | self.images = None 179 | for i in range(n_test + 1): 180 | ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5 181 | pose = np.eye(4, dtype=np.float32) 182 | pose[:3, :3] = slerp(ratio).as_matrix() 183 | pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3] 184 | self.poses.append(pose) 185 | 186 | else: 187 | # for colmap, manually split a valid set (the first frame). 188 | if self.mode == 'colmap': 189 | if type == 'train': 190 | frames = frames[1:] 191 | elif type == 'val': 192 | frames = frames[:1] 193 | # else 'all' or 'trainval' : use all frames 194 | 195 | self.poses = [] 196 | self.images = [] 197 | for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): 198 | f_path = os.path.join(self.root_path, f['file_path']) 199 | if self.mode == 'blender' and '.' not in os.path.basename(f_path): 200 | f_path += '.png' # so silly... 201 | 202 | # there are non-exist paths in fox... 203 | if not os.path.exists(f_path): 204 | print(f'[WARN] {f_path} not exists!') 205 | continue 206 | 207 | pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] 208 | pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) 209 | 210 | image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] 211 | if self.H is None or self.W is None: 212 | self.H = image.shape[0] // self.downscale 213 | self.W = image.shape[1] // self.downscale 214 | 215 | # add support for the alpha channel as a mask. 216 | if image.shape[-1] == 3: 217 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 218 | else: 219 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) 220 | 221 | # if mask is available, load as the alpha channel 222 | m_path = f_path.replace('images', 'mask') 223 | if os.path.exists(m_path): 224 | mask = cv2.imread(m_path, cv2.IMREAD_UNCHANGED) # [H, W] 225 | if len(mask.shape) == 2: 226 | mask = mask[..., None] 227 | image = np.concatenate([image, mask[..., :1]], axis=-1) 228 | 229 | if image.shape[0] != self.H or image.shape[1] != self.W: 230 | image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA) 231 | 232 | self.poses.append(pose) 233 | self.images.append(image) 234 | 235 | self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4] 236 | if self.images is not None: 237 | self.images = torch.from_numpy(np.stack(self.images, axis=0).astype(np.uint8)) # [N, H, W, C] 238 | 239 | # calculate mean radius of all camera poses 240 | self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() 241 | #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') 242 | 243 | # [debug] uncomment to view all training poses. 244 | if self.opt.vis_pose: 245 | visualize_poses(self.poses.numpy(), bound=self.opt.bound) 246 | 247 | # load intrinsics 248 | if 'fl_x' in transform or 'fl_y' in transform: 249 | fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / self.downscale 250 | fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / self.downscale 251 | elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: 252 | # blender, assert in radians. already downscaled since we use H/W 253 | fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None 254 | fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None 255 | if fl_x is None: fl_x = fl_y 256 | if fl_y is None: fl_y = fl_x 257 | else: 258 | raise RuntimeError('Failed to load focal length, please check the transforms.json!') 259 | 260 | cx = (transform['cx'] / self.downscale) if 'cx' in transform else (self.W / 2.0) 261 | cy = (transform['cy'] / self.downscale) if 'cy' in transform else (self.H / 2.0) 262 | 263 | self.intrinsics = np.array([fl_x, fl_y, cx, cy]) 264 | 265 | # perspective projection matrix 266 | self.near = self.opt.min_near 267 | self.far = 1000 # infinite 268 | y = self.H / (2.0 * fl_y) 269 | aspect = self.W / self.H 270 | self.projection = np.array([[1/(y*aspect), 0, 0, 0], 271 | [0, -1/y, 0, 0], 272 | [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], 273 | [0, 0, -1, 0]], dtype=np.float32) 274 | 275 | self.projection = torch.from_numpy(self.projection) 276 | self.mvps = self.projection.unsqueeze(0) @ torch.inverse(self.poses) 277 | 278 | # tmp: dodecahedron_cameras for mesh visibility test 279 | dodecahedron_poses = create_dodecahedron_cameras() 280 | # visualize_poses(dodecahedron_poses, bound=self.opt.bound, points=self.pts3d) 281 | self.dodecahedron_poses = torch.from_numpy(dodecahedron_poses.astype(np.float32)) # [N, 4, 4] 282 | self.dodecahedron_mvps = self.projection.unsqueeze(0) @ torch.inverse(self.dodecahedron_poses) 283 | 284 | if self.preload: 285 | self.poses = self.poses.to(self.device) 286 | if self.images is not None: 287 | self.images = self.images.to(self.device) 288 | self.projection = self.projection.to(self.device) 289 | self.mvps = self.mvps.to(self.device) 290 | 291 | 292 | def collate(self, index): 293 | 294 | B = len(index) # a list of length 1 295 | 296 | results = {'H': self.H, 'W': self.W} 297 | 298 | if self.training and self.opt.stage == 0: 299 | # randomly sample over images too 300 | num_rays = self.opt.num_rays 301 | 302 | if self.opt.random_image_batch: 303 | index = torch.randint(0, len(self.poses), size=(num_rays,), device=self.device) 304 | 305 | else: 306 | num_rays = -1 307 | 308 | poses = self.poses[index].to(self.device) # [N, 4, 4] 309 | 310 | rays = get_rays(poses, self.intrinsics, self.H, self.W, num_rays, self.opt.patch_size) 311 | 312 | results['rays_o'] = rays['rays_o'] 313 | results['rays_d'] = rays['rays_d'] 314 | results['index'] = index 315 | 316 | if self.opt.stage > 0: 317 | mvp = self.mvps[index].to(self.device) 318 | results['mvp'] = mvp 319 | 320 | if self.images is not None: 321 | 322 | if self.training and self.opt.stage == 0: 323 | images = self.images[index, rays['j'], rays['i']].float().to(self.device) / 255 # [N, 3/4] 324 | else: 325 | images = self.images[index].squeeze(0).float().to(self.device) / 255 # [H, W, 3/4] 326 | 327 | if self.training: 328 | C = self.images.shape[-1] 329 | images = images.view(-1, C) 330 | 331 | results['images'] = images 332 | 333 | return results 334 | 335 | def dataloader(self): 336 | size = len(self.poses) 337 | loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) 338 | loader._data = self # an ugly fix... we need to access error_map & poses in trainer. 339 | loader.has_gt = self.images is not None 340 | return loader -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | from packaging import version 4 | import torch 5 | 6 | torch_version = torch.__version__ 7 | 8 | _src_path = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | nvcc_flags = [ 11 | '-O3', '-std=c++14', 12 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 13 | ] 14 | 15 | if os.name == "posix": 16 | c_flags = ['-O3', '-std=c++14'] 17 | if version.parse(torch_version) >= version.parse("2.1"): 18 | nvcc_flags = [ 19 | '-O3', '-std=c++17', 20 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 21 | '-use_fast_math' 22 | ] 23 | c_flags = ['-O3', '-std=c++17'] 24 | elif os.name == "nt": 25 | c_flags = ['/O2', '/std:c++17'] 26 | 27 | # find cl.exe 28 | def find_cl_path(): 29 | import glob 30 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 31 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 32 | if paths: 33 | return paths[0] 34 | 35 | # If cl.exe is not on path, try to find it. 36 | if os.system("where cl.exe >nul 2>nul") != 0: 37 | cl_path = find_cl_path() 38 | if cl_path is None: 39 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 40 | os.environ["PATH"] += ";" + cl_path 41 | 42 | _backend = load(name='_raymarching_mob', 43 | extra_cflags=c_flags, 44 | extra_cuda_cflags=nvcc_flags, 45 | sources=[os.path.join(_src_path, 'src', f) for f in [ 46 | 'raymarching.cu', 47 | 'bindings.cpp', 48 | ]], 49 | ) 50 | 51 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/raymarching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _raymarching_mob as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | # ---------------------------------------- 16 | # utils 17 | # ---------------------------------------- 18 | 19 | class _near_far_from_aabb(Function): 20 | @staticmethod 21 | @custom_fwd(cast_inputs=torch.float32) 22 | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): 23 | ''' near_far_from_aabb, CUDA implementation 24 | Calculate rays' intersection time (near and far) with aabb 25 | Args: 26 | rays_o: float, [N, 3] 27 | rays_d: float, [N, 3] 28 | aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) 29 | min_near: float, scalar 30 | Returns: 31 | nears: float, [N] 32 | fars: float, [N] 33 | ''' 34 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 35 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 36 | 37 | rays_o = rays_o.contiguous().view(-1, 3) 38 | rays_d = rays_d.contiguous().view(-1, 3) 39 | 40 | N = rays_o.shape[0] # num rays 41 | 42 | nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 43 | fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 44 | 45 | _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) 46 | 47 | return nears, fars 48 | 49 | near_far_from_aabb = _near_far_from_aabb.apply 50 | 51 | 52 | class _sph_from_ray(Function): 53 | @staticmethod 54 | @custom_fwd(cast_inputs=torch.float32) 55 | def forward(ctx, rays_o, rays_d, radius): 56 | ''' sph_from_ray, CUDA implementation 57 | get spherical coordinate on the background sphere from rays. 58 | Assume rays_o are inside the Sphere(radius). 59 | Args: 60 | rays_o: [N, 3] 61 | rays_d: [N, 3] 62 | radius: scalar, float 63 | Return: 64 | coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) 65 | ''' 66 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 67 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 68 | 69 | rays_o = rays_o.contiguous().view(-1, 3) 70 | rays_d = rays_d.contiguous().view(-1, 3) 71 | 72 | N = rays_o.shape[0] # num rays 73 | 74 | coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) 75 | 76 | _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) 77 | 78 | return coords 79 | 80 | sph_from_ray = _sph_from_ray.apply 81 | 82 | class _morton3D(Function): 83 | @staticmethod 84 | def forward(ctx, coords): 85 | ''' morton3D, CUDA implementation 86 | Args: 87 | coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) 88 | ENHANCE: check if the coord range is valid! (current 128 is safe) 89 | Returns: 90 | indices: [N], int32, in [0, 128^3) 91 | 92 | ''' 93 | if not coords.is_cuda: coords = coords.cuda() 94 | 95 | N = coords.shape[0] 96 | 97 | indices = torch.empty(N, dtype=torch.int32, device=coords.device) 98 | 99 | _backend.morton3D(coords.int(), N, indices) 100 | 101 | return indices 102 | 103 | morton3D = _morton3D.apply 104 | 105 | class _morton3D_invert(Function): 106 | @staticmethod 107 | def forward(ctx, indices): 108 | ''' morton3D_invert, CUDA implementation 109 | Args: 110 | indices: [N], int32, in [0, 128^3) 111 | Returns: 112 | coords: [N, 3], int32, in [0, 128) 113 | 114 | ''' 115 | if not indices.is_cuda: indices = indices.cuda() 116 | 117 | N = indices.shape[0] 118 | 119 | coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) 120 | 121 | _backend.morton3D_invert(indices.int(), N, coords) 122 | 123 | return coords 124 | 125 | morton3D_invert = _morton3D_invert.apply 126 | 127 | 128 | class _packbits(Function): 129 | @staticmethod 130 | @custom_fwd(cast_inputs=torch.float32) 131 | def forward(ctx, grid, thresh, bitfield=None): 132 | ''' packbits, CUDA implementation 133 | Pack up the density grid into a bit field to accelerate ray marching. 134 | Args: 135 | grid: float, [C, H * H * H], assume H % 2 == 0 136 | thresh: float, threshold 137 | Returns: 138 | bitfield: uint8, [C, H * H * H / 8] 139 | ''' 140 | if not grid.is_cuda: grid = grid.cuda() 141 | grid = grid.contiguous() 142 | 143 | C = grid.shape[0] 144 | H3 = grid.shape[1] 145 | N = C * H3 // 8 146 | 147 | if bitfield is None: 148 | bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) 149 | 150 | _backend.packbits(grid, N, thresh, bitfield) 151 | 152 | return bitfield 153 | 154 | packbits = _packbits.apply 155 | 156 | 157 | class _flatten_rays(Function): 158 | @staticmethod 159 | def forward(ctx, rays, M): 160 | ''' flatten rays 161 | Args: 162 | rays: [N, 2], all rays' (point_offset, point_count), 163 | M: scalar, int, count of points (we cannot get this info from rays unfortunately...) 164 | Returns: 165 | res: [M], flattened ray index. 166 | ''' 167 | if not rays.is_cuda: rays = rays.cuda() 168 | rays = rays.contiguous() 169 | 170 | N = rays.shape[0] 171 | 172 | res = torch.zeros(M, dtype=torch.int, device=rays.device) 173 | 174 | _backend.flatten_rays(rays, N, M, res) 175 | 176 | return res 177 | 178 | flatten_rays = _flatten_rays.apply 179 | 180 | # ---------------------------------------- 181 | # train functions 182 | # ---------------------------------------- 183 | 184 | class _march_rays_train(Function): 185 | @staticmethod 186 | @custom_fwd(cast_inputs=torch.float32) 187 | def forward(ctx, rays_o, rays_d, bound, contract, density_bitfield, C, H, nears, fars, perturb=False, dt_gamma=0, max_steps=1024): 188 | ''' march rays to generate points (forward only) 189 | Args: 190 | rays_o/d: float, [N, 3] 191 | bound: float, scalar 192 | density_bitfield: uint8: [CHHH // 8] 193 | C: int 194 | H: int 195 | nears/fars: float, [N] 196 | step_counter: int32, (2), used to count the actual number of generated points. 197 | mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) 198 | perturb: bool 199 | align: int, pad output so its size is dividable by align, set to -1 to disable. 200 | force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. 201 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 202 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 203 | Returns: 204 | xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) 205 | dirs: float, [M, 3], all generated points' view dirs. 206 | ts: float, [M, 2], all generated points' ts. 207 | rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0] 208 | ''' 209 | 210 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 211 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 212 | if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() 213 | 214 | rays_o = rays_o.float().contiguous().view(-1, 3) 215 | rays_d = rays_d.float().contiguous().view(-1, 3) 216 | density_bitfield = density_bitfield.contiguous() 217 | 218 | N = rays_o.shape[0] # num rays 219 | 220 | step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter 221 | 222 | if perturb: 223 | noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) 224 | else: 225 | noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) 226 | 227 | # first pass: write rays, get total number of points M to render 228 | rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps 229 | _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, None, None, None, rays, step_counter, noises) 230 | 231 | # allocate based on M 232 | M = step_counter.item() 233 | # print(M, N) 234 | # print(rays[:, 0].max()) 235 | 236 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 237 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 238 | ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) 239 | 240 | # second pass: write outputs 241 | _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, xyzs, dirs, ts, rays, step_counter, noises) 242 | 243 | return xyzs, dirs, ts, rays 244 | 245 | march_rays_train = _march_rays_train.apply 246 | 247 | 248 | class _composite_rays_train(Function): 249 | @staticmethod 250 | @custom_fwd(cast_inputs=torch.float32) 251 | def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, alpha_mode=False): 252 | ''' composite rays' rgbs, according to the ray marching formula. 253 | Args: 254 | rgbs: float, [M, 3] 255 | sigmas: float, [M,] 256 | ts: float, [M, 2] 257 | rays: int32, [N, 3] 258 | alpha_mode: bool, sigmas are treated as alphas instead 259 | Returns: 260 | weights: float, [M] 261 | weights_sum: float, [N,], the alpha channel 262 | depth: float, [N, ], the Depth 263 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 264 | ''' 265 | 266 | sigmas = sigmas.float().contiguous() 267 | rgbs = rgbs.float().contiguous() 268 | 269 | M = sigmas.shape[0] 270 | N = rays.shape[0] 271 | 272 | weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0 273 | weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 274 | 275 | depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 276 | image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) 277 | 278 | _backend.composite_rays_train_forward(sigmas, rgbs, ts, rays, M, N, T_thresh, alpha_mode, weights, weights_sum, depth, image) 279 | 280 | ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image) 281 | ctx.dims = [M, N, T_thresh, alpha_mode] 282 | 283 | return weights, weights_sum, depth, image 284 | 285 | @staticmethod 286 | @custom_bwd 287 | def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image): 288 | 289 | grad_weights = grad_weights.contiguous() 290 | grad_weights_sum = grad_weights_sum.contiguous() 291 | grad_depth = grad_depth.contiguous() 292 | grad_image = grad_image.contiguous() 293 | 294 | sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors 295 | M, N, T_thresh, alpha_mode = ctx.dims 296 | 297 | grad_sigmas = torch.zeros_like(sigmas) 298 | grad_rgbs = torch.zeros_like(rgbs) 299 | 300 | _backend.composite_rays_train_backward(grad_weights, grad_weights_sum, grad_depth, grad_image, sigmas, rgbs, ts, rays, weights_sum, depth, image, M, N, T_thresh, alpha_mode, grad_sigmas, grad_rgbs) 301 | 302 | return grad_sigmas, grad_rgbs, None, None, None, None 303 | 304 | 305 | composite_rays_train = _composite_rays_train.apply 306 | 307 | # ---------------------------------------- 308 | # infer functions 309 | # ---------------------------------------- 310 | 311 | class _march_rays(Function): 312 | @staticmethod 313 | @custom_fwd(cast_inputs=torch.float32) 314 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, density_bitfield, C, H, near, far, perturb=False, dt_gamma=0, max_steps=1024): 315 | ''' march rays to generate points (forward only, for inference) 316 | Args: 317 | n_alive: int, number of alive rays 318 | n_step: int, how many steps we march 319 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) 320 | rays_t: float, [N], the alive rays' time, we only use the first n_alive. 321 | rays_o/d: float, [N, 3] 322 | bound: float, scalar 323 | density_bitfield: uint8: [CHHH // 8] 324 | C: int 325 | H: int 326 | nears/fars: float, [N] 327 | align: int, pad output so its size is dividable by align, set to -1 to disable. 328 | perturb: bool/int, int > 0 is used as the random seed. 329 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 330 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 331 | Returns: 332 | xyzs: float, [n_alive * n_step, 3], all generated points' coords 333 | dirs: float, [n_alive * n_step, 3], all generated points' view dirs. 334 | ts: float, [n_alive * n_step, 2], all generated points' ts 335 | ''' 336 | 337 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 338 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 339 | 340 | rays_o = rays_o.float().contiguous().view(-1, 3) 341 | rays_d = rays_d.float().contiguous().view(-1, 3) 342 | 343 | M = n_alive * n_step 344 | 345 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 346 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 347 | ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth 348 | 349 | if perturb: 350 | # torch.manual_seed(perturb) # test_gui uses spp index as seed 351 | noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) 352 | else: 353 | noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) 354 | 355 | _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, ts, noises) 356 | 357 | return xyzs, dirs, ts 358 | 359 | march_rays = _march_rays.apply 360 | 361 | 362 | class _composite_rays(Function): 363 | @staticmethod 364 | @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float 365 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh=1e-2, alpha_mode=False): 366 | ''' composite rays' rgbs, according to the ray marching formula. (for inference) 367 | Args: 368 | n_alive: int, number of alive rays 369 | n_step: int, how many steps we march 370 | rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) 371 | rays_t: float, [N], the alive rays' time 372 | sigmas: float, [n_alive * n_step,] 373 | rgbs: float, [n_alive * n_step, 3] 374 | ts: float, [n_alive * n_step, 2] 375 | In-place Outputs: 376 | weights_sum: float, [N,], the alpha channel 377 | depth: float, [N,], the depth value 378 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 379 | ''' 380 | sigmas = sigmas.float().contiguous() 381 | rgbs = rgbs.float().contiguous() 382 | _backend.composite_rays(n_alive, n_step, T_thresh, alpha_mode, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image) 383 | return tuple() 384 | 385 | 386 | composite_rays = _composite_rays.apply -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | from packaging import version 5 | import torch 6 | 7 | torch_version = torch.__version__ 8 | 9 | _src_path = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | nvcc_flags = [ 12 | '-O3', '-std=c++14', 13 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | 16 | if os.name == "posix": 17 | c_flags = ['-O3', '-std=c++14'] 18 | if version.parse(torch_version) >= version.parse("2.1"): 19 | nvcc_flags = [ 20 | '-O3', '-std=c++17', 21 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 22 | '-use_fast_math' 23 | ] 24 | c_flags = ['-O3', '-std=c++17'] 25 | elif os.name == "nt": 26 | c_flags = ['/O2', '/std:c++17'] 27 | 28 | # find cl.exe 29 | def find_cl_path(): 30 | import glob 31 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 32 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 33 | if paths: 34 | return paths[0] 35 | 36 | # If cl.exe is not on path, try to find it. 37 | if os.system("where cl.exe >nul 2>nul") != 0: 38 | cl_path = find_cl_path() 39 | if cl_path is None: 40 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 41 | os.environ["PATH"] += ";" + cl_path 42 | 43 | ''' 44 | Usage: 45 | 46 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 47 | 48 | python setup.py install # build extensions and install (copy) to PATH. 49 | pip install . # ditto but better (e.g., dependency & metadata handling) 50 | 51 | python setup.py develop # build extensions and install (symbolic) to PATH. 52 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 53 | 54 | ''' 55 | setup( 56 | name='raymarching_mob', # package name, import this to use python API 57 | ext_modules=[ 58 | CUDAExtension( 59 | name='_raymarching_mob', # extension name, import this to use CUDA API 60 | sources=[os.path.join(_src_path, 'src', f) for f in [ 61 | 'raymarching.cu', 62 | 'bindings.cpp', 63 | ]], 64 | extra_compile_args={ 65 | 'cxx': c_flags, 66 | 'nvcc': nvcc_flags, 67 | } 68 | ), 69 | ], 70 | cmdclass={ 71 | 'build_ext': BuildExtension, 72 | } 73 | ) -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)"); 8 | m.def("packbits", &packbits, "packbits (CUDA)"); 9 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 10 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); 11 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 12 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 13 | // train 14 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 15 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 16 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 17 | // infer 18 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 19 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 20 | } -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res); 13 | 14 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises); 15 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool alpha_mode, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 16 | void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool alpha_mode, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 17 | 18 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises); 19 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool alpha_mode, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # nerf2mesh 2 | 3 | 4 | This repository contains a PyTorch re-implementation of the paper: [Delicate Textured Mesh Recovery from NeRF via Adaptive Surface Refinement](https://arxiv.org/abs/2303.02091). 5 | 6 | ### [Project Page](https://ashawkey.github.io/nerf2mesh/) | [Arxiv](https://arxiv.org/abs/2303.02091) | [Paper](https://huggingface.co/ashawkey/nerf2mesh/resolve/main/paper.pdf) | [Models](https://huggingface.co/ashawkey/nerf2mesh/tree/main/scenes) 7 | 8 | **News (2023.5.3)**: support [background removal](https://github.com/OPHoperHPO/image-background-remove-tool) and [SDF](https://github.com/Totoro97/NeuS) mode for stage 0, which produces more robust and smooth mesh for single-object reconstruction: 9 | 10 | ![](assets/teaser2.jpg) 11 | 12 | ![](assets/teaser.jpg) 13 | 14 | # Install 15 | 16 | ```bash 17 | git clone https://github.com/ashawkey/nerf2mesh.git 18 | cd nerf2mesh 19 | ``` 20 | 21 | ### Install with pip 22 | ```bash 23 | pip install -r requirements.txt 24 | 25 | # tiny-cuda-nn 26 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 27 | 28 | # nvdiffrast 29 | pip install git+https://github.com/NVlabs/nvdiffrast/ 30 | 31 | # pytorch3d 32 | pip install git+https://github.com/facebookresearch/pytorch3d.git 33 | ``` 34 | 35 | ### Build extension (optional) 36 | By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. 37 | However, this may be inconvenient sometimes. 38 | Therefore, we also provide the `setup.py` to build each extension: 39 | ```bash 40 | # install all extension modules 41 | bash scripts/install_ext.sh 42 | 43 | # if you want to install manually, here is an example: 44 | cd raymarching 45 | python setup.py build_ext --inplace # build ext only, do not install (only can be used in the parent directory) 46 | pip install . # install to python path (you still need the raymarching/ folder, since this only install the built extension.) 47 | ``` 48 | 49 | ### Tested environments 50 | * Ubuntu 22 with torch 1.12 & CUDA 11.6 on a V100. 51 | 52 | # Usage 53 | 54 | We support the original NeRF data format like [nerf-synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1), and COLMAP dataset like [Mip-NeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip). 55 | Please download and put them under `./data`. 56 | 57 | First time running will take some time to compile the CUDA extensions. 58 | 59 | ### Basics 60 | ```bash 61 | ### Stage0 (NeRF, continuous, volumetric rendering), this stage exports a coarse mesh under /mesh_stage0/ 62 | 63 | # nerf 64 | python main.py data/nerf_synthetic/lego/ --workspace trial_syn_lego/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --lambda_tv 1e-8 65 | 66 | # colmap 67 | python main.py data/garden/ --workspace trial_360_garden -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --lambda_tv 2e-8 --visibility_mask_dilation 50 68 | 69 | ### Stage1 (Mesh, binarized, rasterization), this stage exports a fine mesh with textures under /mesh_stage1/ 70 | 71 | # nerf 72 | python main.py data/nerf_synthetic/lego/ --workspace trial_syn_lego/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 73 | 74 | # colmap 75 | python main.py data/garden/ --workspace trial_360_garden -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 1 --iters 10000 76 | 77 | ### Web Renderer 78 | # you can simply open /mesh_stage1/mesh.obj with a 3D viewer to visualize the diffuse texture. 79 | # to render full diffuse + specular, you'll need to host this folder (e.g., by vscode live server), and open renderer.html for further instructions. 80 | ``` 81 | 82 | ### Custom Dataset 83 | 84 | **Tips:** 85 | * To get best mesh quality, you may need to adjust `--scale` to let the most interested object fall inside the unit box `[-1, 1]^3`, which can be visualized by appending `--vis_pose`. 86 | * To better model background (especially for outdoor scenes), you may need to adjust `--bound` to let most sparse points fall into the full box `[-bound, bound]^3`, which can also be visualized by appending `--vis_pose`. 87 | * For single object centered captures focusing on mesh assets quality: 88 | * remove the background by `scripts/remove_bg.py` and only reconstruct the targeted object. 89 | * use `--sdf` to enable sdf based stage 0 model. 90 | * use `--diffuse_only` if you only want to get the diffuse texture. 91 | * adjust `--decimate_target 1e5` to control stage 0 number of mesh faces, and adjust `--refine_remesh_size 0.01` to control stage 1 number of mesh faces (average edge length). 92 | * adjust `--lambda_normal 1e-2` for more smooth surface. 93 | * For forward-facing captures: 94 | * remove `--enable_cam_center` so the scene center is determined by sparse points instead of camera positions. 95 | 96 | ```bash 97 | # prepare your video or images under /data/custom, and run colmap (assumed installed): 98 | python scripts/colmap2nerf.py --video ./data/custom/video.mp4 --run_colmap # if use video 99 | python scripts/colmap2nerf.py --images ./data/custom/images/ --run_colmap # if use images 100 | 101 | # generate downscaled images if resolution is very high and OOM (asve to`data//images_{downscale}`) 102 | python scripts/downscale.py data/ --downscale 4 103 | # NOTE: remember to append `--downscale 4` as well when running main.py 104 | 105 | # perform background removal for single object 360 captures (save to 'data//mask') 106 | python scripts/remove_bg.py data//images 107 | # NOTE: the mask quality depends on background complexity, do check the mask! 108 | 109 | # recommended options for single object 360 captures 110 | python main.py data/custom/ --workspace trial_custom -O --data_format colmap --bound 1 --dt_gamma 0 --stage 0 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 --iters 10000 --decimate_target 1e5 --sdf 111 | # NOTE: for finer faces, try --decimate_target 3e5 112 | 113 | python main.py data/custom/ --workspace trial_custom -O --data_format colmap --bound 1 --dt_gamma 0 --stage 1 --iters 5000 --lambda_normal 1e-2 --refine_remesh_size 0.01 --sdf 114 | # NOTE: for finer faces, try --lambda_normal 1e-1 --refine_remesh_size 0.005 115 | 116 | # recommended options for outdoor 360-inwarding captures 117 | python main.py data/custom/ --workspace trial_custom -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --lambda_tv 2e-8 --visibility_mask_dilation 50 118 | 119 | python main.py data/custom/ --workspace trial_custom -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --stage 1 --iters 10000 --lambda_normal 1e-3 120 | 121 | # recommended options for forward-facing captures 122 | python main.py data/custom/ --workspace trial_custom -O --data_format colmap --bound 2 --scale 0.1 --stage 0 --clean_min_f 16 --clean_min_d 10 --lambda_tv 2e-8 --visibility_mask_dilation 50 123 | 124 | python main.py data/custom/ --workspace trial_custom -O --data_format colmap --bound 2 --scale 0.1 --stage 1 --iters 10000 --lambda_normal 1e-3 125 | ``` 126 | 127 | ### Advanced Usage 128 | ```bash 129 | ### -O: the recommended setting, equals 130 | --fp16 --preload --mark_untrained --random_image_batch --adaptive_num_rays --refine --mesh_visibility_culling 131 | 132 | ### load checkpoint 133 | --ckpt latest # by default we load the latest checkpoint in the workspace 134 | --ckpt scratch # train from scratch. For stage 1, this will still load the stage 0 model as an initialization. 135 | --ckpt trial/checkpoints/xxx.pth # specify it by path 136 | 137 | ### testing 138 | --test # test, save video and mesh 139 | --test_no_video # do not save video 140 | --test_no_mesh # do not save mesh 141 | 142 | ### dataset related 143 | --data_format [colmap|nerf|dtu] # dataset format 144 | --enable_cam_center # use camera center instead of sparse point center as scene center (colmap dataset only) 145 | --enable_cam_near_far # estimate camera near & far from sparse points (colmap dataset only) 146 | 147 | --bound 16 # scene bound set to [-16, 16]^3, note that only meshes inside the center [-1, 1]^3 will be adaptively refined! 148 | --scale 0.3 # camera scale, if not specified, automatically estimate one based on camera positions. Important targets should be scaled into the center [-1, 1]^3. 149 | 150 | ### visualization 151 | --vis_pose # viusalize camera poses and sparse points (sparse points are colmap dataset only) 152 | --gui # open gui (only for testing, training in gui is not well supported!) 153 | 154 | ### balance between surface quality / rendering quality 155 | 156 | # increase these weights to get better surface quality but worse rendering quality 157 | --lambda_tv 1e-7 # total variation loss (stage 0) 158 | --lambda_entropy 1e-3 # entropy on rendering weights (transparency, alpha), encourage them to be either 0 or 1 (stage 0) 159 | --lambda_lap 0.001 # laplacian smoothness loss (stage 1) 160 | --lambda_normal 0.001 # normal consistency loss (stage 1) 161 | --lambda_offsets 0.1 # vertex offsets L2 loss (stage 1) 162 | --lambda_edgelen 0.1 # edge length L2 loss (stage 1) 163 | 164 | # set all smoothness regularizations to 0, usually get the best rendering quality 165 | --wo_smooth 166 | 167 | # only use diffuse shading 168 | --diffuse_only 169 | 170 | ### coarse mesh extraction & post-processing 171 | --mcubes_reso 512 # marching cubes resolution 172 | --decimate_target 300000 # decimate raw mesh to this face number 173 | --clean_min_d 5 # isolated floaters with smaller diameter will be removed 174 | --clean_min_f 8 # isolated floaters with fewer faces will be removed 175 | --visibility_mask_dilation 5 # dilate iterations after performing visibility face culling 176 | 177 | ### fine mesh exportation 178 | --texture_size 4096 # max texture image resolution 179 | --ssaa 2 # super-sampling anti-alias ratio 180 | --refine_size 0.01 # finest edge len at subdivision 181 | --refine_decimate_ratio 0.1 # decimate ratio at each refine step 182 | --refine_remesh_size 0.02 # remesh edge len after decimation 183 | 184 | ### Depth supervision (colmap dataset only) 185 | 186 | # download depth checkpoints (omnidata v2) 187 | cd depth_tools 188 | bash download_models.sh 189 | cd .. 190 | 191 | # generate dense depth (save to `data//depths`) 192 | python depth_tools/extract_depth.py data//images_4 193 | 194 | # enable dense depth training 195 | python main.py data/ -O --bound 16 --data_format colmap --enable_dense_depth 196 | ``` 197 | 198 | Please check the `scripts` directory for more examples on common datasets, and check `main.py` for all options. 199 | 200 | # Acknowledgement 201 | 202 | * The NeRF framework is based on [torch-ngp](https://github.com/ashawkey/torch-ngp). 203 | * The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui). 204 | 205 | # Citation 206 | 207 | ``` 208 | @article{tang2022nerf2mesh, 209 | title={Delicate Textured Mesh Recovery from NeRF via Adaptive Surface Refinement}, 210 | author={Tang, Jiaxiang and Zhou, Hang and Chen, Xiaokang and Hu, Tianshu and Ding, Errui and Wang, Jingdong and Zeng, Gang}, 211 | journal={arXiv preprint arXiv:2303.02091}, 212 | year={2022} 213 | } 214 | ``` 215 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rich 2 | tqdm 3 | ninja 4 | torch 5 | numpy 6 | scipy 7 | lpips 8 | pandas 9 | trimesh 10 | PyMCubes 11 | torch-ema 12 | dearpygui 13 | packaging 14 | matplotlib 15 | tensorboardX 16 | opencv-python 17 | imageio 18 | imageio-ffmpeg 19 | pymeshlab 20 | torch-scatter 21 | xatlas 22 | scikit-learn 23 | torchmetrics -------------------------------------------------------------------------------- /scripts/colmap2nerf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | import json 17 | import sys 18 | import math 19 | import cv2 20 | import os 21 | import shutil 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") 25 | 26 | parser.add_argument("--video", default="", help="input path to the video") 27 | parser.add_argument("--images", default="", help="input path to the images folder, ignored if --video is provided") 28 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") 29 | 30 | parser.add_argument("--dynamic", action="store_true", help="for dynamic scene, extraly save time calculated from frame index.") 31 | parser.add_argument("--estimate_affine_shape", action="store_true", help="colmap SiftExtraction option, may yield better results, yet can only be run on CPU.") 32 | parser.add_argument('--hold', type=int, default=8, help="hold out for validation every $ images") 33 | 34 | parser.add_argument("--video_fps", default=3) 35 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") 36 | 37 | parser.add_argument("--colmap_matcher", default="exhaustive", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") 38 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start") 39 | 40 | parser.add_argument("--colmap_text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") 41 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") 42 | 43 | args = parser.parse_args() 44 | return args 45 | 46 | def do_system(arg): 47 | print(f"==== running: {arg}") 48 | err = os.system(arg) 49 | if err: 50 | print("FATAL: command failed") 51 | sys.exit(err) 52 | 53 | def run_ffmpeg(args): 54 | video = args.video 55 | images = args.images 56 | fps = float(args.video_fps) or 1.0 57 | 58 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") 59 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 60 | sys.exit(1) 61 | 62 | try: 63 | shutil.rmtree(images) 64 | except: 65 | pass 66 | 67 | do_system(f"mkdir {images}") 68 | 69 | time_slice_value = "" 70 | time_slice = args.time_slice 71 | if time_slice: 72 | start, end = time_slice.split(",") 73 | time_slice_value = f",select='between(t\,{start}\,{end})'" 74 | 75 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") 76 | 77 | def run_colmap(args): 78 | db = args.colmap_db 79 | images = args.images 80 | text = args.colmap_text 81 | flag_EAS = int(args.estimate_affine_shape) # 0 / 1 82 | 83 | db_noext = str(Path(db).with_suffix("")) 84 | sparse = db_noext + "_sparse" 85 | 86 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") 87 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 88 | sys.exit(1) 89 | if os.path.exists(db): 90 | os.remove(db) 91 | do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape {flag_EAS} --SiftExtraction.domain_size_pooling {flag_EAS} --ImageReader.single_camera 1 --database_path {db} --image_path {images}") 92 | do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching {flag_EAS} --database_path {db}") 93 | try: 94 | shutil.rmtree(sparse) 95 | except: 96 | pass 97 | do_system(f"mkdir {sparse}") 98 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") 99 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") 100 | try: 101 | shutil.rmtree(text) 102 | except: 103 | pass 104 | do_system(f"mkdir {text}") 105 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") 106 | 107 | def variance_of_laplacian(image): 108 | return cv2.Laplacian(image, cv2.CV_64F).var() 109 | 110 | def sharpness(imagePath): 111 | image = cv2.imread(imagePath) 112 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 113 | fm = variance_of_laplacian(gray) 114 | return fm 115 | 116 | def qvec2rotmat(qvec): 117 | return np.array([ 118 | [ 119 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 120 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 121 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 122 | ], [ 123 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 124 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 125 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 126 | ], [ 127 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 128 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 129 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 130 | ] 131 | ]) 132 | 133 | def rotmat(a, b): 134 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 135 | v = np.cross(a, b) 136 | c = np.dot(a, b) 137 | # handle exception for the opposite direction input 138 | if c < -1 + 1e-10: 139 | return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) 140 | s = np.linalg.norm(v) 141 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 142 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 143 | 144 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 145 | da = da / np.linalg.norm(da) 146 | db = db / np.linalg.norm(db) 147 | c = np.cross(da, db) 148 | denom = np.linalg.norm(c)**2 149 | t = ob - oa 150 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 151 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 152 | if ta > 0: 153 | ta = 0 154 | if tb > 0: 155 | tb = 0 156 | return (oa+ta*da+ob+tb*db) * 0.5, denom 157 | 158 | if __name__ == "__main__": 159 | args = parse_args() 160 | 161 | if args.video != "": 162 | root_dir = os.path.dirname(args.video) 163 | args.images = os.path.join(root_dir, "images") # override args.images 164 | run_ffmpeg(args) 165 | else: 166 | args.images = args.images[:-1] if args.images[-1] == '/' else args.images # remove trailing / (./a/b/ --> ./a/b) 167 | root_dir = os.path.dirname(args.images) 168 | 169 | args.colmap_db = os.path.join(root_dir, args.colmap_db) 170 | args.colmap_text = os.path.join(root_dir, args.colmap_text) 171 | 172 | if args.run_colmap: 173 | run_colmap(args) 174 | 175 | SKIP_EARLY = int(args.skip_early) 176 | TEXT_FOLDER = args.colmap_text 177 | 178 | with open(os.path.join(TEXT_FOLDER, "cameras.txt"), "r") as f: 179 | angle_x = math.pi / 2 180 | for line in f: 181 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 182 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 183 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 184 | if line[0] == "#": 185 | continue 186 | els = line.split(" ") 187 | w = float(els[2]) 188 | h = float(els[3]) 189 | fl_x = float(els[4]) 190 | fl_y = float(els[4]) 191 | k1 = 0 192 | k2 = 0 193 | p1 = 0 194 | p2 = 0 195 | cx = w / 2 196 | cy = h / 2 197 | if els[1] == "SIMPLE_PINHOLE": 198 | cx = float(els[5]) 199 | cy = float(els[6]) 200 | elif els[1] == "PINHOLE": 201 | fl_y = float(els[5]) 202 | cx = float(els[6]) 203 | cy = float(els[7]) 204 | elif els[1] == "SIMPLE_RADIAL": 205 | cx = float(els[5]) 206 | cy = float(els[6]) 207 | k1 = float(els[7]) 208 | elif els[1] == "RADIAL": 209 | cx = float(els[5]) 210 | cy = float(els[6]) 211 | k1 = float(els[7]) 212 | k2 = float(els[8]) 213 | elif els[1] == "OPENCV": 214 | fl_y = float(els[5]) 215 | cx = float(els[6]) 216 | cy = float(els[7]) 217 | k1 = float(els[8]) 218 | k2 = float(els[9]) 219 | p1 = float(els[10]) 220 | p2 = float(els[11]) 221 | else: 222 | print("unknown camera model ", els[1]) 223 | # fl = 0.5 * w / tan(0.5 * angle_x); 224 | angle_x = math.atan(w / (fl_x * 2)) * 2 225 | angle_y = math.atan(h / (fl_y * 2)) * 2 226 | fovx = angle_x * 180 / math.pi 227 | fovy = angle_y * 180 / math.pi 228 | 229 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") 230 | 231 | with open(os.path.join(TEXT_FOLDER, "images.txt"), "r") as f: 232 | i = 0 233 | 234 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) 235 | 236 | frames = [] 237 | 238 | up = np.zeros(3) 239 | for line in f: 240 | line = line.strip() 241 | 242 | if line[0] == "#": 243 | continue 244 | 245 | i = i + 1 246 | if i < SKIP_EARLY*2: 247 | continue 248 | 249 | if i % 2 == 1: 250 | elems = line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) 251 | 252 | name = '_'.join(elems[9:]) 253 | full_name = os.path.join(args.images, name) 254 | rel_name = full_name[len(root_dir) + 1:] 255 | 256 | b = sharpness(full_name) 257 | # print(name, "sharpness =",b) 258 | 259 | image_id = int(elems[0]) 260 | qvec = np.array(tuple(map(float, elems[1:5]))) 261 | tvec = np.array(tuple(map(float, elems[5:8]))) 262 | R = qvec2rotmat(-qvec) 263 | t = tvec.reshape([3, 1]) 264 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 265 | c2w = np.linalg.inv(m) 266 | 267 | c2w[0:3, 2] *= -1 # flip the y and z axis 268 | c2w[0:3, 1] *= -1 269 | c2w = c2w[[1, 0, 2, 3],:] # swap y and z 270 | c2w[2, :] *= -1 # flip whole world upside down 271 | 272 | up += c2w[0:3, 1] 273 | 274 | frame = { 275 | "file_path": rel_name, 276 | "sharpness": b, 277 | "transform_matrix": c2w 278 | } 279 | 280 | frames.append(frame) 281 | 282 | N = len(frames) 283 | up = up / np.linalg.norm(up) 284 | 285 | print("[INFO] up vector was", up) 286 | 287 | R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] 288 | R = np.pad(R, [0, 1]) 289 | R[-1, -1] = 1 290 | 291 | for f in frames: 292 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis 293 | 294 | # find a central point they are all looking at 295 | print("[INFO] computing center of attention...") 296 | totw = 0.0 297 | totp = np.array([0.0, 0.0, 0.0]) 298 | for f in frames: 299 | mf = f["transform_matrix"][0:3,:] 300 | for g in frames: 301 | mg = g["transform_matrix"][0:3,:] 302 | p, weight = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 303 | if weight > 0.01: 304 | totp += p * weight 305 | totw += weight 306 | totp /= totw 307 | for f in frames: 308 | f["transform_matrix"][0:3,3] -= totp 309 | avglen = 0. 310 | for f in frames: 311 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) 312 | avglen /= N 313 | print("[INFO] avg camera distance from origin", avglen) 314 | for f in frames: 315 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" 316 | 317 | # sort frames by id 318 | frames.sort(key=lambda d: d['file_path']) 319 | 320 | # add time if scene is dynamic 321 | if args.dynamic: 322 | for i, f in enumerate(frames): 323 | f['time'] = i / N 324 | 325 | for f in frames: 326 | f["transform_matrix"] = f["transform_matrix"].tolist() 327 | 328 | # construct frames 329 | 330 | def write_json(filename, frames): 331 | 332 | out = { 333 | "camera_angle_x": angle_x, 334 | "camera_angle_y": angle_y, 335 | "fl_x": fl_x, 336 | "fl_y": fl_y, 337 | "k1": k1, 338 | "k2": k2, 339 | "p1": p1, 340 | "p2": p2, 341 | "cx": cx, 342 | "cy": cy, 343 | "w": w, 344 | "h": h, 345 | "frames": frames, 346 | } 347 | 348 | output_path = os.path.join(root_dir, filename) 349 | print(f"[INFO] writing {len(frames)} frames to {output_path}") 350 | with open(output_path, "w") as outfile: 351 | json.dump(out, outfile, indent=2) 352 | 353 | # just one transforms.json, don't do data split 354 | if args.hold <= 0: 355 | 356 | write_json('transforms.json', frames) 357 | 358 | else: 359 | all_ids = np.arange(N) 360 | test_ids = all_ids[::args.hold] 361 | train_ids = np.array([i for i in all_ids if i not in test_ids]) 362 | 363 | frames_train = [f for i, f in enumerate(frames) if i in train_ids] 364 | frames_test = [f for i, f in enumerate(frames) if i in test_ids] 365 | 366 | write_json('transforms_train.json', frames_train) 367 | write_json('transforms_val.json', frames_test[::10]) 368 | write_json('transforms_test.json', frames_test) -------------------------------------------------------------------------------- /scripts/downscale.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import argparse 5 | 6 | from PIL import Image 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('path', type=str, help="path to the folder that contains `images/`") 10 | parser.add_argument('--downscale', type=int, default=4) 11 | 12 | opt = parser.parse_args() 13 | 14 | in_dir = os.path.join(opt.path, f'images') 15 | out_dir = os.path.join(opt.path, f'images_{opt.downscale}') 16 | 17 | os.makedirs(out_dir, exist_ok=True) 18 | 19 | def run_image(img_path): 20 | # img: filepath 21 | img = Image.open(img_path) 22 | W, H = img.size 23 | img = img.resize((W // opt.downscale, H // opt.downscale), Image.Resampling.BILINEAR) 24 | out_path = os.path.join(out_dir, os.path.basename(img_path)) 25 | img.save(out_path) 26 | 27 | img_paths = glob.glob(os.path.join(in_dir, '*')) 28 | for img_path in tqdm.tqdm(img_paths): 29 | run_image(img_path) 30 | -------------------------------------------------------------------------------- /scripts/install_ext.sh: -------------------------------------------------------------------------------- 1 | pip install ./raymarching 2 | pip install ./gridencoder 3 | pip install ./freqencoder 4 | pip install ./shencoder -------------------------------------------------------------------------------- /scripts/remove_bg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import cv2 4 | import glob 5 | import argparse 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | from PIL import Image 14 | 15 | class BackgroundRemoval(): 16 | def __init__(self, device='cuda'): 17 | 18 | from carvekit.api.high import HiInterface 19 | self.interface = HiInterface( 20 | object_type="object", # Can be "object" or "hairs-like". 21 | batch_size_seg=5, 22 | batch_size_matting=1, 23 | device=device, 24 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net 25 | matting_mask_size=2048, 26 | trimap_prob_threshold=231, 27 | trimap_dilation=30, 28 | trimap_erosion_iters=5, 29 | fp16=True, 30 | ) 31 | 32 | @torch.no_grad() 33 | def __call__(self, image): 34 | # image: PIL Image 35 | image = self.interface([image])[0] 36 | image = np.array(image) 37 | return image 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('path', type=str) 41 | 42 | opt = parser.parse_args() 43 | 44 | if opt.path[-1] == '/': 45 | opt.path = opt.path[:-1] 46 | 47 | out_dir = os.path.join(os.path.dirname(opt.path), f'mask') 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | print(f'[INFO] removing background: {opt.path} --> {out_dir}') 51 | 52 | model = BackgroundRemoval() 53 | 54 | def run_image(img_path): 55 | # img: filepath 56 | image = Image.open(img_path) 57 | carved_image = model(image) # [H, W, 4] 58 | mask = (carved_image[..., -1] > 0).astype(np.uint8) * 255 # [H, W] 59 | out_path = os.path.join(out_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png') 60 | cv2.imwrite(out_path, mask) 61 | 62 | img_paths = glob.glob(os.path.join(opt.path, '*')) 63 | for img_path in tqdm.tqdm(img_paths): 64 | run_image(img_path) 65 | -------------------------------------------------------------------------------- /scripts/runall_360_indoor.sh: -------------------------------------------------------------------------------- 1 | # indoor 2 | CUDA_VISIBLE_DEVICES=4 python main.py data/room/ --workspace trial_360_room -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 3 | CUDA_VISIBLE_DEVICES=4 python main.py data/room/ --workspace trial_360_room -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 4 | 5 | CUDA_VISIBLE_DEVICES=4 python main.py data/bonsai/ --workspace trial_360_bonsai -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 6 | CUDA_VISIBLE_DEVICES=4 python main.py data/bonsai/ --workspace trial_360_bonsai -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 7 | 8 | CUDA_VISIBLE_DEVICES=4 python main.py data/kitchen/ --workspace trial_360_kitchen -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 9 | CUDA_VISIBLE_DEVICES=4 python main.py data/kitchen/ --workspace trial_360_kitchen -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 10 | 11 | CUDA_VISIBLE_DEVICES=4 python main.py data/counter/ --workspace trial_360_counter -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 12 | CUDA_VISIBLE_DEVICES=4 python main.py data/counter/ --workspace trial_360_counter -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 -------------------------------------------------------------------------------- /scripts/runall_360_indoor_sdf.sh: -------------------------------------------------------------------------------- 1 | # indoor 2 | CUDA_VISIBLE_DEVICES=4 python main.py data/room/ --workspace trial_sdf_360_room -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 --sdf 3 | CUDA_VISIBLE_DEVICES=4 python main.py data/room/ --workspace trial_sdf_360_room -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 --sdf 4 | 5 | CUDA_VISIBLE_DEVICES=4 python main.py data/bonsai/ --workspace trial_sdf_360_bonsai -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 --sdf 6 | CUDA_VISIBLE_DEVICES=4 python main.py data/bonsai/ --workspace trial_sdf_360_bonsai -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 --sdf 7 | 8 | # CUDA_VISIBLE_DEVICES=4 python main.py data/kitchen/ --workspace trial_sdf_360_kitchen -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 --sdf 9 | # CUDA_VISIBLE_DEVICES=4 python main.py data/kitchen/ --workspace trial_sdf_360_kitchen -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 --sdf 10 | 11 | # CUDA_VISIBLE_DEVICES=4 python main.py data/counter/ --workspace trial_sdf_360_counter -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 --sdf 12 | # CUDA_VISIBLE_DEVICES=4 python main.py data/counter/ --workspace trial_sdf_360_counter -O --data_format colmap --bound 8 --enable_cam_center --enable_cam_near_far --scale 0.2 --downscale 4 --stage 1 --iters 10000 --lambda_lap 1e-3 --lambda_normal 1e-3 --sdf -------------------------------------------------------------------------------- /scripts/runall_360_outdoor.sh: -------------------------------------------------------------------------------- 1 | # outdoor 2 | CUDA_VISIBLE_DEVICES=5 python main.py data/garden/ --workspace trial_360_garden -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 3 | CUDA_VISIBLE_DEVICES=5 python main.py data/garden/ --workspace trial_360_garden -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 1 --iters 10000 4 | 5 | CUDA_VISIBLE_DEVICES=5 python main.py data/stump/ --workspace trial_360_stump -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 6 | CUDA_VISIBLE_DEVICES=5 python main.py data/stump/ --workspace trial_360_stump -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 1 --iters 10000 7 | 8 | CUDA_VISIBLE_DEVICES=5 python main.py data/bicycle/ --workspace trial_360_bicycle -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 0 --lambda_entropy 1e-3 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 50 9 | CUDA_VISIBLE_DEVICES=5 python main.py data/bicycle/ --workspace trial_360_bicycle -O --data_format colmap --bound 16 --enable_cam_center --enable_cam_near_far --scale 0.3 --downscale 4 --stage 1 --iters 10000 -------------------------------------------------------------------------------- /scripts/runall_llff.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/fern --workspace trial_llff_fern -O --data_format colmap --bound 4 --downscale 4 --stage 0 --visibility_mask_dilation 50 2 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/fern --workspace trial_llff_fern -O --data_format colmap --bound 4 --downscale 4 --stage 1 --iters 10000 3 | 4 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/flower --workspace trial_llff_flower -O --data_format colmap --bound 4 --downscale 4 --stage 0 --visibility_mask_dilation 50 5 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/flower --workspace trial_llff_flower -O --data_format colmap --bound 4 --downscale 4 --stage 1 --iters 10000 6 | 7 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/fortress --workspace trial_llff_fortress -O --data_format colmap --bound 4 --downscale 4 --stage 0 --visibility_mask_dilation 50 8 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/fortress --workspace trial_llff_fortress -O --data_format colmap --bound 4 --downscale 4 --stage 1 --iters 10000 9 | 10 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/horns --workspace trial_llff_horns -O --data_format colmap --bound 4 --downscale 4 --stage 0 --visibility_mask_dilation 50 11 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/horns --workspace trial_llff_horns -O --data_format colmap --bound 4 --downscale 4 --stage 1 --iters 10000 12 | 13 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/leaves --workspace trial_llff_leaves -O --data_format colmap --bound 4 --downscale 4 --stage 0 --visibility_mask_dilation 50 14 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/leaves --workspace trial_llff_leaves -O --data_format colmap --bound 4 --downscale 4 --stage 1 --iters 10000 15 | 16 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/orchids --workspace trial_llff_orchids -O --data_format colmap --bound 4 --downscale 4 --stage 0 --visibility_mask_dilation 50 17 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/orchids --workspace trial_llff_orchids -O --data_format colmap --bound 4 --downscale 4 --stage 1 --iters 10000 18 | 19 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/room --workspace trial_llff_room -O --data_format colmap --bound 1 --downscale 4 --stage 0 --visibility_mask_dilation 50 20 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/room --workspace trial_llff_room -O --data_format colmap --bound 1 --downscale 4 --stage 1 --iters 10000 21 | 22 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/trex --workspace trial_llff_trex -O --data_format colmap --bound 1 --downscale 4 --stage 0 --visibility_mask_dilation 50 23 | CUDA_VISIBLE_DEVICES=6 python main.py data/nerf_llff_data/trex --workspace trial_llff_trex -O --data_format colmap --bound 1 --downscale 4 --stage 1 --iters 10000 24 | -------------------------------------------------------------------------------- /scripts/runall_outdoor_sdf.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 python main.py data/garden/ --workspace trial_sdf_garden_ori -O --data_format colmap --bound 16 --scale 0.3 --enable_cam_center --stage 0 --sdf --downscale 4 --n_eval 1 --iters 15000 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 10 --decimate_target 1e5 --enable_dense_depth 2 | CUDA_VISIBLE_DEVICES=4 python main.py data/garden/ --workspace trial_sdf_garden_ori -O --data_format colmap --bound 16 --scale 0.3 --enable_cam_center --stage 1 --sdf --downscale 4 --n_eval 1 --iters 5000 --lambda_normal 1e-1 --refine_remesh_size 0.01 3 | 4 | # CUDA_VISIBLE_DEVICES=4 python main.py data/room/ --workspace trial_sdf_room -O --data_format colmap --bound 8 --scale 0.2 --enable_cam_center --stage 0 --sdf --downscale 4 --n_eval 1 --iters 15000 --clean_min_f 16 --clean_min_d 10 --visibility_mask_dilation 10 --decimate_target 1e5 --enable_dense_depth 5 | # CUDA_VISIBLE_DEVICES=4 python main.py data/room/ --workspace trial_sdf_room -O --data_format colmap --bound 8 --scale 0.2 --enable_cam_center --stage 1 --sdf --downscale 4 --n_eval 1 --iters 5000 --lambda_normal 1e-1 --refine_remesh_size 0.01 6 | -------------------------------------------------------------------------------- /scripts/runall_syn.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/lego/ --workspace trial_syn_lego/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 2 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/lego/ --workspace trial_syn_lego/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 3 | 4 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/mic/ --workspace trial_syn_mic/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 5 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/mic/ --workspace trial_syn_mic/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 6 | 7 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/materials/ --workspace trial_syn_materials/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 8 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/materials/ --workspace trial_syn_materials/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 9 | 10 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/chair/ --workspace trial_syn_chair/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 11 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/chair/ --workspace trial_syn_chair/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 12 | 13 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/hotdog/ --workspace trial_syn_hotdog/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 0 14 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/hotdog/ --workspace trial_syn_hotdog/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 1 15 | 16 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ficus/ --workspace trial_syn_ficus/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 17 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ficus/ --workspace trial_syn_ficus/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 18 | 19 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/drums/ --workspace trial_syn_drums/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 20 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/drums/ --workspace trial_syn_drums/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 21 | 22 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ship/ --workspace trial_syn_ship/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 0 23 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ship/ --workspace trial_syn_ship/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 1 -------------------------------------------------------------------------------- /scripts/runall_syn_sdf.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/lego/ --workspace trial_syn_sdf_lego/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --sdf 2 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/lego/ --workspace trial_syn_sdf_lego/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 3 | 4 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/mic/ --workspace trial_syn_sdf_mic/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --sdf 5 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/mic/ --workspace trial_syn_sdf_mic/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 6 | 7 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/materials/ --workspace trial_syn_sdf_materials/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --sdf 8 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/materials/ --workspace trial_syn_sdf_materials/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 9 | 10 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/chair/ --workspace trial_syn_sdf_chair/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --sdf 11 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/chair/ --workspace trial_syn_sdf_chair/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 12 | 13 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/hotdog/ --workspace trial_syn_sdf_hotdog/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 0 --sdf 14 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/hotdog/ --workspace trial_syn_sdf_hotdog/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 1 15 | 16 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ficus/ --workspace trial_syn_sdf_ficus/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --sdf 17 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ficus/ --workspace trial_syn_sdf_ficus/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 18 | 19 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/drums/ --workspace trial_syn_sdf_drums/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 0 --sdf 20 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/drums/ --workspace trial_syn_sdf_drums/ -O --bound 1 --scale 0.8 --dt_gamma 0 --stage 1 21 | 22 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ship/ --workspace trial_syn_sdf_ship/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 0 --sdf 23 | CUDA_VISIBLE_DEVICES=7 python main.py data/nerf_synthetic/ship/ --workspace trial_syn_sdf_ship/ -O --bound 1 --scale 0.7 --dt_gamma 0 --stage 1 -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | from packaging import version 4 | import torch 5 | 6 | torch_version = torch.__version__ 7 | 8 | _src_path = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | nvcc_flags = [ 11 | '-O3', '-std=c++14', 12 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 13 | ] 14 | 15 | if os.name == "posix": 16 | c_flags = ['-O3', '-std=c++14'] 17 | if version.parse(torch_version) >= version.parse("2.1"): 18 | nvcc_flags = [ 19 | '-O3', '-std=c++17', 20 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 21 | '-use_fast_math' 22 | ] 23 | c_flags = ['-O3', '-std=c++17'] 24 | elif os.name == "nt": 25 | c_flags = ['/O2', '/std:c++17'] 26 | 27 | # find cl.exe 28 | def find_cl_path(): 29 | import glob 30 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 31 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 32 | if paths: 33 | return paths[0] 34 | 35 | # If cl.exe is not on path, try to find it. 36 | if os.system("where cl.exe >nul 2>nul") != 0: 37 | cl_path = find_cl_path() 38 | if cl_path is None: 39 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 40 | os.environ["PATH"] += ";" + cl_path 41 | 42 | _backend = load(name='_sh_encoder', 43 | extra_cflags=c_flags, 44 | extra_cuda_cflags=nvcc_flags, 45 | sources=[os.path.join(_src_path, 'src', f) for f in [ 46 | 'shencoder.cu', 47 | 'bindings.cpp', 48 | ]], 49 | ) 50 | 51 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | from packaging import version 5 | import torch 6 | 7 | torch_version = torch.__version__ 8 | 9 | _src_path = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | nvcc_flags = [ 12 | '-O3', '-std=c++14', 13 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | 16 | if os.name == "posix": 17 | c_flags = ['-O3', '-std=c++14'] 18 | if version.parse(torch_version) >= version.parse("2.1"): 19 | nvcc_flags = [ 20 | '-O3', '-std=c++17', 21 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 22 | '-use_fast_math' 23 | ] 24 | c_flags = ['-O3', '-std=c++17'] 25 | elif os.name == "nt": 26 | c_flags = ['/O2', '/std:c++17'] 27 | 28 | # find cl.exe 29 | def find_cl_path(): 30 | import glob 31 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 32 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 33 | if paths: 34 | return paths[0] 35 | 36 | # If cl.exe is not on path, try to find it. 37 | if os.system("where cl.exe >nul 2>nul") != 0: 38 | cl_path = find_cl_path() 39 | if cl_path is None: 40 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 41 | os.environ["PATH"] += ";" + cl_path 42 | 43 | setup( 44 | name='shencoder', # package name, import this to use python API 45 | ext_modules=[ 46 | CUDAExtension( 47 | name='_shencoder', # extension name, import this to use CUDA API 48 | sources=[os.path.join(_src_path, 'src', f) for f in [ 49 | 'shencoder.cu', 50 | 'bindings.cpp', 51 | ]], 52 | extra_compile_args={ 53 | 'cxx': c_flags, 54 | 'nvcc': nvcc_flags, 55 | } 56 | ), 57 | ], 58 | cmdclass={ 59 | 'build_ext': BuildExtension, 60 | } 61 | ) -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = None 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | 37 | return outputs 38 | 39 | @staticmethod 40 | #@once_differentiable 41 | @custom_bwd 42 | def backward(ctx, grad): 43 | # grad: [B, C * C] 44 | 45 | inputs, dy_dx = ctx.saved_tensors 46 | 47 | if dy_dx is not None: 48 | grad = grad.contiguous() 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | # normalize 82 | inputs = inputs / torch.norm(inputs, dim=-1, keepdim=True) 83 | 84 | prefix_shape = list(inputs.shape[:-1]) 85 | inputs = inputs.reshape(-1, self.input_dim) 86 | 87 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 88 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 89 | 90 | return outputs -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); 10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); --------------------------------------------------------------------------------