├── LICENSE ├── README.md ├── configs ├── building.yaml ├── campus.yaml ├── quad.yaml ├── residence.yaml ├── rubble.yaml └── sci-art.yaml ├── eval.sh ├── gp_nerf ├── __pycache__ │ ├── opts.cpython-39.pyc │ ├── rendering_gpnerf.cpython-39.pyc │ ├── runner_gpnerf.cpython-39.pyc │ └── sample_bg.cpython-39.pyc ├── eval.py ├── models │ ├── Plane_module.py │ ├── __pycache__ │ │ ├── Plane_module.cpython-39.pyc │ │ ├── gp_nerf.cpython-39.pyc │ │ └── model_utils.cpython-39.pyc │ ├── gp_nerf.py │ └── model_utils.py ├── opts.py ├── rendering_gpnerf.py ├── runner_gpnerf.py ├── sample_bg.py ├── torch_ngp │ ├── __pycache__ │ │ ├── activation.cpython-39.pyc │ │ └── encoding.cpython-39.pyc │ ├── activation.py │ ├── encoding.py │ ├── gridencoder │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── backend.cpython-39.pyc │ │ │ └── grid.cpython-39.pyc │ │ ├── backend.py │ │ ├── build │ │ │ ├── temp.linux-x86_64-3.9 │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ └── disk1 │ │ │ │ │ └── yuqi │ │ │ │ │ └── code │ │ │ │ │ └── mega-nerf-zyq │ │ │ │ │ └── mega_nerf │ │ │ │ │ └── torch_ngp │ │ │ │ │ └── gridencoder │ │ │ │ │ └── src │ │ │ │ │ ├── bindings.o │ │ │ │ │ └── gridencoder.o │ │ │ └── temp.linux-x86_64-cpython-39 │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ └── disk1 │ │ │ │ └── yuqi │ │ │ │ └── code │ │ │ │ └── mega-nerf-zyq │ │ │ │ └── mega_nerf │ │ │ │ └── torch_ngp │ │ │ │ └── gridencoder │ │ │ │ └── src │ │ │ │ ├── bindings.o │ │ │ │ └── gridencoder.o │ │ ├── dist │ │ │ └── gridencoder-0.0.0-py3.9-linux-x86_64.egg │ │ ├── grid.py │ │ ├── gridencoder.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ └── top_level.txt │ │ ├── setup.py │ │ └── src │ │ │ ├── bindings.cpp │ │ │ ├── gridencoder.cu │ │ │ └── gridencoder.h │ ├── nerf │ │ └── network.py │ ├── raymarching │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── build │ │ │ └── temp.linux-x86_64-cpython-39 │ │ │ │ └── disk1 │ │ │ │ └── yuqi │ │ │ │ └── code │ │ │ │ └── mega-nerf-zyq │ │ │ │ └── mega_nerf │ │ │ │ └── torch_ngp │ │ │ │ └── raymarching │ │ │ │ └── src │ │ │ │ ├── bindings.o │ │ │ │ └── raymarching.o │ │ ├── dist │ │ │ └── raymarching-0.0.0-py3.9-linux-x86_64.egg │ │ ├── raymarching.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ └── top_level.txt │ │ ├── raymarching.py │ │ ├── setup.py │ │ └── src │ │ │ ├── bindings.cpp │ │ │ ├── pcg32.h │ │ │ ├── raymarching.cu │ │ │ └── raymarching.h │ └── shencoder │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── backend.cpython-39.pyc │ │ └── sphere_harmonics.cpython-39.pyc │ │ ├── backend.py │ │ ├── build │ │ ├── temp.linux-x86_64-3.9 │ │ │ ├── .ninja_deps │ │ │ ├── .ninja_log │ │ │ ├── build.ninja │ │ │ └── disk1 │ │ │ │ └── yuqi │ │ │ │ └── code │ │ │ │ └── mega-nerf-zyq │ │ │ │ └── mega_nerf │ │ │ │ └── torch_ngp │ │ │ │ └── shencoder │ │ │ │ └── src │ │ │ │ ├── bindings.o │ │ │ │ └── shencoder.o │ │ └── temp.linux-x86_64-cpython-39 │ │ │ └── disk1 │ │ │ └── yuqi │ │ │ └── code │ │ │ └── mega-nerf-zyq │ │ │ └── mega_nerf │ │ │ └── torch_ngp │ │ │ └── shencoder │ │ │ └── src │ │ │ ├── bindings.o │ │ │ └── shencoder.o │ │ ├── dist │ │ └── shencoder-0.0.0-py3.9-linux-x86_64.egg │ │ ├── setup.py │ │ ├── shencoder.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ │ ├── sphere_harmonics.py │ │ └── src │ │ ├── bindings.cpp │ │ ├── shencoder.cu │ │ └── shencoder.h └── train.py ├── mega_nerf ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── image_metadata.cpython-39.pyc │ ├── metrics.cpython-39.pyc │ ├── misc_utils.cpython-39.pyc │ ├── ray_utils.cpython-39.pyc │ └── spherical_harmonics.cpython-39.pyc ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── dataset_utils.cpython-39.pyc │ │ ├── filesystem_dataset.cpython-39.pyc │ │ └── memory_dataset.cpython-39.pyc │ ├── dataset_utils.py │ ├── filesystem_dataset.py │ └── memory_dataset.py ├── image_metadata.py ├── metrics.py ├── misc_utils.py ├── models │ ├── __init__.py │ ├── cascade.py │ ├── mega_nerf.py │ ├── mega_nerf_container.py │ └── nerf.py ├── ray_utils.py └── spherical_harmonics.py ├── requirements.txt ├── scripts ├── colmap_to_mega_nerf.py ├── convert_to_container.py ├── copy_images.py └── create_cluster_masks.py └── train.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 zhangyuqi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GP-NeRF 2 | 3 | This repository contains the code needed to train GP-NeRF. 4 | 5 | **Note:** This is a preliminary release and there may still be outstanding bugs. 6 | 7 | 8 | 9 | ## Setup 10 | ### Create new conda env ([pytorch-version](https://pytorch.org/get-started/previous-versions/), [CUDA](https://developer.nvidia.com/cuda-toolkit-archive)) 11 | ``` 12 | conda create -n gpnerf python=3.9 13 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | 18 | ### Install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) 19 | 20 | ``` 21 | tiny-cuda-nn$ cd bindings/torch 22 | tiny-cuda-nn/bindings/torch$ python setup.py install 23 | ``` 24 | 25 | 26 | ## Pretrained Models 27 | coming soon. 28 | 29 | 30 | ## Datasets 31 | Please refer to [Mega-NeRF](https://github.com/cmusatyalab/mega-nerf#data) for downloading the datasets. 32 | ```none 33 | MegaNeRF 34 | ├── Mill19 35 | │ ├── building 36 | │ │ ├── building-pixsfm 37 | │ ├── rubble 38 | │ │ ├── rubble-pixsfm 39 | │ ├── building_chunk-1 (auto processed by scripts when training if use "--dataset_type filesystem". You can choose another path to save.) 40 | │ ├── rubble_chunk-1 41 | ├── Quad6k (the same as above) 42 | │ ├── quad 43 | ├── UrbanScene3D (the same as above) 44 | │ ├── residence 45 | │ ├── sci-art 46 | │ ├── campus 47 | ``` 48 | 49 | ## Training 50 | 51 | ``` 52 | python gp_nerf/train.py --config_file configs/${DATASET_NAME}.yml --exp_name $EXP_PATH --dataset_path $DATASET_PATH --chunk_paths $CHUNK_PATH 53 | ``` 54 | At the first time of running, it takes some times to write the chunk into the disk. 55 | ## Evaluation 56 | 57 | ``` 58 | python gp_nerf/eval.py --config_file configs/${DATASET_NAME}.yaml --exp_name $EXP_NAME --dataset_path $DATASET_PATH --ckpt_path $ckpt_path 59 | ``` 60 | 61 | 62 | ## Acknowledgements 63 | 64 | Large parts of this codebase are based on existing work in the [Mega-NeRF](https://github.com/cmusatyalab/mega-nerf) and [torch-ngp](https://github.com/ashawkey/torch-ngp) repositories. 65 | -------------------------------------------------------------------------------- /configs/building.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [8, 50] 2 | -------------------------------------------------------------------------------- /configs/campus.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [3, 132] 2 | -------------------------------------------------------------------------------- /configs/quad.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [-55, 10] 2 | val_scale_factor: 1 3 | near: 0 4 | all_val: True 5 | cluster_2d: True 6 | -------------------------------------------------------------------------------- /configs/residence.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [30, 118] 2 | -------------------------------------------------------------------------------- /configs/rubble.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [11, 38] 2 | -------------------------------------------------------------------------------- /configs/sci-art.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [-14, 70] 2 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export OMP_NUM_THREADS=4 3 | export CUDA_VISIBLE_DEVICES=4 4 | 5 | exp_name=logs/eval 6 | ckpt_path= # give the checkpoint path 7 | dataset1='Mill19' # "Mill19" "Quad6k" "UrbanScene3D" 8 | dataset2='building' # "building" "rubble" "quad" "residence" "sci-art" "campus" 9 | python gp_nerf/eval.py --config_file configs/$dataset2.yaml --dataset_path /data/yuqi/Datasets/MegaNeRF/$dataset1/$dataset2/$dataset2-pixsfm --exp_name $exp_name --ckpt_path $ckpt_path 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /gp_nerf/__pycache__/opts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/__pycache__/opts.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/__pycache__/rendering_gpnerf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/__pycache__/rendering_gpnerf.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/__pycache__/runner_gpnerf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/__pycache__/runner_gpnerf.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/__pycache__/sample_bg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/__pycache__/sample_bg.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/eval.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | import sys 7 | sys.path.append('.') 8 | 9 | from gp_nerf.opts import get_opts_base 10 | 11 | 12 | 13 | def _get_eval_opts() -> Namespace: 14 | parser = get_opts_base() 15 | 16 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 17 | parser.add_argument('--dataset_path', type=str, required=True) 18 | parser.add_argument('--centroid_path', type=str) 19 | 20 | return parser.parse_args() 21 | 22 | @record 23 | def main(hparams: Namespace) -> None: 24 | assert hparams.ckpt_path is not None or hparams.container_path is not None 25 | from gp_nerf.runner_gpnerf import Runner 26 | 27 | print("run clean version, remove the bg nerf") 28 | hparams.bg_nerf = False 29 | 30 | if hparams.detect_anomalies: 31 | with torch.autograd.detect_anomaly(): 32 | Runner(hparams).eval() 33 | else: 34 | Runner(hparams).eval() 35 | 36 | 37 | if __name__ == '__main__': 38 | main(_get_eval_opts()) 39 | -------------------------------------------------------------------------------- /gp_nerf/models/Plane_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import tinycudann as tcnn 6 | 7 | #zyq : torch-ngp 8 | import sys 9 | import os 10 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 11 | 12 | class Plane_v7(nn.Module): 13 | def __init__(self,hparams, 14 | desired_resolution=1024, 15 | base_solution=128, 16 | n_levels=4, 17 | ): 18 | super(Plane_v7, self).__init__() 19 | 20 | per_level_scale = np.exp2(np.log2(desired_resolution / base_solution) / (int(n_levels) - 1)) 21 | encoding_2d_config = { 22 | "otype": "Grid", 23 | "type": "Dense", 24 | "n_levels": n_levels, 25 | "n_features_per_level": 2, 26 | "base_resolution": base_solution, 27 | "per_level_scale":per_level_scale, 28 | } 29 | self.xy = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 30 | self.yz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 31 | self.xz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 32 | self.feat_dim = n_levels * 2 *3 33 | 34 | def forward(self, x, bound): 35 | x = (x + bound) / (2 * bound) # zyq: map to [0, 1] 36 | xy_feat = self.xy(x[:, [0, 1]]) 37 | yz_feat = self.yz(x[:, [0, 2]]) 38 | xz_feat = self.xz(x[:, [1, 2]]) 39 | return torch.cat([xy_feat, yz_feat, xz_feat], dim=-1) 40 | 41 | def get_Plane_encoder(hparams, **kwargs): 42 | plane_encoder = Plane_v7(hparams) 43 | plane_feat_dim = plane_encoder.feat_dim 44 | return plane_encoder, plane_feat_dim 45 | 46 | -------------------------------------------------------------------------------- /gp_nerf/models/__pycache__/Plane_module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/models/__pycache__/Plane_module.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/models/__pycache__/gp_nerf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/models/__pycache__/gp_nerf.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/models/__pycache__/model_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/models/__pycache__/model_utils.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/models/gp_nerf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | # zyq : torch-ngp 8 | import sys 9 | import os 10 | 11 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 12 | from gp_nerf.torch_ngp.encoding import get_encoder 13 | from gp_nerf.torch_ngp.activation import trunc_exp 14 | from gp_nerf.models.Plane_module import get_Plane_encoder 15 | 16 | 17 | timer = 0 18 | 19 | 20 | class NeRF(nn.Module): 21 | def __init__(self, pos_xyz_dim: int, # 12 positional embedding 22 | pos_dir_dim: int, # 4 positional embedding 23 | layers: int, # 8 24 | skip_layers: List[int], # [4] 25 | layer_dim: int, # 256 26 | appearance_dim: int, # 48 27 | affine_appearance: bool, # affine_appearance : False 28 | appearance_count: int, # appearance_count : number of images (for rubble is 1678) 29 | rgb_dim: int, # rgb_dim : 3 30 | xyz_dim: int, # xyz_dim : fg = 3, bg =4 31 | sigma_activation: nn.Module, hparams): 32 | super(NeRF, self).__init__() 33 | self.layer_dim = layer_dim 34 | print("layer_dim: {}".format(self.layer_dim)) 35 | self.appearance_count = appearance_count 36 | self.appearance_dim = appearance_dim 37 | self.num_layers = hparams.num_layers 38 | self.num_layers_color = hparams.num_layers_color 39 | self.geo_feat_dim = hparams.geo_feat_dim 40 | 41 | #hash 42 | base_resolution = hparams.base_resolution 43 | desired_resolution = hparams.desired_resolution 44 | log2_hashmap_size = hparams.log2_hashmap_size 45 | num_levels = hparams.num_levels 46 | 47 | self.fg_bound = 1 48 | self.bg_bound = 1+hparams.contract_bg_len 49 | self.xyz_dim = xyz_dim 50 | 51 | #plane 52 | self.use_scaling = hparams.use_scaling 53 | if self.use_scaling: 54 | if 'quad' in hparams.dataset_path or 'sci' in hparams.dataset_path: 55 | self.scaling_factor_ground = (abs(hparams.sphere_center[1:]) + abs(hparams.sphere_radius[1:])) / hparams.aabb_bound 56 | self.scaling_factor_altitude_bottom = 0 57 | self.scaling_factor_altitude_range = (abs(hparams.sphere_center[0]) + abs(hparams.sphere_radius[0])) / hparams.aabb_bound 58 | else: 59 | self.scaling_factor_ground = (abs(hparams.sphere_center[1:]) + abs(hparams.sphere_radius[1:])) / hparams.aabb_bound 60 | self.scaling_factor_altitude_bottom = 0.5 * (hparams.z_range[0]+ hparams.z_range[1])/ hparams.aabb_bound 61 | self.scaling_factor_altitude_range = (hparams.z_range[1]-hparams.z_range[0]) / (2 * hparams.aabb_bound) 62 | 63 | self.embedding_a = nn.Embedding(self.appearance_count, self.appearance_dim) 64 | if 'quad' in hparams.dataset_path: 65 | desired_resolution_fg = desired_resolution * hparams.quad_factor 66 | print("Quad6k") 67 | else: 68 | desired_resolution_fg = desired_resolution 69 | encoding = "hashgrid" 70 | 71 | print("use two mlp") 72 | self.encoder, self.in_dim = get_encoder(encoding, base_resolution=base_resolution, 73 | desired_resolution=desired_resolution_fg, 74 | log2_hashmap_size=log2_hashmap_size, num_levels=num_levels) 75 | self.encoder_bg, _ = get_encoder(encoding, base_resolution=base_resolution, 76 | desired_resolution=desired_resolution, 77 | log2_hashmap_size=19, num_levels=num_levels) 78 | 79 | 80 | self.plane_encoder, self.plane_dim = get_Plane_encoder(hparams) 81 | self.sigma_net, self.color_net, self.encoder_dir = self.get_nerf_mlp() 82 | self.sigma_net_bg, self.color_net_bg, self.encoder_dir_bg = self.get_nerf_mlp(nerf_type='bg') 83 | 84 | def get_nerf_mlp(self, nerf_type='fg'): 85 | encoding_dir = "sphere_harmonics" 86 | geo_feat_dim = self.geo_feat_dim 87 | sigma_nets = [] 88 | for l in range(self.num_layers): 89 | if l == 0: 90 | in_dim = self.in_dim 91 | print("Hash and Plane") 92 | if nerf_type == 'fg': 93 | in_dim = in_dim + self.plane_dim 94 | 95 | else: 96 | in_dim = self.layer_dim # 64 97 | if l == self.num_layers - 1: 98 | out_dim = 1 + geo_feat_dim # 1 sigma + 15 SH features for color 99 | else: 100 | out_dim = self.layer_dim 101 | sigma_nets.append(nn.Linear(in_dim, out_dim, bias=False)) 102 | 103 | sigma_net = nn.ModuleList(sigma_nets) 104 | encoder_dir, in_dim_dir = get_encoder(encoding_dir) 105 | color_nets = [] 106 | for l in range(self.num_layers_color): 107 | if l == 0: 108 | in_dim = in_dim_dir + geo_feat_dim + self.appearance_dim 109 | if nerf_type == 'fg': 110 | in_dim = in_dim + self.plane_dim 111 | else: 112 | in_dim = self.layer_dim 113 | 114 | if l == self.num_layers_color - 1: 115 | out_dim = 3 # rgb 116 | else: 117 | out_dim = self.layer_dim 118 | 119 | color_nets.append(nn.Linear(in_dim, out_dim, bias=False)) 120 | 121 | color_net = nn.ModuleList(color_nets) 122 | return sigma_net, color_net, encoder_dir 123 | 124 | def forward(self, point_type, x: torch.Tensor, sigma_only: bool = False, 125 | sigma_noise: Optional[torch.Tensor] = None,train_iterations=-1) -> torch.Tensor: 126 | if point_type == 'fg': 127 | out = self.forward_fg(point_type, x, sigma_only, sigma_noise,train_iterations=train_iterations) 128 | elif point_type == 'bg': 129 | out = self.forward_bg(point_type, x, sigma_only, sigma_noise,train_iterations=train_iterations) 130 | else: 131 | NotImplementedError('Unkonwn point type') 132 | return out 133 | def forward_fg(self, point_type, x: torch.Tensor, sigma_only: bool = False, sigma_noise: Optional[torch.Tensor] = None,train_iterations=-1) -> torch.Tensor: 134 | 135 | position = x[:, :self.xyz_dim] 136 | h = self.encoder(position, bound=self.fg_bound) 137 | 138 | if self.use_scaling: 139 | position[:, 0] = (position[:, 0]-self.scaling_factor_altitude_bottom)/self.scaling_factor_altitude_range 140 | position[:, 1:] = position[:, 1:] / self.scaling_factor_ground 141 | plane_feat = self.plane_encoder(position, bound=self.fg_bound) 142 | h = torch.cat([h, plane_feat], dim=-1) 143 | 144 | for l in range(self.num_layers): 145 | h = self.sigma_net[l](h) 146 | if l != self.num_layers - 1: 147 | h = F.relu(h, inplace=True) 148 | sigma = trunc_exp(h[..., 0]) 149 | geo_feat = h[..., 1:] 150 | 151 | # color 152 | d = x[:, self.xyz_dim:-1] 153 | d = self.encoder_dir(d) 154 | a = self.embedding_a(x[:, -1].long()) 155 | h = torch.cat([d, geo_feat, a, plane_feat], dim=-1) 156 | for l in range(self.num_layers_color): 157 | h = self.color_net[l](h) 158 | if l != self.num_layers_color - 1: 159 | h = F.relu(h, inplace=True) 160 | # sigmoid activation for rgb 161 | color = torch.sigmoid(h) 162 | return torch.cat([color, sigma.unsqueeze(1)], -1) 163 | 164 | def forward_bg(self, point_type, x: torch.Tensor, sigma_only: bool = False, sigma_noise: Optional[torch.Tensor] = None,train_iterations=-1) -> torch.Tensor: 165 | position = x[:, :self.xyz_dim] 166 | h = self.encoder_bg(position, bound=self.bg_bound) 167 | 168 | for l in range(self.num_layers): 169 | h = self.sigma_net_bg[l](h) 170 | if l != self.num_layers - 1: 171 | h = F.relu(h, inplace=True) 172 | sigma = trunc_exp(h[..., 0]) 173 | geo_feat = h[..., 1:] 174 | 175 | # color 176 | d = x[:, self.xyz_dim:-1] 177 | d = self.encoder_dir_bg(d) 178 | a = self.embedding_a(x[:, -1].long()) 179 | h = torch.cat([d, geo_feat, a], dim=-1) 180 | for l in range(self.num_layers_color): 181 | h = self.color_net_bg[l](h) 182 | if l != self.num_layers_color - 1: 183 | h = F.relu(h, inplace=True) 184 | # sigmoid activation for rgb 185 | color = torch.sigmoid(h) 186 | 187 | return torch.cat([color, sigma.unsqueeze(1)], -1) 188 | 189 | 190 | class Embedding(nn.Module): 191 | def __init__(self, num_freqs: int, logscale=True): 192 | """ 193 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 194 | """ 195 | super(Embedding, self).__init__() 196 | 197 | if logscale: 198 | self.freq_bands = 2 ** torch.linspace(0, num_freqs - 1, num_freqs) 199 | else: 200 | self.freq_bands = torch.linspace(1, 2 ** (num_freqs - 1), num_freqs) 201 | 202 | def forward(self, x: torch.Tensor) -> torch.Tensor: 203 | out = [x] 204 | for freq in self.freq_bands: 205 | out += [torch.sin(freq * x), torch.cos(freq * x)] 206 | 207 | return torch.cat(out, -1) 208 | 209 | 210 | class ShiftedSoftplus(nn.Module): 211 | __constants__ = ['beta', 'threshold'] 212 | beta: int 213 | threshold: int 214 | 215 | def __init__(self, beta: int = 1, threshold: int = 20) -> None: 216 | super(ShiftedSoftplus, self).__init__() 217 | self.beta = beta 218 | self.threshold = threshold 219 | 220 | def forward(self, x: torch.Tensor) -> torch.Tensor: 221 | return F.softplus(x - 1, self.beta, self.threshold) 222 | 223 | def extra_repr(self) -> str: 224 | return 'beta={}, threshold={}'.format(self.beta, self.threshold) 225 | 226 | -------------------------------------------------------------------------------- /gp_nerf/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch 3 | from torch import nn 4 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 5 | 6 | 7 | def get_nerf(hparams: Namespace, appearance_count: int, construct_container: bool = True) -> nn.Module: 8 | return _get_nerf_inner(hparams, appearance_count, hparams.layer_dim, 3, 'model_state_dict', construct_container) 9 | 10 | def get_bg_nerf(hparams: Namespace, appearance_count: int, construct_container: bool = True) -> nn.Module: 11 | return _get_nerf_inner(hparams, appearance_count, hparams.bg_layer_dim, 4, 'bg_model_state_dict', construct_container) 12 | 13 | def _get_nerf_inner(hparams: Namespace, appearance_count: int, layer_dim: int, xyz_dim: int, 14 | weight_key: str, construct_container: bool = True) -> nn.Module: 15 | nerf = _get_single_nerf_inner(hparams, appearance_count, layer_dim, xyz_dim) 16 | if hparams.ckpt_path is not None: 17 | state_dict = torch.load(hparams.ckpt_path, map_location='cpu')[weight_key] 18 | consume_prefix_in_state_dict_if_present(state_dict, prefix='module.') 19 | model_dict = nerf.state_dict() 20 | model_dict.update(state_dict) 21 | nerf.load_state_dict(model_dict) 22 | return nerf 23 | 24 | def _get_single_nerf_inner(hparams: Namespace, appearance_count: int, layer_dim: int, xyz_dim: int) -> nn.Module: 25 | rgb_dim = 3 * ((hparams.sh_deg + 1) ** 2) if hparams.sh_deg is not None else 3 26 | 27 | from gp_nerf.models.gp_nerf import NeRF, ShiftedSoftplus 28 | return NeRF(hparams.pos_xyz_dim, 29 | hparams.pos_dir_dim, 30 | hparams.layers, 31 | hparams.skip_layers, 32 | layer_dim, 33 | hparams.appearance_dim, 34 | hparams.affine_appearance, 35 | appearance_count, 36 | rgb_dim, 37 | xyz_dim, 38 | ShiftedSoftplus() if hparams.shifted_softplus else nn.ReLU(), 39 | hparams) 40 | -------------------------------------------------------------------------------- /gp_nerf/opts.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | 4 | def get_opts_base(): 5 | 6 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 7 | parser.add_argument('--gpnerf', default=True, type=eval, choices=[True, False], help='if true use gp-nerf, else mega-nerf') 8 | 9 | # network setting 10 | parser.add_argument('--num_layers', type=int, default=2, help='change our sigma layer') 11 | parser.add_argument('--num_layers_color', type=int, default=3, help='change our color layer') 12 | parser.add_argument('--layer_dim', type=int, default=64, help='number of channels in foreground MLP') 13 | parser.add_argument('--appearance_dim', type=int, default=48, help='dimension of appearance embedding vector (set to 0 to disable)') 14 | parser.add_argument('--geo_feat_dim', type=int, default=15, help='') 15 | 16 | parser.add_argument('--num_levels', type=int, default=16, help='') 17 | parser.add_argument('--base_resolution', type=int, default=16, help='') 18 | parser.add_argument('--desired_resolution', type=int, default=2048, help='') 19 | parser.add_argument('--log2_hashmap_size', type=int, default=19, help='') 20 | 21 | # logger 22 | parser.add_argument('--writer_log', default=True, type=eval, choices=[True, False], help='') 23 | parser.add_argument('--wandb_id', default='None', type=str, help='') 24 | parser.add_argument('--wandb_run_name', default='test', type=str, help='') 25 | 26 | parser.add_argument('--use_scaling', default=True, type=eval, choices=[True, False], help='scale plane feature') 27 | parser.add_argument('--contract_norm', type=str, default='l2', choices=['l2', 'inf'], help='') 28 | parser.add_argument('--contract_bg_len', default=1, type=float, help='set 0.4 of 1:1') 29 | parser.add_argument('--aabb_bound', default=1.6, type=float, help='work only when not use ellipsoid') 30 | parser.add_argument('--quad_factor', default=6, type=float, help='') 31 | 32 | 33 | parser.add_argument('--train_iterations', type=int, default=100000, help='training iterations') 34 | parser.add_argument('--val_interval', type=int, default=100000, help='validation interval') 35 | parser.add_argument('--ckpt_interval', type=int, default=100000, help='checkpoint interval') 36 | parser.add_argument('--model_chunk_size', type=int, default=10*1024*1024, help='chunk size to split the input to avoid OOM') 37 | # parser.add_argument('--model_chunk_size', type=int, default=32 * 1024, help='chunk size to split the input to avoid OOM') 38 | parser.add_argument('--batch_size', type=int, default=5120, help='batch size') 39 | parser.add_argument('--coarse_samples', type=int, default=128, help='number of coarse samples') 40 | parser.add_argument('--fine_samples', type=int, default=128, help='number of additional fine samples') 41 | 42 | parser.add_argument('--ckpt_path', type=str, default=None, help='path towards serialized model checkpoint') 43 | parser.add_argument('--config_file', is_config_file=True) 44 | parser.add_argument('--dataset_type', type=str, default='filesystem', choices=['filesystem', 'memory'], 45 | help="""specifies whether to hold all images in CPU memory during training, or whether to write randomized 46 | batches or pixels/rays to disk""") 47 | parser.add_argument('--chunk_paths', type=str, nargs='+', default=None, 48 | help="""scratch directory to write shuffled batches to when training using the filesystem dataset. 49 | Should be set to a non-existent path when first created, and can then be reused by subsequent training runs once all chunks are written""") 50 | parser.add_argument('--desired_chunks', type=int, default=60, 51 | help='due to the long time and hugh space consumption,we only keep part of chunk') 52 | parser.add_argument('--num_chunks', type=int, default=300, 53 | help='number of shuffled chunk files to write to disk. Each chunk should be small enough to fit into CPU memory') 54 | 55 | 56 | parser.add_argument('--disk_flush_size', type=int, default=10000000) 57 | parser.add_argument('--train_every', type=int, default=1, 58 | help='if set to larger than 1, subsamples each n training images') 59 | parser.add_argument('--cluster_mask_path', type=str, default=None, 60 | help='directory containing pixel masks for all training images (generated by create_cluster_masks.py)') 61 | parser.add_argument('--container_path', type=str, default=None, 62 | help='path towards merged Mega-NeRF model generated by merged_submodules.py') 63 | parser.add_argument('--bg_layer_dim', type=int, default=256, help='number of channels in background MLP, NO use in gpnerf') 64 | 65 | parser.add_argument('--near', type=float, default=1, help='ray near bounds') 66 | parser.add_argument('--far', type=float, default=None, 67 | help='ray far bounds. Will be automatically set if not explicitly set') 68 | parser.add_argument('--ray_altitude_range', nargs='+', type=float, default=None, 69 | help='constrains ray sampling to the given altitude') 70 | parser.add_argument('--train_scale_factor', type=int, default=1, 71 | help='downsamples training images if greater than 1') 72 | parser.add_argument('--val_scale_factor', type=int, default=4, 73 | help='downsamples validation images if greater than 1') 74 | 75 | parser.add_argument('--pos_xyz_dim', type=int, default=12, 76 | help='frequency encoding dimension applied to xyz position') 77 | parser.add_argument('--pos_dir_dim', type=int, default=4, 78 | help='frequency encoding dimension applied to view direction (set to 0 to disable)') 79 | parser.add_argument('--layers', type=int, default=8, help='number of layers in MLP') 80 | parser.add_argument('--skip_layers', type=int, nargs='+', default=[4], help='indices of the skip connections') 81 | parser.add_argument('--affine_appearance', default=False, action='store_true', 82 | help='set to true to use affine transformation for appearance instead of latent embedding') 83 | 84 | parser.add_argument('--use_cascade', default=False, action='store_true', 85 | help='use separate MLPs to query coarse and fine samples') 86 | parser.add_argument('--train_mega_nerf', type=str, default=None, 87 | help='directory train a Mega-NeRF architecture (point this towards the params.pt file generated by create_cluster_masks.py)') 88 | parser.add_argument('--boundary_margin', type=float, default=1.15, 89 | help='overlap factor between different spatial cells') 90 | parser.add_argument('--all_val', default=False, action='store_true', 91 | help='use all pixels for validation images instead of those specified in cluster masks') 92 | parser.add_argument('--cluster_2d', default=False, action='store_true', help='cluster without altitude dimension') 93 | 94 | parser.add_argument('--sh_deg', type=int, default=None, help='use spherical harmonics (pos_dir_dim should be set to 0)') 95 | parser.add_argument('--no_center_pixels', dest='center_pixels', default=True, action='store_false', help='do not shift pixels by +0.5 when computing ray directions') 96 | parser.add_argument('--no_shifted_softplus', dest='shifted_softplus', default=True, action='store_false', help='use ReLU instead of shifted softplus activation') 97 | parser.add_argument('--image_pixel_batch_size', type=int, default=16 * 1024, help='number of pixels to evaluate per split when rendering validation images') 98 | parser.add_argument('--perturb', type=float, default=1.0, help='factor to perturb depth sampling points') 99 | parser.add_argument('--noise_std', type=float, default=1.0, help='std dev of noise added to regularize sigma') 100 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 101 | parser.add_argument('--lr_decay_factor', type=float, default=1, help='learning rate decay factor') 102 | parser.add_argument('--no_bg_nerf', dest='bg_nerf', default=True, action='store_false',help='do not use background MLP') 103 | parser.add_argument('--ellipse_scale_factor', type=float, default=1.1, help='Factor to scale foreground bounds') 104 | parser.add_argument('--no_ellipse_bounds', dest='ellipse_bounds', default=True, action='store_false', help='use spherical foreground bounds instead of ellipse') 105 | parser.add_argument('--no_resume_ckpt_state', dest='resume_ckpt_state', default=True, action='store_false') 106 | parser.add_argument('--no_amp', dest='amp', default=True, action='store_false') 107 | parser.add_argument('--detect_anomalies', default=False, action='store_true') 108 | parser.add_argument('--random_seed', type=int, default=42) 109 | 110 | return parser 111 | -------------------------------------------------------------------------------- /gp_nerf/sample_bg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def bg_sample_inv(near, far, point_num, device): 5 | z = torch.linspace(0, 1, point_num, device=device) 6 | z_vals = 1. / near * (1 - z) + 1. / far * (z) # linear combination in the inveres space 7 | z_vals = 1. / z_vals # inverse back 8 | return z_vals 9 | 10 | 11 | #@torch.no_grad() 12 | def contract_to_unisphere(x: torch.Tensor, hparams): 13 | 14 | aabb_bound = hparams.aabb_bound 15 | aabb = torch.tensor([-aabb_bound, -aabb_bound, -aabb_bound, aabb_bound, aabb_bound, aabb_bound]).to(x.device) 16 | 17 | aabb_min, aabb_max = torch.split(aabb, 3, dim=-1) 18 | x = (x - aabb_min) / (aabb_max - aabb_min) 19 | x = x * 2 - 1 # aabb is at [-1, 1] 20 | if hparams.contract_norm == 'inf': 21 | mag = x.abs().amax(dim=-1, keepdim=True) 22 | elif hparams.contract_norm == 'l2': 23 | mag = x.norm(dim=-1, keepdim=True) 24 | else: 25 | print("the norm of contract is wrong!") 26 | raise NotImplementedError 27 | mask = mag.squeeze(-1) > 1 28 | x[mask] = (1 + hparams.contract_bg_len - hparams.contract_bg_len / mag[mask]) * (x[mask] / mag[mask]) # out of bound points trun to [-2, 2] 29 | return x 30 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/__pycache__/activation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/__pycache__/activation.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/__pycache__/encoding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/__pycache__/encoding.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | class _trunc_exp(Function): 6 | @staticmethod 7 | @custom_fwd(cast_inputs=torch.float32) # cast to float32 8 | def forward(ctx, x): 9 | ctx.save_for_backward(x) 10 | return torch.exp(x) 11 | 12 | @staticmethod 13 | @custom_bwd 14 | def backward(ctx, g): 15 | x = ctx.saved_tensors[0] 16 | return g * torch.exp(x.clamp(-15, 15)) 17 | 18 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | 16 | self.output_dim = 0 17 | if self.include_input: 18 | self.output_dim += self.input_dim 19 | 20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 21 | 22 | if log_sampling: 23 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 24 | else: 25 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 26 | 27 | self.freq_bands = self.freq_bands.numpy().tolist() 28 | 29 | def forward(self, input, **kwargs): 30 | 31 | out = [] 32 | if self.include_input: 33 | out.append(input) 34 | 35 | for i in range(len(self.freq_bands)): 36 | freq = self.freq_bands[i] 37 | for p_fn in self.periodic_fns: 38 | out.append(p_fn(input * freq)) 39 | 40 | out = torch.cat(out, dim=-1) 41 | 42 | 43 | return out 44 | 45 | def get_encoder(encoding, input_dim=3, 46 | multires=6, 47 | degree=4, 48 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 49 | **kwargs): 50 | 51 | if encoding == 'None': 52 | return lambda x, **kwargs: x, input_dim 53 | 54 | elif encoding == 'frequency': 55 | encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 56 | 57 | elif encoding == 'sphere_harmonics': 58 | from gp_nerf.torch_ngp.shencoder import SHEncoder 59 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 60 | 61 | elif encoding == 'hashgrid': 62 | from gp_nerf.torch_ngp.gridencoder import GridEncoder 63 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 64 | 65 | elif encoding == 'tiledgrid': 66 | from gp_nerf.torch_ngp.gridencoder import GridEncoder 67 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 68 | 69 | # elif encoding == 'ash': 70 | # from ashencoder import AshEncoder 71 | # encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) 72 | 73 | else: 74 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 75 | 76 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/__pycache__/backend.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/__pycache__/backend.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/__pycache__/grid.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/__pycache__/grid.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/.ninja_deps -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 9844 1660292644788444399 /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o 677561ebc163e07e 3 | 0 41507 1660292676447984634 /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o ba7906cda7ca41b1 4 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-11.3/bin/nvcc 4 | 5 | cflags = -pthread -B /home/yuqi/anaconda3/envs/mega-ingp/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/yuqi/anaconda3/envs/mega-ingp/include -I/home/yuqi/anaconda3/envs/mega-ingp/include -fPIC -O2 -isystem /home/yuqi/anaconda3/envs/mega-ingp/include -fPIC -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/TH -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/home/yuqi/anaconda3/envs/mega-ingp/include/python3.9 -c 6 | post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 7 | cuda_cflags = -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/TH -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/home/yuqi/anaconda3/envs/mega-ingp/include/python3.9 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o: compile /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.cpp 24 | build /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o: cuda_compile /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.cu 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 9068 1666772617694144721 /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o 2f8245e02f4fb6d 3 | 0 44580 1666772653208037715 /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o e4c6afb555606523 4 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-11.3/bin/nvcc 4 | 5 | cflags = -pthread -B /home/yuqi/anaconda3/envs/mega-ingp/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/yuqi/anaconda3/envs/mega-ingp/include -I/home/yuqi/anaconda3/envs/mega-ingp/include -fPIC -O2 -isystem /home/yuqi/anaconda3/envs/mega-ingp/include -fPIC -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/TH -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/home/yuqi/anaconda3/envs/mega-ingp/include/python3.9 -c 6 | post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 7 | cuda_cflags = -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/TH -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/home/yuqi/anaconda3/envs/mega-ingp/include/python3.9 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o: compile /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.cpp 24 | build /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o: cuda_compile /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.cu 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/dist/gridencoder-0.0.0-py3.9-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/gridencoder/dist/gridencoder-0.0.0-py3.9-linux-x86_64.egg -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | class _grid_encode(Function): 20 | @staticmethod 21 | @custom_fwd 22 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False): 23 | # inputs: [B, D], float in [0, 1] 24 | # embeddings: [sO, C], float 25 | # offsets: [L + 1], int 26 | # RETURN: [B, F], float 27 | 28 | inputs = inputs.contiguous() 29 | 30 | B, D = inputs.shape # batch size, coord dim 31 | L = offsets.shape[0] - 1 # level 32 | C = embeddings.shape[1] # embedding dim for each level 33 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 34 | H = base_resolution # base resolution 35 | 36 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 37 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 38 | if torch.is_autocast_enabled() and C % 2 == 0: 39 | embeddings = embeddings.to(torch.half) 40 | 41 | # L first, optimize cache for cuda kernel, but needs an extra permute later 42 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 43 | 44 | if calc_grad_inputs: 45 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) 46 | else: 47 | dy_dx = torch.empty(1, device=inputs.device, dtype=embeddings.dtype) # placeholder... TODO: a better way? 48 | 49 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners) 50 | 51 | # permute back to [B, L * C] 52 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 53 | 54 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 55 | ctx.dims = [B, D, C, L, S, H, gridtype] 56 | ctx.calc_grad_inputs = calc_grad_inputs 57 | ctx.align_corners = align_corners 58 | 59 | return outputs 60 | 61 | @staticmethod 62 | #@once_differentiable 63 | @custom_bwd 64 | def backward(ctx, grad): 65 | 66 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 67 | B, D, C, L, S, H, gridtype = ctx.dims 68 | calc_grad_inputs = ctx.calc_grad_inputs 69 | align_corners = ctx.align_corners 70 | 71 | # grad: [B, L * C] --> [L, B, C] 72 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 73 | 74 | grad_embeddings = torch.zeros_like(embeddings) 75 | 76 | if calc_grad_inputs: 77 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 78 | else: 79 | grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype) 80 | 81 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners) 82 | 83 | if calc_grad_inputs: 84 | grad_inputs = grad_inputs.to(inputs.dtype) 85 | return grad_inputs, grad_embeddings, None, None, None, None, None, None 86 | else: 87 | return None, grad_embeddings, None, None, None, None, None, None 88 | 89 | 90 | grid_encode = _grid_encode.apply 91 | 92 | 93 | class GridEncoder(nn.Module): 94 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False): 95 | super().__init__() 96 | 97 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 98 | if desired_resolution is not None: 99 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 100 | 101 | self.input_dim = input_dim # coord dims, 2 or 3 102 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 103 | self.level_dim = level_dim # encode channels per level 104 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 105 | self.log2_hashmap_size = log2_hashmap_size 106 | self.base_resolution = base_resolution 107 | self.output_dim = num_levels * level_dim 108 | self.gridtype = gridtype 109 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 110 | self.align_corners = align_corners 111 | 112 | # allocate parameters 113 | offsets = [] 114 | offset = 0 115 | self.max_params = 2 ** log2_hashmap_size 116 | for i in range(num_levels): 117 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 118 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 119 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 120 | offsets.append(offset) 121 | offset += params_in_level 122 | offsets.append(offset) 123 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 124 | self.register_buffer('offsets', offsets) 125 | 126 | self.n_params = offsets[-1] * level_dim 127 | 128 | # parameters 129 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 130 | 131 | self.reset_parameters() 132 | 133 | def reset_parameters(self): 134 | std = 1e-4 135 | self.embeddings.data.uniform_(-std, std) 136 | 137 | def __repr__(self): 138 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" 139 | 140 | def forward(self, inputs, bound=1): 141 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 142 | # return: [..., num_levels * level_dim] 143 | 144 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 145 | 146 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 147 | 148 | prefix_shape = list(inputs.shape[:-1]) 149 | inputs = inputs.view(-1, self.input_dim) 150 | 151 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners) 152 | outputs = outputs.view(prefix_shape + [self.output_dim]) 153 | 154 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 155 | 156 | return outputs -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/gridencoder.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: gridencoder 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/gridencoder.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/bindings.cpp 3 | /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/gridencoder/src/gridencoder.cu 4 | gridencoder.egg-info/PKG-INFO 5 | gridencoder.egg-info/SOURCES.txt 6 | gridencoder.egg-info/dependency_links.txt 7 | gridencoder.egg-info/top_level.txt -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/gridencoder.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/gridencoder.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | _gridencoder 2 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners); 14 | 15 | #endif -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/nerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | 10 | class NeRFNetwork(NeRFRenderer): 11 | def __init__(self, 12 | encoding="hashgrid", 13 | encoding_dir="sphere_harmonics", 14 | encoding_bg="hashgrid", 15 | num_layers=2, 16 | hidden_dim=64, 17 | geo_feat_dim=15, 18 | num_layers_color=3, 19 | hidden_dim_color=64, 20 | num_layers_bg=2, 21 | hidden_dim_bg=64, 22 | bound=1, 23 | **kwargs, 24 | ): 25 | super().__init__(bound, **kwargs) 26 | 27 | # sigma network 28 | self.num_layers = num_layers #2 29 | self.hidden_dim = hidden_dim #64 30 | self.geo_feat_dim = geo_feat_dim # 15 31 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) # in_dim=32 32 | 33 | sigma_net = [] 34 | for l in range(num_layers): 35 | if l == 0: 36 | in_dim = self.in_dim 37 | else: 38 | in_dim = hidden_dim #64 39 | 40 | if l == num_layers - 1: #最后一层 41 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color 42 | else: 43 | out_dim = hidden_dim 44 | 45 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 46 | 47 | self.sigma_net = nn.ModuleList(sigma_net) #两层全连接 48 | 49 | # color network 50 | self.num_layers_color = num_layers_color #3 51 | self.hidden_dim_color = hidden_dim_color #64 52 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) #in_dim_dir=16 53 | 54 | color_net = [] 55 | for l in range(num_layers_color): 56 | if l == 0: 57 | in_dim = self.in_dim_dir + self.geo_feat_dim 58 | else: 59 | in_dim = hidden_dim 60 | 61 | if l == num_layers_color - 1: #最后一层 62 | out_dim = 3 # 3 rgb 63 | else: 64 | out_dim = hidden_dim 65 | 66 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 67 | 68 | self.color_net = nn.ModuleList(color_net) #3层全连接 69 | 70 | # background network 71 | if self.bg_radius > 0: 72 | self.num_layers_bg = num_layers_bg 73 | self.hidden_dim_bg = hidden_dim_bg 74 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid 75 | 76 | bg_net = [] 77 | for l in range(num_layers_bg): 78 | if l == 0: 79 | in_dim = self.in_dim_bg + self.in_dim_dir 80 | else: 81 | in_dim = hidden_dim_bg 82 | 83 | if l == num_layers_bg - 1: 84 | out_dim = 3 # 3 rgb 85 | else: 86 | out_dim = hidden_dim_bg 87 | 88 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) 89 | 90 | self.bg_net = nn.ModuleList(bg_net) 91 | else: 92 | self.bg_net = None 93 | 94 | 95 | def forward(self, x, d): 96 | # x: [N, 3], in [-bound, bound] 97 | # d: [N, 3], nomalized in [-1, 1] 98 | 99 | # sigma 100 | x = self.encoder(x, bound=self.bound) # 3 ->32 101 | 102 | h = x 103 | for l in range(self.num_layers): 104 | h = self.sigma_net[l](h) 105 | if l != self.num_layers - 1: 106 | h = F.relu(h, inplace=True) 107 | 108 | #sigma = F.relu(h[..., 0]) 109 | sigma = trunc_exp(h[..., 0]) 110 | geo_feat = h[..., 1:] 111 | 112 | # color 113 | 114 | d = self.encoder_dir(d) 115 | h = torch.cat([d, geo_feat], dim=-1) 116 | for l in range(self.num_layers_color): 117 | h = self.color_net[l](h) 118 | if l != self.num_layers_color - 1: 119 | h = F.relu(h, inplace=True) 120 | 121 | # sigmoid activation for rgb 122 | color = torch.sigmoid(h) 123 | 124 | return sigma, color 125 | 126 | def density(self, x): 127 | # x: [N, 3], in [-bound, bound] 128 | 129 | x = self.encoder(x, bound=self.bound) 130 | h = x 131 | for l in range(self.num_layers): 132 | h = self.sigma_net[l](h) 133 | if l != self.num_layers - 1: 134 | h = F.relu(h, inplace=True) 135 | 136 | #sigma = F.relu(h[..., 0]) 137 | sigma = trunc_exp(h[..., 0]) 138 | geo_feat = h[..., 1:] 139 | 140 | return { 141 | 'sigma': sigma, 142 | 'geo_feat': geo_feat, 143 | } 144 | 145 | def background(self, x, d): 146 | # x: [N, 2], in [-1, 1] 147 | 148 | h = self.encoder_bg(x) # [N, C] 149 | d = self.encoder_dir(d) 150 | 151 | h = torch.cat([d, h], dim=-1) 152 | for l in range(self.num_layers_bg): 153 | h = self.bg_net[l](h) 154 | if l != self.num_layers_bg - 1: 155 | h = F.relu(h, inplace=True) 156 | 157 | # sigmoid activation for rgb 158 | rgbs = torch.sigmoid(h) 159 | 160 | return rgbs 161 | 162 | # allow masked inference 163 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 164 | # x: [N, 3] in [-bound, bound] 165 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 166 | 167 | if mask is not None: 168 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 169 | # in case of empty mask 170 | if not mask.any(): 171 | return rgbs 172 | x = x[mask] 173 | d = d[mask] 174 | geo_feat = geo_feat[mask] 175 | 176 | d = self.encoder_dir(d) 177 | h = torch.cat([d, geo_feat], dim=-1) 178 | for l in range(self.num_layers_color): 179 | h = self.color_net[l](h) 180 | if l != self.num_layers_color - 1: 181 | h = F.relu(h, inplace=True) 182 | 183 | # sigmoid activation for rgb 184 | h = torch.sigmoid(h) 185 | 186 | if mask is not None: 187 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 188 | else: 189 | rgbs = h 190 | 191 | return rgbs 192 | 193 | # optimizer utils 194 | def get_params(self, lr): 195 | 196 | params = [ 197 | {'params': self.encoder.parameters(), 'lr': lr}, 198 | {'params': self.sigma_net.parameters(), 'lr': lr}, 199 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 200 | {'params': self.color_net.parameters(), 'lr': lr}, 201 | ] 202 | if self.bg_radius > 0: 203 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 204 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 205 | 206 | return params 207 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_raymarching', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'raymarching.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/raymarching/src/bindings.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/raymarching/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/raymarching/src/bindings.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/raymarching/src/raymarching.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/raymarching/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/raymarching/src/raymarching.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/dist/raymarching-0.0.0-py3.9-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/raymarching/dist/raymarching-0.0.0-py3.9-linux-x86_64.egg -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/raymarching.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: raymarching 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/raymarching.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/raymarching/src/bindings.cpp 3 | /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/raymarching/src/raymarching.cu 4 | raymarching.egg-info/PKG-INFO 5 | raymarching.egg-info/SOURCES.txt 6 | raymarching.egg-info/dependency_links.txt 7 | raymarching.egg-info/top_level.txt -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/raymarching.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/raymarching.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | _raymarching 2 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/raymarching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _raymarching as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | # ---------------------------------------- 16 | # utils 17 | # ---------------------------------------- 18 | 19 | class _near_far_from_aabb(Function): 20 | @staticmethod 21 | @custom_fwd(cast_inputs=torch.float32) 22 | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): 23 | ''' near_far_from_aabb, CUDA implementation 24 | Calculate rays' intersection time (near and far) with aabb 25 | Args: 26 | rays_o: float, [N, 3] 27 | rays_d: float, [N, 3] 28 | aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) 29 | min_near: float, scalar 30 | Returns: 31 | nears: float, [N] 32 | fars: float, [N] 33 | ''' 34 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 35 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 36 | 37 | rays_o = rays_o.contiguous().view(-1, 3) 38 | rays_d = rays_d.contiguous().view(-1, 3) 39 | 40 | N = rays_o.shape[0] # num rays 41 | 42 | nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 43 | fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 44 | 45 | _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) 46 | 47 | return nears, fars 48 | 49 | near_far_from_aabb = _near_far_from_aabb.apply 50 | 51 | 52 | class _polar_from_ray(Function): 53 | @staticmethod 54 | @custom_fwd(cast_inputs=torch.float32) 55 | def forward(ctx, rays_o, rays_d, radius): 56 | ''' polar_from_ray, CUDA implementation 57 | get polar coordinate on the background sphere from rays. 58 | Assume rays_o are inside the Sphere(radius). 59 | Args: 60 | rays_o: [N, 3] 61 | rays_d: [N, 3] 62 | radius: scalar, float 63 | Return: 64 | coords: [N, 2], in [-1, 1], theta and phi on a sphere. 65 | ''' 66 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 67 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 68 | 69 | rays_o = rays_o.contiguous().view(-1, 3) 70 | rays_d = rays_d.contiguous().view(-1, 3) 71 | 72 | N = rays_o.shape[0] # num rays 73 | 74 | coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) 75 | 76 | _backend.polar_from_ray(rays_o, rays_d, radius, N, coords) 77 | 78 | return coords 79 | 80 | polar_from_ray = _polar_from_ray.apply 81 | 82 | 83 | class _morton3D(Function): 84 | @staticmethod 85 | def forward(ctx, coords): 86 | ''' morton3D, CUDA implementation 87 | Args: 88 | coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) 89 | TODO: check if the coord range is valid! (current 128 is safe) 90 | Returns: 91 | indices: [N], int32, in [0, 128^3) 92 | 93 | ''' 94 | if not coords.is_cuda: coords = coords.cuda() 95 | 96 | N = coords.shape[0] 97 | 98 | indices = torch.empty(N, dtype=torch.int32, device=coords.device) 99 | 100 | _backend.morton3D(coords.int(), N, indices) 101 | 102 | return indices 103 | 104 | morton3D = _morton3D.apply 105 | 106 | class _morton3D_invert(Function): 107 | @staticmethod 108 | def forward(ctx, indices): 109 | ''' morton3D_invert, CUDA implementation 110 | Args: 111 | indices: [N], int32, in [0, 128^3) 112 | Returns: 113 | coords: [N, 3], int32, in [0, 128) 114 | 115 | ''' 116 | if not indices.is_cuda: indices = indices.cuda() 117 | 118 | N = indices.shape[0] 119 | 120 | coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) 121 | 122 | _backend.morton3D_invert(indices.int(), N, coords) 123 | 124 | return coords 125 | 126 | morton3D_invert = _morton3D_invert.apply 127 | 128 | 129 | class _packbits(Function): 130 | @staticmethod 131 | @custom_fwd(cast_inputs=torch.float32) 132 | def forward(ctx, grid, thresh, bitfield=None): 133 | ''' packbits, CUDA implementation 134 | Pack up the density grid into a bit field to accelerate ray marching. 135 | Args: 136 | grid: float, [C, H * H * H], assume H % 2 == 0 137 | thresh: float, threshold 138 | Returns: 139 | bitfield: uint8, [C, H * H * H / 8] 140 | ''' 141 | if not grid.is_cuda: grid = grid.cuda() 142 | grid = grid.contiguous() 143 | 144 | C = grid.shape[0] 145 | H3 = grid.shape[1] 146 | N = C * H3 // 8 147 | 148 | if bitfield is None: 149 | bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) 150 | 151 | _backend.packbits(grid, N, thresh, bitfield) 152 | 153 | return bitfield 154 | 155 | packbits = _packbits.apply 156 | 157 | # ---------------------------------------- 158 | # train functions 159 | # ---------------------------------------- 160 | 161 | class _march_rays_train(Function): 162 | @staticmethod 163 | @custom_fwd(cast_inputs=torch.float32) 164 | def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): 165 | ''' march rays to generate points (forward only) 166 | Args: 167 | rays_o/d: float, [N, 3] 168 | bound: float, scalar 169 | density_bitfield: uint8: [CHHH // 8] 170 | C: int 171 | H: int 172 | nears/fars: float, [N] 173 | step_counter: int32, (2), used to count the actual number of generated points. 174 | mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) 175 | perturb: bool 176 | align: int, pad output so its size is dividable by align, set to -1 to disable. 177 | force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. 178 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 179 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 180 | Returns: 181 | xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) 182 | dirs: float, [M, 3], all generated points' view dirs. 183 | deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) 184 | rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] 185 | ''' 186 | 187 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 188 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 189 | if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() 190 | 191 | rays_o = rays_o.contiguous().view(-1, 3) 192 | rays_d = rays_d.contiguous().view(-1, 3) 193 | density_bitfield = density_bitfield.contiguous() 194 | 195 | N = rays_o.shape[0] # num rays 196 | M = N * max_steps # init max points number in total 197 | 198 | # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) 199 | # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. 200 | if not force_all_rays and mean_count > 0: 201 | if align > 0: 202 | mean_count += align - mean_count % align 203 | M = mean_count 204 | 205 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 206 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 207 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) 208 | rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps 209 | 210 | if step_counter is None: 211 | step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter 212 | 213 | _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, perturb) # m is the actually used points number 214 | 215 | #print(step_counter, M) 216 | 217 | # only used at the first (few) epochs. 218 | if force_all_rays or mean_count <= 0: 219 | m = step_counter[0].item() # D2H copy 220 | if align > 0: 221 | m += align - m % align 222 | xyzs = xyzs[:m] 223 | dirs = dirs[:m] 224 | deltas = deltas[:m] 225 | 226 | torch.cuda.empty_cache() 227 | 228 | return xyzs, dirs, deltas, rays 229 | 230 | march_rays_train = _march_rays_train.apply 231 | 232 | 233 | class _composite_rays_train(Function): 234 | @staticmethod 235 | @custom_fwd(cast_inputs=torch.float32) 236 | def forward(ctx, sigmas, rgbs, deltas, rays): 237 | ''' composite rays' rgbs, according to the ray marching formula. 238 | Args: 239 | rgbs: float, [M, 3] 240 | sigmas: float, [M,] 241 | deltas: float, [M, 2] 242 | rays: int32, [N, 3] 243 | Returns: 244 | weights_sum: float, [N,], the alpha channel 245 | depth: float, [N, ], the Depth 246 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 247 | ''' 248 | 249 | sigmas = sigmas.contiguous() 250 | rgbs = rgbs.contiguous() 251 | 252 | M = sigmas.shape[0] 253 | N = rays.shape[0] 254 | 255 | weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 256 | depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 257 | image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) 258 | 259 | _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, weights_sum, depth, image) 260 | 261 | ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) 262 | ctx.dims = [M, N] 263 | 264 | return weights_sum, depth, image 265 | 266 | @staticmethod 267 | @custom_bwd 268 | def backward(ctx, grad_weights_sum, grad_depth, grad_image): 269 | 270 | # NOTE: grad_depth is not used now! It won't be propagated to sigmas. 271 | 272 | grad_weights_sum = grad_weights_sum.contiguous() 273 | grad_image = grad_image.contiguous() 274 | 275 | sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors 276 | M, N = ctx.dims 277 | 278 | grad_sigmas = torch.zeros_like(sigmas) 279 | grad_rgbs = torch.zeros_like(rgbs) 280 | 281 | _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, grad_sigmas, grad_rgbs) 282 | 283 | return grad_sigmas, grad_rgbs, None, None 284 | 285 | 286 | composite_rays_train = _composite_rays_train.apply 287 | 288 | # ---------------------------------------- 289 | # infer functions 290 | # ---------------------------------------- 291 | 292 | class _march_rays(Function): 293 | @staticmethod 294 | @custom_fwd(cast_inputs=torch.float32) 295 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): 296 | ''' march rays to generate points (forward only, for inference) 297 | Args: 298 | n_alive: int, number of alive rays 299 | n_step: int, how many steps we march 300 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) 301 | rays_t: float, [N], the alive rays' time, we only use the first n_alive. 302 | rays_o/d: float, [N, 3] 303 | bound: float, scalar 304 | density_bitfield: uint8: [CHHH // 8] 305 | C: int 306 | H: int 307 | nears/fars: float, [N] 308 | align: int, pad output so its size is dividable by align, set to -1 to disable. 309 | perturb: bool/int, int > 0 is used as the random seed. 310 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 311 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 312 | Returns: 313 | xyzs: float, [n_alive * n_step, 3], all generated points' coords 314 | dirs: float, [n_alive * n_step, 3], all generated points' view dirs. 315 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). 316 | ''' 317 | 318 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 319 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 320 | 321 | rays_o = rays_o.contiguous().view(-1, 3) 322 | rays_d = rays_d.contiguous().view(-1, 3) 323 | 324 | M = n_alive * n_step 325 | 326 | if align > 0: 327 | M += align - (M % align) 328 | 329 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 330 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 331 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth 332 | 333 | _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, perturb) 334 | 335 | return xyzs, dirs, deltas 336 | 337 | march_rays = _march_rays.apply 338 | 339 | 340 | class _composite_rays(Function): 341 | @staticmethod 342 | @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float 343 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image): 344 | ''' composite rays' rgbs, according to the ray marching formula. (for inference) 345 | Args: 346 | n_alive: int, number of alive rays 347 | n_step: int, how many steps we march 348 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) 349 | rays_t: float, [N], the alive rays' time, we only use the first n_alive. 350 | sigmas: float, [n_alive * n_step,] 351 | rgbs: float, [n_alive * n_step, 3] 352 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). 353 | In-place Outputs: 354 | weights_sum: float, [N,], the alpha channel 355 | depth: float, [N,], the depth value 356 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 357 | ''' 358 | _backend.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) 359 | return tuple() 360 | 361 | 362 | composite_rays = _composite_rays.apply 363 | 364 | 365 | class _compact_rays(Function): 366 | @staticmethod 367 | @custom_fwd(cast_inputs=torch.float32) 368 | def forward(ctx, n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter): 369 | ''' compact rays, remove dead rays and reallocate alive rays, to accelerate next ray marching. 370 | Args: 371 | n_alive: int, number of alive rays 372 | rays_alive_old: int, [N] 373 | rays_t_old: float, [N], dead rays are marked by rays_t < 0 374 | alive_counter: int, [1], used to count remained alive rays. 375 | In-place Outputs: 376 | rays_alive: int, [N] 377 | rays_t: float, [N] 378 | ''' 379 | _backend.compact_rays(n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter) 380 | return tuple() 381 | 382 | compact_rays = _compact_rays.apply -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | ''' 33 | Usage: 34 | 35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 36 | 37 | python setup.py install # build extensions and install (copy) to PATH. 38 | pip install . # ditto but better (e.g., dependency & metadata handling) 39 | 40 | python setup.py develop # build extensions and install (symbolic) to PATH. 41 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 42 | 43 | ''' 44 | setup( 45 | name='raymarching', # package name, import this to use python API 46 | ext_modules=[ 47 | CUDAExtension( 48 | name='_raymarching', # extension name, import this to use CUDA API 49 | sources=[os.path.join(_src_path, 'src', f) for f in [ 50 | 'raymarching.cu', 51 | 'bindings.cpp', 52 | ]], 53 | extra_compile_args={ 54 | 'cxx': c_flags, 55 | 'nvcc': nvcc_flags, 56 | } 57 | ), 58 | ], 59 | cmdclass={ 60 | 'build_ext': BuildExtension, 61 | } 62 | ) -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("packbits", &packbits, "packbits (CUDA)"); 8 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 9 | m.def("polar_from_ray", &polar_from_ray, "polar_from_ray (CUDA)"); 10 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 11 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 12 | // train 13 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 14 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 15 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 16 | // infer 17 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 18 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 19 | m.def("compact_rays", &compact_rays, "compact rays (CUDA)"); 20 | } -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/src/pcg32.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tiny self-contained version of the PCG Random Number Generation for C++ 3 | * put together from pieces of the much larger C/C++ codebase. 4 | * Wenzel Jakob, February 2015 5 | * 6 | * The PCG random number generator was developed by Melissa O'Neill 7 | * 8 | * 9 | * Licensed under the Apache License, Version 2.0 (the "License"); 10 | * you may not use this file except in compliance with the License. 11 | * You may obtain a copy of the License at 12 | * 13 | * http://www.apache.org/licenses/LICENSE-2.0 14 | * 15 | * Unless required by applicable law or agreed to in writing, software 16 | * distributed under the License is distributed on an "AS IS" BASIS, 17 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | * See the License for the specific language governing permissions and 19 | * limitations under the License. 20 | * 21 | * For additional information about the PCG random number generation scheme, 22 | * including its license and other licensing options, visit 23 | * 24 | * http://www.pcg-random.org 25 | * 26 | * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors. 27 | */ 28 | 29 | #pragma once 30 | 31 | #define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL 32 | #define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL 33 | #define PCG32_MULT 0x5851f42d4c957f2dULL 34 | 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | #include 41 | #include 42 | 43 | /// PCG32 Pseudorandom number generator 44 | struct pcg32 { 45 | /// Initialize the pseudorandom number generator with default seed 46 | __host__ __device__ pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {} 47 | 48 | /// Initialize the pseudorandom number generator with the \ref seed() function 49 | __host__ __device__ pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); } 50 | 51 | /** 52 | * \brief Seed the pseudorandom number generator 53 | * 54 | * Specified in two parts: a state initializer and a sequence selection 55 | * constant (a.k.a. stream id) 56 | */ 57 | __host__ __device__ void seed(uint64_t initstate, uint64_t initseq = 1) { 58 | state = 0U; 59 | inc = (initseq << 1u) | 1u; 60 | next_uint(); 61 | state += initstate; 62 | next_uint(); 63 | } 64 | 65 | /// Generate a uniformly distributed unsigned 32-bit random number 66 | __host__ __device__ uint32_t next_uint() { 67 | uint64_t oldstate = state; 68 | state = oldstate * PCG32_MULT + inc; 69 | uint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u); 70 | uint32_t rot = (uint32_t) (oldstate >> 59u); 71 | return (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31)); 72 | } 73 | 74 | /// Generate a uniformly distributed number, r, where 0 <= r < bound 75 | __host__ __device__ uint32_t next_uint(uint32_t bound) { 76 | // To avoid bias, we need to make the range of the RNG a multiple of 77 | // bound, which we do by dropping output less than a threshold. 78 | // A naive scheme to calculate the threshold would be to do 79 | // 80 | // uint32_t threshold = 0x100000000ull % bound; 81 | // 82 | // but 64-bit div/mod is slower than 32-bit div/mod (especially on 83 | // 32-bit platforms). In essence, we do 84 | // 85 | // uint32_t threshold = (0x100000000ull-bound) % bound; 86 | // 87 | // because this version will calculate the same modulus, but the LHS 88 | // value is less than 2^32. 89 | 90 | uint32_t threshold = (~bound+1u) % bound; 91 | 92 | // Uniformity guarantees that this loop will terminate. In practice, it 93 | // should usually terminate quickly; on average (assuming all bounds are 94 | // equally likely), 82.25% of the time, we can expect it to require just 95 | // one iteration. In the worst case, someone passes a bound of 2^31 + 1 96 | // (i.e., 2147483649), which invalidates almost 50% of the range. In 97 | // practice, bounds are typically small and only a tiny amount of the range 98 | // is eliminated. 99 | for (;;) { 100 | uint32_t r = next_uint(); 101 | if (r >= threshold) 102 | return r % bound; 103 | } 104 | } 105 | 106 | /// Generate a single precision floating point value on the interval [0, 1) 107 | __host__ __device__ float next_float() { 108 | /* Trick from MTGP: generate an uniformly distributed 109 | single precision number in [1,2) and subtract 1. */ 110 | union { 111 | uint32_t u; 112 | float f; 113 | } x; 114 | x.u = (next_uint() >> 9) | 0x3f800000u; 115 | return x.f - 1.0f; 116 | } 117 | 118 | /** 119 | * \brief Generate a double precision floating point value on the interval [0, 1) 120 | * 121 | * \remark Since the underlying random number generator produces 32 bit output, 122 | * only the first 32 mantissa bits will be filled (however, the resolution is still 123 | * finer than in \ref next_float(), which only uses 23 mantissa bits) 124 | */ 125 | __host__ __device__ double next_double() { 126 | /* Trick from MTGP: generate an uniformly distributed 127 | double precision number in [1,2) and subtract 1. */ 128 | union { 129 | uint64_t u; 130 | double d; 131 | } x; 132 | x.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL; 133 | return x.d - 1.0; 134 | } 135 | 136 | /** 137 | * \brief Multi-step advance function (jump-ahead, jump-back) 138 | * 139 | * The method used here is based on Brown, "Random Number Generation 140 | * with Arbitrary Stride", Transactions of the American Nuclear 141 | * Society (Nov. 1994). The algorithm is very similar to fast 142 | * exponentiation. 143 | * 144 | * The default value of 2^32 ensures that the PRNG is advanced 145 | * sufficiently far that there is (likely) no overlap with 146 | * previously drawn random numbers, even if small advancements. 147 | * are made inbetween. 148 | */ 149 | __host__ __device__ void advance(int64_t delta_ = (1ll<<32)) { 150 | uint64_t 151 | cur_mult = PCG32_MULT, 152 | cur_plus = inc, 153 | acc_mult = 1u, 154 | acc_plus = 0u; 155 | 156 | /* Even though delta is an unsigned integer, we can pass a signed 157 | integer to go backwards, it just goes "the long way round". */ 158 | uint64_t delta = (uint64_t) delta_; 159 | 160 | while (delta > 0) { 161 | if (delta & 1) { 162 | acc_mult *= cur_mult; 163 | acc_plus = acc_plus * cur_mult + cur_plus; 164 | } 165 | cur_plus = (cur_mult + 1) * cur_plus; 166 | cur_mult *= cur_mult; 167 | delta /= 2; 168 | } 169 | state = acc_mult * state + acc_plus; 170 | } 171 | 172 | /// Compute the distance between two PCG32 pseudorandom number generators 173 | __host__ __device__ int64_t operator-(const pcg32 &other) const { 174 | assert(inc == other.inc); 175 | 176 | uint64_t 177 | cur_mult = PCG32_MULT, 178 | cur_plus = inc, 179 | cur_state = other.state, 180 | the_bit = 1u, 181 | distance = 0u; 182 | 183 | while (state != cur_state) { 184 | if ((state & the_bit) != (cur_state & the_bit)) { 185 | cur_state = cur_state * cur_mult + cur_plus; 186 | distance |= the_bit; 187 | } 188 | assert((state & the_bit) == (cur_state & the_bit)); 189 | the_bit <<= 1; 190 | cur_plus = (cur_mult + 1ULL) * cur_plus; 191 | cur_mult *= cur_mult; 192 | } 193 | 194 | return (int64_t) distance; 195 | } 196 | 197 | /// Equality operator 198 | __host__ __device__ bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; } 199 | 200 | /// Inequality operator 201 | __host__ __device__ bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; } 202 | 203 | uint64_t state; // RNG state. All values are possible. 204 | uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd. 205 | }; -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void polar_from_ray(at::Tensor rays_o, at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | 13 | void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb); 14 | void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 15 | void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 16 | 17 | void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor grid, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb); 18 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 19 | void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter); -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/__pycache__/backend.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/__pycache__/backend.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/__pycache__/sphere_harmonics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/__pycache__/sphere_harmonics.cpython-39.pyc -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_sh_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'shencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/.ninja_deps -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 10412 1657003623710230171 /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.o f2e21cb208573b76 3 | 0 27415 1657003640709877492 /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.o f8d6814930df1069 4 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-11.3/bin/nvcc 4 | 5 | cflags = -pthread -B /home/yuqi/anaconda3/envs/mega-ingp/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/yuqi/anaconda3/envs/mega-ingp/include -I/home/yuqi/anaconda3/envs/mega-ingp/include -fPIC -O2 -isystem /home/yuqi/anaconda3/envs/mega-ingp/include -fPIC -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/TH -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/home/yuqi/anaconda3/envs/mega-ingp/include/python3.9 -c 6 | post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_shencoder -D_GLIBCXX_USE_CXX11_ABI=0 7 | cuda_cflags = -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/TH -I/home/yuqi/anaconda3/envs/mega-ingp/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/home/yuqi/anaconda3/envs/mega-ingp/include/python3.9 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_shencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.o: compile /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.cpp 24 | build /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.o: cuda_compile /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.cu 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-3.9/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/build/temp.linux-x86_64-cpython-39/disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.o -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/dist/shencoder-0.0.0-py3.9-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/gp_nerf/torch_ngp/shencoder/dist/shencoder-0.0.0-py3.9-linux-x86_64.egg -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='shencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_shencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'shencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/shencoder.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: shencoder 3 | Version: 0.0.0 4 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/shencoder.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/bindings.cpp 3 | /disk1/yuqi/code/mega-nerf-zyq/mega_nerf/torch_ngp/shencoder/src/shencoder.cu 4 | shencoder.egg-info/PKG-INFO 5 | shencoder.egg-info/SOURCES.txt 6 | shencoder.egg-info/dependency_links.txt 7 | shencoder.egg-info/top_level.txt -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/shencoder.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/shencoder.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | _shencoder 2 | -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = torch.empty(1, dtype=inputs.dtype, device=inputs.device) 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | ctx.calc_grad_inputs = calc_grad_inputs 37 | 38 | return outputs 39 | 40 | @staticmethod 41 | #@once_differentiable 42 | @custom_bwd 43 | def backward(ctx, grad): 44 | # grad: [B, C * C] 45 | 46 | if ctx.calc_grad_inputs: 47 | grad = grad.contiguous() 48 | inputs, dy_dx = ctx.saved_tensors 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /gp_nerf/torch_ngp/shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | // encode_forward(inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx) 10 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx); 11 | 12 | // sh_encode_backward(grad, inputs, B, input_dim, degree, ctx.calc_grad_inputs, dy_dx, grad_inputs) 13 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /gp_nerf/train.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | import sys 7 | sys.path.append('.') 8 | 9 | from gp_nerf.opts import get_opts_base 10 | 11 | 12 | 13 | def _get_train_opts() -> Namespace: 14 | parser = get_opts_base() 15 | 16 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 17 | parser.add_argument('--dataset_path', type=str, required=True) 18 | 19 | 20 | return parser.parse_args() 21 | 22 | @record 23 | def main(hparams: Namespace) -> None: 24 | from gp_nerf.runner_gpnerf import Runner 25 | 26 | print("run clean version, remove the bg nerf") 27 | hparams.bg_nerf = False 28 | 29 | 30 | if hparams.detect_anomalies: 31 | with torch.autograd.detect_anomaly(): 32 | Runner(hparams).train() 33 | else: 34 | Runner(hparams).train() 35 | 36 | 37 | if __name__ == '__main__': 38 | main(_get_train_opts()) 39 | -------------------------------------------------------------------------------- /mega_nerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__init__.py -------------------------------------------------------------------------------- /mega_nerf/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/__pycache__/image_metadata.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__pycache__/image_metadata.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/__pycache__/misc_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__pycache__/misc_utils.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/__pycache__/ray_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__pycache__/ray_utils.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/__pycache__/spherical_harmonics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/__pycache__/spherical_harmonics.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/datasets/__init__.py -------------------------------------------------------------------------------- /mega_nerf/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/datasets/__pycache__/dataset_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/datasets/__pycache__/dataset_utils.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/datasets/__pycache__/filesystem_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/datasets/__pycache__/filesystem_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/datasets/__pycache__/memory_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/datasets/__pycache__/memory_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /mega_nerf/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | import torch 4 | 5 | from mega_nerf.image_metadata import ImageMetadata 6 | 7 | 8 | def get_rgb_index_mask(metadata: ImageMetadata) -> Optional[ 9 | Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: 10 | rgbs = metadata.load_image().view(-1, 3) 11 | 12 | keep_mask = metadata.load_mask() 13 | 14 | # if keep_mask is not None: 15 | # import numpy as np 16 | # import cv2 17 | # mask_shape = keep_mask[keep_mask == True].shape[0] 18 | # if mask_shape != keep_mask.shape[0] * keep_mask.shape[1] and mask_shape !=0: 19 | # visual_rgb = metadata.load_image().view(keep_mask.shape[0], keep_mask.shape[1], 3) 20 | # cv2.imshow('2', visual_rgb.numpy()[:,:, ::-1]) 21 | # cv2.waitKey() 22 | # cv2.destroyAllWindows() 23 | # visual_rgb = (visual_rgb * keep_mask.unsqueeze(-1).repeat(1,1,3)) 24 | # cv2.imshow('3', visual_rgb.numpy()[:,:, ::-1]) 25 | # cv2.waitKey() 26 | # cv2.destroyAllWindows() 27 | # print(mask_shape) 28 | 29 | if metadata.is_val: 30 | if keep_mask is None: 31 | keep_mask = torch.ones(metadata.H, metadata.W, dtype=torch.bool) 32 | else: 33 | # Get how many pixels we're discarding that would otherwise be added 34 | discard_half = keep_mask[:, metadata.W // 2:] 35 | discard_pos_count = discard_half[discard_half == True].shape[0] 36 | 37 | candidates_to_add = torch.arange(metadata.H * metadata.W).view(metadata.H, metadata.W)[:, :metadata.W // 2] 38 | keep_half = keep_mask[:, :metadata.W // 2] 39 | candidates_to_add = candidates_to_add[keep_half == False].reshape(-1) 40 | to_add = candidates_to_add[torch.randperm(candidates_to_add.shape[0])[:discard_pos_count]] 41 | 42 | keep_mask.view(-1).scatter_(0, to_add, torch.ones_like(to_add, dtype=torch.bool)) 43 | 44 | keep_mask[:, metadata.W // 2:] = False 45 | 46 | if keep_mask is not None: 47 | if keep_mask[keep_mask == True].shape[0] == 0: 48 | return None 49 | 50 | keep_mask = keep_mask.view(-1) 51 | rgbs = rgbs[keep_mask == True] 52 | 53 | assert metadata.image_index <= torch.iinfo(torch.int32).max 54 | return rgbs, metadata.image_index * torch.ones(rgbs.shape[0], dtype=torch.int32), keep_mask 55 | -------------------------------------------------------------------------------- /mega_nerf/datasets/filesystem_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import shutil 4 | from concurrent.futures import Future, ThreadPoolExecutor 5 | from itertools import cycle 6 | from pathlib import Path 7 | from typing import List, Optional, Dict, Tuple, Union, Type 8 | 9 | import numpy as np 10 | import pyarrow as pa 11 | import pyarrow.parquet as pq 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | from mega_nerf.datasets.dataset_utils import get_rgb_index_mask 16 | from mega_nerf.image_metadata import ImageMetadata 17 | from mega_nerf.misc_utils import main_tqdm, main_print 18 | from mega_nerf.ray_utils import get_ray_directions, get_rays, get_rays_batch 19 | 20 | #RAY_CHUNK_SIZE = 64 21 | RAY_CHUNK_SIZE = 64 * 1024 22 | 23 | 24 | class FilesystemDataset(Dataset): 25 | 26 | def __init__(self, metadata_items: List[ImageMetadata], near: float, far: float, ray_altitude_range: List[float], 27 | center_pixels: bool, device: torch.device, chunk_paths: List[Path], num_chunks: int, 28 | scale_factor: int, disk_flush_size: int, desired_chunks=1000): 29 | super(FilesystemDataset, self).__init__() 30 | self._device = device 31 | self._c2ws = torch.cat([x.c2w.unsqueeze(0) for x in metadata_items]) 32 | self._near = near 33 | self._far = far 34 | self._ray_altitude_range = ray_altitude_range 35 | 36 | intrinsics = torch.cat( 37 | [torch.cat([torch.FloatTensor([x.W, x.H]), x.intrinsics]).unsqueeze(0) for x in metadata_items]) 38 | if (intrinsics - intrinsics[0]).abs().max() == 0: 39 | main_print( 40 | 'All intrinsics identical: W: {} H: {}, intrinsics: {}'.format(metadata_items[0].W, metadata_items[0].H, 41 | metadata_items[0].intrinsics)) 42 | 43 | self._directions = get_ray_directions(metadata_items[0].W, 44 | metadata_items[0].H, 45 | metadata_items[0].intrinsics[0], 46 | metadata_items[0].intrinsics[1], 47 | metadata_items[0].intrinsics[2], 48 | metadata_items[0].intrinsics[3], 49 | center_pixels, 50 | device).view(-1, 3) 51 | else: 52 | main_print('Differing intrinsics') 53 | self._directions = None 54 | 55 | parquet_paths = self._check_existing_paths(chunk_paths, center_pixels, scale_factor, 56 | len(metadata_items)) 57 | if parquet_paths is not None: 58 | main_print('Reusing {} chunks from previous run'.format(len(parquet_paths))) 59 | self._parquet_paths = parquet_paths 60 | else: 61 | self._parquet_paths = [] 62 | self._write_chunks(metadata_items, center_pixels, device, chunk_paths, num_chunks, scale_factor, 63 | disk_flush_size,desired_chunks) 64 | 65 | self._parquet_paths.sort(key=lambda x: x.name) 66 | 67 | self._chunk_index = cycle(range(len(self._parquet_paths))) 68 | self._loaded_rgbs = None 69 | self._loaded_rays = None 70 | self._loaded_img_indices = None 71 | self._chunk_load_executor = ThreadPoolExecutor(max_workers=1) 72 | self._chunk_future = self._chunk_load_executor.submit(self._load_chunk_inner) 73 | self._chosen = None 74 | 75 | def load_chunk(self) -> None: 76 | chosen, self._loaded_rgbs, self._loaded_rays, self._loaded_img_indices = self._chunk_future.result() 77 | self._chosen = chosen 78 | self._chunk_future = self._chunk_load_executor.submit(self._load_chunk_inner) 79 | 80 | def get_state(self) -> str: 81 | return self._chosen 82 | 83 | def set_state(self, chosen: str) -> None: 84 | while self._chosen != chosen: 85 | self.load_chunk() 86 | 87 | def __len__(self) -> int: 88 | return self._loaded_rgbs.shape[0] 89 | 90 | #@torch.no_grad() 91 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 92 | #zyq only load one batch 93 | #idx = 1 94 | return { 95 | 'rgbs': self._loaded_rgbs[idx], 96 | 'rays': self._loaded_rays[idx], 97 | 'img_indices': self._loaded_img_indices[idx] 98 | } 99 | 100 | def _load_chunk_inner(self) -> Tuple[str, torch.FloatTensor, torch.FloatTensor, torch.ShortTensor]: 101 | if 'RANK' in os.environ: 102 | torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) 103 | 104 | next_index = next(self._chunk_index) 105 | chosen = self._parquet_paths[next_index] 106 | loaded_chunk = pq.read_table(chosen) 107 | loaded_img_indices = torch.IntTensor(loaded_chunk['img_indices'].to_numpy().astype('int32')) 108 | 109 | if self._directions is not None: 110 | loaded_pixel_indices = torch.IntTensor(loaded_chunk['pixel_indices'].to_numpy()) 111 | 112 | loaded_rays = [] 113 | for i in range(0, loaded_pixel_indices.shape[0], RAY_CHUNK_SIZE): 114 | img_indices = loaded_img_indices[i:i + RAY_CHUNK_SIZE] 115 | unique_img_indices, inverse_img_indices = torch.unique(img_indices, return_inverse=True) 116 | c2ws = self._c2ws[unique_img_indices.long()].to(self._device) 117 | 118 | pixel_indices = loaded_pixel_indices[i:i + RAY_CHUNK_SIZE] 119 | unique_pixel_indices, inverse_pixel_indices = torch.unique(pixel_indices, return_inverse=True) 120 | 121 | # (#unique images, w*h, 8) 122 | image_rays = get_rays_batch(self._directions[unique_pixel_indices.long()], 123 | c2ws, self._near, self._far, 124 | self._ray_altitude_range).cpu() 125 | 126 | del c2ws 127 | 128 | loaded_rays.append(image_rays[inverse_img_indices, inverse_pixel_indices]) 129 | 130 | loaded_rays = torch.cat(loaded_rays) 131 | else: 132 | loaded_rays = torch.FloatTensor( 133 | loaded_chunk.to_pandas()[['rays_{}'.format(i) for i in range(8)]].to_numpy()) 134 | 135 | rgbs = torch.FloatTensor(loaded_chunk.to_pandas()[['rgbs_{}'.format(i) for i in range(3)]].to_numpy()) / 255. 136 | return str(chosen), rgbs, loaded_rays, loaded_img_indices 137 | 138 | def _write_chunks(self, metadata_items: List[ImageMetadata], center_pixels: bool, device: torch.device, 139 | chunk_paths: List[Path], num_chunks: int, scale_factor: int, disk_flush_size: int, desired_chunks=1000) -> None: 140 | assert ('RANK' not in os.environ) or int(os.environ['LOCAL_RANK']) == 0 141 | 142 | path_frees = [] 143 | total_free = 0 144 | 145 | for chunk_path in chunk_paths: 146 | chunk_path.mkdir(parents=True) 147 | 148 | _, _, free = shutil.disk_usage(chunk_path) 149 | total_free += free 150 | path_frees.append(free) 151 | 152 | parquet_writers = [] 153 | 154 | index = 0 155 | 156 | max_index = max(metadata_items, key=lambda x: x.image_index).image_index 157 | if max_index <= np.iinfo(np.uint16).max: 158 | img_indices_dtype = np.uint16 159 | else: 160 | assert max_index <= np.iinfo(np.int32).max # Can support int64 if need be 161 | img_indices_dtype = np.int32 162 | 163 | main_print('Max image index is {}: using dtype: {}'.format(max_index, img_indices_dtype)) 164 | 165 | for chunk_path, path_free in zip(chunk_paths, path_frees): 166 | allocated = int(path_free / total_free * num_chunks) 167 | main_print('Allocating {} chunks to dataset path {}'.format(allocated, chunk_path)) 168 | for j in range(allocated): 169 | parquet_path = chunk_path / '{0:06d}.parquet'.format(index) 170 | self._parquet_paths.append(parquet_path) 171 | 172 | dtypes = [('img_indices', pa.from_numpy_dtype(img_indices_dtype))] 173 | 174 | for i in range(3): 175 | dtypes.append(('rgbs_{}'.format(i), pa.uint8())) 176 | 177 | if self._directions is not None: 178 | dtypes.append(('pixel_indices', pa.int32())) 179 | else: 180 | for i in range(8): 181 | dtypes.append(('rays_{}'.format(i), pa.float32())) 182 | 183 | parquet_writers.append(pq.ParquetWriter(parquet_path, pa.schema(dtypes), compression='BROTLI')) 184 | 185 | index += 1 186 | 187 | main_print('{} chunks allocated'.format(index)) 188 | 189 | write_futures = [] 190 | rgbs = [] 191 | rays = [] 192 | indices = [] 193 | in_memory_count = 0 194 | 195 | if self._directions is not None: 196 | all_pixel_indices = torch.arange(self._directions.shape[0], dtype=torch.int) 197 | 198 | with ThreadPoolExecutor(max_workers=len(parquet_writers)) as executor: 199 | for metadata_item in main_tqdm(metadata_items): 200 | image_data = get_rgb_index_mask(metadata_item) 201 | 202 | if image_data is None: 203 | continue 204 | 205 | image_rgbs, img_indices, image_keep_mask = image_data 206 | rgbs.append(image_rgbs) 207 | indices.append(img_indices) 208 | in_memory_count += len(image_rgbs) 209 | 210 | if self._directions is not None: 211 | image_pixel_indices = all_pixel_indices 212 | if image_keep_mask is not None: 213 | image_pixel_indices = image_pixel_indices[image_keep_mask == True] 214 | 215 | rays.append(image_pixel_indices) 216 | else: 217 | directions = get_ray_directions(metadata_item.W, 218 | metadata_item.H, 219 | metadata_item.intrinsics[0], 220 | metadata_item.intrinsics[1], 221 | metadata_item.intrinsics[2], 222 | metadata_item.intrinsics[3], 223 | center_pixels, 224 | device) 225 | image_rays = get_rays(directions, metadata_item.c2w.to(device), self._near, self._far, 226 | self._ray_altitude_range).view(-1, 8).cpu() 227 | 228 | if image_keep_mask is not None: 229 | image_rays = image_rays[image_keep_mask == True] 230 | 231 | rays.append(image_rays) 232 | 233 | if in_memory_count >= disk_flush_size: 234 | for write_future in write_futures: 235 | write_future.result() 236 | 237 | write_futures = self._write_to_disk(executor, torch.cat(rgbs), torch.cat(rays), torch.cat(indices), 238 | parquet_writers, img_indices_dtype,desired_chunks) 239 | 240 | rgbs = [] 241 | rays = [] 242 | indices = [] 243 | in_memory_count = 0 244 | 245 | for write_future in write_futures: 246 | write_future.result() 247 | 248 | if in_memory_count > 0: 249 | write_futures = self._write_to_disk(executor, torch.cat(rgbs), torch.cat(rays), torch.cat(indices), 250 | parquet_writers, img_indices_dtype,desired_chunks) 251 | 252 | for write_future in write_futures: 253 | write_future.result() 254 | for chunk_path in chunk_paths: 255 | chunk_metadata = { 256 | 'images': len(metadata_items), 257 | 'scale_factor': scale_factor 258 | } 259 | 260 | if self._directions is None: 261 | chunk_metadata['near'] = self._near 262 | chunk_metadata['far'] = self._far 263 | chunk_metadata['center_pixels'] = center_pixels 264 | chunk_metadata['ray_altitude_range'] = self._ray_altitude_range 265 | 266 | torch.save(chunk_metadata, chunk_path / 'metadata.pt') 267 | 268 | for parquet_writer in parquet_writers: 269 | parquet_writer.close() 270 | 271 | main_print('Finished writing chunks to dataset paths') 272 | 273 | def _check_existing_paths(self, chunk_paths: List[Path], center_pixels: bool, scale_factor: int, images: int) -> \ 274 | Optional[List[Path]]: 275 | parquet_files = [] 276 | 277 | num_exist = 0 278 | for chunk_path in chunk_paths: 279 | if chunk_path.exists(): 280 | assert (chunk_path / 'metadata.pt').exists(), \ 281 | "Could not find metadata file (did previous writing to this directory not complete successfully?)" 282 | dataset_metadata = torch.load(chunk_path / 'metadata.pt', map_location='cpu') 283 | assert dataset_metadata['images'] == images 284 | assert dataset_metadata['scale_factor'] == scale_factor 285 | 286 | if self._directions is None: 287 | assert dataset_metadata['near'] == self._near 288 | assert dataset_metadata['far'] == self._far 289 | assert dataset_metadata['center_pixels'] == center_pixels 290 | 291 | if self._ray_altitude_range is not None: 292 | assert (torch.allclose(torch.FloatTensor(dataset_metadata['ray_altitude_range']), 293 | torch.FloatTensor(self._ray_altitude_range))) 294 | else: 295 | assert dataset_metadata['ray_altitude_range'] is None 296 | 297 | for child in list(chunk_path.iterdir()): 298 | if child.name != 'metadata.pt': 299 | parquet_files.append(child) 300 | num_exist += 1 301 | 302 | if num_exist > 0: 303 | assert num_exist == len(chunk_paths) 304 | return parquet_files 305 | else: 306 | return None 307 | 308 | def _write_to_disk(self, executor: ThreadPoolExecutor, rgbs: torch.Tensor, rays: torch.FloatTensor, 309 | img_indices: torch.Tensor, parquet_writers: List[pq.ParquetWriter], 310 | img_indices_dtype: Type[Union[np.ushort, np.int]],desired_chunks=1000) -> List[Future[None]]: 311 | indices = torch.randperm(rgbs.shape[0]) 312 | shuffled_rgbs = rgbs[indices] 313 | shuffled_rays = rays[indices] 314 | shuffled_img_indices = img_indices[indices] 315 | 316 | num_chunks = len(parquet_writers) 317 | chunk_size = math.ceil(rgbs.shape[0] / num_chunks) 318 | 319 | futures = [] 320 | 321 | def append(index: int) -> None: 322 | columns = { 323 | 'img_indices': shuffled_img_indices[index * chunk_size:(index + 1) * chunk_size].numpy().astype( 324 | img_indices_dtype) 325 | } 326 | 327 | for i in range(rgbs.shape[1]): 328 | columns['rgbs_{}'.format(i)] = shuffled_rgbs[index * chunk_size:(index + 1) * chunk_size, i].numpy() 329 | 330 | if self._directions is not None: 331 | columns['pixel_indices'] = shuffled_rays[index * chunk_size:(index + 1) * chunk_size].numpy() 332 | else: 333 | for i in range(rays.shape[1]): 334 | columns['rays_{}'.format(i)] = shuffled_rays[index * chunk_size:(index + 1) * chunk_size, i].numpy() 335 | 336 | parquet_writers[index].write_table(pa.table(columns)) 337 | 338 | for chunk_index in range(num_chunks): 339 | if chunk_index > desired_chunks: 340 | break 341 | future = executor.submit(append, chunk_index) 342 | futures.append(future) 343 | 344 | 345 | return futures 346 | -------------------------------------------------------------------------------- /mega_nerf/datasets/memory_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from mega_nerf.datasets.dataset_utils import get_rgb_index_mask 7 | from mega_nerf.image_metadata import ImageMetadata 8 | from mega_nerf.misc_utils import main_tqdm, main_print 9 | from mega_nerf.ray_utils import get_rays, get_ray_directions 10 | 11 | 12 | class MemoryDataset(Dataset): 13 | 14 | def __init__(self, metadata_items: List[ImageMetadata], near: float, far: float, ray_altitude_range: List[float], 15 | center_pixels: bool, device: torch.device): 16 | super(MemoryDataset, self).__init__() 17 | 18 | rgbs = [] 19 | rays = [] 20 | indices = [] 21 | 22 | main_print('Loading data') 23 | 24 | for metadata_item in main_tqdm(metadata_items): 25 | image_data = get_rgb_index_mask(metadata_item) 26 | 27 | if image_data is None: 28 | continue 29 | 30 | image_rgbs, image_indices, image_keep_mask = image_data 31 | # print("image index: {}, fx: {}, fy: {}".format(metadata_item.image_index, metadata_item.intrinsics[0], metadata_item.intrinsics[1])) 32 | directions = get_ray_directions(metadata_item.W, 33 | metadata_item.H, 34 | metadata_item.intrinsics[0], 35 | metadata_item.intrinsics[1], 36 | metadata_item.intrinsics[2], 37 | metadata_item.intrinsics[3], 38 | center_pixels, 39 | device) 40 | image_rays = get_rays(directions, metadata_item.c2w.to(device), near, far, ray_altitude_range).view(-1, 41 | 8).cpu() 42 | if image_keep_mask is not None: 43 | image_rays = image_rays[image_keep_mask == True] 44 | 45 | rgbs.append(image_rgbs.float() / 255.) 46 | rays.append(image_rays) 47 | indices.append(image_indices) 48 | 49 | main_print('Finished loading data') 50 | 51 | self._rgbs = torch.cat(rgbs) 52 | self._rays = torch.cat(rays) 53 | self._img_indices = torch.cat(indices) 54 | 55 | def __len__(self) -> int: 56 | return self._rgbs.shape[0] 57 | 58 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 59 | return { 60 | 'rgbs': self._rgbs[idx], 61 | 'rays': self._rays[idx], 62 | 'img_indices': self._img_indices[idx] 63 | } 64 | -------------------------------------------------------------------------------- /mega_nerf/image_metadata.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | from zipfile import ZipFile 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | 10 | 11 | class ImageMetadata: 12 | def __init__(self, image_path: Path, c2w: torch.Tensor, W: int, H: int, intrinsics: torch.Tensor, image_index: int, 13 | mask_path: Optional[Path], is_val: bool, metadata_label=None): 14 | self.image_path = image_path 15 | self.c2w = c2w 16 | self.W = W 17 | self.H = H 18 | self.intrinsics = intrinsics 19 | self.image_index = image_index 20 | self._mask_path = mask_path 21 | self.is_val = is_val 22 | self.label = metadata_label 23 | 24 | def load_image(self) -> torch.Tensor: 25 | rgbs = Image.open(self.image_path).convert('RGB') 26 | size = rgbs.size 27 | 28 | if size[0] != self.W or size[1] != self.H: 29 | rgbs = rgbs.resize((self.W, self.H), Image.LANCZOS) 30 | 31 | return torch.ByteTensor(np.asarray(rgbs)) 32 | 33 | def load_mask(self) -> Optional[torch.Tensor]: 34 | if self._mask_path is None: 35 | return None 36 | 37 | with ZipFile(self._mask_path) as zf: 38 | with zf.open(self._mask_path.name) as f: 39 | keep_mask = torch.load(f, map_location='cpu') 40 | 41 | if keep_mask.shape[0] != self.H or keep_mask.shape[1] != self.W: 42 | keep_mask = F.interpolate(keep_mask.unsqueeze(0).unsqueeze(0).float(), 43 | size=(self.H, self.W)).bool().squeeze() 44 | 45 | return keep_mask 46 | -------------------------------------------------------------------------------- /mega_nerf/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import lpips as plips 6 | 7 | 8 | def psnr(rgbs: torch.Tensor, target_rgbs: torch.Tensor) -> float: 9 | mse = torch.mean((rgbs - target_rgbs) ** 2) 10 | return -10 * torch.log10(mse).item() 11 | 12 | 13 | def lpips(rgbs: torch.Tensor, target_rgbs: torch.Tensor) -> Dict[str, float]: 14 | gt = target_rgbs.permute([2, 0, 1]).contiguous() 15 | pred = rgbs.permute([2, 0, 1]).contiguous() 16 | 17 | lpips_vgg = plips.LPIPS(net='vgg').eval().to(rgbs.device) 18 | lpips_vgg_i = lpips_vgg(gt, pred, normalize=True) 19 | 20 | lpips_alex = plips.LPIPS(net='alex').eval().to(rgbs.device) 21 | lpips_alex_i = lpips_alex(gt, pred, normalize=True) 22 | 23 | lpips_squeeze = plips.LPIPS(net='squeeze').eval().to(rgbs.device) 24 | lpips_squeeze_i = lpips_squeeze(gt, pred, normalize=True) 25 | 26 | return {'vgg': lpips_vgg_i.item(), 'alex': lpips_alex_i.item(), 'squeeze': lpips_squeeze_i.item()} 27 | 28 | 29 | # Copyright 2021 The PlenOctree Authors. 30 | # Redistribution and use in source and binary forms, with or without 31 | # modification, are permitted provided that the following conditions are met: 32 | # 33 | # 1. Redistributions of source code must retain the above copyright notice, 34 | # this list of conditions and the following disclaimer. 35 | # 36 | # 2. Redistributions in binary form must reproduce the above copyright notice, 37 | # this list of conditions and the following disclaimer in the documentation 38 | # and/or other materials provided with the distribution. 39 | # 40 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 41 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 42 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 43 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 44 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 45 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 46 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 47 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 48 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 49 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 50 | # POSSIBILITY OF SUCH DAMAGE. 51 | def ssim( 52 | rgbs: torch.Tensor, 53 | target_rgbs: torch.Tensor, 54 | max_val: float, 55 | filter_size: int = 11, 56 | filter_sigma: float = 1.5, 57 | k1: float = 0.01, 58 | k2: float = 0.03, 59 | ) -> float: 60 | """Computes SSIM from two images. 61 | This function was modeled after tf.image.ssim, and should produce comparable 62 | output. 63 | Args: 64 | rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 65 | target_rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 66 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 67 | filter_size: int >= 1. Window size. 68 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 69 | k1: float > 0. One of the SSIM dampening parameters. 70 | k2: float > 0. One of the SSIM dampening parameters. 71 | Returns: 72 | Each image's mean SSIM. 73 | """ 74 | device = rgbs.device 75 | ori_shape = rgbs.size() 76 | width, height, num_channels = ori_shape[-3:] 77 | rgbs = rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 78 | target_rgbs = target_rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 79 | 80 | # Construct a 1D Gaussian blur filter. 81 | hw = filter_size // 2 82 | shift = (2 * hw - filter_size + 1) / 2 83 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 84 | filt = torch.exp(-0.5 * f_i) 85 | filt /= torch.sum(filt) 86 | 87 | # Blur in x and y (faster than the 2D convolution). 88 | # z is a tensor of size [B, H, W, C] 89 | filt_fn1 = lambda z: F.conv2d( 90 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 91 | padding=[hw, 0], groups=num_channels) 92 | filt_fn2 = lambda z: F.conv2d( 93 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 94 | padding=[0, hw], groups=num_channels) 95 | 96 | # Vmap the blurs to the tensor size, and then compose them. 97 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 98 | mu0 = filt_fn(rgbs) 99 | mu1 = filt_fn(target_rgbs) 100 | mu00 = mu0 * mu0 101 | mu11 = mu1 * mu1 102 | mu01 = mu0 * mu1 103 | sigma00 = filt_fn(rgbs ** 2) - mu00 104 | sigma11 = filt_fn(target_rgbs ** 2) - mu11 105 | sigma01 = filt_fn(rgbs * target_rgbs) - mu01 106 | 107 | # Clip the variances and covariances to valid values. 108 | # Variance must be non-negative: 109 | sigma00 = torch.clamp(sigma00, min=0.0) 110 | sigma11 = torch.clamp(sigma11, min=0.0) 111 | sigma01 = torch.sign(sigma01) * torch.min( 112 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 113 | ) 114 | 115 | c1 = (k1 * max_val) ** 2 116 | c2 = (k2 * max_val) ** 2 117 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 118 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 119 | ssim_map = numer / denom 120 | 121 | return torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1).item() 122 | -------------------------------------------------------------------------------- /mega_nerf/misc_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | 6 | def main_print(log) -> None: 7 | if ('LOCAL_RANK' not in os.environ) or int(os.environ['LOCAL_RANK']) == 0: 8 | print(log) 9 | 10 | 11 | def main_tqdm(inner): 12 | if ('LOCAL_RANK' not in os.environ) or int(os.environ['LOCAL_RANK']) == 0: 13 | return tqdm(inner) 14 | else: 15 | return inner 16 | -------------------------------------------------------------------------------- /mega_nerf/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyqz97/GP-NeRF/968791d162f2f29b82ba8c4c7dc7757e6374a811/mega_nerf/models/__init__.py -------------------------------------------------------------------------------- /mega_nerf/models/cascade.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Cascade(nn.Module): 8 | def __init__(self, coarse: nn.Module, fine: nn.Module): 9 | super(Cascade, self).__init__() 10 | self.coarse = coarse 11 | self.fine = fine 12 | 13 | def forward(self, use_coarse: bool, x: torch.Tensor, sigma_only: bool = False, 14 | sigma_noise: Optional[torch.Tensor] = None) -> torch.Tensor: 15 | if use_coarse: 16 | return self.coarse(x, sigma_only, sigma_noise) 17 | else: 18 | return self.fine(x, sigma_only, sigma_noise) 19 | 20 | 21 | class Cascade_2model(nn.Module): 22 | def __init__(self, coarse: nn.Module, fine: nn.Module): 23 | super(Cascade_2model, self).__init__() 24 | self.coarse = coarse 25 | self.fine = fine 26 | 27 | def forward(self, use_coarse: bool, point_type, x: torch.Tensor, sigma_only: bool = False, 28 | sigma_noise: Optional[torch.Tensor] = None) -> torch.Tensor: 29 | if use_coarse: 30 | return self.coarse(point_type, x, sigma_only, sigma_noise) 31 | else: 32 | return self.fine(point_type, x, sigma_only, sigma_noise) 33 | 34 | -------------------------------------------------------------------------------- /mega_nerf/models/mega_nerf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MegaNeRF(nn.Module): 8 | def __init__(self, sub_modules: List[nn.Module], centroids: torch.Tensor, boundary_margin: float, xyz_real: bool, 9 | cluster_2d: bool, joint_training: bool = False): 10 | super(MegaNeRF, self).__init__() 11 | assert boundary_margin >= 1 12 | self.sub_modules = nn.ModuleList(sub_modules) 13 | self.register_buffer('centroids', centroids) 14 | self.boundary_margin = boundary_margin 15 | self.xyz_real = xyz_real 16 | self.cluster_dim_start = 1 if cluster_2d else 0 17 | self.joint_training = joint_training 18 | 19 | def forward(self, x: torch.Tensor, sigma_only: bool = False, 20 | sigma_noise: Optional[torch.Tensor] = None) -> torch.Tensor: 21 | if self.boundary_margin > 1: 22 | cluster_distances = torch.cdist(x[:, self.cluster_dim_start:3], self.centroids[:, self.cluster_dim_start:]) 23 | inverse_cluster_distances = 1 / (cluster_distances + 1e-8) 24 | 25 | min_cluster_distances = cluster_distances.min(dim=1)[0].unsqueeze(-1).repeat(1, cluster_distances.shape[1]) 26 | inverse_cluster_distances[cluster_distances > self.boundary_margin * min_cluster_distances] = 0 27 | weights = inverse_cluster_distances / inverse_cluster_distances.sum(dim=-1).unsqueeze(-1) 28 | else: 29 | cluster_assignments = torch.cdist(x[:, self.cluster_dim_start:3], 30 | self.centroids[:, self.cluster_dim_start:]).argmin(dim=1) 31 | 32 | results = torch.empty(0) 33 | 34 | for i, child in enumerate(self.sub_modules): 35 | cluster_mask = cluster_assignments == i if self.boundary_margin == 1 else weights[:, i] > 0 36 | sub_input = x[cluster_mask, 3:] if self.xyz_real else x[cluster_mask] 37 | 38 | if sub_input.shape[0] > 0: 39 | sub_result = child(sub_input, sigma_only, 40 | sigma_noise[cluster_mask] if sigma_noise is not None else None) 41 | 42 | if results.shape[0] == 0: 43 | results = torch.zeros(x.shape[0], sub_result.shape[1], device=sub_result.device, 44 | dtype=sub_result.dtype) 45 | 46 | if self.boundary_margin == 1: 47 | results[cluster_mask] = sub_result 48 | else: 49 | results[cluster_mask] += sub_result * weights[cluster_mask, i].unsqueeze(-1) 50 | 51 | elif self.joint_training: # Hack to make distributed training happy 52 | sub_result = child(x[:0, 3:] if self.xyz_real else x[:0], sigma_only, 53 | sigma_noise[:0] if sigma_noise is not None else None) 54 | 55 | if results.shape[0] == 0: 56 | results = torch.zeros(x.shape[0], sub_result.shape[1], device=sub_result.device, 57 | dtype=sub_result.dtype) 58 | 59 | results[:0] += 0 * sub_result 60 | 61 | return results 62 | -------------------------------------------------------------------------------- /mega_nerf/models/mega_nerf_container.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MegaNeRFContainer(nn.Module): 8 | def __init__(self, sub_modules: List[nn.Module], bg_sub_modules: List[nn.Module], centroids: torch.Tensor, 9 | grid_dim: torch.Tensor, min_position: torch.Tensor, max_position: torch.Tensor, need_viewdir: bool, 10 | need_appearance_embedding: bool, cluster_2d: bool): 11 | super(MegaNeRFContainer, self).__init__() 12 | 13 | for i, sub_module in enumerate(sub_modules): 14 | setattr(self, 'sub_module_{}'.format(i), sub_module) 15 | 16 | for i, bg_sub_module in enumerate(bg_sub_modules): 17 | setattr(self, 'bg_sub_module_{}'.format(i), bg_sub_module) 18 | 19 | self.centroids = centroids 20 | self.grid_dim = grid_dim 21 | self.min_position = min_position 22 | self.max_position = max_position 23 | self.need_viewdir = need_viewdir 24 | self.need_appearance_embedding = need_appearance_embedding 25 | self.cluster_2d = cluster_2d 26 | -------------------------------------------------------------------------------- /mega_nerf/models/nerf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | 8 | class Embedding(nn.Module): 9 | def __init__(self, num_freqs: int, logscale=True): 10 | """ 11 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 12 | """ 13 | super(Embedding, self).__init__() 14 | 15 | if logscale: 16 | self.freq_bands = 2 ** torch.linspace(0, num_freqs - 1, num_freqs) 17 | else: 18 | self.freq_bands = torch.linspace(1, 2 ** (num_freqs - 1), num_freqs) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | out = [x] 22 | for freq in self.freq_bands: 23 | out += [torch.sin(freq * x), torch.cos(freq * x)] 24 | 25 | return torch.cat(out, -1) 26 | 27 | 28 | class ShiftedSoftplus(nn.Module): 29 | __constants__ = ['beta', 'threshold'] 30 | beta: int 31 | threshold: int 32 | 33 | def __init__(self, beta: int = 1, threshold: int = 20) -> None: 34 | super(ShiftedSoftplus, self).__init__() 35 | self.beta = beta 36 | self.threshold = threshold 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return F.softplus(x - 1, self.beta, self.threshold) 40 | 41 | def extra_repr(self) -> str: 42 | return 'beta={}, threshold={}'.format(self.beta, self.threshold) 43 | 44 | 45 | class NeRF(nn.Module): 46 | def __init__(self, pos_xyz_dim: int, pos_dir_dim: int, layers: int, skip_layers: List[int], layer_dim: int, 47 | appearance_dim: int, affine_appearance: bool, appearance_count: int, rgb_dim: int, xyz_dim: int, 48 | sigma_activation: nn.Module, hparams): 49 | super(NeRF, self).__init__() 50 | print('original Mega Nerf') 51 | self.xyz_dim = xyz_dim 52 | 53 | if rgb_dim > 3: 54 | assert pos_dir_dim == 0 55 | 56 | self.embedding_xyz = Embedding(pos_xyz_dim) 57 | in_channels_xyz = xyz_dim + xyz_dim * pos_xyz_dim * 2 58 | 59 | self.skip_layers = skip_layers 60 | 61 | xyz_encodings = [] 62 | 63 | # xyz encoding layers 64 | for i in range(layers): 65 | if i == 0: 66 | layer = nn.Linear(in_channels_xyz, layer_dim) 67 | elif i in skip_layers: 68 | layer = nn.Linear(layer_dim + in_channels_xyz, layer_dim) 69 | else: 70 | layer = nn.Linear(layer_dim, layer_dim) 71 | layer = nn.Sequential(layer, nn.ReLU(True)) 72 | xyz_encodings.append(layer) 73 | 74 | self.xyz_encodings = nn.ModuleList(xyz_encodings) 75 | 76 | if pos_dir_dim > 0: 77 | self.embedding_dir = Embedding(pos_dir_dim) 78 | in_channels_dir = 3 + 3 * pos_dir_dim * 2 79 | else: 80 | self.embedding_dir = None 81 | in_channels_dir = 0 82 | 83 | if appearance_dim > 0: 84 | self.embedding_a = nn.Embedding(appearance_count, appearance_dim) 85 | else: 86 | self.embedding_a = None 87 | 88 | if affine_appearance: 89 | assert appearance_dim > 0 90 | self.affine = nn.Linear(appearance_dim, 12) 91 | else: 92 | self.affine = None 93 | 94 | if pos_dir_dim > 0 or (appearance_dim > 0 and not affine_appearance): 95 | self.xyz_encoding_final = nn.Linear(layer_dim, layer_dim) 96 | # direction and appearance encoding layers 97 | self.dir_a_encoding = nn.Sequential( 98 | nn.Linear(layer_dim + in_channels_dir + (appearance_dim if not affine_appearance else 0), 99 | layer_dim // 2), 100 | nn.ReLU(True)) 101 | else: 102 | self.xyz_encoding_final = None 103 | 104 | # output layers 105 | self.sigma = nn.Linear(layer_dim, 1) 106 | self.sigma_activation = sigma_activation 107 | 108 | self.rgb = nn.Linear( 109 | layer_dim // 2 if (pos_dir_dim > 0 or (appearance_dim > 0 and not affine_appearance)) else layer_dim, 110 | rgb_dim) 111 | if rgb_dim == 3: 112 | self.rgb_activation = nn.Sigmoid() # = nn.Sequential(rgb, nn.Sigmoid()) 113 | else: 114 | self.rgb_activation = None # We're using spherical harmonics and will convert to sigmoid in rendering.py 115 | 116 | def forward(self, x: torch.Tensor, sigma_only: bool = False, 117 | sigma_noise: Optional[torch.Tensor] = None,train_iterations: int =-1) -> torch.Tensor: 118 | expected = self.xyz_dim \ 119 | + (0 if (sigma_only or self.embedding_dir is None) else 3) \ 120 | + (0 if (sigma_only or self.embedding_a is None) else 1) 121 | 122 | if x.shape[1] != expected: 123 | raise Exception( 124 | 'Unexpected input shape: {} (expected: {}, xyz_dim: {})'.format(x.shape, expected, self.xyz_dim)) 125 | 126 | input_xyz = self.embedding_xyz(x[:, :self.xyz_dim]) 127 | xyz_ = input_xyz 128 | for i, xyz_encoding in enumerate(self.xyz_encodings): 129 | if i in self.skip_layers: 130 | xyz_ = torch.cat([input_xyz, xyz_], -1) 131 | xyz_ = xyz_encoding(xyz_) 132 | 133 | sigma = self.sigma(xyz_) 134 | if sigma_noise is not None: 135 | sigma += sigma_noise 136 | 137 | sigma = self.sigma_activation(sigma) 138 | 139 | if sigma_only: 140 | return sigma 141 | 142 | if self.xyz_encoding_final is not None: 143 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 144 | dir_a_encoding_input = [xyz_encoding_final] 145 | 146 | if self.embedding_dir is not None: 147 | dir_a_encoding_input.append(self.embedding_dir(x[:, -4:-1])) 148 | 149 | if self.embedding_a is not None and self.affine is None: 150 | dir_a_encoding_input.append(self.embedding_a(x[:, -1].long())) 151 | 152 | dir_a_encoding = self.dir_a_encoding(torch.cat(dir_a_encoding_input, -1)) 153 | rgb = self.rgb(dir_a_encoding) 154 | else: 155 | rgb = self.rgb(xyz_) 156 | 157 | if self.affine is not None and self.embedding_a is not None: 158 | affine_transform = self.affine(self.embedding_a(x[:, -1].long())).view(-1, 3, 4) 159 | rgb = (affine_transform[:, :, :3] @ rgb.unsqueeze(-1) + affine_transform[:, :, 3:]).squeeze(-1) 160 | 161 | return torch.cat([self.rgb_activation(rgb) if self.rgb_activation is not None else rgb, sigma], -1) 162 | -------------------------------------------------------------------------------- /mega_nerf/ray_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def get_ray_directions(W: int, H: int, fx: float, fy: float, cx: float, cy: float, center_pixels: bool, 7 | device: torch.device) -> torch.Tensor: 8 | i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32, device=device), 9 | torch.arange(H, dtype=torch.float32, device=device), indexing='xy') 10 | if center_pixels: 11 | i = i.clone() + 0.5 12 | j = j.clone() + 0.5 13 | 14 | directions = \ 15 | torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3) 16 | directions = directions / torch.linalg.norm(directions, dim=-1, keepdim=True) 17 | 18 | return directions 19 | 20 | 21 | def get_rays(directions: torch.Tensor, c2w: torch.Tensor, near: float, far: float, 22 | ray_altitude_range: List[float]) -> torch.Tensor: 23 | # Rotate ray directions from camera coordinate to the world coordinate 24 | rays_d = directions @ c2w[:, :3].T # (H, W, 3) 25 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 26 | 27 | # The origin of all rays is the camera origin in world coordinate 28 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3) 29 | 30 | return _get_rays_inner(rays_o, rays_d, near, far, ray_altitude_range) 31 | 32 | 33 | def get_rays_batch(directions: torch.Tensor, c2w: torch.Tensor, near: float, far: float, 34 | ray_altitude_range: List[float]) -> torch.Tensor: 35 | # Rotate ray directions from camera coordinate to the world coordinate 36 | rays_d = directions @ c2w[:, :, :3].transpose(1, 2) # (n, H*W, 3) 37 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 38 | # The origin of all rays is the camera origin in world coordinate 39 | rays_o = c2w[:, :, 3].unsqueeze(1).expand(rays_d.shape) # (n, H*W, 3) 40 | 41 | return _get_rays_inner(rays_o, rays_d, near, far, ray_altitude_range) 42 | 43 | 44 | def _get_rays_inner(rays_o: torch.Tensor, rays_d: torch.Tensor, near: float, far: float, 45 | ray_altitude_range: List[float]) -> torch.Tensor: 46 | # c2w is drb, ray_altitude_range is max_altitude (neg), min_altitude (neg) 47 | near_bounds = near * torch.ones_like(rays_o[..., :1]) 48 | far_bounds = far * torch.ones_like(rays_o[..., :1]) 49 | 50 | if ray_altitude_range is not None: 51 | _truncate_with_plane_intersection(rays_o, rays_d, ray_altitude_range[0], near_bounds) 52 | near_bounds = torch.clamp(near_bounds, min=near) 53 | _truncate_with_plane_intersection(rays_o, rays_d, ray_altitude_range[1], far_bounds) 54 | far_bounds = torch.clamp(far_bounds, max=far) 55 | far_bounds = torch.maximum(near_bounds, far_bounds) 56 | 57 | return torch.cat([rays_o, 58 | rays_d, 59 | near_bounds, 60 | far_bounds], 61 | -1) # (h, w, 8) 62 | 63 | 64 | def _truncate_with_plane_intersection(rays_o: torch.Tensor, rays_d: torch.Tensor, altitude: float, 65 | default_bounds: torch.Tensor) -> None: 66 | starts_before = rays_o[:, :, 0] < altitude 67 | goes_down = rays_d[:, :, 0] > 0 68 | boundable_rays = torch.minimum(starts_before, goes_down) 69 | 70 | ray_points = rays_o[boundable_rays] 71 | if ray_points.shape[0] == 0: 72 | return 73 | 74 | ray_directions = rays_d[boundable_rays] 75 | 76 | plane_normal = torch.FloatTensor([-1, 0, 0]).to(rays_o.device).unsqueeze(1) 77 | ndotu = ray_directions.mm(plane_normal) 78 | 79 | plane_point = torch.FloatTensor([altitude, 0, 0]).to(rays_o.device) 80 | w = ray_points - plane_point 81 | si = -w.mm(plane_normal) / ndotu 82 | plane_intersection = w + si * ray_directions + plane_point 83 | default_bounds[boundable_rays] = (ray_points - plane_intersection).norm(dim=-1).unsqueeze(1) 84 | -------------------------------------------------------------------------------- /mega_nerf/spherical_harmonics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | import torch 24 | 25 | C0 = 0.28209479177387814 26 | C1 = 0.4886025119029199 27 | C2 = [ 28 | 1.0925484305920792, 29 | -1.0925484305920792, 30 | 0.31539156525252005, 31 | -1.0925484305920792, 32 | 0.5462742152960396 33 | ] 34 | C3 = [ 35 | -0.5900435899266435, 36 | 2.890611442640554, 37 | -0.4570457994644658, 38 | 0.3731763325901154, 39 | -0.4570457994644658, 40 | 1.445305721320277, 41 | -0.5900435899266435 42 | ] 43 | C4 = [ 44 | 2.5033429417967046, 45 | -1.7701307697799304, 46 | 0.9461746957575601, 47 | -0.6690465435572892, 48 | 0.10578554691520431, 49 | -0.6690465435572892, 50 | 0.47308734787878004, 51 | -1.7701307697799304, 52 | 0.6258357354491761, 53 | ] 54 | 55 | def eval_sh(deg: int, sh: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Evaluate spherical harmonics at unit directions 58 | using hardcoded SH polynomials. 59 | Works with torch/np/jnp. 60 | ... Can be 0 or more batch dimensions. 61 | Args: 62 | deg: int SH deg. Currently, 0-3 supported 63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 64 | dirs: jnp.ndarray unit directions [..., 3] 65 | Returns: 66 | [..., C] 67 | """ 68 | assert deg <= 4 and deg >= 0 69 | assert (deg + 1) ** 2 == sh.shape[-1] 70 | 71 | result = C0 * sh[..., 0] 72 | if deg > 0: 73 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 74 | result = (result - 75 | C1 * y * sh[..., 1] + 76 | C1 * z * sh[..., 2] - 77 | C1 * x * sh[..., 3]) 78 | if deg > 1: 79 | xx, yy, zz = x * x, y * y, z * z 80 | xy, yz, xz = x * y, y * z, x * z 81 | result = (result + 82 | C2[0] * xy * sh[..., 4] + 83 | C2[1] * yz * sh[..., 5] + 84 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 85 | C2[3] * xz * sh[..., 7] + 86 | C2[4] * (xx - yy) * sh[..., 8]) 87 | 88 | if deg > 2: 89 | result = (result + 90 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 91 | C3[1] * xy * z * sh[..., 10] + 92 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 93 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 94 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 95 | C3[5] * z * (xx - yy) * sh[..., 14] + 96 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 97 | if deg > 3: 98 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 99 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 100 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 101 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 102 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 103 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 104 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 105 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 106 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 107 | return result -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | configargparse 2 | PyYAML 3 | opencv-python 4 | tensorboard 5 | tqdm 6 | pyarrow 7 | lpips 8 | wandb 9 | setuptools==56.1.0 10 | pandas 11 | -------------------------------------------------------------------------------- /scripts/colmap_to_mega_nerf.py: -------------------------------------------------------------------------------- 1 | # All of the model reading methods are taken from https://github.com/cvg/Hierarchical-Localization, original 2 | # license listed below: 3 | # 4 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright 11 | # notice, this list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright 14 | # notice, this list of conditions and the following disclaimer in the 15 | # documentation and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 18 | # its contributors may be used to endorse or promote products derived 19 | # from this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | # 33 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 34 | 35 | import argparse 36 | import collections 37 | import os 38 | import struct 39 | from argparse import Namespace 40 | from pathlib import Path 41 | 42 | import cv2 43 | import numpy as np 44 | import torch 45 | from tqdm import tqdm 46 | 47 | import zipfile 48 | from zipfile import ZipFile 49 | 50 | RDF_TO_DRB = torch.FloatTensor([[0, 1, 0], 51 | [1, 0, 0], 52 | [0, 0, -1]]) 53 | 54 | CameraModel = collections.namedtuple( 55 | "CameraModel", ["model_id", "model_name", "num_params"]) 56 | Camera = collections.namedtuple( 57 | "Camera", ["id", "model", "width", "height", "params"]) 58 | BaseImage = collections.namedtuple( 59 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 60 | Point3D = collections.namedtuple( 61 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 62 | 63 | CAMERA_MODELS = { 64 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 65 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 66 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 67 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 68 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 69 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 70 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 71 | CameraModel(model_id=7, model_name="FOV", num_params=5), 72 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 73 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 74 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 75 | } 76 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 77 | for camera_model in CAMERA_MODELS]) 78 | 79 | 80 | def qvec2rotmat(qvec): 81 | return np.array([ 82 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 83 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 84 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 85 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 86 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 87 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 88 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 89 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 90 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) 91 | 92 | 93 | class Image(BaseImage): 94 | def qvec2rotmat(self): 95 | return qvec2rotmat(self.qvec) 96 | 97 | 98 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 99 | """Read and unpack the next bytes from a binary file. 100 | :param fid: 101 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 102 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 103 | :param endian_character: Any of {@, =, <, >, !} 104 | :return: Tuple of read and unpacked values. 105 | """ 106 | data = fid.read(num_bytes) 107 | return struct.unpack(endian_character + format_char_sequence, data) 108 | 109 | 110 | def read_points3D_binary(path_to_model_file): 111 | """ 112 | see: src/base/reconstruction.cc 113 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 114 | void Reconstruction::WritePoints3DBinary(const std::string& path) 115 | """ 116 | points3D = {} 117 | with open(path_to_model_file, "rb") as fid: 118 | num_points = read_next_bytes(fid, 8, "Q")[0] 119 | for _ in range(num_points): 120 | binary_point_line_properties = read_next_bytes( 121 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 122 | point3D_id = binary_point_line_properties[0] 123 | xyz = np.array(binary_point_line_properties[1:4]) 124 | rgb = np.array(binary_point_line_properties[4:7]) 125 | error = np.array(binary_point_line_properties[7]) 126 | track_length = read_next_bytes( 127 | fid, num_bytes=8, format_char_sequence="Q")[0] 128 | track_elems = read_next_bytes( 129 | fid, num_bytes=8 * track_length, 130 | format_char_sequence="ii" * track_length) 131 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 132 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 133 | points3D[point3D_id] = Point3D( 134 | id=point3D_id, xyz=xyz, rgb=rgb, 135 | error=error, image_ids=image_ids, 136 | point2D_idxs=point2D_idxs) 137 | return points3D 138 | 139 | 140 | def read_images_binary(path_to_model_file): 141 | """ 142 | see: src/base/reconstruction.cc 143 | void Reconstruction::ReadImagesBinary(const std::string& path) 144 | void Reconstruction::WriteImagesBinary(const std::string& path) 145 | """ 146 | images = {} 147 | with open(path_to_model_file, "rb") as fid: 148 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 149 | for _ in range(num_reg_images): 150 | binary_image_properties = read_next_bytes( 151 | fid, num_bytes=64, format_char_sequence="idddddddi") 152 | image_id = binary_image_properties[0] 153 | qvec = np.array(binary_image_properties[1:5]) 154 | tvec = np.array(binary_image_properties[5:8]) 155 | camera_id = binary_image_properties[8] 156 | image_name = "" 157 | current_char = read_next_bytes(fid, 1, "c")[0] 158 | while current_char != b"\x00": # look for the ASCII 0 entry 159 | image_name += current_char.decode("utf-8") 160 | current_char = read_next_bytes(fid, 1, "c")[0] 161 | num_points2D = read_next_bytes(fid, num_bytes=8, 162 | format_char_sequence="Q")[0] 163 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, 164 | format_char_sequence="ddq" * num_points2D) 165 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 166 | tuple(map(float, x_y_id_s[1::3]))]) 167 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 168 | images[image_id] = Image( 169 | id=image_id, qvec=qvec, tvec=tvec, 170 | camera_id=camera_id, name=image_name, 171 | xys=xys, point3D_ids=point3D_ids) 172 | return images 173 | 174 | 175 | def read_cameras_binary(path_to_model_file): 176 | """ 177 | see: src/base/reconstruction.cc 178 | void Reconstruction::WriteCamerasBinary(const std::string& path) 179 | void Reconstruction::ReadCamerasBinary(const std::string& path) 180 | """ 181 | cameras = {} 182 | with open(path_to_model_file, "rb") as fid: 183 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 184 | for _ in range(num_cameras): 185 | camera_properties = read_next_bytes( 186 | fid, num_bytes=24, format_char_sequence="iiQQ") 187 | camera_id = camera_properties[0] 188 | model_id = camera_properties[1] 189 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 190 | width = camera_properties[2] 191 | height = camera_properties[3] 192 | num_params = CAMERA_MODEL_IDS[model_id].num_params 193 | params = read_next_bytes(fid, num_bytes=8 * num_params, 194 | format_char_sequence="d" * num_params) 195 | cameras[camera_id] = Camera(id=camera_id, 196 | model=model_name, 197 | width=width, 198 | height=height, 199 | params=np.array(params)) 200 | assert len(cameras) == num_cameras 201 | return cameras 202 | 203 | 204 | def read_points3D_text(path): 205 | """ 206 | see: src/base/reconstruction.cc 207 | void Reconstruction::ReadPoints3DText(const std::string& path) 208 | void Reconstruction::WritePoints3DText(const std::string& path) 209 | """ 210 | points3D = {} 211 | with open(path, "r") as fid: 212 | while True: 213 | line = fid.readline() 214 | if not line: 215 | break 216 | line = line.strip() 217 | if len(line) > 0 and line[0] != "#": 218 | elems = line.split() 219 | point3D_id = int(elems[0]) 220 | xyz = np.array(tuple(map(float, elems[1:4]))) 221 | rgb = np.array(tuple(map(int, elems[4:7]))) 222 | error = float(elems[7]) 223 | image_ids = np.array(tuple(map(int, elems[8::2]))) 224 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 225 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 226 | error=error, image_ids=image_ids, 227 | point2D_idxs=point2D_idxs) 228 | return points3D 229 | 230 | 231 | def read_images_text(path): 232 | """ 233 | see: src/base/reconstruction.cc 234 | void Reconstruction::ReadImagesText(const std::string& path) 235 | void Reconstruction::WriteImagesText(const std::string& path) 236 | """ 237 | images = {} 238 | with open(path, "r") as fid: 239 | while True: 240 | line = fid.readline() 241 | if not line: 242 | break 243 | line = line.strip() 244 | if len(line) > 0 and line[0] != "#": 245 | elems = line.split() 246 | image_id = int(elems[0]) 247 | qvec = np.array(tuple(map(float, elems[1:5]))) 248 | tvec = np.array(tuple(map(float, elems[5:8]))) 249 | camera_id = int(elems[8]) 250 | image_name = elems[9] 251 | elems = fid.readline().split() 252 | xys = np.column_stack([tuple(map(float, elems[0::3])), 253 | tuple(map(float, elems[1::3]))]) 254 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 255 | images[image_id] = Image( 256 | id=image_id, qvec=qvec, tvec=tvec, 257 | camera_id=camera_id, name=image_name, 258 | xys=xys, point3D_ids=point3D_ids) 259 | return images 260 | 261 | 262 | def read_cameras_text(path): 263 | """ 264 | see: src/base/reconstruction.cc 265 | void Reconstruction::WriteCamerasText(const std::string& path) 266 | void Reconstruction::ReadCamerasText(const std::string& path) 267 | """ 268 | cameras = {} 269 | with open(path, "r") as fid: 270 | while True: 271 | line = fid.readline() 272 | if not line: 273 | break 274 | line = line.strip() 275 | if len(line) > 0 and line[0] != "#": 276 | elems = line.split() 277 | camera_id = int(elems[0]) 278 | model = elems[1] 279 | width = int(elems[2]) 280 | height = int(elems[3]) 281 | params = np.array(tuple(map(float, elems[4:]))) 282 | cameras[camera_id] = Camera(id=camera_id, model=model, 283 | width=width, height=height, 284 | params=params) 285 | return cameras 286 | 287 | 288 | def detect_model_format(path, ext): 289 | if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ 290 | os.path.isfile(os.path.join(path, "images" + ext)) and \ 291 | os.path.isfile(os.path.join(path, "points3D" + ext)): 292 | return True 293 | 294 | return False 295 | 296 | 297 | def read_model(path, ext=""): 298 | # try to detect the extension automatically 299 | if ext == "": 300 | if detect_model_format(path, ".bin"): 301 | ext = ".bin" 302 | elif detect_model_format(path, ".txt"): 303 | ext = ".txt" 304 | else: 305 | try: 306 | cameras, images, points3D = read_model(os.path.join(path, "model/")) 307 | return cameras, images, points3D 308 | except FileNotFoundError: 309 | raise FileNotFoundError( 310 | f"Could not find binary or text COLMAP model at {path}") 311 | 312 | if ext == ".txt": 313 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 314 | images = read_images_text(os.path.join(path, "images" + ext)) 315 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 316 | else: 317 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 318 | images = read_images_binary(os.path.join(path, "images" + ext)) 319 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) 320 | return cameras, images, points3D 321 | 322 | 323 | def _get_opts() -> Namespace: 324 | parser = argparse.ArgumentParser() 325 | 326 | parser.add_argument('--model_path', type=str, required=True, help='Path to PixSFM/COLMAP model') 327 | parser.add_argument('--images_path', type=str, required=True, help='Path to images') 328 | parser.add_argument('--output_path', type=str, required=True, help='Path to write converted dataset to') 329 | parser.add_argument('--scale', type=float, required=True, 330 | help='Scale all poses by this factor. You generally want this to be between [-1, 1]') 331 | parser.add_argument('--num_val', type=int, default=20, help='Number of images to hold out in validation set') 332 | 333 | return parser.parse_args() 334 | 335 | 336 | def main(hparams: Namespace) -> None: 337 | cameras, images, _ = read_model(hparams.model_path) 338 | 339 | c2ws = {} 340 | for image in images.values(): 341 | w2c = torch.eye(4) 342 | w2c[:3, :3] = torch.FloatTensor(qvec2rotmat(image.qvec)) 343 | w2c[:3, 3] = torch.FloatTensor(image.tvec) 344 | c2w = torch.inverse(w2c) 345 | 346 | c2w = torch.hstack(( 347 | RDF_TO_DRB @ c2w[:3, :3] @ torch.inverse(RDF_TO_DRB), 348 | RDF_TO_DRB @ c2w[:3, 3:] 349 | )) 350 | 351 | c2ws[image.id] = c2w 352 | 353 | positions = torch.cat([c2w[:3, 3].unsqueeze(0) for c2w in c2ws.values()]) 354 | print('{} images'.format(positions.shape[0])) 355 | max_values = positions.max(0)[0] 356 | min_values = positions.min(0)[0] 357 | origin = ((max_values + min_values) * 0.5) 358 | dist = (positions - origin).norm(dim=-1) 359 | diagonal = dist.max() 360 | 361 | print(origin, diagonal, max_values, min_values) 362 | coordinates = { 363 | 'origin_drb': origin, 364 | 'pose_scale_factor': hparams.scale 365 | } 366 | 367 | output_path = Path(hparams.output_path) 368 | output_path.mkdir(parents=True) 369 | (output_path / 'train' / 'metadata').mkdir(parents=True) 370 | (output_path / 'val' / 'metadata').mkdir(parents=True) 371 | 372 | (output_path / 'train' / 'rgbs').mkdir(parents=True) 373 | (output_path / 'val' / 'rgbs').mkdir(parents=True) 374 | 375 | images_path = Path(hparams.images_path) 376 | 377 | with (output_path / 'mappings.txt').open('w') as f: 378 | for i, image in enumerate(tqdm(sorted(images.values(), key=lambda x: x.name))): 379 | if i % int(positions.shape[0] / hparams.num_val) == 0: 380 | split_dir = output_path / 'val' 381 | else: 382 | split_dir = output_path / 'train' 383 | 384 | distorted = cv2.imread(str(images_path / image.name)) 385 | 386 | camera = cameras[image.camera_id] 387 | 388 | # TODO: make camera model more flexible - should mainly involve changing the camera matrix accordingly 389 | assert camera.model == 'SIMPLE_RADIAL', camera.model 390 | 391 | camera_matrix = np.array([[camera.params[0], 0, camera.params[1]], 392 | [0, camera.params[0], camera.params[2]], 393 | [0, 0, 1]]) 394 | 395 | distortion = np.array([camera.params[3], 0, 0, 0]) 396 | undistorted = cv2.undistort(distorted, camera_matrix, distortion) 397 | cv2.imwrite(str(split_dir / 'rgbs' / '{0:06d}.jpg'.format(i)), undistorted) 398 | 399 | camera_in_drb = c2ws[image.id] 400 | camera_in_drb[:, 3] = (camera_in_drb[:, 3] - origin) / hparams.scale 401 | 402 | assert np.logical_and(camera_in_drb >= -1, camera_in_drb <= 1).all() 403 | 404 | metadata_name = '{0:06d}.pt'.format(i) 405 | torch.save({ 406 | 'H': distorted.shape[0], 407 | 'W': distorted.shape[1], 408 | 'c2w': torch.cat( 409 | [camera_in_drb[:, 1:2], -camera_in_drb[:, :1], camera_in_drb[:, 2:4]], 410 | -1), 411 | 'intrinsics': torch.FloatTensor( 412 | [camera_matrix[0][0], camera_matrix[1][1], camera_matrix[0][2], camera_matrix[1][2]]), 413 | 'distortion': torch.FloatTensor(distortion) 414 | }, split_dir / 'metadata' / metadata_name) 415 | 416 | f.write('{},{}\n'.format(image.name, metadata_name)) 417 | 418 | torch.save(coordinates, output_path / 'coordinates.pt') 419 | 420 | 421 | if __name__ == '__main__': 422 | main(_get_opts()) 423 | -------------------------------------------------------------------------------- /scripts/convert_to_container.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 6 | 7 | from mega_nerf.models.mega_nerf import MegaNeRF 8 | from mega_nerf.models.mega_nerf_container import MegaNeRFContainer 9 | from gp_nerf.models.model_utils import get_nerf, get_bg_nerf 10 | from gp_nerf.opts import get_opts_base 11 | 12 | 13 | def _get_merge_opts() -> Namespace: 14 | parser = get_opts_base() 15 | parser.add_argument('--output', type=str, required=True) 16 | 17 | return parser.parse_known_args()[0] 18 | 19 | 20 | @torch.inference_mode() 21 | def main(hparams: Namespace) -> None: 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | centroids = torch.zeros(1, 3) 25 | 26 | loaded = torch.load(hparams.ckpt_path, map_location='cpu') 27 | consume_prefix_in_state_dict_if_present(loaded['model_state_dict'], prefix='module.') 28 | 29 | if hparams.appearance_dim > 0: 30 | appearance_count = len(loaded['model_state_dict']['embedding_a.weight']) 31 | else: 32 | appearance_count = 0 33 | 34 | sub_module = get_nerf(hparams, appearance_count) 35 | model_dict = sub_module.state_dict() 36 | model_dict.update(loaded['model_state_dict']) 37 | sub_module.load_state_dict(model_dict) 38 | 39 | if 'bg_model_state_dict' in loaded: 40 | consume_prefix_in_state_dict_if_present(loaded['bg_model_state_dict'], prefix='module.') 41 | bg_sub_module = get_bg_nerf(hparams, appearance_count) 42 | model_dict = bg_sub_module.state_dict() 43 | model_dict.update(loaded['bg_model_state_dict']) 44 | bg_sub_module.load_state_dict(model_dict) 45 | 46 | container = MegaNeRFContainer([sub_module], [bg_sub_module] if 'bg_model_state_dict' in loaded else [], centroids, 47 | torch.IntTensor([1, 1]), 48 | torch.zeros(3), 49 | torch.ones(3), 50 | hparams.pos_dir_dim > 0, 51 | hparams.appearance_dim > 0, 52 | False) 53 | torch.jit.save(torch.jit.script(container.eval()), hparams.output) 54 | container = torch.jit.load(hparams.output, map_location='cpu') 55 | 56 | # Test container 57 | nerf = MegaNeRF([getattr(container, 'sub_module_{}'.format(i)) for i in range(len(container.centroids))], 58 | container.centroids, hparams.boundary_margin, False, container.cluster_2d).to(device) 59 | 60 | width = 3 61 | if hparams.pos_dir_dim > 0: 62 | width += 3 63 | if hparams.appearance_dim > 0: 64 | width += 1 65 | 66 | print('fg test eval: {}'.format(nerf(torch.ones(1, width, device=device)))) 67 | 68 | if 'bg_model_state_dict' in loaded: 69 | bg_nerf = MegaNeRF([getattr(container, 'bg_sub_module_{}'.format(i)) for i in range(len(container.centroids))], 70 | container.centroids, hparams.boundary_margin, True, container.cluster_2d).to(device) 71 | 72 | width += 4 73 | print('bg test eval: {}'.format(bg_nerf(torch.ones(1, width, device=device)))) 74 | 75 | 76 | if __name__ == '__main__': 77 | main(_get_merge_opts()) 78 | -------------------------------------------------------------------------------- /scripts/copy_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | 11 | def _get_images_opts(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--image_path', type=str, required=True) 15 | parser.add_argument('--dataset_path', type=str, required=True) 16 | 17 | return parser.parse_args() 18 | 19 | 20 | def main(hparams: Namespace) -> None: 21 | image_path = Path(hparams.image_path) 22 | dataset_path = Path(hparams.dataset_path) 23 | (dataset_path / 'train' / 'rgbs').mkdir() 24 | (dataset_path / 'val' / 'rgbs').mkdir() 25 | 26 | with (Path(hparams.dataset_path) / 'mappings.txt').open() as f: 27 | for line in tqdm(f): 28 | image_name, metadata_name = line.strip().split(',') 29 | metadata_path = dataset_path / 'train' / 'metadata' / metadata_name 30 | if not metadata_path.exists(): 31 | metadata_path = dataset_path / 'val' / 'metadata' / metadata_name 32 | assert metadata_path.exists() 33 | 34 | distorted = cv2.imread(str(image_path / image_name)) 35 | metadata = torch.load(metadata_path, map_location='cpu') 36 | intrinsics = metadata['intrinsics'] 37 | camera_matrix = np.array([[intrinsics[0], 0, intrinsics[2]], 38 | [0, intrinsics[1], intrinsics[3]], 39 | [0, 0, 1]]) 40 | 41 | undistorted = cv2.undistort(distorted, camera_matrix, metadata['distortion'].numpy()) 42 | assert undistorted.shape[0] == metadata['H'] 43 | assert undistorted.shape[1] == metadata['W'] 44 | 45 | cv2.imwrite(str(dataset_path / metadata_path.parent.parent / 'rgbs' / '{}.{}'.format(metadata_path.stem, 46 | image_name.split('.')[ 47 | -1])), 48 | undistorted) 49 | 50 | 51 | if __name__ == '__main__': 52 | main(_get_images_opts()) 53 | -------------------------------------------------------------------------------- /scripts/create_cluster_masks.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import traceback 4 | import zipfile 5 | from argparse import Namespace 6 | from pathlib import Path 7 | from zipfile import ZipFile 8 | 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | from torch.distributed.elastic.multiprocessing.errors import record 13 | 14 | import sys 15 | sys.path.append('.') 16 | 17 | from mega_nerf.misc_utils import main_tqdm, main_print 18 | from gp_nerf.opts import get_opts_base 19 | from mega_nerf.ray_utils import get_ray_directions, get_rays 20 | 21 | 22 | def _get_mask_opts() -> Namespace: 23 | parser = get_opts_base() 24 | 25 | parser.add_argument('--dataset_path', type=str, required=True) 26 | parser.add_argument('--segmentation_path', type=str, default=None) 27 | parser.add_argument('--output', type=str, required=True) 28 | parser.add_argument('--grid_dim', nargs='+', type=int, required=True) 29 | parser.add_argument('--ray_samples', type=int, default=1000) 30 | parser.add_argument('--ray_chunk_size', type=int, default=48 * 1024) 31 | parser.add_argument('--dist_chunk_size', type=int, default=64 * 1024 * 1024) 32 | parser.add_argument('--resume', default=False, action='store_true') 33 | 34 | return parser.parse_known_args()[0] 35 | 36 | 37 | @record 38 | @torch.inference_mode() 39 | def main(hparams: Namespace) -> None: 40 | assert hparams.ray_altitude_range is not None 41 | output_path = Path(hparams.output) 42 | 43 | if 'RANK' in os.environ: 44 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(0, hours=24)) 45 | torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) 46 | rank = int(os.environ['RANK']) 47 | if rank == 0: 48 | output_path.mkdir(parents=True, exist_ok=hparams.resume) 49 | dist.barrier() 50 | world_size = int(os.environ['WORLD_SIZE']) 51 | else: 52 | output_path.mkdir(parents=True, exist_ok=hparams.resume) 53 | rank = 0 54 | world_size = 1 55 | 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | 58 | dataset_path = Path(hparams.dataset_path) 59 | coordinate_info = torch.load(dataset_path / 'coordinates.pt', map_location='cpu') 60 | origin_drb = coordinate_info['origin_drb'] 61 | pose_scale_factor = coordinate_info['pose_scale_factor'] 62 | 63 | ray_altitude_range = [(x - origin_drb[0]) / pose_scale_factor for x in hparams.ray_altitude_range] 64 | 65 | metadata_paths = list((dataset_path / 'train' / 'metadata').iterdir()) \ 66 | + list((dataset_path / 'val' / 'metadata').iterdir()) 67 | 68 | camera_positions = torch.cat([torch.load(x, map_location='cpu')['c2w'][:3, 3].unsqueeze(0) for x in metadata_paths]) 69 | main_print('Number of images in dir: {}'.format(camera_positions.shape)) 70 | 71 | min_position = camera_positions.min(dim=0)[0] 72 | max_position = camera_positions.max(dim=0)[0] 73 | 74 | main_print('Coord range: {} {}'.format(min_position, max_position)) 75 | 76 | ranges = max_position[1:] - min_position[1:] 77 | offsets = [torch.arange(s) * ranges[i] / s + ranges[i] / (s * 2) for i, s in enumerate(hparams.grid_dim)] 78 | centroids = torch.stack((torch.zeros(hparams.grid_dim[0], hparams.grid_dim[1]), # Ignore altitude dimension 79 | torch.ones(hparams.grid_dim[0], hparams.grid_dim[1]) * min_position[1], 80 | torch.ones(hparams.grid_dim[0], hparams.grid_dim[1]) * min_position[2])).permute(1, 2, 0) 81 | centroids[:, :, 1] += offsets[0].unsqueeze(1) 82 | centroids[:, :, 2] += offsets[1] 83 | centroids = centroids.view(-1, 3) 84 | 85 | main_print('Centroids: {}'.format(centroids)) 86 | 87 | near = hparams.near / pose_scale_factor 88 | 89 | if hparams.far is not None: 90 | far = hparams.far / pose_scale_factor 91 | else: 92 | far = 2 93 | 94 | torch.save({ 95 | 'origin_drb': origin_drb, 96 | 'pose_scale_factor': pose_scale_factor, 97 | 'ray_altitude_range': ray_altitude_range, 98 | 'near': near, 99 | 'far': far, 100 | 'centroids': centroids, 101 | 'grid_dim': (hparams.grid_dim), 102 | 'min_position': min_position, 103 | 'max_position': max_position, 104 | 'cluster_2d': hparams.cluster_2d 105 | }, output_path / 'params.pt') 106 | 107 | z_steps = torch.linspace(0, 1, hparams.ray_samples, device=device) # (N_samples) 108 | centroids = centroids.to(device) 109 | 110 | if rank == 0 and not hparams.resume: 111 | for i in range(centroids.shape[0]): 112 | (output_path / str(i)).mkdir(parents=True) 113 | 114 | if 'RANK' in os.environ: 115 | dist.barrier() 116 | 117 | cluster_dim_start = 1 if hparams.cluster_2d else 0 118 | for subdir in ['train', 'val']: 119 | metadata_paths = list((dataset_path / subdir / 'metadata').iterdir()) 120 | for i in main_tqdm(np.arange(rank, len(metadata_paths), world_size)): 121 | metadata_path = metadata_paths[i] 122 | 123 | if hparams.resume: 124 | # Check to see if mask has been generated already 125 | all_valid = True 126 | filename = metadata_path.stem + '.pt' 127 | for j in range(centroids.shape[0]): 128 | mask_path = output_path / str(j) / filename 129 | if not mask_path.exists(): 130 | all_valid = False 131 | break 132 | else: 133 | try: 134 | with ZipFile(mask_path) as zf: 135 | with zf.open(filename) as f: 136 | torch.load(f, map_location='cpu') 137 | except: 138 | traceback.print_exc() 139 | all_valid = False 140 | break 141 | 142 | if all_valid: 143 | continue 144 | 145 | metadata = torch.load(metadata_path, map_location='cpu') 146 | 147 | c2w = metadata['c2w'].to(device) 148 | intrinsics = metadata['intrinsics'] 149 | directions = get_ray_directions(metadata['W'], 150 | metadata['H'], 151 | intrinsics[0], 152 | intrinsics[1], 153 | intrinsics[2], 154 | intrinsics[3], 155 | hparams.center_pixels, 156 | device) 157 | 158 | rays = get_rays(directions, c2w, near, far, ray_altitude_range).view(-1, 8) 159 | 160 | min_dist_ratios = [] 161 | for j in range(0, rays.shape[0], hparams.ray_chunk_size): 162 | rays_o = rays[j:j + hparams.ray_chunk_size, :3] 163 | rays_d = rays[j:j + hparams.ray_chunk_size, 3:6] 164 | 165 | near_bounds, far_bounds = rays[j:j + hparams.ray_chunk_size, 6:7], \ 166 | rays[j:j + hparams.ray_chunk_size, 7:8] # both (N_rays, 1) 167 | z_vals = near_bounds * (1 - z_steps) + far_bounds * z_steps 168 | 169 | xyz = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals.unsqueeze(-1) 170 | del rays_d 171 | del z_vals 172 | xyz = xyz.view(-1, 3) 173 | 174 | min_distances = [] 175 | cluster_distances = [] 176 | for k in range(0, xyz.shape[0], hparams.dist_chunk_size): 177 | distances = torch.cdist(xyz[k:k + hparams.dist_chunk_size, cluster_dim_start:], 178 | centroids[:, cluster_dim_start:]) 179 | cluster_distances.append(distances) 180 | min_distances.append(distances.min(dim=1)[0]) 181 | 182 | del xyz 183 | 184 | cluster_distances = torch.cat(cluster_distances).view(rays_o.shape[0], -1, 185 | centroids.shape[0]) # (rays, samples, clusters) 186 | min_distances = torch.cat(min_distances).view(rays_o.shape[0], -1) # (rays, samples) 187 | min_dist_ratio = (cluster_distances / (min_distances.unsqueeze(-1) + 1e-8)).min(dim=1)[0] 188 | del min_distances 189 | del cluster_distances 190 | del rays_o 191 | min_dist_ratios.append(min_dist_ratio) # (rays, clusters) 192 | 193 | min_dist_ratios = torch.cat(min_dist_ratios).view(metadata['H'], metadata['W'], centroids.shape[0]) 194 | 195 | filename = (metadata_path.stem + '.pt') 196 | 197 | if hparams.segmentation_path is not None: 198 | with ZipFile(Path(hparams.segmentation_path) / filename) as zf: 199 | with zf.open(filename) as zf2: 200 | segmentation_mask = torch.load(zf2, map_location='cpu') 201 | 202 | for j in range(centroids.shape[0]): 203 | cluster_ratios = min_dist_ratios[:, :, j] 204 | ray_in_cluster = cluster_ratios <= hparams.boundary_margin 205 | 206 | with ZipFile(output_path / str(j) / filename, compression=zipfile.ZIP_DEFLATED, mode='w') as zf: 207 | with zf.open(filename, 'w') as f: 208 | cluster_mask = ray_in_cluster.cpu() 209 | 210 | if hparams.segmentation_path is not None: 211 | cluster_mask = torch.logical_and(cluster_mask, segmentation_mask) 212 | 213 | torch.save(cluster_mask, f) 214 | 215 | del ray_in_cluster 216 | 217 | 218 | if __name__ == '__main__': 219 | main(_get_mask_opts()) 220 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export OMP_NUM_THREADS=4 3 | export CUDA_VISIBLE_DEVICES=4 4 | 5 | 6 | exp_name='logs/test' 7 | 8 | dataset1='Mill19' # "Mill19" "Quad6k" "UrbanScene3D" 9 | dataset2='building' # "building" "rubble" "quad" "residence" "sci-art" "campus" 10 | python gp_nerf/train.py --config_file configs/$dataset2.yaml --dataset_path /data/yuqi/Datasets/MegaNeRF/$dataset1/$dataset2/$dataset2-pixsfm --chunk_paths /data/yuqi/Datasets/MegaNeRF/$dataset1/${dataset2}_chunk-1 --exp_name $exp_name 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | --------------------------------------------------------------------------------