├── INSTALL.md ├── README.md ├── SPARFDataset.py ├── Svox2 ├── opt │ ├── autotune.py │ ├── calc_metrics.py │ ├── opt.py │ ├── reflect.py │ ├── render_imgs.py │ ├── render_imgs_circle.py │ ├── scripts │ │ ├── colmap2nsvf.py │ │ ├── create_split.py │ │ ├── proc_record3d.py │ │ ├── run_colmap.py │ │ ├── unsplit.py │ │ └── view_data.py │ ├── to_svox1.py │ └── util │ │ ├── __init__.py │ │ ├── co3d_dataset.py │ │ ├── config_util.py │ │ ├── dataset.py │ │ ├── dataset_base.py │ │ ├── llff_dataset.py │ │ ├── load_llff.py │ │ ├── nerf_dataset.py │ │ ├── nsvf_dataset.py │ │ └── util.py ├── setup.py ├── svox2 │ ├── __init__.py │ ├── csrc │ │ ├── .ccls │ │ ├── CMakeLists.txt │ │ ├── include │ │ │ ├── cubemap_util.cuh │ │ │ ├── cuda_util.cuh │ │ │ ├── data_spec.hpp │ │ │ ├── data_spec_packed.cuh │ │ │ ├── random_util.cuh │ │ │ ├── render_util.cuh │ │ │ └── util.hpp │ │ ├── loss_kernel.cu │ │ ├── misc_kernel.cu │ │ ├── optim_kernel.cu │ │ ├── render_lerp_kernel_cuvol.cu │ │ ├── render_lerp_kernel_nvol.cu │ │ ├── render_svox1_kernel.cu │ │ ├── svox2.cpp │ │ └── svox2_kernel.cu │ ├── defs.py │ ├── svox2.py │ ├── utils.py │ └── version.py └── test │ ├── prof.py │ ├── sanity.py │ ├── test_render_gradcheck.py │ ├── test_render_timing.py │ ├── test_render_timing_smallbat.py │ ├── test_render_visual.py │ ├── test_sample.py │ └── util.py ├── data ├── nerf_datasets │ └── rs_dtu_4 │ │ ├── check_same.py │ │ ├── proc.py │ │ ├── resize_cams.py │ │ └── resize_imgs.py └── source_meshes │ ├── alien.obj │ ├── candle.obj │ ├── horse.obj │ ├── lamp.obj │ ├── person.obj │ ├── shoe.obj │ └── vase.obj ├── datasets.py ├── examples └── mvimage_load.ipynb ├── extra_utils.py ├── models.py ├── nerf_ops.py ├── ops.py ├── plt_ops.py └── run_sparf.py /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | 4 | ## Requirements 5 | 6 | you need a Cuda `+11.1` installed on your system (check with command `nvcc --version` ). You also need `gcc --version` to be `>= 8.2.0`. Then follow the following steps: 7 | 8 | 1. install `Plenxels` and `Minkowski Engine` (depending on your system from [here](https://nvidia.github.io/MinkowskiEngine/quick_start.html) and [here](https://github.com/sxyu/svox2) ). 9 | Alternatively you can follow the following steps:where `CONDA_PREFIX` is the path to your conda environment. 10 | ```bash 11 | conda create --name sparf0 python=3.6 12 | conda activate sparf0 13 | conda install numpy pytorch=1.9.0 torchvision cudatoolkit=11.1 openblas-devel open3d=0.9.0 pytorch-lightning pyyaml jupyterlab -c open3d-admin -c conda-forge -c anaconda -c pytorch -c nvidia 14 | git clone https://github.com/NVIDIA/MinkowskiEngine 15 | cd MinkowskiEngine 16 | python setup.py install --blas=openblas --blas_include_dirs=${CONDA_PREFIX}/include 17 | cd .. 18 | pip install imageio imageio-ffmpeg ipdb lpips opencv-python Pillow pyyaml tensorboard imageio imageio-ffmpeg PyMCubes moviepy matplotlib scipy wandb pandas trimesh pyglet einops pyhocon ConfigArgParse timm dotmap pretrainedmodels scikit-image ipdb tqdm ipyplot 19 | cd Svox2 20 | pip install . 21 | python3 -m pip install pyvirtualdisplay # optional 22 | ``` 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPARF: Large-Scale Learning of 3D Sparse Radiance Fields from Few Input Images (ICCV 2023) 2 | By [Abdullah Hamdi](https://abdullahamdi.com/), [Bernard Ghanem](http://www.bernardghanem.com/), [Matthias Nießner](https://niessnerlab.org/members/matthias_niessner/profile.html) 3 | ### [Paper](https://openaccess.thecvf.com/content/ICCV2023W/AI3DCC/html/Hamdi_SPARF_Large-Scale_Learning_of_3D_Sparse_Radiance_Fields_from_Few_ICCVW_2023_paper.html) | [Video](https://youtu.be/VcjypZ0hp4w) | [Website](https://abdullahamdi.com/sparf/) | [Dataset (code: sparf)](https://exrcsdrive.kaust.edu.sa/index.php/s/AzPfy0k45X01ql3) .
4 |

5 | 6 | 7 | 8 |

9 | 10 | The official Pytroch code of the paper [SPARF: Large-Scale Learning of 3D Sparse Radiance Fields from Few Input Images](https://arxiv.org/abs/2212.09100). SPARF is a large-scale sparse radiance field dataset consisting of ~ 1 million SRFs with multiple voxel resolutions (32, 128, and 512) and 17 million posed images with a resolution of 400 X 400. Furthermore, we propose SuRFNet, a pipline to generate SRFs conditioned on input images, achieving SOTA on ShapeNet novel views synthesis from one or few input images. 11 | 12 | # Environment setup 13 | 14 | follow instructions in [INTALL.md](https://github.com/ajhamdi/sparf_pytorch/blob/main/INSTALL.md) to setup the conda environment. 15 | 16 | ## SPARF Posed Multi-View Image Dataset 17 | The dataset is released in the [link (code: sparf)](https://exrcsdrive.kaust.edu.sa/index.php/s/AzPfy0k45X01ql3). Each of SPARF's classes has the same structure of [NeRF-synthetic](https://github.com/sxyu/pixel-nerf) dataset and can be loaded similarly. Download all content in the link and place inside `data/SPARF_images`. Then you can run the [notebook example](https://github.com/ajhamdi/sparf_pytorch/blob/main/examples/mvimage_load.ipynb). 18 | 19 | 20 | ## SPARF Radiance Field Dataset 21 | The dataset is released in the [link](https://drive.google.com/drive/folders/1Qd_hBrRKR1vlCacOSyK_FN4igkHSbPSM?usp=sharing). Each of SPARF's instances has (beside the posed images above) two directories: `STF` (RGB voxels) and `SRF` (Spherical Harmonics voxels). The full radiance fileds are available under `//SRF/vox_/full`, where `` is the resolution (32, 128 or 512). The partial SRFs are stored in `//STF/vox_/partial` similarly. The partitioning (shards) and splits of the dataset is available on the file `SNRL_splits.csv` in the root of the dataset. The voxles information are stored as sparse voxels in `data_0.npz`as coords and values. 22 | 23 | Download all content in the link and place inside `data/SPARF_srf`. Then you can run the [main training code](https://github.com/ajhamdi/sparf_pytorch/blob/main/run_sparf.py). 24 | 25 | ## Script for rendering ShapeNet images used in creating SPARF 26 | make sure that `ShapeNetCore.v2` is downloaded and placed in `data/ShapeNetCore.v2`. Then run the following script to render the images used in creating SPARF. 27 | ```bash 28 | python run_sparf.py --run render --data_dir data/SPARF_srf/ --nb_views 400 --object_class car 29 | ``` 30 | ## Script for extracting SPARF Radiance Fields (full SRFs with voxel res=128 and SH dim=4) 31 | make sure that `SPARF_images` is downloaded and placed in `data/SPARF_images`. Then run the following script to extract the SRFs. 32 | ```bash 33 | python run_sparf.py --run extract --vox_res 128 --sh_dim 4 --object_class airplane --data_dir data/SPARF_images/ --visualize --evaluate 34 | ``` 35 | 36 | ## Script for extracting SPARF Radiance Fields (partial SRFs with voxel res=512 and SH dim=1, nb_views=3) 37 | make sure that `SPARF_images` is downloaded and placed in `data/SPARF_images`. Then run the following script to extract the SRFs. 38 | ```bash 39 | python run_sparf.py --run preprocess --vox_res 512 --sh_dim 1 --rf_variant 0 --object_class airplane --nb_views 3 --data_dir data/SPARF_images/ --randomized_views 40 | ``` 41 | ## Training and Inference pipeline on SPARF Radiance Fields 42 | make sure that `SPARF_srf` is downloaded and placed in `data/SPARF_srf`. Then run the following script to train on SRFs. 43 | ```bash 44 | python run_sparf.py --vox_res 128 --nb_views 3 --nb_rf_variants 4 --input_quantization_size 1.0 --strides 2 --lr_decay 0.99 --batch_size 6 --lr 1e-2 --visualize --normalize_input const --lambda_cls 30.0 --lambda_main 2.0 --augment_type none --mask_type densepoints --ignore_loss_mask --nb_frames 200 --validate_training --data_dir data/SPARF_srf/ --run train --object_class airplane 45 | ``` 46 | 47 | ## Citation 48 | If you find our work useful in your research, please consider citing: 49 | ```bibtex 50 | @InProceedings{Hamdi_2023_ICCV, 51 | author = {Hamdi, Abdullah and Ghanem, Bernard and Nie{\ss}sner, Matthias}, 52 | title = {SPARF: Large-Scale Learning of 3D Sparse Radiance Fields from Few Input Images}, 53 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops}, 54 | month = {October}, 55 | year = {2023}, 56 | pages = {2930-2940} 57 | } 58 | ``` 59 | 60 | -------------------------------------------------------------------------------- /SPARFDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import glob 5 | import imageio 6 | import numpy as np 7 | import pandas as pd 8 | import json 9 | from torchvision import transforms 10 | 11 | def get_image_to_tensor_balanced(image_size=0): 12 | ops = [] 13 | if image_size > 0: 14 | ops.append(transforms.Resize(image_size)) 15 | ops.extend( 16 | [transforms.ToTensor(), transforms.Normalize( 17 | (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] 18 | ) 19 | return transforms.Compose(ops) 20 | 21 | 22 | def get_mask_to_tensor(): 23 | return transforms.Compose( 24 | [transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))] 25 | ) 26 | 27 | class SPARFDataset(torch.utils.data.Dataset): 28 | """SPARF: posed Multi-view image dataset , Hamdi et.al, 2022 29 | A class for loading multi-view posed images of SPARF dataset. 30 | 31 | Parameters: 32 | ----------- 33 | data_dir : str 34 | The path to the directory containing the dataset. 35 | views_split : str, optional 36 | The split of views to be loaded. Can be "train", "test", or "hard". 37 | Default is "train". 38 | object_class : str, optional 39 | The category of objects to be loaded. Default is "car". possible classes: ["watercraft", "rifle", "display", "lamp", "speaker", "cabinet", "chair", "bench", "car", "airplane", "sofa", "table", "phone"] 40 | n_views : int or None, optional 41 | The number of views to be loaded. If None, all available views are loaded. 42 | Default is None. 43 | dset_partition : int, optional 44 | The partition of the dataset to be loaded. Can be -1 to include all 20 partitions 45 | of the data of that class [0:19], otherwise takes just a portion of the data. 46 | Default is -1. 47 | z_near : float, optional 48 | The minimum depth value of the camera. Default is 0.01. 49 | z_far : float, optional 50 | The maximum depth value of the camera. Default is 10.0. 51 | return_as_lists : bool, optional 52 | Whether to return the data as a dictionary of lists (each list has length `n_views`) 53 | or as a dictionary of tensors. Default is False. 54 | 55 | Examples: 56 | --------- 57 | # Initialize the dataset object 58 | data_dir = "/path/to/data" 59 | dset = SPARFDataset(data_dir) 60 | 61 | # Get the first data sample 62 | data = dset[0] 63 | 64 | 65 | # Get the images and masks of the first data sample 66 | images = data["images"] 67 | masks = data["masks"] 68 | ipyplot.plot_images(images,) 69 | """ 70 | 71 | def __init__(self, path, views_split="train", z_near=0.01, z_far=1000.0, n_views=None, object_class="car", dset_partition=-1,return_as_lists=False): 72 | super().__init__() 73 | self.views_split = views_split 74 | self.data_splits = ["train","test"] 75 | self.return_as_lists = return_as_lists 76 | 77 | self.object_class = object_class 78 | self.dset_partition = dset_partition 79 | self.object_class_dir = os.path.join(os.getcwd(),path, self.object_class) 80 | 81 | splits = pd.read_csv(os.path.join(os.getcwd(),path, "SNRL_splits.csv"), sep=",", dtype=str) 82 | avail_files = sorted(list(os.listdir(self.object_class_dir))) 83 | if self.dset_partition == -1: 84 | splits = splits[splits.file.isin(avail_files) & splits.classlabel.isin([ 85 | str(self.object_class)])] 86 | 87 | else: 88 | splits = splits[splits.file.isin(avail_files) & splits.partition.isin([str(x) for x in range( 89 | self.dset_partition+1)]) & splits.classlabel.isin([str(self.object_class)])] 90 | 91 | self.model_ids = list(splits[splits.split.isin(self.data_splits)]["file"]) 92 | # print(len(self.model_ids)) 93 | self.synset_ids = [ 94 | self.object_class for _ in range(len(self.model_ids))] 95 | 96 | # path = os.path.join(path, views_split) 97 | self.base_path = self.object_class_dir 98 | print("Loading NeRF synthetic dataset", self.base_path) 99 | # trans_files = [] 100 | # TRANS_FILE = "transforms_{}.json".format(self.views_split) 101 | # for root, directories, filenames in os.walk(self.base_path): 102 | # if TRANS_FILE in filenames: 103 | # trans_files.append(os.path.join(root, TRANS_FILE)) 104 | trans_files = [os.path.join(self.base_path, c_id, "transforms_{}.json".format(self.views_split))for c_id in self.model_ids] 105 | self.trans_files = trans_files 106 | self.image_to_tensor = get_image_to_tensor_balanced() 107 | self.mask_to_tensor = get_mask_to_tensor() 108 | 109 | self.z_near = z_near 110 | self.z_far = z_far 111 | self.lindisp = False 112 | self.n_views = n_views 113 | 114 | print("{} instances in split {}".format(len(self.trans_files), views_split)) 115 | 116 | def __len__(self): 117 | return len(self.trans_files) 118 | 119 | def _check_valid(self, index): 120 | if self.n_views is None: 121 | return True 122 | trans_file = self.trans_files[index] 123 | dir_path = os.path.dirname(trans_file) 124 | try: 125 | with open(trans_file, "r") as f: 126 | transform = json.load(f) 127 | except Exception as e: 128 | print("Problematic transforms.json file", trans_file) 129 | print("JSON loading exception", e) 130 | return False 131 | if len(transform["frames"]) < self.n_views: 132 | print("requested number of views ({}) is more than available {} views".format(self.n_views,len(transform["frames"]))) 133 | return False 134 | # if len(glob.glob(os.path.join(dir_path, "*.png"))) != self.n_views: 135 | # return False 136 | return True 137 | 138 | def __getitem__(self, index): 139 | if not self._check_valid(index): 140 | return {} 141 | 142 | trans_file = self.trans_files[index] 143 | dir_path = os.path.dirname(trans_file) 144 | with open(trans_file, "r") as f: 145 | transform = json.load(f) 146 | 147 | imgs = [] 148 | bboxes = [] 149 | masks = [] 150 | poses = [] 151 | for frame in transform["frames"]: 152 | fpath = frame["file_path"] 153 | basename = os.path.splitext(os.path.basename(fpath))[0] 154 | obj_path = os.path.join(dir_path,self.views_split, "{}.png".format(basename)) 155 | img = imageio.imread(obj_path) 156 | mask = self.mask_to_tensor(img[..., 3]) 157 | rows = np.any(img[..., 3], axis=1) 158 | cols = np.any(img[..., 3], axis=0) 159 | rnz = np.where(rows)[0] 160 | cnz = np.where(cols)[0] 161 | if len(rnz) == 0: 162 | cmin = rmin = 0 163 | cmax = mask.shape[-1] 164 | rmax = mask.shape[-2] 165 | else: 166 | rmin, rmax = rnz[[0, -1]] 167 | cmin, cmax = cnz[[0, -1]] 168 | bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32) 169 | 170 | img_tensor = self.image_to_tensor(img[..., :3]) 171 | img = img_tensor * mask + ( 172 | 1.0 - mask 173 | ) # solid white background where transparent 174 | imgs.append(img) 175 | bboxes.append(bbox) 176 | masks.append(mask) 177 | poses.append(torch.tensor(frame["transform_matrix"])) 178 | if not self.return_as_lists: 179 | imgs = torch.stack(imgs) 180 | masks = torch.stack(masks) 181 | bboxes = torch.stack(bboxes) 182 | poses = torch.stack(poses) 183 | 184 | H, W = imgs[0].shape[-2:] 185 | camera_angle_x = transform.get("camera_angle_x") 186 | focal = 0.5 * W / np.tan(0.5 * camera_angle_x) 187 | # print(bboxes.mean(), masks.mean()) 188 | result = { 189 | "path": dir_path, 190 | "img_id": index, 191 | "focal": focal, 192 | "images": imgs[:self.n_views], 193 | "masks": masks[:self.n_views], 194 | "bbox": bboxes[:self.n_views], 195 | "poses": poses[:self.n_views], 196 | } 197 | return result 198 | 199 | -------------------------------------------------------------------------------- /Svox2/opt/autotune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import random 7 | from multiprocessing import Process, Queue 8 | import os 9 | from os import path, listdir 10 | import argparse 11 | import json 12 | import subprocess 13 | import sys 14 | from typing import List, Dict 15 | import itertools 16 | from warnings import warn 17 | from datetime import datetime 18 | import numpy as np 19 | from glob import glob 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("task_json", type=str) 23 | parser.add_argument("--gpus", "-g", type=str, required=True, 24 | help="space delimited GPU id list (global id in nvidia-smi, " 25 | "not considering CUDA_VISIBLE_DEVICES)") 26 | parser.add_argument('--eval', action='store_true', default=False, 27 | help='evaluation mode (run the render_imgs script)') 28 | parser.add_argument('--render', action='store_true', default=False, 29 | help='also run render_imgs.py with --render_path to render a rotating trajectory (forward-facing case)') 30 | args = parser.parse_args() 31 | 32 | PSNR_FILE_NAME = 'test_psnr.txt' 33 | 34 | def run_exp(env, eval_mode:bool, enable_render:bool, train_dir, data_dir, config, flags, eval_flags, common_flags): 35 | opt_base_cmd = [ "python", "opt.py", "--tune_mode" ] 36 | 37 | if not eval_mode: 38 | opt_base_cmd += ["--tune_nosave"] 39 | opt_base_cmd += [ 40 | "-t", train_dir, 41 | data_dir 42 | ] 43 | if config != '': 44 | opt_base_cmd += ['-c', config] 45 | log_file_path = path.join(train_dir, 'log') 46 | psnr_file_path = path.join(train_dir, PSNR_FILE_NAME) 47 | ckpt_path = path.join(train_dir, 'ckpt.npz') 48 | if path.isfile(psnr_file_path): 49 | print('! SKIP', train_dir) 50 | return 51 | print('********************************************') 52 | if eval_mode: 53 | print('EVAL MODE') 54 | 55 | if eval_mode and path.isfile(ckpt_path): 56 | print('! SKIP training because ckpt exists', ckpt_path) 57 | opt_ret = "" # Silence 58 | else: 59 | print('! RUN opt.py -t', train_dir) 60 | opt_cmd = ' '.join(opt_base_cmd + flags + common_flags) 61 | print(opt_cmd) 62 | try: 63 | opt_ret = subprocess.check_output(opt_cmd, shell=True, env=env).decode( 64 | sys.stdout.encoding) 65 | except subprocess.CalledProcessError: 66 | print('Error occurred while running OPT for exp', train_dir, 'on', env["CUDA_VISIBLE_DEVICES"]) 67 | return 68 | with open(log_file_path, 'w') as f: 69 | f.write(opt_ret) 70 | 71 | if eval_mode: 72 | eval_base_cmd = [ 73 | "python", "render_imgs.py", 74 | ckpt_path, 75 | data_dir 76 | ] 77 | if config != '': 78 | eval_base_cmd += ['-c', config] 79 | psnr_file_path = path.join(train_dir, 'test_renders', 'psnr.txt') 80 | if not path.exists(psnr_file_path): 81 | eval_cmd = ' '.join(eval_base_cmd + eval_flags + common_flags) 82 | print('! RUN render_imgs.py', ckpt_path) 83 | print(eval_cmd) 84 | try: 85 | eval_ret = subprocess.check_output(eval_cmd, shell=True, env=env).decode( 86 | sys.stdout.encoding) 87 | except subprocess.CalledProcessError: 88 | print('Error occurred while running EVAL for exp', train_dir, 'on', env["CUDA_VISIBLE_DEVICES"]) 89 | return 90 | else: 91 | print('! SKIP eval because psnr.txt exists', psnr_file_path) 92 | 93 | if enable_render: 94 | eval_base_cmd += ['--render_path'] 95 | render_cmd = ' '.join(eval_base_cmd + eval_flags + common_flags) 96 | try: 97 | render_ret = subprocess.check_output(render_cmd, shell=True, env=env).decode( 98 | sys.stdout.encoding) 99 | except subprocess.CalledProcessError: 100 | print('Error occurred while running RENDER for exp', train_dir, 'on', env["CUDA_VISIBLE_DEVICES"]) 101 | return 102 | else: 103 | test_stats = [eval(x.split('eval stats:')[-1].strip()) 104 | for x in opt_ret.split('\n') if 105 | x.startswith('eval stats: ')] 106 | if len(test_stats) == 0: 107 | print('note: invalid config or crash') 108 | final_test_psnr = 0.0 109 | else: 110 | test_psnrs = [stats['psnr'] for stats in test_stats if 'psnr' in stats.keys()] 111 | print('final psnrs', test_psnrs[-5:]) 112 | final_test_psnr = test_psnrs[-1] 113 | with open(psnr_file_path, 'w') as f: 114 | f.write(str(final_test_psnr)) 115 | 116 | def process_main(device, eval_mode:bool, enable_render:bool, queue): 117 | # Set CUDA_VISIBLE_DEVICES programmatically 118 | env = os.environ.copy() 119 | env["CUDA_VISIBLE_DEVICES"] = str(device) 120 | while True: 121 | task = queue.get() 122 | if len(task) == 0: 123 | break 124 | run_exp(env, eval_mode, enable_render, **task) 125 | 126 | # Variable value list generation helpers 127 | def lin(start, stop, num): 128 | return np.linspace(start, stop, num).tolist() 129 | 130 | def randlin(start, stop, num): 131 | lst = np.linspace(start, stop, num + 1)[:-1] 132 | lst += np.random.uniform(low=0.0, high=(lst[1] - lst[0]), size=lst.shape) 133 | return lst.tolist() 134 | 135 | def loglin(start, stop, num): 136 | return np.exp(np.linspace(np.log(start), np.log(stop), num)).tolist() 137 | 138 | def randloglin(start, stop, num): 139 | lst = np.linspace(np.log(start), np.log(stop), num + 1)[:-1] 140 | lst += np.random.uniform(low=0.0, high=(lst[1] - lst[0]), size=lst.shape) 141 | return np.exp(lst).tolist() 142 | # End variable value list generation helpers 143 | 144 | def create_prodvars(variables, noise_stds={}): 145 | """ 146 | Create a dict for each setting of variable values 147 | (product across lists) 148 | """ 149 | 150 | def auto_list(x): 151 | if isinstance(x, list): 152 | return x 153 | elif isinstance(x, dict) or isinstance(x, set): 154 | return [x] 155 | elif isinstance(x, str): 156 | return eval(x) 157 | else: 158 | raise NotImplementedError('variable value must be list of values, or str generator') 159 | 160 | variables = {varname:auto_list(variables[varname]) for varname in variables} 161 | print('variables (prod)', variables) 162 | varnames = list(variables.keys()) 163 | noise_stds = np.array([noise_stds.get(varname, 0.0) for varname in varnames]) 164 | variables = [[(i, val) for val in variables[varname]] for i, varname in enumerate(varnames)] 165 | prodvars = list(itertools.product(*variables)) 166 | noise_vals = np.random.randn(len(prodvars), len(varnames)) * noise_stds 167 | prodvars = [{varnames[i]:((val + n) if n != 0.0 else val) for (i, val), n in zip(sample, noise_vals_samp)} for sample, noise_vals_samp in zip(prodvars, noise_vals)] 168 | return prodvars 169 | 170 | 171 | def recursive_replace(data, variables): 172 | if isinstance(data, str): 173 | return data.format(**variables) 174 | elif isinstance(data, list): 175 | return [recursive_replace(d, variables) for d in data] 176 | elif isinstance(data, dict): 177 | return {k:recursive_replace(data[k], variables) for k in data.keys()} 178 | else: 179 | return data 180 | 181 | 182 | if __name__ == '__main__': 183 | with open(args.task_json, 'r') as f: 184 | tasks_file = json.load(f) 185 | assert isinstance(tasks_file, dict), 'Root of json must be dict' 186 | all_tasks_templ = tasks_file.get('tasks', []) 187 | all_tasks = [] 188 | data_root = path.expanduser(tasks_file['data_root']) # Required 189 | train_root = path.expanduser(tasks_file['train_root']) # Required 190 | base_flags = tasks_file.get('base_flags', []) 191 | base_eval_flags = tasks_file.get('base_eval_flags', []) 192 | base_common_flags = tasks_file.get('base_common_flags', []) 193 | default_config = tasks_file.get('config', '') 194 | 195 | if 'eval' in tasks_file: 196 | args.eval = tasks_file['eval'] 197 | print('Eval mode?', args.eval) 198 | if 'render' in tasks_file: 199 | args.render = tasks_file['render'] 200 | print('Render traj?', args.render) 201 | pqueue = Queue() 202 | 203 | leaderboard_path = path.join(train_root, 'results.txt' if args.eval else 'leaderboard.txt') 204 | print('Leaderboard path:', leaderboard_path) 205 | 206 | variables : Dict = tasks_file.get('variables', {}) 207 | noises : Dict = tasks_file.get('noises', {}) 208 | assert isinstance(variables, dict), 'var must be dict' 209 | 210 | prodvars : List[Dict] = create_prodvars(variables, noises) 211 | del variables 212 | 213 | for task_templ in all_tasks_templ: 214 | for variables in prodvars: 215 | task : Dict = recursive_replace(task_templ, variables) 216 | task['train_dir'] = path.join(train_root, task['train_dir']) # Required 217 | task['data_dir'] = path.join(data_root, task.get('data_dir', '')).rstrip('/') 218 | task['flags'] = task.get('flags', []) + base_flags 219 | task['eval_flags'] = task.get('eval_flags', []) + base_eval_flags 220 | task['common_flags'] = task.get('common_flags', []) + base_common_flags 221 | task['config'] = task.get('config', default_config) 222 | os.makedirs(task['train_dir'], exist_ok=True) 223 | # santity check 224 | assert path.exists(task['train_dir']), task['train_dir'] + ' does not exist' 225 | assert path.exists(task['data_dir']), task['data_dir'] + ' does not exist' 226 | all_tasks.append(task) 227 | task = None 228 | # Shuffle the tasks 229 | if not args.eval: 230 | random.shuffle(all_tasks) 231 | 232 | for task in all_tasks: 233 | pqueue.put(task) 234 | 235 | args.gpus = list(map(int, args.gpus.split())) 236 | print('GPUS:', args.gpus) 237 | 238 | for _ in args.gpus: 239 | pqueue.put({}) 240 | 241 | all_procs = [] 242 | for i, gpu in enumerate(args.gpus): 243 | process = Process(target=process_main, args=(gpu, args.eval, args.render, pqueue)) 244 | process.daemon = True 245 | process.start() 246 | all_procs.append(process) 247 | 248 | for i, gpu in enumerate(args.gpus): 249 | all_procs[i].join() 250 | 251 | if args.eval: 252 | print('Done') 253 | with open(leaderboard_path, 'w') as leaderboard_file: 254 | lines = [f'dir\tPSNR\tSSIM\tLPIPS\nminutes\n'] 255 | all_tasks = sorted(all_tasks, key=lambda task:task['train_dir']) 256 | all_psnr = [] 257 | all_ssim = [] 258 | all_lpips = [] 259 | all_times = [] 260 | for task in all_tasks: 261 | train_dir = task['train_dir'] 262 | psnr_file_path = path.join(train_dir, 'test_renders', 'psnr.txt') 263 | ssim_file_path = path.join(train_dir, 'test_renders', 'ssim.txt') 264 | lpips_file_path = path.join(train_dir, 'test_renders', 'lpips.txt') 265 | time_file_path = path.join(train_dir, 'time_mins.txt') 266 | 267 | if path.isfile(psnr_file_path): 268 | with open(psnr_file_path, 'r') as f: 269 | psnr = float(f.read()) 270 | all_psnr.append(psnr) 271 | psnr_txt = f'{psnr:.10f}' 272 | else: 273 | psnr_txt = 'ERR' 274 | if path.isfile(ssim_file_path): 275 | with open(ssim_file_path, 'r') as f: 276 | ssim = float(f.read()) 277 | all_ssim.append(ssim) 278 | ssim_txt = f'{ssim:.10f}' 279 | else: 280 | ssim_txt = 'ERR' 281 | if path.isfile(lpips_file_path): 282 | with open(lpips_file_path, 'r') as f: 283 | lpips = float(f.read()) 284 | all_lpips.append(lpips) 285 | lpips_txt = f'{lpips:.10f}' 286 | else: 287 | lpips_txt = 'ERR' 288 | if path.isfile(time_file_path): 289 | with open(time_file_path, 'r') as f: 290 | time_mins = float(f.read()) 291 | all_times.append(time_mins) 292 | time_txt = f'{time_mins:.10f}' 293 | else: 294 | time_txt = 'ERR' 295 | line = f'{path.basename(train_dir.rstrip("/"))}\t{psnr_txt}\t{ssim_txt}\t{lpips_txt}\t{time_txt}\n' 296 | lines.append(line) 297 | lines.append('---------\n') 298 | if len(all_psnr): 299 | lines.append('Average PSNR: ' + str(sum(all_psnr) / len(all_psnr)) + '\n') 300 | if len(all_ssim): 301 | lines.append('Average SSIM: ' + str(sum(all_ssim) / len(all_ssim)) + '\n') 302 | if len(all_lpips): 303 | lines.append('Average LPIPS: ' + str(sum(all_lpips) / len(all_lpips)) + '\n') 304 | if len(all_times): 305 | lines.append('Average Time (mins): ' + str(sum(all_times) / len(all_times)) + '\n') 306 | leaderboard_file.writelines(lines) 307 | 308 | else: 309 | with open(leaderboard_path, 'w') as leaderboard_file: 310 | exps = [] 311 | for task in all_tasks: 312 | train_dir = task['train_dir'] 313 | psnr_file_path = path.join(train_dir, PSNR_FILE_NAME) 314 | 315 | with open(psnr_file_path, 'r') as f: 316 | test_psnr = float(f.read()) 317 | print(train_dir, test_psnr) 318 | exps.append((test_psnr, train_dir)) 319 | exps = sorted(exps, key = lambda x: -x[0]) 320 | lines = [f'{psnr:.10f}\t{train_dir}\n' for psnr, train_dir in exps] 321 | leaderboard_file.writelines(lines) 322 | print('Wrote', leaderboard_path) 323 | 324 | -------------------------------------------------------------------------------- /Svox2/opt/calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Calculate metrics on saved images 2 | 3 | # Usage: python calc_metrics.py 4 | # Where is ckpt_dir/test_renders 5 | # or jaxnerf test renders dir 6 | 7 | from util.dataset import datasets 8 | from util.util import compute_ssim, viridis_cmap 9 | from util import config_util 10 | from os import path 11 | from glob import glob 12 | import imageio 13 | import math 14 | import argparse 15 | import torch 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('render_dir', type=str) 19 | parser.add_argument('--crop', type=float, default=1.0, help='center crop') 20 | config_util.define_common_args(parser) 21 | args = parser.parse_args() 22 | 23 | if path.isfile(args.render_dir): 24 | print('please give the test_renders directory (not checkpoint) in the future') 25 | args.render_dir = path.join(path.dirname(args.render_dir), 'test_renders') 26 | 27 | device = 'cuda:0' 28 | 29 | import lpips 30 | lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device) 31 | 32 | dset = datasets[args.dataset_type](args.data_dir, split="test", 33 | **config_util.build_data_options(args)) 34 | 35 | 36 | im_files = sorted(glob(path.join(args.render_dir, "*.png"))) 37 | im_files = [x for x in im_files if not path.basename(x).startswith('disp_')] # Remove depths 38 | assert len(im_files) == dset.n_images, \ 39 | f'number of images found {len(im_files)} differs from test set images:{dset.n_images}' 40 | 41 | avg_psnr = 0.0 42 | avg_ssim = 0.0 43 | avg_lpips = 0.0 44 | n_images_gen = 0 45 | for i, im_path in enumerate(im_files): 46 | im = torch.from_numpy(imageio.imread(im_path)) 47 | im_gt = dset.gt[i] 48 | if im.shape[1] >= im_gt.shape[1] * 2: 49 | # Assume we have some gt/baselines on the left 50 | im = im[:, -im_gt.shape[1]:] 51 | im = im.float() / 255 52 | if args.crop != 1.0: 53 | del_tb = int(im.shape[0] * (1.0 - args.crop) * 0.5) 54 | del_lr = int(im.shape[1] * (1.0 - args.crop) * 0.5) 55 | im = im[del_tb:-del_tb, del_lr:-del_lr] 56 | im_gt = im_gt[del_tb:-del_tb, del_lr:-del_lr] 57 | 58 | mse = (im - im_gt) ** 2 59 | mse_num : float = mse.mean().item() 60 | psnr = -10.0 * math.log10(mse_num) 61 | ssim = compute_ssim(im_gt, im).item() 62 | lpips_i = lpips_vgg(im_gt.permute([2, 0, 1]).cuda().contiguous(), 63 | im.permute([2, 0, 1]).cuda().contiguous(), 64 | normalize=True).item() 65 | 66 | print(i, 'of', len(im_files), '; PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips_i) 67 | avg_psnr += psnr 68 | avg_ssim += ssim 69 | avg_lpips += lpips_i 70 | n_images_gen += 1 # Just to be sure 71 | 72 | avg_psnr /= n_images_gen 73 | avg_ssim /= n_images_gen 74 | avg_lpips /= n_images_gen 75 | print('AVERAGES') 76 | print('PSNR:', avg_psnr) 77 | print('SSIM:', avg_ssim) 78 | print('LPIPS:', avg_lpips) 79 | postfix = '_cropped' if args.crop != 1.0 else '' 80 | # with open(path.join(args.render_dir, f'psnr{postfix}.txt'), 'w') as f: 81 | # f.write(str(avg_psnr)) 82 | # with open(path.join(args.render_dir, f'ssim{postfix}.txt'), 'w') as f: 83 | # f.write(str(avg_ssim)) 84 | # with open(path.join(args.render_dir, f'lpips{postfix}.txt'), 'w') as f: 85 | # f.write(str(avg_lpips)) 86 | -------------------------------------------------------------------------------- /Svox2/opt/render_imgs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | # Eval 3 | 4 | import torch 5 | import svox2 6 | import svox2.utils 7 | import math 8 | import argparse 9 | import json 10 | import numpy as np 11 | import os 12 | from os import path 13 | from util.dataset import datasets 14 | from util.util import Timing, compute_ssim, viridis_cmap 15 | from util import config_util 16 | from reflect import SparseRadianceFields 17 | 18 | import imageio 19 | import cv2 20 | from tqdm import tqdm 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('ckpt', type=str) 23 | 24 | config_util.define_common_args(parser) 25 | 26 | parser.add_argument('--n_eval', '-n', type=int, default=100000, help='images to evaluate (equal interval), at most evals every image') 27 | parser.add_argument('--train', action='store_true', default=False, help='render train set') 28 | parser.add_argument('--hard', action='store_true', 29 | default=False, help='render hard set') 30 | 31 | parser.add_argument('--render_path', 32 | action='store_true', 33 | default=False, 34 | help="Render path instead of test images (no metrics will be given)") 35 | parser.add_argument('--timing', 36 | action='store_true', 37 | default=False, 38 | help="Run only for timing (do not save images or use LPIPS/SSIM; " 39 | "still computes PSNR to make sure images are being generated)") 40 | parser.add_argument('--no_lpips', 41 | action='store_true', 42 | default=False, 43 | help="Disable LPIPS (faster load)") 44 | parser.add_argument('--no_vid', 45 | action='store_true', 46 | default=False, 47 | help="Disable video generation") 48 | parser.add_argument('--no_imsave', 49 | action='store_true', 50 | default=False, 51 | help="Disable image saving (can still save video; MUCH faster)") 52 | parser.add_argument('--fps', 53 | type=int, 54 | default=30, 55 | help="FPS of video") 56 | 57 | # Camera adjustment 58 | parser.add_argument('--crop', 59 | type=float, 60 | default=1.0, 61 | help="Crop (0, 1], 1.0 = full image") 62 | parser.add_argument("--density_threshold", type=float, default=-10000.0, 63 | help="smaller Radius of orbit (only if --traj_type == `zoom`)") 64 | # Foreground/background only 65 | parser.add_argument('--nofg', 66 | action='store_true', 67 | default=False, 68 | help="Do not render foreground (if using BG model)") 69 | parser.add_argument('--nobg', 70 | action='store_true', 71 | default=False, 72 | help="Do not render background (if using BG model)") 73 | 74 | # Random debugging features 75 | parser.add_argument('--blackbg', 76 | action='store_true', 77 | default=False, 78 | help="Force a black BG (behind BG model) color; useful for debugging 'clouds'") 79 | parser.add_argument('--ray_len', 80 | action='store_true', 81 | default=False, 82 | help="Render the ray lengths") 83 | 84 | args = parser.parse_args() 85 | config_util.maybe_merge_config_file(args, allow_invalid=True) 86 | device = 'cuda:0' 87 | 88 | if args.timing: 89 | args.no_lpips = True 90 | args.no_vid = True 91 | args.ray_len = False 92 | 93 | if not args.no_lpips: 94 | import lpips 95 | lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device) 96 | config_file_name = os.path.join(args.ckpt, "meta.json") 97 | with open(config_file_name, "r") as config_file: 98 | old_configs = json.load(config_file) 99 | if not path.isfile(args.ckpt): 100 | args.ckpt = path.join(args.ckpt, "data_{}.npz".format(old_configs["rf_variant"])) 101 | 102 | 103 | split_render = 'train_metrics' if args.train else 'test_metrics' 104 | if args.hard: 105 | split_render = "hard_metrics" 106 | render_dir = path.join(path.dirname(args.ckpt), split_render) 107 | want_metrics = True 108 | if args.render_path: 109 | assert not args.train 110 | render_dir += '_path' 111 | want_metrics = False 112 | 113 | # Handle various image transforms 114 | if not args.render_path: 115 | # Do not crop if not render_path 116 | args.crop = 1.0 117 | if args.crop != 1.0: 118 | render_dir += f'_crop{args.crop}' 119 | if args.ray_len: 120 | render_dir += f'_raylen' 121 | want_metrics = False 122 | 123 | split = "test_train" if args.train else "test" 124 | if args.hard: 125 | split = "hard" 126 | dset = datasets[args.dataset_type](args.data_dir, split=split, 127 | **config_util.build_data_options(args)) 128 | 129 | # grid = svox2.SparseGrid.load(args.ckpt, device=device) 130 | config_file_name = os.path.join(path.join(path.dirname(args.ckpt), "meta.json")) 131 | with open(config_file_name, "r") as config_file: 132 | old_configs = json.load(config_file) 133 | partial_alias = path.split(path.dirname(args.ckpt))[1] 134 | vox_resolution = json.loads(old_configs["reso"])[-1][0] 135 | srf = SparseRadianceFields( 136 | vox_res=vox_resolution, sh_dim=old_configs["sh_dim"], partial_alias=partial_alias, normalize="none", dataset_type=args.dataset_type) 137 | coords, feats = srf.load_coords_and_feats(path.join(path.dirname( 138 | args.ckpt), "data_{}.npz".format(old_configs["rf_variant"])), device=device) 139 | if args.dataset_type != "co3d": 140 | grid = srf.construct_grid(args.data_dir, coords, feats) 141 | else: 142 | grid = srf.construct_grid(path.split(path.split(path.split( 143 | path.split(config_file_name)[0])[0])[0])[0], coords, feats) 144 | 145 | 146 | 147 | ############################################################## 148 | # coords, feats = coords_and_feats_from_grid(grid) 149 | # density_threshold = args.density_threshold # -10_000.0 # 0.0 150 | # # print(coords.shape[0]) 151 | # s_alias = "p" if density_threshold >= 0 else "n" 152 | # rf_alias = s_alias + str(int(abs(density_threshold))) 153 | # render_dir += rf_alias 154 | # coords, feats = prune_sparse_voxels(coords, feats, density_threshold) 155 | # save_coords_and_feats(path.join(path.dirname(args.ckpt), 156 | # "data_{}.npz".format(rf_alias)), coords, feats) 157 | # # extract_mesh_from_sparse_voxels(coords, feats[:, 0], path.join(path.dirname(args.ckpt), "rf_{}_mesh.obj".format(rf_alias)), vox_res=512, 158 | # # smooth=False, level_set=density_threshold, clean=False) 159 | # coords, feats = load_coords_and_feats(path.join(path.dirname( 160 | # args.ckpt), "data_{}.npz".format(rf_alias)), device=device) 161 | # # print(coords.shape[0]) 162 | # grid = construct_grid(os.path.split(path.dirname(args.ckpt))[0], coords, feats, 163 | # resolution=511, denormalize=False, device=device, sh_dim=1) 164 | 165 | ############################################################## 166 | 167 | # print(grid.use_background,grid.basis_type,grid.sh_data.shape,grid.density_data.shape,grid.capacity,(torch.unique(grid.links)).max()) 168 | # raise Exception("STOP HERE ") 169 | if grid.use_background: 170 | if args.nobg: 171 | # grid.background_cubemap.data = grid.background_cubemap.data.cuda() 172 | grid.background_data.data[..., -1] = 0.0 173 | render_dir += '_nobg' 174 | if args.nofg: 175 | grid.density_data.data[:] = 0.0 176 | # grid.sh_data.data[..., 0] = 1.0 / svox2.utils.SH_C0 177 | # grid.sh_data.data[..., 9] = 1.0 / svox2.utils.SH_C0 178 | # grid.sh_data.data[..., 18] = 1.0 / svox2.utils.SH_C0 179 | render_dir += '_nofg' 180 | 181 | # DEBUG 182 | # grid.links.data[grid.links.size(0)//2:] = -1 183 | # render_dir += "_chopx2" 184 | 185 | config_util.setup_render_opts(grid.opt, args) 186 | 187 | if args.blackbg: 188 | print('Forcing black bg') 189 | render_dir += '_blackbg' 190 | grid.opt.background_brightness = 0.0 191 | 192 | print('Writing to', render_dir) 193 | os.makedirs(render_dir, exist_ok=True) 194 | 195 | if not args.no_imsave: 196 | print('Will write out all frames as PNG (this take most of the time)') 197 | 198 | # NOTE: no_grad enables the fast image-level rendering kernel for cuvol backend only 199 | # other backends will manually generate rays per frame (slow) 200 | with torch.no_grad(): 201 | n_images = dset.render_c2w.size(0) if args.render_path else dset.n_images 202 | img_eval_interval = max(n_images // args.n_eval, 1) 203 | avg_psnr = 0.0 204 | avg_ssim = 0.0 205 | avg_lpips = 0.0 206 | n_images_gen = 0 207 | c2ws = dset.render_c2w.to(device=device) if args.render_path else dset.c2w.to(device=device) 208 | # DEBUGGING 209 | # rad = [1.496031746031746, 1.6613756613756614, 1.0] 210 | # half_sz = [grid.links.size(0) // 2, grid.links.size(1) // 2] 211 | # pad_size_x = int(half_sz[0] - half_sz[0] / 1.496031746031746) 212 | # pad_size_y = int(half_sz[1] - half_sz[1] / 1.6613756613756614) 213 | # print(pad_size_x, pad_size_y) 214 | # grid.links[:pad_size_x] = -1 215 | # grid.links[-pad_size_x:] = -1 216 | # grid.links[:, :pad_size_y] = -1 217 | # grid.links[:, -pad_size_y:] = -1 218 | # grid.links[:, :, -8:] = -1 219 | 220 | # LAYER = -16 221 | # grid.links[:, :, :LAYER] = -1 222 | # grid.links[:, :, LAYER+1:] = -1 223 | 224 | frames = [] 225 | # im_gt_all = dset.gt.to(device=device) 226 | 227 | for img_id in tqdm(range(0, n_images, img_eval_interval)): 228 | dset_h, dset_w = dset.get_image_size(img_id) 229 | im_size = dset_h * dset_w 230 | w = dset_w if args.crop == 1.0 else int(dset_w * args.crop) 231 | h = dset_h if args.crop == 1.0 else int(dset_h * args.crop) 232 | 233 | cam = svox2.Camera(c2ws[img_id], 234 | dset.intrins.get('fx', img_id), 235 | dset.intrins.get('fy', img_id), 236 | dset.intrins.get('cx', img_id) + (w - dset_w) * 0.5, 237 | dset.intrins.get('cy', img_id) + (h - dset_h) * 0.5, 238 | w, h, 239 | ndc_coeffs=dset.ndc_coeffs) 240 | im = grid.volume_render_image(cam, use_kernel=True, return_raylen=args.ray_len) 241 | if args.ray_len: 242 | minv, meanv, maxv = im.min().item(), im.mean().item(), im.max().item() 243 | im = viridis_cmap(im.cpu().numpy()) 244 | cv2.putText(im, "{:.4f} {:.4f} {:.4f}".format(minv,meanv,maxv), (10, 20), 245 | 0, 0.5, [255, 0, 0]) 246 | im = torch.from_numpy(im).to(device=device) 247 | im.clamp_(0.0, 1.0) 248 | 249 | if not args.render_path: 250 | im_gt = dset.gt[img_id].to(device=device) 251 | mse = (im - im_gt) ** 2 252 | mse_num : float = mse.mean().item() 253 | psnr = -10.0 * math.log10(mse_num) 254 | avg_psnr += psnr 255 | if not args.timing: 256 | ssim = compute_ssim(im_gt, im).item() 257 | avg_ssim += ssim 258 | if not args.no_lpips: 259 | lpips_i = lpips_vgg(im_gt.permute([2, 0, 1]).contiguous(), 260 | im.permute([2, 0, 1]).contiguous(), normalize=True).item() 261 | avg_lpips += lpips_i 262 | print(img_id, 'PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips_i) 263 | else: 264 | print(img_id, 'PSNR', psnr, 'SSIM', ssim) 265 | img_path = path.join(render_dir, f'{img_id:04d}.png'); 266 | im = im.cpu().numpy() 267 | if not args.render_path: 268 | im_gt = dset.gt[img_id].numpy() 269 | im = np.concatenate([im_gt, im], axis=1) 270 | if not args.timing: 271 | im = (im * 255).astype(np.uint8) 272 | if not args.no_imsave: 273 | imageio.imwrite(img_path,im) 274 | if not args.no_vid: 275 | frames.append(im) 276 | im = None 277 | n_images_gen += 1 278 | if want_metrics: 279 | print('AVERAGES') 280 | 281 | avg_psnr /= n_images_gen 282 | with open(path.join(render_dir, 'psnr.txt'), 'w') as f: 283 | f.write(str(avg_psnr)) 284 | print('PSNR:', avg_psnr) 285 | if not args.timing: 286 | avg_ssim /= n_images_gen 287 | print('SSIM:', avg_ssim) 288 | with open(path.join(render_dir, 'ssim.txt'), 'w') as f: 289 | f.write(str(avg_ssim)) 290 | if not args.no_lpips: 291 | avg_lpips /= n_images_gen 292 | print('LPIPS:', avg_lpips) 293 | with open(path.join(render_dir, 'lpips.txt'), 'w') as f: 294 | f.write(str(avg_lpips)) 295 | if not args.no_vid and len(frames): 296 | vid_path = render_dir + '.mp4' 297 | imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) # pip install imageio-ffmpeg 298 | 299 | 300 | -------------------------------------------------------------------------------- /Svox2/opt/render_imgs_circle.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | # Render 360 circle path 3 | 4 | import torch 5 | import svox2 6 | import json 7 | import svox2.utils 8 | import math 9 | import argparse 10 | import numpy as np 11 | import os 12 | from os import path 13 | from util.dataset import datasets 14 | from util.util import Timing, compute_ssim, viridis_cmap, pose_spherical 15 | from util import config_util 16 | 17 | import imageio 18 | import cv2 19 | from tqdm import tqdm 20 | from reflect import SparseRadianceFields 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('ckpt', type=str) 24 | 25 | config_util.define_common_args(parser) 26 | 27 | parser.add_argument('--n_eval', '-n', type=int, default=100000, help='images to evaluate (equal interval), at most evals every image') 28 | # parser.add_argument('--rf_variant', type=int, default=0, 29 | # help='the variant of the rf used in rendering') 30 | parser.add_argument('--traj_type', 31 | choices=['spiral', 'circle',"zoom","vertical"], 32 | default='spiral', 33 | help="Render a spiral (doubles length, using 2 elevations), or just a cirle") 34 | parser.add_argument('--fps', 35 | type=int, 36 | default=30, 37 | help="FPS of video") 38 | parser.add_argument( 39 | "--width", "-W", type=float, default=None, help="Rendering image width (only if not --traj)" 40 | ) 41 | parser.add_argument( 42 | "--height", "-H", type=float, default=None, help="Rendering image height (only if not --traj)" 43 | ) 44 | parser.add_argument( 45 | "--num_views", "-N", type=int, default=300, 46 | help="Number of frames to render" 47 | ) 48 | 49 | # Path adjustment 50 | parser.add_argument( 51 | "--offset", type=str, default="0,0,0", help="Center point to rotate around (only if not --traj)" 52 | ) 53 | parser.add_argument("--radius", type=float, default=0.85, help="Radius of orbit (only if not --traj)") 54 | parser.add_argument("--closeup_factor", type=float, default=0.5, 55 | help="smaller Radius of orbit (only if --traj_type == `zoom`)") 56 | parser.add_argument("--density_threshold", type=float, default=-10000.0, 57 | help="smaller Radius of orbit (only if --traj_type == `zoom`)") 58 | parser.add_argument( 59 | "--elevation", 60 | type=float, 61 | default=-45.0, 62 | help="Elevation of orbit in deg, negative is above", 63 | ) 64 | parser.add_argument( 65 | "--elevation2", 66 | type=float, 67 | default=20.0, 68 | help="Max elevation, only for spiral", 69 | ) 70 | parser.add_argument( 71 | "--vec_up", 72 | type=str, 73 | default=None, 74 | help="up axis for camera views (only if not --traj);" 75 | "3 floats separated by ','; if not given automatically determined", 76 | ) 77 | parser.add_argument( 78 | "--vert_shift", 79 | type=float, 80 | default=0.0, 81 | help="vertical shift by up axis" 82 | ) 83 | 84 | # Camera adjustment 85 | parser.add_argument('--crop', 86 | type=float, 87 | default=1.0, 88 | help="Crop (0, 1], 1.0 = full image") 89 | 90 | # Foreground/background only 91 | parser.add_argument('--nofg', 92 | action='store_true', 93 | default=False, 94 | help="Do not render foreground (if using BG model)") 95 | parser.add_argument('--nobg', 96 | action='store_true', 97 | default=False, 98 | help="Do not render background (if using BG model)") 99 | 100 | # Random debugging features 101 | parser.add_argument('--blackbg', 102 | action='store_true', 103 | default=False, 104 | help="Force a black BG (behind BG model) color; useful for debugging 'clouds'") 105 | 106 | args = parser.parse_args() 107 | config_util.maybe_merge_config_file(args, allow_invalid=True) 108 | device = 'cuda:0' 109 | 110 | 111 | dset = datasets[args.dataset_type](args.data_dir, split="test", 112 | **config_util.build_data_options(args)) 113 | 114 | if args.vec_up is None: 115 | up_rot = dset.c2w[:, :3, :3].cpu().numpy() 116 | ups = np.matmul(up_rot, np.array([0, -1.0, 0])[None, :, None])[..., 0] 117 | args.vec_up = np.mean(ups, axis=0) 118 | args.vec_up /= np.linalg.norm(args.vec_up) 119 | print(' Auto vec_up', args.vec_up) 120 | else: 121 | args.vec_up = np.array(list(map(float, args.vec_up.split(",")))) 122 | 123 | 124 | args.offset = np.array(list(map(float, args.offset.split(",")))) 125 | if args.traj_type == 'spiral': 126 | angles = np.linspace(-180, 180, args.num_views + 1)[:-1] 127 | elevations = np.linspace(args.elevation, args.elevation2, args.num_views) 128 | c2ws = [ 129 | pose_spherical( 130 | angle, 131 | ele, 132 | args.radius, 133 | args.offset, 134 | vec_up=args.vec_up, 135 | ) 136 | for ele, angle in zip(elevations, angles) 137 | ] 138 | c2ws += [ 139 | pose_spherical( 140 | angle, 141 | ele, 142 | args.radius, 143 | args.offset, 144 | vec_up=args.vec_up, 145 | ) 146 | for ele, angle in zip(reversed(elevations), angles) 147 | ] 148 | elif args.traj_type == 'zoom': 149 | angles = np.linspace(-180, 180, args.num_views + 1)[:-1] 150 | elevations = np.linspace(args.elevation, args.elevation2, args.num_views) 151 | distances = np.linspace(args.radius, args.radius * args.closeup_factor, args.num_views) 152 | c2ws = [ 153 | pose_spherical( 154 | angle, 155 | ele, 156 | dist, 157 | args.offset, 158 | vec_up=args.vec_up, 159 | ) 160 | for ele, angle, dist in zip(elevations, angles, distances) 161 | ] 162 | c2ws += [ 163 | pose_spherical( 164 | angle, 165 | ele, 166 | dist, 167 | args.offset, 168 | vec_up=args.vec_up, 169 | ) 170 | for ele, angle, dist in zip(reversed(elevations), angles, reversed(distances)) 171 | ] 172 | 173 | elif args.traj_type == 'circle': 174 | c2ws = [ 175 | pose_spherical( 176 | angle, 177 | args.elevation, 178 | args.radius, 179 | args.offset, 180 | vec_up=args.vec_up, 181 | ) 182 | for angle in np.linspace(-180, 180, args.num_views + 1)[:-1] 183 | ] 184 | elif args.traj_type == 'vertical': 185 | c2ws = [ 186 | pose_spherical( 187 | 0, 188 | angle, 189 | args.radius, 190 | args.offset, 191 | vec_up=args.vec_up, 192 | ) 193 | for angle in np.linspace(-90, 90, args.num_views + 1)[:-1] 194 | ] 195 | c2ws = np.stack(c2ws, axis=0) 196 | if args.vert_shift != 0.0: 197 | c2ws[:, :3, 3] += np.array(args.vec_up) * args.vert_shift 198 | c2ws = torch.from_numpy(c2ws).to(device=device) 199 | 200 | config_file_name = os.path.join(args.ckpt, "meta.json") 201 | with open(config_file_name, "r") as config_file: 202 | old_configs = json.load(config_file) 203 | if not path.isfile(args.ckpt): 204 | args.ckpt = path.join(args.ckpt, "data_{}.npz".format(old_configs["rf_variant"])) 205 | 206 | render_out_path = path.join(path.dirname(args.ckpt), "{}_renders".format(args.traj_type)) 207 | 208 | # Handle various image transforms 209 | if args.crop != 1.0: 210 | render_out_path += f'_crop{args.crop}' 211 | if args.vert_shift != 0.0: 212 | render_out_path += f'_vshift{args.vert_shift}' 213 | 214 | # grid = svox2.SparseGrid.load(args.ckpt, device=device) 215 | # print(grid.center, grid.radius) 216 | partial_alias = path.split(path.dirname(args.ckpt))[1] 217 | vox_resolution = json.loads(old_configs["reso"])[-1][0] 218 | srf = SparseRadianceFields(vox_res=vox_resolution, sh_dim=old_configs["sh_dim"], partial_alias=partial_alias, normalize="none", dataset_type=args.dataset_type) 219 | coords, feats = srf.load_coords_and_feats(path.join(path.dirname(args.ckpt), "data_{}.npz".format(old_configs["rf_variant"])), device=device) 220 | if args.dataset_type != "co3d": 221 | grid = srf.construct_grid(args.data_dir, coords, feats) 222 | else: 223 | grid = srf.construct_grid(path.split(path.split(path.split(path.split(config_file_name)[0])[0])[0])[0], coords, feats) 224 | 225 | 226 | ############################################################## 227 | # coords, feats = coords_and_feats_from_grid(grid) 228 | # density_threshold = args.density_threshold # -10_000.0 # 0.0 229 | # # print(coords.shape[0]) 230 | # s_alias = "p" if density_threshold >=0 else "n" 231 | # rf_alias = s_alias + str(int(abs(density_threshold))) 232 | # render_out_path += rf_alias 233 | # coords, feats = prune_sparse_voxels(coords, feats, density_threshold) 234 | # save_coords_and_feats(path.join(path.dirname(args.ckpt), 235 | # "data_{}.npz".format(rf_alias)), coords, feats) 236 | # # extract_mesh_from_sparse_voxels(coords, feats[:, 0], path.join(path.dirname(args.ckpt), "rf_{}_mesh.obj".format(rf_alias)), vox_res=512, 237 | # # smooth=False, level_set=density_threshold, clean=False) 238 | # coords, feats = load_coords_and_feats(path.join(path.dirname( 239 | # args.ckpt), "data_{}.npz".format(rf_alias)), device=device) 240 | # # print(coords.shape[0]) 241 | # grid = construct_grid(os.path.split(path.dirname(args.ckpt))[0], coords, feats, 242 | # resolution=511, denormalize=False, device=device,sh_dim=1) 243 | 244 | 245 | ####################################################### 246 | # DEBUG 247 | # grid.background_data.data[:, 32:, -1] = 0.0 248 | # render_out_path += '_front' 249 | 250 | if grid.use_background: 251 | if args.nobg: 252 | grid.background_data.data[..., -1] = 0.0 253 | render_out_path += '_nobg' 254 | if args.nofg: 255 | grid.density_data.data[:] = 0.0 256 | # grid.sh_data.data[..., 0] = 1.0 / svox2.utils.SH_C0 257 | # grid.sh_data.data[..., 9] = 1.0 / svox2.utils.SH_C0 258 | # grid.sh_data.data[..., 18] = 1.0 / svox2.utils.SH_C0 259 | render_out_path += '_nofg' 260 | 261 | # # DEBUG 262 | # grid.background_data.data[..., -1] = 100.0 263 | # a1 = torch.linspace(0, 1, grid.background_data.size(0) // 2, dtype=torch.float32, device=device)[:, None] 264 | # a2 = torch.linspace(1, 0, (grid.background_data.size(0) - 1) // 2 + 1, dtype=torch.float32, device=device)[:, None] 265 | # a = torch.cat([a1, a2], dim=0) 266 | # c = torch.stack([a, 1-a, torch.zeros_like(a)], dim=-1) 267 | # grid.background_data.data[..., :-1] = c 268 | # render_out_path += "_gradient" 269 | 270 | config_util.setup_render_opts(grid.opt, args) 271 | 272 | if args.blackbg: 273 | print('Forcing black bg') 274 | render_out_path += '_blackbg' 275 | grid.opt.background_brightness = 0.0 276 | 277 | render_out_path += '.mp4' 278 | print('Writing to', render_out_path) 279 | 280 | # NOTE: no_grad enables the fast image-level rendering kernel for cuvol backend only 281 | # other backends will manually generate rays per frame (slow) 282 | with torch.no_grad(): 283 | n_images = c2ws.size(0) 284 | img_eval_interval = max(n_images // args.n_eval, 1) 285 | avg_psnr = 0.0 286 | avg_ssim = 0.0 287 | avg_lpips = 0.0 288 | n_images_gen = 0 289 | frames = [] 290 | # if args.near_clip >= 0.0: 291 | grid.opt.near_clip = 0.0 #args.near_clip 292 | if args.width is None: 293 | args.width = dset.get_image_size(0)[1] 294 | if args.height is None: 295 | args.height = dset.get_image_size(0)[0] 296 | 297 | for img_id in tqdm(range(0, n_images, img_eval_interval)): 298 | dset_h, dset_w = args.height, args.width 299 | im_size = dset_h * dset_w 300 | w = dset_w if args.crop == 1.0 else int(dset_w * args.crop) 301 | h = dset_h if args.crop == 1.0 else int(dset_h * args.crop) 302 | 303 | cam = svox2.Camera(c2ws[img_id], 304 | dset.intrins.get('fx', 0), 305 | dset.intrins.get('fy', 0), 306 | w * 0.5, 307 | h * 0.5, 308 | w, h, 309 | ndc_coeffs=(-1.0, -1.0)) 310 | torch.cuda.synchronize() 311 | im = grid.volume_render_image(cam, use_kernel=True) 312 | torch.cuda.synchronize() 313 | im.clamp_(0.0, 1.0) 314 | 315 | im = im.cpu().numpy() 316 | im = (im * 255).astype(np.uint8) 317 | frames.append(im) 318 | im = None 319 | n_images_gen += 1 320 | if len(frames): 321 | vid_path = render_out_path 322 | imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) # pip install imageio-ffmpeg 323 | 324 | 325 | -------------------------------------------------------------------------------- /Svox2/opt/scripts/create_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | Splits dataset using NSVF conventions. 3 | Every eighth image is used as a test image (1_ prefix) and other images are train (0_ prefix) 4 | 5 | Usage: 6 | python create_split.py 7 | data_set_root should contain directories like images/, pose/ 8 | """ 9 | # Copyright 2021 Alex Yu 10 | import os 11 | import os.path as osp 12 | from typing import NamedTuple, List 13 | import argparse 14 | import random 15 | 16 | parser = argparse.ArgumentParser("Automatic dataset splitting") 17 | parser.add_argument('root_dir', type=str, help="COLMAP dataset root dir") 18 | parser.add_argument('--every', type=int, default=16, help="Every x images used for testing") 19 | parser.add_argument('--dry_run', action='store_true', help="Dry run, prints renames without modifying any files") 20 | parser.add_argument('--yes', '-y', action='store_true', help="Answer yes") 21 | parser.add_argument('--random', action='store_true', help="If set, chooses the split randomly rather than at a fixed interval " 22 | "(but number of images in train/test set is same)") 23 | args = parser.parse_args() 24 | 25 | class Dir(NamedTuple): 26 | name: str 27 | valid_exts: List[str] 28 | 29 | def list_filter_dirs(base): 30 | all_dirs = [x for x in os.listdir(base) if osp.isdir(osp.join(base, x))] 31 | image_exts = [".png", ".jpg", ".jpeg", ".gif", ".tif", ".tiff", ".bmp"] 32 | depth_exts = [".exr", ".pfm", ".png", ".npy"] 33 | dirs_prefixes = [Dir(name="pose", valid_exts=[".txt"]), 34 | Dir(name="poses", valid_exts=[".txt"]), 35 | Dir(name="feature", valid_exts=[".npz"]), 36 | Dir(name="rgb", valid_exts=image_exts), 37 | Dir(name="images", valid_exts=image_exts), 38 | Dir(name="image", valid_exts=image_exts), 39 | Dir(name="c2w", valid_exts=image_exts), 40 | Dir(name="depths", valid_exts=depth_exts)] 41 | dirs = [] 42 | dir_idx = 0 43 | for pfx in dirs_prefixes: 44 | for d in all_dirs: 45 | if d.startswith(pfx.name): 46 | if d == "pose": 47 | dir_idx = len(dirs) 48 | dirs.append(Dir(name=osp.join(base, d), valid_exts=pfx.valid_exts)) 49 | return dirs, dir_idx 50 | 51 | dirs, dir_idx = list_filter_dirs(args.root_dir) 52 | 53 | refdir = dirs[dir_idx] 54 | print("going to split", [x.name for x in dirs], "reference", refdir.name) 55 | do_proceed = args.dry_run or args.yes 56 | if not do_proceed: 57 | import click 58 | do_proceed = click.confirm("Continue?", default=True) 59 | if do_proceed: 60 | filedata = {} 61 | base_files = [osp.splitext(x)[0] for x in sorted(os.listdir(refdir.name)) 62 | if osp.splitext(x)[1].lower() in refdir.valid_exts] 63 | if args.random: 64 | print('random enabled') 65 | random.shuffle(base_files) 66 | base_files_map = {x: f"{int(i % args.every == 0)}_" + x for i, x in enumerate(base_files)} 67 | 68 | for dir_obj in dirs: 69 | dirname = dir_obj.name 70 | files = sorted(os.listdir(dirname)) 71 | for filename in files: 72 | full_filename = osp.join(dirname, filename) 73 | if filename.startswith("0_") or filename.startswith("1_"): 74 | continue 75 | if not osp.isfile(full_filename): 76 | continue 77 | base_file, ext = osp.splitext(filename) 78 | if ext.lower() not in dir_obj.valid_exts: 79 | print('SKIP ', full_filename, ' Since it has an unsupported extension') 80 | continue 81 | if base_file not in base_files_map: 82 | print('SKIP ', full_filename, ' Since it does not match any reference file') 83 | continue 84 | new_base_file = base_files_map[base_file] 85 | new_full_filename = osp.join(dirname, new_base_file + ext) 86 | print('rename', full_filename, 'to', new_full_filename) 87 | if not args.dry_run: 88 | os.rename(full_filename, new_full_filename) 89 | if args.dry_run: 90 | print('(dry run complete)') 91 | else: 92 | print('use unsplit.py to undo this operation') 93 | -------------------------------------------------------------------------------- /Svox2/opt/scripts/proc_record3d.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | from os import path 5 | import glob 6 | import numpy as np 7 | import cv2 8 | from tqdm import tqdm 9 | from scipy.spatial.transform import Rotation 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('data_dir', type=str) 13 | parser.add_argument('--every', type=int, default=15) 14 | parser.add_argument('--factor', type=int, default=2, help='downsample') 15 | args = parser.parse_args() 16 | 17 | video_file = glob.glob(args.data_dir + '/*.mp4')[0] 18 | print('Video file:', video_file) 19 | json_meta = path.join(args.data_dir, 'metadata.json') 20 | meta = json.load(open(json_meta, 'r')) 21 | 22 | K_3 = np.array(meta['K']).reshape(3, 3) 23 | K = np.eye(4) 24 | K[:3, :3] = K_3.T / args.factor 25 | output_intrin_file = path.join(args.data_dir, 'intrinsics.txt') 26 | np.savetxt(output_intrin_file, K) 27 | 28 | poses = np.array(meta['poses']) 29 | 30 | t = poses[:, 4:] 31 | q = poses[:, :4] 32 | R = Rotation.from_quat(q).as_matrix() 33 | 34 | # Recenter the poses 35 | center = np.mean(t, axis=0) 36 | print('Scene center', center) 37 | t -= center 38 | 39 | all_poses = np.zeros((q.shape[0], 4, 4)) 40 | all_poses[:, -1, -1] = 1 41 | 42 | Rt = np.concatenate([R, t[:, :, None]], axis=2) 43 | all_poses[:, :3] = Rt 44 | all_poses = all_poses @ np.diag([1, -1, -1, 1]) 45 | video = cv2.VideoCapture(str(video_file)) 46 | print(Rt.shape) 47 | 48 | fps = video.get(cv2.CAP_PROP_FPS) 49 | img_wh = ori_w, ori_h = ( 50 | int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2, 51 | int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)), 52 | ) 53 | 54 | print('image size', img_wh) 55 | pose_dir = path.join(args.data_dir, 'pose') 56 | os.makedirs(pose_dir, exist_ok=True) 57 | 58 | image_dir = path.join(args.data_dir, 'rgb') 59 | os.makedirs(image_dir, exist_ok=True) 60 | video_length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 61 | print('length', video_length) 62 | 63 | idx = 0 64 | for i in tqdm(range(0, video_length, args.every)): 65 | video.set(cv2.CAP_PROP_POS_FRAMES, i) 66 | ret, frame = video.read() 67 | if not ret or frame is None: 68 | print('skip', i) 69 | continue 70 | assert frame.shape[1] == img_wh[0] * 2 71 | assert frame.shape[0] == img_wh[1] 72 | frame = frame[:, img_wh[0]:] 73 | image_path = path.join(image_dir, f"{idx:05d}.png") 74 | pose_path = path.join(pose_dir, f"{idx:05d}.txt") 75 | 76 | if args.factor != 1: 77 | frame = cv2.resize(frame, (img_wh[0] // args.factor, img_wh[1] // args.factor), cv2.INTER_AREA) 78 | 79 | cv2.imwrite(image_path, frame) 80 | np.savetxt(pose_path, all_poses[i]) 81 | idx += 1 82 | -------------------------------------------------------------------------------- /Svox2/opt/scripts/unsplit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inverse of create_split.py 3 | """ 4 | # Copyright 2021 Alex Yu 5 | import os 6 | import os.path as osp 7 | import click 8 | from typing import NamedTuple, List 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('root_dir', type=str, help="COLMAP dataset root dir") 13 | parser.add_argument('--dry_run', action='store_true', help="Dry run, prints renames without modifying any files") 14 | parser.add_argument('--yes', '-y', action='store_true', help="Answer yes") 15 | args = parser.parse_args() 16 | 17 | class Dir(NamedTuple): 18 | name: str 19 | valid_exts: List[str] 20 | 21 | def list_filter_dirs(base): 22 | all_dirs = [x for x in os.listdir(base) if osp.isdir(osp.join(base, x))] 23 | image_exts = [".png", ".jpg", ".jpeg", ".gif", ".tif", ".tiff", ".bmp"] 24 | depth_exts = [".exr", ".pfm", ".png", ".npy"] 25 | dirs_prefixes = [Dir(name="pose", valid_exts=[".txt"]), 26 | Dir(name="feature", valid_exts=[".npz"]), 27 | Dir(name="rgb", valid_exts=image_exts), 28 | Dir(name="images", valid_exts=image_exts), 29 | Dir(name="depths", valid_exts=depth_exts)] 30 | dirs = [] 31 | dir_idx = 0 32 | for pfx in dirs_prefixes: 33 | for d in all_dirs: 34 | if d.startswith(pfx.name): 35 | if d == "pose": 36 | dir_idx = len(dirs) 37 | dirs.append(Dir(name=osp.join(base, d), valid_exts=pfx.valid_exts)) 38 | return dirs, dir_idx 39 | 40 | dirs, dir_idx = list_filter_dirs(args.root_dir) 41 | 42 | refdir = dirs[dir_idx] 43 | print("going to unsplit", [x.name for x in dirs], "reference", dirs[dir_idx].name) 44 | do_proceed = args.dry_run or args.yes 45 | if not do_proceed: 46 | import click 47 | do_proceed = click.confirm("Continue?", default=True) 48 | if do_proceed: 49 | filedata = {} 50 | base_files = [osp.splitext(x)[0] for x in sorted(os.listdir(refdir.name)) 51 | if osp.splitext(x)[1] in refdir.valid_exts and 52 | (x.startswith('0_') or x.startswith('1_'))] 53 | base_files_map = {x: '_'.join(x.split('_')[1:]) for x in base_files} 54 | 55 | for dir_obj in dirs: 56 | dirname = dir_obj.name 57 | files = sorted(os.listdir(dirname)) 58 | for filename in files: 59 | full_filename = osp.join(dirname, filename) 60 | if not osp.isfile(full_filename): 61 | continue 62 | base_file, ext = osp.splitext(filename) 63 | if ext.lower() not in dir_obj.valid_exts: 64 | print('SKIP ', full_filename, ' Since it has an unsupported extension') 65 | continue 66 | if base_file not in base_files_map: 67 | print('SKIP ', full_filename, ' Since it does not match any reference file') 68 | continue 69 | new_base_file = base_files_map[base_file] 70 | new_full_filename = osp.join(dirname, new_base_file + ext) 71 | print('rename', full_filename, 'to', new_full_filename) 72 | if not args.dry_run: 73 | os.rename(full_filename, new_full_filename) 74 | if args.dry_run: 75 | print('(dry run complete)') 76 | else: 77 | print('use create_split.py to split again') 78 | -------------------------------------------------------------------------------- /Svox2/opt/to_svox1.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import svox 3 | import math 4 | import argparse 5 | from os import path 6 | from tqdm import tqdm 7 | import torch 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('ckpt', type=str) 11 | args = parser.parse_args() 12 | 13 | grid = svox2.SparseGrid.load(args.ckpt) 14 | t = grid.to_svox1() 15 | print(t) 16 | 17 | out_path = path.splitext(args.ckpt)[0] + '_svox1.npz' 18 | print('Saving', out_path) 19 | t.save(out_path) 20 | -------------------------------------------------------------------------------- /Svox2/opt/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajhamdi/sparf_pytorch/2083419e6afa7d171ada02e7e1e41a8ab8613f7e/Svox2/opt/util/__init__.py -------------------------------------------------------------------------------- /Svox2/opt/util/co3d_dataset.py: -------------------------------------------------------------------------------- 1 | # CO3D dataset loader 2 | # https://github.com/facebookresearch/co3d/ 3 | # 4 | # Adapted from basenerf 5 | # Copyright 2021 Alex Yu 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import os 10 | import cv2 11 | from tqdm import tqdm 12 | from os import path 13 | import json 14 | import gzip 15 | 16 | from scipy.spatial.transform import Rotation 17 | from typing import NamedTuple, Optional, List, Union 18 | from .util import Rays, Intrin, similarity_from_cameras 19 | from .dataset_base import DatasetBase 20 | 21 | 22 | class CO3DDataset(DatasetBase): 23 | """ 24 | CO3D Dataset 25 | Preloads all images for an object. 26 | Will create a data index on first load, to make later loads faster. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root, 32 | split, 33 | seq_id: Optional[int] = None, 34 | epoch_size: Optional[int] = None, 35 | permutation: bool = True, 36 | device: Union[str, torch.device] = "cpu", 37 | max_image_dim: int = 800, 38 | max_pose_dist: float = 5.0, 39 | cam_scale_factor: float = 0.95, 40 | hold_every=8, 41 | **kwargs, 42 | ): 43 | """ 44 | :param root: str dataset root directory 45 | :param device: data prefetch device 46 | """ 47 | super().__init__() 48 | os.makedirs('co3d_tmp', exist_ok=True) 49 | index_file = path.join('co3d_tmp', 'co3d_index.npz') 50 | self.split = split 51 | self.permutation = permutation 52 | self.data_dir = root 53 | self.epoch_size = epoch_size 54 | self.max_image_dim = max_image_dim 55 | self.max_pose_dist = max_pose_dist 56 | self.cam_scale_factor = cam_scale_factor 57 | 58 | self.cats = sorted([x for x in os.listdir(root) if path.isdir( 59 | path.join(root, x))]) 60 | self.gt = [] 61 | self.n_images = 0 62 | self.curr_offset = 0 63 | self.next_offset = 0 64 | self.hold_every = hold_every 65 | self.curr_seq_cat = self.curr_seq_name = '' 66 | self.device = device 67 | if path.exists(index_file): 68 | print(' Using cached CO3D index', index_file) 69 | z = np.load(index_file) 70 | self.seq_cats = z.f.seq_cats 71 | self.seq_names = z.f.seq_names 72 | self.seq_offsets = z.f.seq_offsets 73 | self.all_image_size = z.f.image_size # NOTE: w, h 74 | self.image_path = z.f.image_path 75 | self.image_pose = z.f.pose 76 | self.fxy = z.f.fxy 77 | self.cxy = z.f.cxy 78 | else: 79 | print(' Constructing CO3D index (1st run only), this may take a while') 80 | cam_trans = np.diag(np.array([-1, -1, 1, 1], dtype=np.float32)) 81 | frame_data_by_seq = {} 82 | self.seq_cats = [] 83 | self.seq_names = [] 84 | self.seq_offsets = [] 85 | self.image_path = [] 86 | self.all_image_size = [] 87 | self.image_pose = [] 88 | self.fxy = [] 89 | self.cxy = [] 90 | for i, cat in enumerate(self.cats): 91 | print(cat, '- category', i + 1, 'of', len(self.cats)) 92 | cat_dir = path.join(root, cat) 93 | if not path.isdir(cat_dir): 94 | continue 95 | frame_data_path = path.join(cat_dir, 'frame_annotations.jgz') 96 | with gzip.open(frame_data_path, 'r') as f: 97 | all_frames_data = json.load(f) 98 | for frame_data in tqdm(all_frames_data): 99 | seq_name = cat + '//' + frame_data['sequence_name'] 100 | # frame_number = frame_data['frame_number'] 101 | if seq_name not in frame_data_by_seq: 102 | frame_data_by_seq[seq_name] = [] 103 | pose = np.zeros((4, 4)) 104 | image_size_hw = frame_data['image']['size'] # H, W 105 | H, W = image_size_hw 106 | half_wh = np.array([W * 0.5, H * 0.5], dtype=np.float32) 107 | R = np.array(frame_data['viewpoint']['R']) 108 | T = np.array(frame_data['viewpoint']['T']) 109 | fxy = np.array(frame_data['viewpoint']['focal_length']) 110 | cxy = np.array(frame_data['viewpoint']['principal_point']) 111 | focal = fxy * half_wh 112 | prp = -1.0 * (cxy - 1.0) * half_wh 113 | pose[:3, :3] = R 114 | pose[:3, 3:] = -R @ T[..., None] 115 | pose[3, 3] = 1.0 116 | pose = pose @ cam_trans 117 | frame_data_obj = { 118 | 'frame_number': frame_data['frame_number'], 119 | 'image_path': frame_data['image']['path'], 120 | 'image_size': np.array([W, H]), # NOTE: this is w, h 121 | 'pose': pose, 122 | 'fxy': focal, # NOTE: this is x, y 123 | 'cxy': prp, # NOTE: this is x, y 124 | } 125 | frame_data_by_seq[seq_name].append(frame_data_obj) 126 | print(' Sorting by sequence') 127 | for k in frame_data_by_seq: 128 | fd = sorted(frame_data_by_seq[k], 129 | key=lambda x: x['frame_number']) 130 | spl = k.split('//') 131 | self.seq_cats.append(spl[0]) 132 | self.seq_names.append(spl[1]) 133 | self.seq_offsets.append(len(self.image_path)) 134 | self.image_path.extend([x['image_path'] for x in fd]) 135 | self.all_image_size.extend([x['image_size'] for x in fd]) 136 | self.image_pose.extend([x['pose'] for x in fd]) 137 | self.fxy.extend([x['fxy'] for x in fd]) 138 | self.cxy.extend([x['cxy'] for x in fd]) 139 | self.all_image_size = np.stack(self.all_image_size) 140 | self.image_pose = np.stack(self.image_pose) 141 | self.fxy = np.stack(self.fxy) 142 | self.cxy = np.stack(self.cxy) 143 | self.seq_offsets.append(len(self.image_path)) 144 | self.seq_offsets = np.array(self.seq_offsets) 145 | print(' Saving to index') 146 | np.savez(index_file, 147 | seq_cats=self.seq_cats, 148 | seq_names=self.seq_names, 149 | seq_offsets=self.seq_offsets, 150 | image_size=self.all_image_size, 151 | image_path=self.image_path, 152 | pose=self.image_pose, 153 | fxy=self.fxy, 154 | cxy=self.cxy) 155 | self.n_seq = len(self.seq_names) 156 | print( 157 | " Loaded CO3D dataset", 158 | root, 159 | "n_seq", self.n_seq 160 | ) 161 | 162 | if seq_id is not None: 163 | self.load_sequence(seq_id) 164 | 165 | def load_sequence(self, sequence_id: int): 166 | """ 167 | Load a different CO3D sequence 168 | sequence_id should be at least 0 and at most (n_seq - 1) 169 | see co3d_tmp/co3d.txt for sequence ID -> name mappings 170 | """ 171 | print(' Loading single CO3D sequence:', 172 | self.seq_cats[sequence_id], self.seq_names[sequence_id]) 173 | self.curr_seq_cat = self.seq_cats[sequence_id] 174 | self.curr_seq_name = self.seq_names[sequence_id] 175 | self.curr_offset = self.seq_offsets[sequence_id] 176 | self.next_offset = self.seq_offsets[sequence_id + 1] 177 | self.gt = [] 178 | fxs, fys, cxs, cys = [], [], [], [] 179 | image_sizes = [] 180 | c2ws = [] 181 | ref_c2ws = [] 182 | for i in tqdm(range(self.curr_offset, self.next_offset)): 183 | is_train = i % self.hold_every != 0 184 | ref_c2ws.append(self.image_pose[i]) 185 | if self.split.endswith('train') != is_train: 186 | continue 187 | im = cv2.imread(path.join(self.data_dir, self.image_path[i])) 188 | ################################3 189 | mask_path, masks_name = os.path.split(self.image_path[i]) 190 | masks_name = masks_name[:-3]+"png" 191 | mask_path = os.path.split(mask_path)[0] 192 | mask_path = path.join(self.data_dir, mask_path, "masks", masks_name) 193 | mask = (cv2.imread(mask_path)[:,:,0]>100.0).astype(np.float32) 194 | 195 | ####################### 196 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 197 | 198 | im = im[..., :3] 199 | im = im*mask[..., None] #+ (1.0 - mask[..., None]) # remove BG 200 | h, w, _ = im.shape 201 | max_hw = max(h, w) 202 | approx_scale = self.max_image_dim / max_hw 203 | 204 | if approx_scale < 1.0: 205 | h2 = int(approx_scale * h) 206 | w2 = int(approx_scale * w) 207 | im = cv2.resize(im, (w2, h2), interpolation=cv2.INTER_AREA) 208 | mask = cv2.resize(mask, (w2, h2), interpolation=cv2.INTER_AREA) 209 | 210 | else: 211 | h2 = h 212 | w2 = w 213 | scale = np.array([w2 / w, h2 / h], dtype=np.float32) 214 | image_sizes.append(np.array([h2, w2])) 215 | cxy = self.cxy[i] * scale 216 | fxy = self.fxy[i] * scale 217 | fxs.append(fxy[0]) 218 | fys.append(fxy[1]) 219 | cxs.append(cxy[0]) 220 | cys.append(cxy[1]) 221 | # grid = data_util.gen_grid(h2, w2, cxy.astype(np.float32), normalize_scale=False) 222 | # grid /= fxy.astype(np.float32) 223 | self.gt.append(torch.from_numpy(im)) 224 | c2ws.append(self.image_pose[i]) 225 | c2w = np.stack(c2ws, axis=0) 226 | ref_c2ws = np.stack(ref_c2ws, axis=0) # For rescaling scene 227 | self.image_size = np.stack(image_sizes) 228 | fxs = torch.tensor(fxs) 229 | fys = torch.tensor(fys) 230 | cxs = torch.tensor(cxs) 231 | cys = torch.tensor(cys) 232 | 233 | # Filter out crazy poses 234 | dists = np.linalg.norm( 235 | c2w[:, :3, 3] - np.median(c2w[:, :3, 3], axis=0), axis=-1) 236 | med = np.median(dists) 237 | good_mask = dists < med * self.max_pose_dist 238 | c2w = c2w[good_mask] 239 | self.image_size = self.image_size[good_mask] 240 | good_idx = np.where(good_mask)[0] 241 | self.gt = [self.gt[i] for i in good_idx] 242 | 243 | self.intrins_full = Intrin(fxs[good_mask], fys[good_mask], 244 | cxs[good_mask], cys[good_mask]) 245 | 246 | # Normalize 247 | # c2w[:, :3, 3] -= np.mean(c2w[:, :3, 3], axis=0) 248 | # dists = np.linalg.norm(c2w[:, :3, 3], axis=-1) 249 | # c2w[:, :3, 3] *= self.cam_scale_factor / np.median(dists) 250 | 251 | T, sscale = similarity_from_cameras(ref_c2ws) 252 | c2w = T @ c2w 253 | c2w[:, :3, 3] *= self.cam_scale_factor * sscale 254 | 255 | self.c2w = torch.from_numpy(c2w).float() 256 | self.cam_n_rays = self.image_size[:, 0] * self.image_size[:, 1] 257 | self.n_images = len(self.gt) 258 | self.image_size_full = self.image_size 259 | 260 | if self.split == "train": 261 | self.gen_rays(factor=1) 262 | else: 263 | # Rays are not needed for testing 264 | self.intrins: Intrin = self.intrins_full 265 | 266 | def gen_rays(self, factor=1): 267 | print(" Generating rays, scaling factor", factor) 268 | # Generate rays 269 | self.factor = factor 270 | self.image_size = self.image_size_full // factor 271 | true_factor = self.image_size_full[:, 0] / self.image_size[:, 0] 272 | self.intrins = self.intrins_full.scale(1.0 / true_factor) 273 | 274 | all_origins = [] 275 | all_dirs = [] 276 | all_gts = [] 277 | for i in tqdm(range(self.n_images)): 278 | yy, xx = torch.meshgrid( 279 | torch.arange(self.image_size[i, 0], dtype=torch.float32) + 0.5, 280 | torch.arange(self.image_size[i, 1], dtype=torch.float32) + 0.5, 281 | ) 282 | xx = (xx - self.intrins.get('cx', i)) / self.intrins.get('fx', i) 283 | yy = (yy - self.intrins.get('cy', i)) / self.intrins.get('fy', i) 284 | zz = torch.ones_like(xx) 285 | dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention 286 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 287 | dirs = dirs.reshape(-1, 3, 1) 288 | del xx, yy, zz 289 | dirs = (self.c2w[i, None, :3, :3] @ dirs)[..., 0] 290 | 291 | if factor != 1: 292 | gt = F.interpolate( 293 | self.gt[i].permute([2, 0, 1])[None], size=(self.image_size[i, 0], 294 | self.image_size[i, 1]), 295 | mode="area" 296 | )[0].permute([1, 2, 0]) 297 | gt = gt.reshape(-1, 3) 298 | else: 299 | gt = self.gt[i].reshape(-1, 3) 300 | origins = self.c2w[i, None, :3, 3].expand(self.image_size[i, 0] * 301 | self.image_size[i, 1], -1).contiguous() 302 | all_origins.append(origins) 303 | all_dirs.append(dirs) 304 | all_gts.append(gt) 305 | origins = all_origins 306 | dirs = all_dirs 307 | gt = all_gts 308 | 309 | if self.split == "train": 310 | origins = torch.cat([o.view(-1, 3) for o in origins], dim=0) 311 | dirs = torch.cat([o.view(-1, 3) for o in dirs], dim=0) 312 | gt = torch.cat([o.reshape(-1, 3) for o in gt], dim=0) 313 | 314 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) 315 | self.rays = self.rays_init 316 | -------------------------------------------------------------------------------- /Svox2/opt/util/config_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from util.dataset import datasets 4 | import json 5 | 6 | 7 | def define_common_args(parser : argparse.ArgumentParser): 8 | parser.add_argument('data_dir', type=str) 9 | 10 | parser.add_argument('--config', '-c', 11 | type=str, 12 | default=None, 13 | help="Config yaml file (will override args)") 14 | 15 | group = parser.add_argument_group("Data loading") 16 | group.add_argument('--dataset_type', 17 | choices=list(datasets.keys()) + ["auto"], 18 | default="auto", 19 | help="Dataset type (specify type or use auto)") 20 | group.add_argument('--scene_scale', 21 | type=float, 22 | default=None, 23 | help="Global scene scaling (or use dataset default)") 24 | group.add_argument('--scale', 25 | type=float, 26 | default=None, 27 | help="Image scale, e.g. 0.5 for half resolution (or use dataset default)") 28 | group.add_argument('--seq_id', 29 | type=int, 30 | default=1000, 31 | help="Sequence ID (for CO3D only)") 32 | group.add_argument('--epoch_size', 33 | type=int, 34 | default=12800, 35 | help="Pseudo-epoch size in term of batches (to be consistent across datasets)") 36 | group.add_argument('--white_bkgd', 37 | type=bool, 38 | default=True, 39 | help="Whether to use white background (ignored in some datasets)") 40 | group.add_argument('--llffhold', 41 | type=int, 42 | default=8, 43 | help="LLFF holdout every") 44 | group.add_argument('--normalize_by_bbox', 45 | type=bool, 46 | default=False, 47 | help="Normalize by bounding box in bbox.txt, if available (NSVF dataset only); precedes normalize_by_camera") 48 | group.add_argument('--data_bbox_scale', 49 | type=float, 50 | default=1.2, 51 | help="Data bbox scaling (NSVF dataset only)") 52 | group.add_argument('--cam_scale_factor', 53 | type=float, 54 | default=0.95, 55 | help="Camera autoscale factor (NSVF/CO3D dataset only)") 56 | group.add_argument('--normalize_by_camera', 57 | type=bool, 58 | default=True, 59 | help="Normalize using cameras, assuming a 360 capture (NSVF dataset only); only used if not normalize_by_bbox") 60 | group.add_argument('--perm', action='store_true', default=False, 61 | help='sample by permutation of rays (true epoch) instead of ' 62 | 'uniformly random rays') 63 | 64 | group = parser.add_argument_group("Render options") 65 | group.add_argument('--step_size', 66 | type=float, 67 | default=0.5, 68 | help="Render step size (in voxel size units)") 69 | group.add_argument('--sigma_thresh', 70 | type=float, 71 | default=1e-8, 72 | help="Skips voxels with sigma < this") 73 | group.add_argument('--stop_thresh', 74 | type=float, 75 | default=1e-7, 76 | help="Ray march stopping threshold") 77 | group.add_argument('--background_brightness', 78 | type=float, 79 | default=1.0, 80 | help="Brightness of the infinite background") 81 | group.add_argument('--renderer_backend', '-B', 82 | choices=['cuvol', 'svox1', 'nvol'], 83 | default='cuvol', 84 | help="Renderer backend") 85 | group.add_argument('--random_sigma_std', 86 | type=float, 87 | default=0.0, 88 | help="Random Gaussian std to add to density values (only if enable_random)") 89 | group.add_argument('--random_sigma_std_background', 90 | type=float, 91 | default=0.0, 92 | help="Random Gaussian std to add to density values for BG (only if enable_random)") 93 | group.add_argument('--near_clip', 94 | type=float, 95 | default=0.00, 96 | help="Near clip distance (in world space distance units, only for FG)") 97 | group.add_argument('--use_spheric_clip', 98 | action='store_true', 99 | default=False, 100 | help="Use spheric ray clipping instead of voxel grid AABB " 101 | "(only for FG; changes near_clip to mean 1-near_intersection_radius; " 102 | "far intersection is always at radius 1)") 103 | group.add_argument('--enable_random', 104 | action='store_true', 105 | default=False, 106 | help="Random Gaussian std to add to density values") 107 | group.add_argument('--last_sample_opaque', 108 | action='store_true', 109 | default=False, 110 | help="Last sample has +1e9 density (used for LLFF)") 111 | 112 | 113 | def build_data_options(args): 114 | """ 115 | Arguments to pass as kwargs to the dataset constructor 116 | """ 117 | return { 118 | 'dataset_type': args.dataset_type, 119 | 'seq_id': args.seq_id, 120 | 'epoch_size': args.epoch_size * args.__dict__.get('batch_size', 5000), 121 | 'scene_scale': args.scene_scale, 122 | 'scale': args.scale, 123 | 'white_bkgd': args.white_bkgd, 124 | 'hold_every': args.llffhold, 125 | 'normalize_by_bbox': args.normalize_by_bbox, 126 | 'data_bbox_scale': args.data_bbox_scale, 127 | 'cam_scale_factor': args.cam_scale_factor, 128 | 'normalize_by_camera': args.normalize_by_camera, 129 | 'permutation': args.perm 130 | } 131 | 132 | def maybe_merge_config_file(args, allow_invalid=False): 133 | """ 134 | Load json config file if specified and merge the arguments 135 | """ 136 | if args.config is not None: 137 | with open(args.config, "r") as config_file: 138 | configs = json.load(config_file) 139 | invalid_args = list(set(configs.keys()) - set(dir(args))) 140 | if invalid_args and not allow_invalid: 141 | raise ValueError(f"Invalid args {invalid_args} in {args.config}.") 142 | args.__dict__.update(configs) 143 | 144 | def setup_render_opts(opt, args): 145 | """ 146 | Pass render arguments to the SparseGrid renderer options 147 | """ 148 | opt.step_size = args.step_size 149 | opt.sigma_thresh = args.sigma_thresh 150 | opt.stop_thresh = args.stop_thresh 151 | opt.background_brightness = args.background_brightness 152 | opt.backend = args.renderer_backend 153 | opt.random_sigma_std = args.random_sigma_std 154 | opt.random_sigma_std_background = args.random_sigma_std_background 155 | opt.last_sample_opaque = args.last_sample_opaque 156 | opt.near_clip = args.near_clip 157 | opt.use_spheric_clip = args.use_spheric_clip 158 | -------------------------------------------------------------------------------- /Svox2/opt/util/dataset.py: -------------------------------------------------------------------------------- 1 | from .nerf_dataset import NeRFDataset, FastNeRFDataset 2 | from .llff_dataset import LLFFDataset 3 | from .nsvf_dataset import NSVFDataset 4 | from .co3d_dataset import CO3DDataset 5 | from os import path 6 | 7 | def auto_dataset(root : str, *args, **kwargs): 8 | if path.isfile(path.join(root, 'apple', 'eval_batches_multisequence.json')): 9 | print("Detected CO3D dataset") 10 | return CO3DDataset(root, *args, **kwargs) 11 | elif path.isfile(path.join(root, 'poses_bounds.npy')): 12 | print("Detected LLFF dataset") 13 | return LLFFDataset(root, *args, **kwargs) 14 | elif path.isfile(path.join(root, 'transforms.json')) or \ 15 | path.isfile(path.join(root, 'transforms_train.json')): 16 | print("Detected NeRF (Blender) dataset") 17 | return NeRFDataset(root, *args, **kwargs) 18 | else: 19 | print("Defaulting to extended NSVF dataset") 20 | return NSVFDataset(root, *args, **kwargs) 21 | 22 | datasets = { 23 | 'nerf': NeRFDataset, 24 | 'fastnerf': FastNeRFDataset, 25 | 'llff': LLFFDataset, 26 | 'nsvf': NSVFDataset, 27 | 'co3d': CO3DDataset, 28 | 'auto': auto_dataset 29 | } 30 | -------------------------------------------------------------------------------- /Svox2/opt/util/dataset_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Union, Optional, List 4 | from .util import select_or_shuffle_rays, Rays, Intrin 5 | 6 | class DatasetBase: 7 | split: str 8 | permutation: bool 9 | epoch_size: Optional[int] 10 | n_images: int 11 | h_full: int 12 | w_full: int 13 | intrins_full: Intrin 14 | c2w: torch.Tensor # C2W OpenCV poses 15 | gt: Union[torch.Tensor, List[torch.Tensor]] # RGB images 16 | device : Union[str, torch.device] 17 | 18 | def __init__(self): 19 | self.ndc_coeffs = (-1, -1) 20 | self.use_sphere_bound = False 21 | self.should_use_background = True # a hint 22 | self.use_sphere_bound = True 23 | self.scene_center = [0.0, 0.0, 0.0] 24 | self.scene_radius = [1.0, 1.0, 1.0] 25 | self.permutation = False 26 | 27 | def shuffle_rays(self): 28 | """ 29 | Shuffle all rays 30 | """ 31 | if self.split == "train": 32 | del self.rays 33 | self.rays = select_or_shuffle_rays(self.rays_init, self.permutation, 34 | self.epoch_size, self.device) 35 | 36 | def gen_rays(self, factor=1): 37 | # print(" Generating rays, scaling factor", factor) 38 | # Generate rays 39 | self.factor = factor 40 | self.h = self.h_full // factor 41 | self.w = self.w_full // factor 42 | true_factor = self.h_full / self.h 43 | self.intrins = self.intrins_full.scale(1.0 / true_factor) 44 | yy, xx = torch.meshgrid( 45 | torch.arange(self.h, dtype=torch.float32) + 0.5, 46 | torch.arange(self.w, dtype=torch.float32) + 0.5, 47 | ) 48 | xx = (xx - self.intrins.cx) / self.intrins.fx 49 | yy = (yy - self.intrins.cy) / self.intrins.fy 50 | zz = torch.ones_like(xx) 51 | dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention 52 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 53 | dirs = dirs.reshape(1, -1, 3, 1) 54 | del xx, yy, zz 55 | dirs = (self.c2w[:, None, :3, :3] @ dirs)[..., 0] 56 | 57 | if factor != 1: 58 | gt = F.interpolate( 59 | self.gt.permute([0, 3, 1, 2]), size=(self.h, self.w), mode="area" 60 | ).permute([0, 2, 3, 1]) 61 | gt = gt.reshape(self.n_images, -1, 3) 62 | else: 63 | gt = self.gt.reshape(self.n_images, -1, 3) 64 | origins = self.c2w[:, None, :3, 3].expand(-1, self.h * self.w, -1).contiguous() 65 | if self.split == "train": 66 | origins = origins.view(-1, 3) 67 | dirs = dirs.view(-1, 3) 68 | gt = gt.reshape(-1, 3) 69 | 70 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) 71 | self.rays = self.rays_init 72 | 73 | def get_image_size(self, i : int): 74 | # H, W 75 | if hasattr(self, 'image_size'): 76 | return tuple(self.image_size[i]) 77 | else: 78 | return self.h, self.w 79 | -------------------------------------------------------------------------------- /Svox2/opt/util/load_llff.py: -------------------------------------------------------------------------------- 1 | # Originally from LLFF 2 | # https://github.com/Fyusion/LLFF 3 | # With minor modifications from NeX 4 | # https://github.com/nex-mpi/nex-code 5 | 6 | import numpy as np 7 | import os 8 | import imageio 9 | 10 | def get_image_size(path : str): 11 | """ 12 | Get image size without loading it 13 | """ 14 | from PIL import Image 15 | im = Image.open(path) 16 | return im.size[1], im.size[0] # H, W 17 | 18 | def _minify(basedir, factors=[], resolutions=[]): 19 | needtoload = False 20 | for r in factors: 21 | imgdir = os.path.join(basedir, "images_{}".format(r)) 22 | if not os.path.exists(imgdir): 23 | needtoload = True 24 | for r in resolutions: 25 | imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0])) 26 | if not os.path.exists(imgdir): 27 | needtoload = True 28 | if not needtoload: 29 | return 30 | 31 | from shutil import copy 32 | from subprocess import check_output 33 | 34 | imgdir = os.path.join(basedir, "images") 35 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 36 | imgs = [ 37 | f 38 | for f in imgs 39 | if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]]) 40 | ] 41 | imgdir_orig = imgdir 42 | 43 | wd = os.getcwd() 44 | 45 | for r in factors + resolutions: 46 | if isinstance(r, int): 47 | name = "images_{}".format(r) 48 | resizearg = "{}%".format(100.0 / r) 49 | else: 50 | name = "images_{}x{}".format(r[1], r[0]) 51 | resizearg = "{}x{}".format(r[1], r[0]) 52 | imgdir = os.path.join(basedir, name) 53 | if os.path.exists(imgdir): 54 | continue 55 | 56 | print("Minifying", r, basedir) 57 | 58 | os.makedirs(imgdir) 59 | check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True) 60 | 61 | ext = imgs[0].split(".")[-1] 62 | args = " ".join( 63 | ["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)] 64 | ) 65 | print(args) 66 | os.chdir(imgdir) 67 | check_output(args, shell=True) 68 | os.chdir(wd) 69 | 70 | if ext != "png": 71 | check_output("rm {}/*.{}".format(imgdir, ext), shell=True) 72 | print("Removed duplicates") 73 | print("Done") 74 | 75 | 76 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 77 | poses_arr = np.load(os.path.join(basedir, "poses_bounds.npy")) 78 | shape = 5 79 | 80 | # poss llff arr [3, 5, images] [R | T | intrinsic] 81 | # intrinsic same for all images 82 | if os.path.isfile(os.path.join(basedir, "hwf_cxcy.npy")): 83 | shape = 4 84 | # h, w, fx, fy, cx, cy 85 | intrinsic_arr = np.load(os.path.join(basedir, "hwf_cxcy.npy")) 86 | 87 | poses = poses_arr[:, :-2].reshape([-1, 3, shape]).transpose([1, 2, 0]) 88 | bds = poses_arr[:, -2:].transpose([1, 0]) 89 | 90 | if not os.path.isfile(os.path.join(basedir, "hwf_cxcy.npy")): 91 | intrinsic_arr = poses[:, 4, 0] 92 | poses = poses[:, :4, :] 93 | 94 | img0 = [ 95 | os.path.join(basedir, "images", f) 96 | for f in sorted(os.listdir(os.path.join(basedir, "images"))) 97 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 98 | ][0] 99 | sh = get_image_size(img0) 100 | 101 | sfx = "" 102 | if factor is not None: 103 | sfx = "_{}".format(factor) 104 | _minify(basedir, factors=[factor]) 105 | factor = factor 106 | elif height is not None: 107 | factor = sh[0] / float(height) 108 | width = int(sh[1] / factor) 109 | _minify(basedir, resolutions=[[height, width]]) 110 | sfx = "_{}x{}".format(width, height) 111 | elif width is not None: 112 | factor = sh[1] / float(width) 113 | height = int(sh[0] / factor) 114 | _minify(basedir, resolutions=[[height, width]]) 115 | sfx = "_{}x{}".format(width, height) 116 | else: 117 | factor = 1 118 | 119 | imgdir = os.path.join(basedir, "images" + sfx) 120 | if not os.path.exists(imgdir): 121 | print(imgdir, "does not exist, returning") 122 | return 123 | 124 | imgfiles = [ 125 | os.path.join(imgdir, f) 126 | for f in sorted(os.listdir(imgdir)) 127 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 128 | ] 129 | if poses.shape[-1] != len(imgfiles): 130 | print( 131 | "Mismatch between imgs {} and poses {} !!!!".format( 132 | len(imgfiles), poses.shape[-1] 133 | ) 134 | ) 135 | return 136 | 137 | if not load_imgs: 138 | return poses, bds, intrinsic_arr 139 | 140 | def imread(f): 141 | if f.endswith("png"): 142 | return imageio.imread(f, ignoregamma=True) 143 | else: 144 | return imageio.imread(f) 145 | 146 | imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles] 147 | imgs = np.stack(imgs, -1) 148 | 149 | print("Loaded image data", imgs.shape, poses[:, -1, 0]) 150 | return poses, bds, imgs, intrinsic_arr 151 | 152 | 153 | def normalize(x): 154 | return x / np.linalg.norm(x) 155 | 156 | 157 | def viewmatrix(z, up, pos): 158 | vec2 = normalize(z) 159 | vec1_avg = up 160 | vec0 = normalize(np.cross(vec1_avg, vec2)) 161 | vec1 = normalize(np.cross(vec2, vec0)) 162 | m = np.stack([vec0, vec1, vec2, pos], 1) 163 | return m 164 | 165 | 166 | def ptstocam(pts, c2w): 167 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] 168 | return tt 169 | 170 | 171 | def poses_avg(poses): 172 | # poses [images, 3, 4] not [images, 3, 5] 173 | # hwf = poses[0, :3, -1:] 174 | 175 | center = poses[:, :3, 3].mean(0) 176 | vec2 = normalize(poses[:, :3, 2].sum(0)) 177 | up = poses[:, :3, 1].sum(0) 178 | c2w = np.concatenate([viewmatrix(vec2, up, center)], 1) 179 | 180 | return c2w 181 | 182 | 183 | def render_path_axis(c2w, up, ax, rad, focal, N): 184 | render_poses = [] 185 | center = c2w[:, 3] 186 | hwf = c2w[:, 4:5] 187 | v = c2w[:, ax] * rad 188 | for t in np.linspace(-1.0, 1.0, N + 1)[:-1]: 189 | c = center + t * v 190 | z = normalize(c - (center - focal * c2w[:, 2])) 191 | # render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 192 | render_poses.append(viewmatrix(z, up, c)) 193 | return render_poses 194 | 195 | 196 | def render_path_spiral(c2w, up, rads, focal, zrate, rots, N): 197 | render_poses = [] 198 | rads = np.array(list(rads) + [1.0]) 199 | # hwf = c2w[:,4:5] 200 | 201 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 202 | c = np.dot( 203 | c2w[:3, :4], 204 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) 205 | * rads, 206 | ) 207 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 208 | # render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 209 | render_poses.append(viewmatrix(z, up, c)) 210 | return render_poses 211 | 212 | 213 | def recenter_poses(poses): 214 | # poses [images, 3, 4] 215 | poses_ = poses + 0 216 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) 217 | c2w = poses_avg(poses) 218 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 219 | 220 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 221 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 222 | 223 | poses = np.linalg.inv(c2w) @ poses 224 | poses_[:, :3, :4] = poses[:, :3, :4] 225 | poses = poses_ 226 | return poses 227 | 228 | 229 | def spherify_poses(poses, bds): 230 | p34_to_44 = lambda p: np.concatenate( 231 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 232 | ) 233 | 234 | rays_d = poses[:, :3, 2:3] 235 | rays_o = poses[:, :3, 3:4] 236 | 237 | def min_line_dist(rays_o, rays_d): 238 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 239 | b_i = -A_i @ rays_o 240 | pt_mindist = np.squeeze( 241 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0) 242 | ) 243 | return pt_mindist 244 | 245 | pt_mindist = min_line_dist(rays_o, rays_d) 246 | 247 | center = pt_mindist 248 | up = (poses[:, :3, 3] - center).mean(0) 249 | 250 | vec0 = normalize(up) 251 | vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0)) 252 | vec2 = normalize(np.cross(vec0, vec1)) 253 | pos = center 254 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 255 | 256 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 257 | 258 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 259 | 260 | sc = 1.0 / rad 261 | poses_reset[:, :3, 3] *= sc 262 | bds *= sc 263 | rad *= sc 264 | 265 | centroid = np.mean(poses_reset[:, :3, 3], 0) 266 | zh = centroid[2] 267 | radcircle = np.sqrt(rad ** 2 - zh ** 2) 268 | new_poses = [] 269 | 270 | for th in np.linspace(0.0, 2.0 * np.pi, 120): 271 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 272 | up = np.array([0, 0, -1.0]) 273 | 274 | vec2 = normalize(camorigin) 275 | vec0 = normalize(np.cross(vec2, up)) 276 | vec1 = normalize(np.cross(vec2, vec0)) 277 | pos = camorigin 278 | p = np.stack([vec0, vec1, vec2, pos], 1) 279 | 280 | new_poses.append(p) 281 | 282 | new_poses = np.stack(new_poses, 0) 283 | 284 | new_poses = np.concatenate( 285 | [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1 286 | ) 287 | poses_reset = np.concatenate( 288 | [ 289 | poses_reset[:, :3, :4], 290 | np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape), 291 | ], 292 | -1, 293 | ) 294 | 295 | return poses_reset, new_poses, bds 296 | 297 | 298 | def load_llff_data( 299 | basedir, 300 | factor=None, 301 | recenter=True, 302 | bd_factor=0.75, 303 | spherify=False, 304 | # path_zflat=False, 305 | split_train_val=8, 306 | render_style="", 307 | ): 308 | 309 | # poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 310 | poses, bds, intrinsic = _load_data( 311 | basedir, factor=factor, load_imgs=False 312 | ) # factor=8 downsamples original imgs by 8x 313 | 314 | print("Loaded LLFF data", basedir, bds.min(), bds.max()) 315 | 316 | # Correct rotation matrix ordering and move variable dim to axis 0 317 | # poses [R | T] [3, 4, images] 318 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 319 | # poses [3, 4, images] --> [images, 3, 4] 320 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 321 | 322 | # imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 323 | # images = imgs 324 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 325 | 326 | # Rescale if bd_factor is provided 327 | sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor) 328 | poses[:, :3, 3] *= sc 329 | bds *= sc 330 | 331 | if recenter: 332 | poses = recenter_poses(poses) 333 | 334 | if spherify: 335 | poses, render_poses, bds = spherify_poses(poses, bds) 336 | else: 337 | c2w = poses_avg(poses) 338 | print("recentered", c2w.shape) 339 | 340 | ## Get spiral 341 | # Get average pose 342 | up = normalize(poses[:, :3, 1].sum(0)) 343 | 344 | close_depth, inf_depth = -1, -1 345 | # Find a reasonable "focus depth" for this dataset 346 | # if os.path.exists(os.path.join(basedir, "planes_spiral.txt")): 347 | # with open(os.path.join(basedir, "planes_spiral.txt"), "r") as fi: 348 | # data = [float(x) for x in fi.readline().split(" ")] 349 | # dmin, dmax = data[:2] 350 | # close_depth = dmin * 0.9 351 | # inf_depth = dmax * 5.0 352 | # elif os.path.exists(os.path.join(basedir, "planes.txt")): 353 | # with open(os.path.join(basedir, "planes.txt"), "r") as fi: 354 | # data = [float(x) for x in fi.readline().split(" ")] 355 | # if len(data) == 3: 356 | # dmin, dmax, invz = data 357 | # elif len(data) == 4: 358 | # dmin, dmax, invz, _ = data 359 | # close_depth = dmin * 0.9 360 | # inf_depth = dmax * 5.0 361 | 362 | prev_close, prev_inf = close_depth, inf_depth 363 | if close_depth < 0 or inf_depth < 0 or render_style == "llff": 364 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 365 | 366 | if render_style == "shiny": 367 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 368 | if close_depth < prev_close: 369 | close_depth = prev_close 370 | if inf_depth > prev_inf: 371 | inf_depth = prev_inf 372 | 373 | dt = 0.75 374 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 375 | focal = mean_dz 376 | 377 | # Get radii for spiral path 378 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 379 | rads = np.percentile(np.abs(tt), 90, 0) 380 | c2w_path = c2w 381 | N_views = 120 382 | N_rots = 2 383 | # if path_zflat: 384 | # # zloc = np.percentile(tt, 10, 0)[2] 385 | # zloc = -close_depth * 0.1 386 | # c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2] 387 | # rads[2] = 0.0 388 | # N_rots = 1 389 | # N_views /= 2 390 | 391 | render_poses = render_path_spiral( 392 | c2w_path, up, rads, focal, zrate=0.5, rots=N_rots, N=N_views 393 | ) 394 | 395 | render_poses = np.array(render_poses).astype(np.float32) 396 | # reference_view_id should stay in train set only 397 | validation_ids = np.arange(poses.shape[0]) 398 | validation_ids[::split_train_val] = -1 399 | validation_ids = validation_ids < 0 400 | train_ids = np.logical_not(validation_ids) 401 | train_poses = poses[train_ids] 402 | train_bds = bds[train_ids] 403 | c2w = poses_avg(train_poses) 404 | 405 | dists = np.sum(np.square(c2w[:3, 3] - train_poses[:, :3, 3]), -1) 406 | reference_view_id = np.argmin(dists) 407 | reference_depth = train_bds[reference_view_id] 408 | print(reference_depth) 409 | 410 | return ( 411 | reference_depth, 412 | reference_view_id, 413 | render_poses, 414 | poses, 415 | intrinsic 416 | ) 417 | -------------------------------------------------------------------------------- /Svox2/opt/util/nerf_dataset.py: -------------------------------------------------------------------------------- 1 | # Standard NeRF Blender dataset loader 2 | from .util import Rays, Intrin, select_or_shuffle_rays 3 | from .dataset_base import DatasetBase 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import NamedTuple, Optional, Union 7 | from os import path 8 | import imageio 9 | from tqdm import tqdm 10 | import cv2 11 | import json 12 | import numpy as np 13 | from timeit import default_timer as timer 14 | 15 | 16 | 17 | class NeRFDataset(DatasetBase): 18 | """ 19 | NeRF dataset loader 20 | """ 21 | 22 | focal: float 23 | c2w: torch.Tensor # (n_images, 4, 4) 24 | gt: torch.Tensor # (n_images, h, w, 3) 25 | h: int 26 | w: int 27 | n_images: int 28 | rays: Optional[Rays] 29 | split: str 30 | 31 | def __init__( 32 | self, 33 | root, 34 | split, 35 | epoch_size : Optional[int] = None, 36 | device: Union[str, torch.device] = "cpu", 37 | scene_scale: Optional[float] = None, 38 | factor: int = 1, 39 | scale : Optional[float] = None, 40 | permutation: bool = True, 41 | white_bkgd: bool = True, 42 | n_images = None, 43 | data_split = None, 44 | randomization: bool = False, 45 | verbose: bool = True, 46 | cropout_size = 0, # 0 default no cropput anything from the posed images 47 | 48 | **kwargs 49 | ): 50 | super().__init__() 51 | assert path.isdir(root), f"'{root}' is not a directory" 52 | 53 | if scene_scale is None: 54 | scene_scale = 2/3 55 | if scale is None: 56 | scale = 1.0 57 | self.device = device 58 | self.permutation = permutation 59 | self.epoch_size = epoch_size 60 | all_c2w = [] 61 | all_gt = [] 62 | 63 | 64 | split_name = split if split != "test_train" else "train" 65 | split_name = data_split if data_split else split_name 66 | data_path = path.join(root, split_name) 67 | data_json = path.join(root, "transforms_" + split_name + ".json") 68 | 69 | if verbose: 70 | print("LOAD DATA", data_path) 71 | 72 | j = json.load(open(data_json, "r")) 73 | 74 | # OpenGL -> OpenCV 75 | cam_trans = torch.diag(torch.tensor([1, -1, -1, 1], dtype=torch.float32)) 76 | 77 | all_fr = tqdm(j["frames"]) if verbose else list(j["frames"]) 78 | for frame in all_fr: 79 | fpath = path.join(data_path, path.basename(frame["file_path"]) + ".png") 80 | c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) 81 | c2w = c2w @ cam_trans # To OpenCV 82 | im_gt = imageio.imread(fpath) 83 | 84 | if scale < 1.0: 85 | full_size = list(im_gt.shape[:2]) 86 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 87 | im_gt = cv2.resize(im_gt, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA) 88 | 89 | all_c2w.append(c2w) 90 | all_gt.append(torch.from_numpy(im_gt)) 91 | 92 | focal = float( 93 | 0.5 * all_gt[0].shape[1] / np.tan(0.5 * j["camera_angle_x"]) 94 | ) 95 | self.c2w = torch.stack(all_c2w) 96 | self.c2w[:, :3, 3] *= scene_scale 97 | 98 | self.gt = torch.stack(all_gt).float() / 255.0 99 | if self.gt.size(-1) == 4: 100 | if white_bkgd: 101 | # Apply alpha channel 102 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 103 | else: 104 | self.gt = self.gt[..., :3] 105 | 106 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 107 | # Choose a subset of training images 108 | if n_images is not None: 109 | if n_images > self.n_images: 110 | print(f'using {self.n_images} available training views instead of the requested {n_images}.') 111 | n_images = self.n_images 112 | self.n_images = n_images 113 | if randomization: 114 | arr = list(range(self.gt.shape[0])) 115 | np.random.shuffle(arr) 116 | self.indices = arr[0:n_images] 117 | self.gt = self.gt[self.indices] 118 | self.c2w = self.c2w[self.indices] 119 | else: 120 | self.gt = self.gt[0:n_images,...] 121 | self.c2w = self.c2w[0:n_images,...] 122 | if cropout_size and 2 * cropout_size < self.gt.shape[1] and 2 * cropout_size < self.gt.shape[2]: 123 | self.gt = self.gt[:, cropout_size:self.gt.shape[1] - 124 | cropout_size, cropout_size:self.gt.shape[2]-cropout_size, :] 125 | 126 | self.intrins_full : Intrin = Intrin(focal, focal, 127 | self.w_full * 0.5, 128 | self.h_full * 0.5) 129 | 130 | self.split = split 131 | self.scene_scale = scene_scale 132 | if self.split == "train": 133 | self.gen_rays(factor=factor) 134 | else: 135 | # Rays are not needed for testing 136 | self.h, self.w = self.h_full, self.w_full 137 | self.intrins : Intrin = self.intrins_full 138 | 139 | self.should_use_background = False # Give warning 140 | 141 | 142 | class FastNeRFDataset(DatasetBase): 143 | """ 144 | Fast NeRF dataset loader for training on nerf data 145 | """ 146 | 147 | focal: float 148 | c2w: torch.Tensor # (n_images, 4, 4) 149 | gt: torch.Tensor # (n_images, h, w, 3) 150 | h: int 151 | w: int 152 | n_images: int 153 | rays: Optional[Rays] 154 | split: str 155 | 156 | def __init__( 157 | self, 158 | root, 159 | split, 160 | epoch_size: Optional[int] = None, 161 | device: Union[str, torch.device] = "cpu", 162 | scene_scale: Optional[float] = None, 163 | factor: int = 1, 164 | scale: Optional[float] = None, 165 | permutation: bool = True, 166 | white_bkgd: bool = True, 167 | n_images=None, 168 | data_split=None, 169 | indices = [], 170 | randomization: bool = False, 171 | 172 | verbose: bool = True, 173 | cropout_size=0, # 0 default no cropput anything from the posed images 174 | 175 | **kwargs 176 | ): 177 | super().__init__() 178 | assert path.isdir(root), f"'{root}' is not a directory" 179 | 180 | if scene_scale is None: 181 | scene_scale = 2/3 182 | if scale is None: 183 | scale = 1.0 184 | self.device = device 185 | self.permutation = permutation 186 | self.epoch_size = epoch_size 187 | all_c2w = [] 188 | all_gt = [] 189 | # s = timer() 190 | 191 | split_name = split if split != "test_train" else "train" 192 | split_name = data_split if data_split else split_name 193 | data_path = path.join(root, split_name) 194 | data_json = path.join(root, "transforms_" + split_name + ".json") 195 | 196 | if verbose: 197 | print("LOAD DATA", data_path) 198 | 199 | j = json.load(open(data_json, "r")) 200 | 201 | # OpenGL -> OpenCV 202 | cam_trans = torch.diag(torch.tensor( 203 | [1, -1, -1, 1], dtype=torch.float32)) 204 | 205 | all_fr = tqdm(j["frames"]) if verbose else list(j["frames"]) 206 | # print(len(all_fr)) 207 | 208 | for indx, frame in enumerate(all_fr): 209 | if indx not in indices: 210 | continue 211 | fpath = path.join(data_path, path.basename( 212 | frame["file_path"]) + ".png") 213 | c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) 214 | c2w = c2w @ cam_trans # To OpenCV 215 | im_gt = imageio.imread(fpath) 216 | 217 | if scale < 1.0: 218 | full_size = list(im_gt.shape[:2]) 219 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 220 | im_gt = cv2.resize(im_gt, (rsz_w, rsz_h), 221 | interpolation=cv2.INTER_AREA) 222 | 223 | all_c2w.append(c2w[None,...]) 224 | all_gt.append(im_gt[None, ...]) 225 | 226 | self.c2w = np.concatenate(all_c2w,axis=0) 227 | self.c2w[:, :3, 3] *= scene_scale 228 | 229 | self.gt = np.concatenate(all_gt, axis=0).astype('float32') / 255.0 230 | 231 | self.gt, self.masks = self.gt[..., :3], self.gt[..., 3] 232 | 233 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 234 | # Choose a subset of training images 235 | 236 | self.should_use_background = False # Give warning 237 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /Svox2/opt/util/nsvf_dataset.py: -------------------------------------------------------------------------------- 1 | # Extended NSVF-format dataset loader 2 | # This is a more sane format vs the NeRF formats 3 | 4 | from .util import Rays, Intrin, similarity_from_cameras 5 | from .dataset_base import DatasetBase 6 | import torch 7 | import torch.nn.functional as F 8 | from typing import NamedTuple, Optional, Union 9 | from os import path 10 | import os 11 | import cv2 12 | import imageio 13 | from tqdm import tqdm 14 | import json 15 | import numpy as np 16 | from warnings import warn 17 | 18 | 19 | class NSVFDataset(DatasetBase): 20 | """ 21 | Extended NSVF dataset loader 22 | """ 23 | 24 | focal: float 25 | c2w: torch.Tensor # (n_images, 4, 4) 26 | gt: torch.Tensor # (n_images, h, w, 3) 27 | h: int 28 | w: int 29 | n_images: int 30 | rays: Optional[Rays] 31 | split: str 32 | 33 | def __init__( 34 | self, 35 | root, 36 | split, 37 | epoch_size : Optional[int] = None, 38 | device: Union[str, torch.device] = "cpu", 39 | scene_scale: Optional[float] = None, # Scene scaling 40 | factor: int = 1, # Image scaling (on ray gen; use gen_rays(factor) to dynamically change scale) 41 | scale : Optional[float] = 1.0, # Image scaling (on load) 42 | permutation: bool = True, 43 | white_bkgd: bool = True, 44 | normalize_by_bbox: bool = False, 45 | data_bbox_scale : float = 1.1, # Only used if normalize_by_bbox 46 | cam_scale_factor : float = 0.95, 47 | normalize_by_camera: bool = True, 48 | **kwargs 49 | ): 50 | super().__init__() 51 | assert path.isdir(root), f"'{root}' is not a directory" 52 | 53 | if scene_scale is None: 54 | scene_scale = 1.0 55 | if scale is None: 56 | scale = 1.0 57 | 58 | self.device = device 59 | self.permutation = permutation 60 | self.epoch_size = epoch_size 61 | all_c2w = [] 62 | all_gt = [] 63 | 64 | split_name = split if split != "test_train" else "train" 65 | 66 | print("LOAD NSVF DATA", root, 'split', split) 67 | 68 | self.split = split 69 | 70 | def sort_key(x): 71 | if len(x) > 2 and x[1] == "_": 72 | return x[2:] 73 | return x 74 | def look_for_dir(cands, required=True): 75 | for cand in cands: 76 | if path.isdir(path.join(root, cand)): 77 | return cand 78 | if required: 79 | assert False, "None of " + str(cands) + " found in data directory" 80 | return "" 81 | 82 | img_dir_name = look_for_dir(["images", "image", "rgb"]) 83 | pose_dir_name = look_for_dir(["poses", "pose"]) 84 | # intrin_dir_name = look_for_dir(["intrin"], required=False) 85 | img_files = sorted(os.listdir(path.join(root, img_dir_name)), key=sort_key) 86 | 87 | # Select subset of files 88 | if self.split == "train" or self.split == "test_train": 89 | img_files = [x for x in img_files if x.startswith("0_")] 90 | elif self.split == "val": 91 | img_files = [x for x in img_files if x.startswith("1_")] 92 | elif self.split == "test": 93 | test_img_files = [x for x in img_files if x.startswith("2_")] 94 | if len(test_img_files) == 0: 95 | test_img_files = [x for x in img_files if x.startswith("1_")] 96 | img_files = test_img_files 97 | 98 | assert len(img_files) > 0, "No matching images in directory: " + path.join(data_dir, img_dir_name) 99 | self.img_files = img_files 100 | 101 | dynamic_resize = scale < 1 102 | self.use_integral_scaling = False 103 | scaled_img_dir = '' 104 | if dynamic_resize and abs((1.0 / scale) - round(1.0 / scale)) < 1e-9: 105 | resized_dir = img_dir_name + "_" + str(round(1.0 / scale)) 106 | if path.exists(path.join(root, resized_dir)): 107 | img_dir_name = resized_dir 108 | dynamic_resize = False 109 | print("> Pre-resized images from", img_dir_name) 110 | if dynamic_resize: 111 | print("> WARNING: Dynamically resizing images") 112 | 113 | full_size = [0, 0] 114 | rsz_h = rsz_w = 0 115 | 116 | for img_fname in tqdm(img_files): 117 | img_path = path.join(root, img_dir_name, img_fname) 118 | image = imageio.imread(img_path) 119 | pose_fname = path.splitext(img_fname)[0] + ".txt" 120 | pose_path = path.join(root, pose_dir_name, pose_fname) 121 | # intrin_path = path.join(root, intrin_dir_name, pose_fname) 122 | 123 | cam_mtx = np.loadtxt(pose_path).reshape(-1, 4) 124 | if len(cam_mtx) == 3: 125 | bottom = np.array([[0.0, 0.0, 0.0, 1.0]]) 126 | cam_mtx = np.concatenate([cam_mtx, bottom], axis=0) 127 | all_c2w.append(torch.from_numpy(cam_mtx)) # C2W (4, 4) OpenCV 128 | full_size = list(image.shape[:2]) 129 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 130 | if dynamic_resize: 131 | image = cv2.resize(image, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA) 132 | 133 | all_gt.append(torch.from_numpy(image)) 134 | 135 | 136 | self.c2w_f64 = torch.stack(all_c2w) 137 | 138 | print('NORMALIZE BY?', 'bbox' if normalize_by_bbox else 'camera' if normalize_by_camera else 'manual') 139 | if normalize_by_bbox: 140 | # Not used, but could be helpful 141 | bbox_path = path.join(root, "bbox.txt") 142 | if path.exists(bbox_path): 143 | bbox_data = np.loadtxt(bbox_path) 144 | center = (bbox_data[:3] + bbox_data[3:6]) * 0.5 145 | radius = (bbox_data[3:6] - bbox_data[:3]) * 0.5 * data_bbox_scale 146 | 147 | # Recenter 148 | self.c2w_f64[:, :3, 3] -= center 149 | # Rescale 150 | scene_scale = 1.0 / radius.max() 151 | else: 152 | warn('normalize_by_bbox=True but bbox.txt was not available') 153 | elif normalize_by_camera: 154 | norm_pose_files = sorted(os.listdir(path.join(root, pose_dir_name)), key=sort_key) 155 | norm_poses = np.stack([np.loadtxt(path.join(root, pose_dir_name, x)).reshape(-1, 4) 156 | for x in norm_pose_files], axis=0) 157 | 158 | # Select subset of files 159 | T, sscale = similarity_from_cameras(norm_poses) 160 | 161 | self.c2w_f64 = torch.from_numpy(T) @ self.c2w_f64 162 | scene_scale = cam_scale_factor * sscale 163 | 164 | # center = np.mean(norm_poses[:, :3, 3], axis=0) 165 | # radius = np.median(np.linalg.norm(norm_poses[:, :3, 3] - center, axis=-1)) 166 | # self.c2w_f64[:, :3, 3] -= center 167 | # scene_scale = cam_scale_factor / radius 168 | # print('good', self.c2w_f64[:2], scene_scale) 169 | 170 | print('scene_scale', scene_scale) 171 | self.c2w_f64[:, :3, 3] *= scene_scale 172 | self.c2w = self.c2w_f64.float() 173 | 174 | self.gt = torch.stack(all_gt).double() / 255.0 175 | if self.gt.size(-1) == 4: 176 | if white_bkgd: 177 | # Apply alpha channel 178 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 179 | else: 180 | self.gt = self.gt[..., :3] 181 | self.gt = self.gt.float() 182 | 183 | assert full_size[0] > 0 and full_size[1] > 0, "Empty images" 184 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 185 | 186 | intrin_path = path.join(root, "intrinsics.txt") 187 | assert path.exists(intrin_path), "intrinsics unavailable" 188 | try: 189 | K: np.ndarray = np.loadtxt(intrin_path) 190 | fx = K[0, 0] 191 | fy = K[1, 1] 192 | cx = K[0, 2] 193 | cy = K[1, 2] 194 | except: 195 | # Weird format sometimes in NSVF data 196 | with open(intrin_path, "r") as f: 197 | spl = f.readline().split() 198 | fx = fy = float(spl[0]) 199 | cx = float(spl[1]) 200 | cy = float(spl[2]) 201 | if scale < 1.0: 202 | scale_w = rsz_w / full_size[1] 203 | scale_h = rsz_h / full_size[0] 204 | fx *= scale_w 205 | cx *= scale_w 206 | fy *= scale_h 207 | cy *= scale_h 208 | 209 | self.intrins_full : Intrin = Intrin(fx, fy, cx, cy) 210 | print(' intrinsics (loaded reso)', self.intrins_full) 211 | 212 | self.scene_scale = scene_scale 213 | if self.split == "train": 214 | self.gen_rays(factor=factor) 215 | else: 216 | # Rays are not needed for testing 217 | self.h, self.w = self.h_full, self.w_full 218 | self.intrins : Intrin = self.intrins_full 219 | -------------------------------------------------------------------------------- /Svox2/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | import os.path as osp 4 | import warnings 5 | 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 9 | 10 | __version__ = None 11 | exec(open('svox2/version.py', 'r').read()) 12 | 13 | CUDA_FLAGS = [] 14 | INSTALL_REQUIREMENTS = [] 15 | include_dirs = [osp.join(ROOT_DIR, "svox2", "csrc", "include")] 16 | 17 | # From PyTorch3D 18 | cub_home = os.environ.get("CUB_HOME", None) 19 | if cub_home is None: 20 | prefix = os.environ.get("CONDA_PREFIX", None) 21 | if prefix is not None and os.path.isdir(prefix + "/include/cub"): 22 | cub_home = prefix + "/include" 23 | 24 | if cub_home is None: 25 | warnings.warn( 26 | "The environment variable `CUB_HOME` was not found." 27 | "Installation will fail if your system CUDA toolkit version is less than 11." 28 | "NVIDIA CUB can be downloaded " 29 | "from `https://github.com/NVIDIA/cub/releases`. You can unpack " 30 | "it to a location of your choice and set the environment variable " 31 | "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." 32 | ) 33 | else: 34 | include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " ")) 35 | 36 | try: 37 | ext_modules = [ 38 | CUDAExtension('svox2.csrc', [ 39 | 'svox2/csrc/svox2.cpp', 40 | 'svox2/csrc/svox2_kernel.cu', 41 | 'svox2/csrc/render_lerp_kernel_cuvol.cu', 42 | 'svox2/csrc/render_lerp_kernel_nvol.cu', 43 | 'svox2/csrc/render_svox1_kernel.cu', 44 | 'svox2/csrc/misc_kernel.cu', 45 | 'svox2/csrc/loss_kernel.cu', 46 | 'svox2/csrc/optim_kernel.cu', 47 | ], include_dirs=include_dirs, 48 | optional=False), 49 | ] 50 | except: 51 | import warnings 52 | warnings.warn("Failed to build CUDA extension") 53 | ext_modules = [] 54 | 55 | setup( 56 | name='svox2', 57 | version=__version__, 58 | author='Alex Yu', 59 | author_email='alexyu99126@gmail.com', 60 | description='PyTorch sparse voxel volume extension, including custom CUDA kernels', 61 | long_description='PyTorch sparse voxel volume extension, including custom CUDA kernels', 62 | ext_modules=ext_modules, 63 | setup_requires=['pybind11>=2.5.0'], 64 | packages=['svox2', 'svox2.csrc'], 65 | cmdclass={'build_ext': BuildExtension}, 66 | zip_safe=False, 67 | ) 68 | -------------------------------------------------------------------------------- /Svox2/svox2/__init__.py: -------------------------------------------------------------------------------- 1 | from .defs import * 2 | from .svox2 import SparseGrid, Camera, Rays, RenderOptions, TinySparseGrid 3 | from .version import __version__ 4 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/.ccls: -------------------------------------------------------------------------------- 1 | %compile_commands.json 2 | %cu -x cuda 3 | %cu --cuda-gpu-arch=sm_61 4 | %cu --cuda-path=/usr/local/cuda-11.2 5 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PlenOctree Authors. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions are met: 5 | # 6 | # 1. Redistributions of source code must retain the above copyright notice, 7 | # this list of conditions and the following disclaimer. 8 | # 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 13 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 15 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 16 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 17 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 18 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 19 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 20 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 21 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 22 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 23 | # POSSIBILITY OF SUCH DAMAGE. 24 | 25 | # NOTE: This CMakeLists is for development purposes only 26 | # (To check CUDA compile errors) 27 | # It is NOT necessary to use this for installation. Just use pip install . 28 | cmake_minimum_required( VERSION 3.3 ) 29 | 30 | if(NOT CMAKE_BUILD_TYPE) 31 | set(CMAKE_BUILD_TYPE Release) 32 | endif() 33 | if (POLICY CMP0048) 34 | cmake_policy(SET CMP0048 NEW) 35 | endif (POLICY CMP0048) 36 | if (POLICY CMP0069) 37 | cmake_policy(SET CMP0069 NEW) 38 | endif (POLICY CMP0069) 39 | if (POLICY CMP0072) 40 | cmake_policy(SET CMP0072 NEW) 41 | endif (POLICY CMP0072) 42 | 43 | project( svox2 ) 44 | 45 | set(CMAKE_CXX_STANDARD 14) 46 | enable_language(CUDA) 47 | message(STATUS "CUDA enabled") 48 | set( CMAKE_CUDA_STANDARD 14 ) 49 | set( CMAKE_CUDA_STANDARD_REQUIRED ON) 50 | set( CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -Xcudafe \"--display_error_number --diag_suppress=3057 --diag_suppress=3058 --diag_suppress=3059 --diag_suppress=3060\" -lineinfo -arch=sm_75 ") 51 | # -Xptxas=\"-v\" 52 | 53 | set( INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" ) 54 | 55 | if( MSVC ) 56 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd") 57 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /GLT /Ox") 58 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler=\"/MT\"" ) 59 | endif() 60 | 61 | file(GLOB SOURCES 62 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp 63 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cu) 64 | 65 | find_package(pybind11 REQUIRED) 66 | find_package(Torch REQUIRED) 67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 68 | 69 | include_directories (${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 70 | 71 | pybind11_add_module(svox2-test SHARED ${SOURCES}) 72 | target_link_libraries(svox2-test PRIVATE "${TORCH_LIBRARIES}") 73 | target_include_directories(svox2-test PRIVATE "${INCLUDE_DIR}") 74 | 75 | if (MSVC) 76 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") 77 | add_custom_command(TARGET svox2-test 78 | POST_BUILD 79 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 80 | ${TORCH_DLLS} 81 | $) 82 | endif (MSVC) 83 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/include/cubemap_util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuda_util.cuh" 3 | #include 4 | #include 5 | 6 | #define _AXIS(x) (x>>1) 7 | #define _ORI(x) (x&1) 8 | #define _FACE(axis, ori) uint8_t((axis << 1) | ori) 9 | 10 | namespace { 11 | namespace device { 12 | 13 | struct CubemapCoord { 14 | uint8_t face; 15 | float uv[2]; 16 | }; 17 | 18 | struct CubemapLocation { 19 | uint8_t face; 20 | int16_t uv[2]; 21 | }; 22 | 23 | struct CubemapBilerpQuery { 24 | CubemapLocation ptr[2][2]; 25 | float duv[2]; 26 | }; 27 | 28 | __device__ __inline__ void 29 | invert_cubemap(int u, int v, float r, 30 | int reso, 31 | float* __restrict__ out) { 32 | const float u_norm = (u + 0.5f) / reso * 2 - 1; 33 | const float v_norm = (v + 0.5f) / reso * 2 - 1; 34 | // EAC 35 | const float tx = tanf((M_PI / 4) * u_norm); 36 | const float ty = tanf((M_PI / 4) * v_norm); 37 | const float common = r * rnorm3df(1.f, tx, ty); 38 | out[0] = tx * common; 39 | out[1] = ty * common; 40 | out[2] = common; 41 | } 42 | 43 | __device__ __inline__ void 44 | invert_cubemap_traditional(int u, int v, float r, 45 | int reso, 46 | float* __restrict__ out) { 47 | const float u_norm = (u + 0.5f) / reso * 2 - 1; 48 | const float v_norm = (v + 0.5f) / reso * 2 - 1; 49 | const float common = r * rnorm3df(1.f, u_norm, v_norm); 50 | out[0] = u_norm * common; 51 | out[1] = v_norm * common; 52 | out[2] = common; 53 | } 54 | 55 | __device__ __host__ __inline__ CubemapCoord 56 | dir_to_cubemap_coord(const float* __restrict__ xyz_o, 57 | int face_reso, 58 | bool eac = true) { 59 | float maxv; 60 | int ax; 61 | float xyz[3] = {xyz_o[0], xyz_o[1], xyz_o[2]}; 62 | if (fabsf(xyz[0]) >= fabsf(xyz[1]) && fabsf(xyz[0]) >= fabsf(xyz[2])) { 63 | ax = 0; maxv = xyz[0]; 64 | } else if (fabsf(xyz[1]) >= fabsf(xyz[2])) { 65 | ax = 1; maxv = xyz[1]; 66 | } else { 67 | ax = 2; maxv = xyz[2]; 68 | } 69 | const float recip = 1.f / fabsf(maxv); 70 | xyz[0] *= recip; 71 | xyz[1] *= recip; 72 | xyz[2] *= recip; 73 | 74 | if (eac) { 75 | #pragma unroll 3 76 | for (int i = 0; i < 3; ++i) { 77 | xyz[i] = atanf(xyz[i]) * (4 * M_1_PI); 78 | } 79 | } 80 | 81 | CubemapCoord idx; 82 | idx.uv[0] = ((xyz[(ax ^ 1) & 1] + 1) * face_reso - 1) * 0.5; 83 | idx.uv[1] = ((xyz[(ax ^ 2) & 2] + 1) * face_reso - 1) * 0.5; 84 | const int ori = xyz[ax] >= 0; 85 | idx.face = _FACE(ax, ori); 86 | 87 | return idx; 88 | } 89 | 90 | __device__ __host__ __inline__ CubemapBilerpQuery 91 | cubemap_build_query( 92 | const CubemapCoord& idx, 93 | int face_reso) { 94 | const int uv_idx[2] ={ (int)floorf(idx.uv[0]), (int)floorf(idx.uv[1]) }; 95 | 96 | bool m[2][2]; 97 | m[0][0] = uv_idx[0] < 0; 98 | m[0][1] = uv_idx[0] > face_reso - 2; 99 | m[1][0] = uv_idx[1] < 0; 100 | m[1][1] = uv_idx[1] > face_reso - 2; 101 | 102 | const int face = idx.face; 103 | const int ax = _AXIS(face); 104 | const int ori = _ORI(face); 105 | // if ax is one of {0, 1, 2}, this trick gets the 2 106 | // of {0, 1, 2} other than ax 107 | const int uvd[2] = {((ax ^ 1) & 1), ((ax ^ 2) & 2)}; 108 | int uv_ori[2]; 109 | 110 | CubemapBilerpQuery result; 111 | result.duv[0] = idx.uv[0] - uv_idx[0]; 112 | result.duv[1] = idx.uv[1] - uv_idx[1]; 113 | 114 | #pragma unroll 2 115 | for (uv_ori[0] = 0; uv_ori[0] < 2; ++uv_ori[0]) { 116 | #pragma unroll 2 117 | for (uv_ori[1] = 0; uv_ori[1] < 2; ++uv_ori[1]) { 118 | CubemapLocation& nidx = result.ptr[uv_ori[0]][uv_ori[1]]; 119 | nidx.face = face; 120 | nidx.uv[0] = uv_idx[0] + uv_ori[0]; 121 | nidx.uv[1] = uv_idx[1] + uv_ori[1]; 122 | 123 | const bool mu = m[0][uv_ori[0]]; 124 | const bool mv = m[1][uv_ori[1]]; 125 | 126 | int edge_idx = -1; 127 | if (mu) { 128 | // Crosses edge in u-axis 129 | if (mv) { 130 | // FIXME: deal with corners properly, right now 131 | // just clamps, resulting in a little artifact 132 | // at each cube corner 133 | nidx.uv[0] = min(max(nidx.uv[0], 0), face_reso - 1); 134 | nidx.uv[1] = min(max(nidx.uv[1], 0), face_reso - 1); 135 | } else { 136 | edge_idx = 0; 137 | } 138 | } else if (mv) { 139 | // Crosses edge in v-axis 140 | edge_idx = 1; 141 | } 142 | if (~edge_idx) { 143 | const int nax = uvd[edge_idx]; 144 | const int16_t other_coord = nidx.uv[1 - edge_idx]; 145 | 146 | // Determine directions in the new face 147 | const int nud = (nax ^ 1) & 1; 148 | // const int nvd = (nax ^ 2) & 2; 149 | 150 | if (nud == ax) { 151 | nidx.uv[0] = ori ? (face_reso - 1) : 0; 152 | nidx.uv[1] = other_coord; 153 | } else { 154 | nidx.uv[0] = other_coord; 155 | nidx.uv[1] = ori ? (face_reso - 1) : 0; 156 | } 157 | 158 | nidx.face = _FACE(nax, uv_ori[edge_idx]); 159 | } 160 | // Interior point: nothing needs to be done 161 | 162 | } 163 | } 164 | 165 | return result; 166 | } 167 | 168 | __device__ __host__ __inline__ float 169 | cubemap_sample( 170 | const float* __restrict__ cubemap, // (6, face_reso, face_reso, n_channels) 171 | const CubemapBilerpQuery& query, 172 | int face_reso, 173 | int n_channels, 174 | int chnl_id) { 175 | 176 | // NOTE: assuming address will fit in int32 177 | const int stride1 = face_reso * n_channels; 178 | const int stride0 = face_reso * stride1; 179 | const CubemapLocation& p00 = query.ptr[0][0]; 180 | const float v00 = cubemap[p00.face * stride0 + p00.uv[0] * stride1 + p00.uv[1] * n_channels + chnl_id]; 181 | const CubemapLocation& p01 = query.ptr[0][1]; 182 | const float v01 = cubemap[p01.face * stride0 + p01.uv[0] * stride1 + p01.uv[1] * n_channels + chnl_id]; 183 | const CubemapLocation& p10 = query.ptr[1][0]; 184 | const float v10 = cubemap[p10.face * stride0 + p10.uv[0] * stride1 + p10.uv[1] * n_channels + chnl_id]; 185 | const CubemapLocation& p11 = query.ptr[1][1]; 186 | const float v11 = cubemap[p11.face * stride0 + p11.uv[0] * stride1 + p11.uv[1] * n_channels + chnl_id]; 187 | 188 | const float val0 = lerp(v00, v01, query.duv[1]); 189 | const float val1 = lerp(v10, v11, query.duv[1]); 190 | 191 | return lerp(val0, val1, query.duv[0]); 192 | } 193 | 194 | __device__ __inline__ void 195 | cubemap_sample_backward( 196 | float* __restrict__ cubemap_grad, // (6, face_reso, face_reso, n_channels) 197 | const CubemapBilerpQuery& query, 198 | int face_reso, 199 | int n_channels, 200 | float grad_out, 201 | int chnl_id, 202 | bool* __restrict__ mask_out = nullptr) { 203 | 204 | // NOTE: assuming address will fit in int32 205 | const float bu = query.duv[0], bv = query.duv[1]; 206 | const float au = 1.f - bu, av = 1.f - bv; 207 | 208 | #define _ADD_CUBEVERT(i, j, val) { \ 209 | const CubemapLocation& p00 = query.ptr[i][j]; \ 210 | const int idx = (p00.face * face_reso + p00.uv[0]) * face_reso + p00.uv[1]; \ 211 | float* __restrict__ v00 = &cubemap_grad[idx * n_channels + chnl_id]; \ 212 | atomicAdd(v00, val); \ 213 | if (mask_out != nullptr) { \ 214 | mask_out[idx] = true; \ 215 | } \ 216 | } 217 | 218 | _ADD_CUBEVERT(0, 0, au * av * grad_out); 219 | _ADD_CUBEVERT(0, 1, au * bv * grad_out); 220 | _ADD_CUBEVERT(1, 0, bu * av * grad_out); 221 | _ADD_CUBEVERT(1, 1, bu * bv * grad_out); 222 | #undef _ADD_CUBEVERT 223 | 224 | } 225 | 226 | __device__ __host__ __inline__ float 227 | multi_cubemap_sample( 228 | const float* __restrict__ cubemap1, // (6, face_reso, face_reso, n_channels) 229 | const float* __restrict__ cubemap2, // (6, face_reso, face_reso, n_channels) 230 | const CubemapBilerpQuery& query, 231 | float interp_wt, 232 | int face_reso, 233 | int n_channels, 234 | int chnl_id) { 235 | const float val1 = cubemap_sample(cubemap1, 236 | query, 237 | face_reso, 238 | n_channels, 239 | chnl_id); 240 | const float val2 = cubemap_sample(cubemap2, 241 | query, 242 | face_reso, 243 | n_channels, 244 | chnl_id); 245 | return lerp(val1, val2, interp_wt); 246 | } 247 | 248 | __device__ __inline__ void 249 | multi_cubemap_sample_backward( 250 | float* __restrict__ cubemap_grad1, // (6, face_reso, face_reso, n_channels) 251 | float* __restrict__ cubemap_grad2, // (6, face_reso, face_reso, n_channels) 252 | const CubemapBilerpQuery& query, 253 | float interp_wt, 254 | int face_reso, 255 | int n_channels, 256 | float grad_out, 257 | int chnl_id, 258 | bool* __restrict__ mask_out1 = nullptr, 259 | bool* __restrict__ mask_out2 = nullptr) { 260 | if (cubemap_grad1 == nullptr) return; 261 | cubemap_sample_backward(cubemap_grad1, 262 | query, 263 | face_reso, 264 | n_channels, 265 | grad_out * (1.f - interp_wt), 266 | chnl_id, 267 | mask_out1); 268 | cubemap_sample_backward(cubemap_grad2, 269 | query, 270 | face_reso, 271 | n_channels, 272 | grad_out * interp_wt, 273 | chnl_id, 274 | mask_out1 == nullptr ? nullptr : mask_out2); 275 | } 276 | 277 | 278 | } // namespace device 279 | } // namespace 280 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/include/cuda_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "util.hpp" 8 | 9 | 10 | #define DEVICE_GUARD(_ten) \ 11 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); 12 | 13 | #define CUDA_GET_THREAD_ID(tid, Q) const int tid = blockIdx.x * blockDim.x + threadIdx.x; \ 14 | if (tid >= Q) return 15 | #define CUDA_GET_THREAD_ID_U64(tid, Q) const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \ 16 | if (tid >= Q) return 17 | #define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1) 18 | #define CUDA_CHECK_ERRORS \ 19 | cudaError_t err = cudaGetLastError(); \ 20 | if (err != cudaSuccess) \ 21 | printf("Error in svox2.%s : %s\n", __FUNCTION__, cudaGetErrorString(err)) 22 | 23 | #define CUDA_MAX_THREADS at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock 24 | 25 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 26 | #else 27 | __device__ inline double atomicAdd(double* address, double val){ 28 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 29 | unsigned long long int old = *address_as_ull, assumed; 30 | do { 31 | assumed = old; 32 | old = atomicCAS(address_as_ull, assumed, 33 | __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | #endif 38 | 39 | __device__ inline void atomicMax(float* result, float value){ 40 | unsigned* result_as_u = (unsigned*)result; 41 | unsigned old = *result_as_u, assumed; 42 | do { 43 | assumed = old; 44 | old = atomicCAS(result_as_u, assumed, 45 | __float_as_int(fmaxf(value, __int_as_float(assumed)))); 46 | } while (old != assumed); 47 | return; 48 | } 49 | 50 | __device__ inline void atomicMax(double* result, double value){ 51 | unsigned long long int* result_as_ull = (unsigned long long int*)result; 52 | unsigned long long int old = *result_as_ull, assumed; 53 | do { 54 | assumed = old; 55 | old = atomicCAS(result_as_ull, assumed, 56 | __double_as_longlong(fmaxf(value, __longlong_as_double(assumed)))); 57 | } while (old != assumed); 58 | return; 59 | } 60 | 61 | __device__ __inline__ void transform_coord(float* __restrict__ point, 62 | const float* __restrict__ scaling, 63 | const float* __restrict__ offset) { 64 | point[0] = fmaf(point[0], scaling[0], offset[0]); // a*b + c 65 | point[1] = fmaf(point[1], scaling[1], offset[1]); // a*b + c 66 | point[2] = fmaf(point[2], scaling[2], offset[2]); // a*b + c 67 | } 68 | 69 | // Linear interp 70 | // Subtract and fused multiply-add 71 | // (1-w) a + w b 72 | template 73 | __host__ __device__ __inline__ T lerp(T a, T b, T w) { 74 | return fmaf(w, b - a, a); 75 | } 76 | 77 | __device__ __inline__ static float _norm( 78 | const float* __restrict__ dir) { 79 | // return sqrtf(dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2]); 80 | return norm3df(dir[0], dir[1], dir[2]); 81 | } 82 | 83 | __device__ __inline__ static float _rnorm( 84 | const float* __restrict__ dir) { 85 | // return 1.f / _norm(dir); 86 | return rnorm3df(dir[0], dir[1], dir[2]); 87 | } 88 | 89 | __host__ __device__ __inline__ static void xsuby3d( 90 | float* __restrict__ x, 91 | const float* __restrict__ y) { 92 | x[0] -= y[0]; 93 | x[1] -= y[1]; 94 | x[2] -= y[2]; 95 | } 96 | 97 | __host__ __device__ __inline__ static float _dot( 98 | const float* __restrict__ x, 99 | const float* __restrict__ y) { 100 | return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]; 101 | } 102 | 103 | __host__ __device__ __inline__ static void _cross( 104 | const float* __restrict__ a, 105 | const float* __restrict__ b, 106 | float* __restrict__ out) { 107 | out[0] = a[1] * b[2] - a[2] * b[1]; 108 | out[1] = a[2] * b[0] - a[0] * b[2]; 109 | out[2] = a[0] * b[1] - a[1] * b[0]; 110 | } 111 | 112 | __device__ __inline__ static float _dist_ray_to_origin( 113 | const float* __restrict__ origin, 114 | const float* __restrict__ dir) { 115 | // dir must be unit vector 116 | float tmp[3]; 117 | _cross(origin, dir, tmp); 118 | return _norm(tmp); 119 | } 120 | 121 | #define int_div2_ceil(x) ((((x) - 1) >> 1) + 1) 122 | 123 | __host__ __inline__ cudaError_t cuda_assert( 124 | const cudaError_t code, const char* const file, 125 | const int line, const bool abort) { 126 | if (code != cudaSuccess) { 127 | fprintf(stderr, "cuda_assert: %s %s %s %d\n", cudaGetErrorName(code) ,cudaGetErrorString(code), 128 | file, line); 129 | 130 | if (abort) { 131 | cudaDeviceReset(); 132 | exit(code); 133 | } 134 | } 135 | 136 | return code; 137 | } 138 | 139 | #define cuda(...) cuda_assert((cuda##__VA_ARGS__), __FILE__, __LINE__, true); 140 | 141 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/include/data_spec.hpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include "util.hpp" 4 | #include 5 | 6 | using torch::Tensor; 7 | 8 | enum BasisType { 9 | // For svox 1 compatibility 10 | // BASIS_TYPE_RGBA = 0 11 | BASIS_TYPE_SH = 1, 12 | // BASIS_TYPE_SG = 2 13 | // BASIS_TYPE_ASG = 3 14 | BASIS_TYPE_3D_TEXTURE = 4, 15 | BASIS_TYPE_MLP = 255, 16 | }; 17 | 18 | struct SparseGridSpec { 19 | Tensor density_data; 20 | Tensor sh_data; 21 | Tensor links; 22 | Tensor _offset; 23 | Tensor _scaling; 24 | 25 | Tensor background_links; 26 | Tensor background_data; 27 | 28 | int basis_dim; 29 | uint8_t basis_type; 30 | Tensor basis_data; 31 | 32 | inline void check() { 33 | CHECK_INPUT(density_data); 34 | CHECK_INPUT(sh_data); 35 | CHECK_INPUT(links); 36 | if (background_links.defined()) { 37 | CHECK_INPUT(background_links); 38 | CHECK_INPUT(background_data); 39 | TORCH_CHECK(background_links.ndimension() == 40 | 2); // (H, W) -> [N] \cup {-1} 41 | TORCH_CHECK(background_data.ndimension() == 3); // (N, D, C) -> R 42 | } 43 | if (basis_data.defined()) { 44 | CHECK_INPUT(basis_data); 45 | } 46 | CHECK_CPU_INPUT(_offset); 47 | CHECK_CPU_INPUT(_scaling); 48 | TORCH_CHECK(density_data.ndimension() == 2); 49 | TORCH_CHECK(sh_data.ndimension() == 2); 50 | TORCH_CHECK(links.ndimension() == 3); 51 | } 52 | }; 53 | 54 | struct GridOutputGrads { 55 | torch::Tensor grad_density_out; 56 | torch::Tensor grad_sh_out; 57 | torch::Tensor grad_basis_out; 58 | torch::Tensor grad_background_out; 59 | 60 | torch::Tensor mask_out; 61 | torch::Tensor mask_background_out; 62 | inline void check() { 63 | if (grad_density_out.defined()) { 64 | CHECK_INPUT(grad_density_out); 65 | } 66 | if (grad_sh_out.defined()) { 67 | CHECK_INPUT(grad_sh_out); 68 | } 69 | if (grad_basis_out.defined()) { 70 | CHECK_INPUT(grad_basis_out); 71 | } 72 | if (grad_background_out.defined()) { 73 | CHECK_INPUT(grad_background_out); 74 | } 75 | if (mask_out.defined() && mask_out.size(0) > 0) { 76 | CHECK_INPUT(mask_out); 77 | } 78 | if (mask_background_out.defined() && mask_background_out.size(0) > 0) { 79 | CHECK_INPUT(mask_background_out); 80 | } 81 | } 82 | }; 83 | 84 | struct CameraSpec { 85 | torch::Tensor c2w; 86 | float fx; 87 | float fy; 88 | float cx; 89 | float cy; 90 | int width; 91 | int height; 92 | 93 | float ndc_coeffx; 94 | float ndc_coeffy; 95 | 96 | inline void check() { 97 | CHECK_INPUT(c2w); 98 | TORCH_CHECK(c2w.is_floating_point()); 99 | TORCH_CHECK(c2w.ndimension() == 2); 100 | TORCH_CHECK(c2w.size(1) == 4); 101 | } 102 | }; 103 | 104 | struct RaysSpec { 105 | Tensor origins; 106 | Tensor dirs; 107 | inline void check() { 108 | CHECK_INPUT(origins); 109 | CHECK_INPUT(dirs); 110 | TORCH_CHECK(origins.is_floating_point()); 111 | TORCH_CHECK(dirs.is_floating_point()); 112 | } 113 | }; 114 | 115 | struct RenderOptions { 116 | float background_brightness; 117 | // float step_epsilon; 118 | float step_size; 119 | float sigma_thresh; 120 | float stop_thresh; 121 | 122 | float near_clip; 123 | bool use_spheric_clip; 124 | 125 | bool last_sample_opaque; 126 | 127 | // bool randomize; 128 | // float random_sigma_std; 129 | // float random_sigma_std_background; 130 | // 32-bit RNG state masks 131 | // uint32_t _m1, _m2, _m3; 132 | 133 | // int msi_start_layer = 0; 134 | // int msi_end_layer = 66; 135 | }; 136 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/include/data_spec_packed.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include "data_spec.hpp" 5 | #include "cuda_util.cuh" 6 | #include "random_util.cuh" 7 | 8 | namespace { 9 | namespace device { 10 | 11 | struct PackedSparseGridSpec { 12 | PackedSparseGridSpec(SparseGridSpec& spec) 13 | : 14 | density_data(spec.density_data.data_ptr()), 15 | sh_data(spec.sh_data.data_ptr()), 16 | links(spec.links.data_ptr()), 17 | basis_type(spec.basis_type), 18 | basis_data(spec.basis_data.defined() ? spec.basis_data.data_ptr() : nullptr), 19 | background_links(spec.background_links.defined() ? 20 | spec.background_links.data_ptr() : 21 | nullptr), 22 | background_data(spec.background_data.defined() ? 23 | spec.background_data.data_ptr() : 24 | nullptr), 25 | size{(int)spec.links.size(0), 26 | (int)spec.links.size(1), 27 | (int)spec.links.size(2)}, 28 | stride_x{(int)spec.links.stride(0)}, 29 | background_reso{ 30 | spec.background_links.defined() ? (int)spec.background_links.size(1) : 0, 31 | }, 32 | background_nlayers{ 33 | spec.background_data.defined() ? (int)spec.background_data.size(1) : 0 34 | }, 35 | basis_dim(spec.basis_dim), 36 | sh_data_dim((int)spec.sh_data.size(1)), 37 | basis_reso(spec.basis_data.defined() ? spec.basis_data.size(0) : 0), 38 | _offset{spec._offset.data_ptr()[0], 39 | spec._offset.data_ptr()[1], 40 | spec._offset.data_ptr()[2]}, 41 | _scaling{spec._scaling.data_ptr()[0], 42 | spec._scaling.data_ptr()[1], 43 | spec._scaling.data_ptr()[2]} { 44 | } 45 | 46 | float* __restrict__ density_data; 47 | float* __restrict__ sh_data; 48 | const int32_t* __restrict__ links; 49 | 50 | const uint8_t basis_type; 51 | float* __restrict__ basis_data; 52 | 53 | const int32_t* __restrict__ background_links; 54 | float* __restrict__ background_data; 55 | 56 | const int size[3], stride_x; 57 | const int background_reso, background_nlayers; 58 | 59 | const int basis_dim, sh_data_dim, basis_reso; 60 | const float _offset[3]; 61 | const float _scaling[3]; 62 | }; 63 | 64 | struct PackedGridOutputGrads { 65 | PackedGridOutputGrads(GridOutputGrads& grads) : 66 | grad_density_out(grads.grad_density_out.defined() ? grads.grad_density_out.data_ptr() : nullptr), 67 | grad_sh_out(grads.grad_sh_out.defined() ? grads.grad_sh_out.data_ptr() : nullptr), 68 | grad_basis_out(grads.grad_basis_out.defined() ? grads.grad_basis_out.data_ptr() : nullptr), 69 | grad_background_out(grads.grad_background_out.defined() ? grads.grad_background_out.data_ptr() : nullptr), 70 | mask_out((grads.mask_out.defined() && grads.mask_out.size(0) > 0) ? grads.mask_out.data_ptr() : nullptr), 71 | mask_background_out((grads.mask_background_out.defined() && grads.mask_background_out.size(0) > 0) ? grads.mask_background_out.data_ptr() : nullptr) 72 | {} 73 | float* __restrict__ grad_density_out; 74 | float* __restrict__ grad_sh_out; 75 | float* __restrict__ grad_basis_out; 76 | float* __restrict__ grad_background_out; 77 | 78 | bool* __restrict__ mask_out; 79 | bool* __restrict__ mask_background_out; 80 | }; 81 | 82 | struct PackedCameraSpec { 83 | PackedCameraSpec(CameraSpec& cam) : 84 | c2w(cam.c2w.packed_accessor32()), 85 | fx(cam.fx), fy(cam.fy), 86 | cx(cam.cx), cy(cam.cy), 87 | width(cam.width), height(cam.height), 88 | ndc_coeffx(cam.ndc_coeffx), ndc_coeffy(cam.ndc_coeffy) {} 89 | const torch::PackedTensorAccessor32 90 | c2w; 91 | float fx; 92 | float fy; 93 | float cx; 94 | float cy; 95 | int width; 96 | int height; 97 | 98 | float ndc_coeffx; 99 | float ndc_coeffy; 100 | }; 101 | 102 | struct PackedRaysSpec { 103 | const torch::PackedTensorAccessor32 origins; 104 | const torch::PackedTensorAccessor32 dirs; 105 | PackedRaysSpec(RaysSpec& spec) : 106 | origins(spec.origins.packed_accessor32()), 107 | dirs(spec.dirs.packed_accessor32()) 108 | { } 109 | }; 110 | 111 | struct SingleRaySpec { 112 | SingleRaySpec() = default; 113 | __device__ SingleRaySpec(const float* __restrict__ origin, const float* __restrict__ dir) 114 | : origin{origin[0], origin[1], origin[2]}, 115 | dir{dir[0], dir[1], dir[2]} {} 116 | __device__ void set(const float* __restrict__ origin, const float* __restrict__ dir) { 117 | #pragma unroll 3 118 | for (int i = 0; i < 3; ++i) { 119 | this->origin[i] = origin[i]; 120 | this->dir[i] = dir[i]; 121 | } 122 | } 123 | 124 | float origin[3]; 125 | float dir[3]; 126 | float tmin, tmax, world_step; 127 | 128 | float pos[3]; 129 | int32_t l[3]; 130 | RandomEngine32 rng; 131 | }; 132 | 133 | } // namespace device 134 | } // namespace 135 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/include/random_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | 6 | // A custom xorshift random generator 7 | // Maybe replace with some CUDA internal stuff? 8 | struct RandomEngine32 { 9 | uint32_t x, y, z; 10 | 11 | // Inclusive both 12 | __host__ __device__ 13 | uint32_t randint(uint32_t lo, uint32_t hi) { 14 | if (hi <= lo) return lo; 15 | uint32_t z = (*this)(); 16 | return z % (hi - lo + 1) + lo; 17 | } 18 | 19 | __host__ __device__ 20 | void rand2(float* out1, float* out2) { 21 | const uint32_t z = (*this)(); 22 | const uint32_t fmax = (1 << 16); 23 | const uint32_t z1 = z >> 16; 24 | const uint32_t z2 = z & (fmax - 1); 25 | const float ifmax = 1.f / fmax; 26 | 27 | *out1 = z1 * ifmax; 28 | *out2 = z2 * ifmax; 29 | } 30 | 31 | __host__ __device__ 32 | float rand() { 33 | uint32_t z = (*this)(); 34 | return float(z) / (1LL << 32); 35 | } 36 | 37 | 38 | __host__ __device__ 39 | void randn2(float* out1, float* out2) { 40 | rand2(out1, out2); 41 | // Box-Muller transform 42 | const float srlog = sqrtf(-2 * logf(*out1 + 1e-32f)); 43 | *out2 *= 2 * M_PI; 44 | *out1 = srlog * cosf(*out2); 45 | *out2 = srlog * sinf(*out2); 46 | } 47 | 48 | __host__ __device__ 49 | float randn() { 50 | float x, y; 51 | rand2(&x, &y); 52 | // Box-Muller transform 53 | return sqrtf(-2 * logf(x + 1e-32f))* cosf(2 * M_PI * y); 54 | } 55 | 56 | __host__ __device__ 57 | uint32_t operator()() { 58 | uint32_t t; 59 | x ^= x << 16; 60 | x ^= x >> 5; 61 | x ^= x << 1; 62 | t = x; 63 | x = y; 64 | y = z; 65 | z = t ^ x ^ y; 66 | return z; 67 | } 68 | }; 69 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/include/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // Changed from x.type().is_cuda() due to deprecation 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") 5 | #define CHECK_CONTIGUOUS(x) \ 6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) \ 8 | CHECK_CUDA(x); \ 9 | CHECK_CONTIGUOUS(x) 10 | #define CHECK_CPU_INPUT(x) \ 11 | CHECK_CPU(x); \ 12 | CHECK_CONTIGUOUS(x) 13 | 14 | #if defined(__CUDACC__) 15 | // #define _EXP(x) expf(x) // SLOW EXP 16 | #define _EXP(x) __expf(x) // FAST EXP 17 | #define _SIGMOID(x) (1 / (1 + _EXP(-(x)))) 18 | 19 | #else 20 | 21 | #define _EXP(x) expf(x) 22 | #define _SIGMOID(x) (1 / (1 + expf(-(x)))) 23 | #endif 24 | #define _SQR(x) ((x) * (x)) 25 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/optim_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | // Optimizer-related kernels 3 | 4 | #include 5 | #include "cuda_util.cuh" 6 | 7 | namespace { 8 | 9 | const int RMSPROP_STEP_CUDA_THREADS = 256; 10 | const int MIN_BLOCKS_PER_SM = 4; 11 | 12 | namespace device { 13 | 14 | // RMSPROP 15 | __inline__ __device__ void rmsprop_once( 16 | float* __restrict__ ptr_data, 17 | float* __restrict__ ptr_rms, 18 | float* __restrict__ ptr_grad, 19 | const float beta, const float lr, const float epsilon, float minval) { 20 | float rms = *ptr_rms; 21 | rms = rms == 0.f ? _SQR(*ptr_grad) : lerp(_SQR(*ptr_grad), rms, beta); 22 | *ptr_rms = rms; 23 | *ptr_data = fmaxf(*ptr_data - lr * (*ptr_grad) / (sqrtf(rms) + epsilon), minval); 24 | *ptr_grad = 0.f; 25 | } 26 | 27 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 28 | __global__ void rmsprop_step_kernel( 29 | torch::PackedTensorAccessor64 all_data, 30 | torch::PackedTensorAccessor64 all_rms, 31 | torch::PackedTensorAccessor64 all_grad, 32 | float beta, 33 | float lr, 34 | float epsilon, 35 | float minval, 36 | float lr_last) { 37 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 38 | int32_t chnl = tid % all_data.size(1); 39 | rmsprop_once(all_data.data() + tid, 40 | all_rms.data() + tid, 41 | all_grad.data() + tid, 42 | beta, 43 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 44 | epsilon, 45 | minval); 46 | } 47 | 48 | 49 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 50 | __global__ void rmsprop_mask_step_kernel( 51 | torch::PackedTensorAccessor64 all_data, 52 | torch::PackedTensorAccessor64 all_rms, 53 | torch::PackedTensorAccessor64 all_grad, 54 | const bool* __restrict__ mask, 55 | float beta, 56 | float lr, 57 | float epsilon, 58 | float minval, 59 | float lr_last) { 60 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 61 | if (mask[tid / all_data.size(1)] == false) return; 62 | int32_t chnl = tid % all_data.size(1); 63 | rmsprop_once(all_data.data() + tid, 64 | all_rms.data() + tid, 65 | all_grad.data() + tid, 66 | beta, 67 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 68 | epsilon, 69 | minval); 70 | } 71 | 72 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 73 | __global__ void rmsprop_index_step_kernel( 74 | torch::PackedTensorAccessor64 all_data, 75 | torch::PackedTensorAccessor64 all_rms, 76 | torch::PackedTensorAccessor64 all_grad, 77 | torch::PackedTensorAccessor32 indices, 78 | float beta, 79 | float lr, 80 | float epsilon, 81 | float minval, 82 | float lr_last) { 83 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); 84 | int32_t i = indices[tid / all_data.size(1)]; 85 | int32_t chnl = tid % all_data.size(1); 86 | size_t off = i * all_data.size(1) + chnl; 87 | rmsprop_once(all_data.data() + off, all_rms.data() + off, 88 | all_grad.data() + off, 89 | beta, 90 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 91 | epsilon, 92 | minval); 93 | } 94 | 95 | 96 | // SGD 97 | __inline__ __device__ void sgd_once( 98 | float* __restrict__ ptr_data, 99 | float* __restrict__ ptr_grad, 100 | const float lr) { 101 | *ptr_data -= lr * (*ptr_grad); 102 | *ptr_grad = 0.f; 103 | } 104 | 105 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 106 | __global__ void sgd_step_kernel( 107 | torch::PackedTensorAccessor64 all_data, 108 | torch::PackedTensorAccessor64 all_grad, 109 | float lr, 110 | float lr_last) { 111 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 112 | int32_t chnl = tid % all_data.size(1); 113 | sgd_once(all_data.data() + tid, 114 | all_grad.data() + tid, 115 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 116 | } 117 | 118 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 119 | __global__ void sgd_mask_step_kernel( 120 | torch::PackedTensorAccessor64 all_data, 121 | torch::PackedTensorAccessor64 all_grad, 122 | const bool* __restrict__ mask, 123 | float lr, 124 | float lr_last) { 125 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 126 | if (mask[tid / all_data.size(1)] == false) return; 127 | int32_t chnl = tid % all_data.size(1); 128 | sgd_once(all_data.data() + tid, 129 | all_grad.data() + tid, 130 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 131 | } 132 | 133 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 134 | __global__ void sgd_index_step_kernel( 135 | torch::PackedTensorAccessor64 all_data, 136 | torch::PackedTensorAccessor64 all_grad, 137 | torch::PackedTensorAccessor32 indices, 138 | float lr, 139 | float lr_last) { 140 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); 141 | int32_t i = indices[tid / all_data.size(1)]; 142 | int32_t chnl = tid % all_data.size(1); 143 | size_t off = i * all_data.size(1) + chnl; 144 | sgd_once(all_data.data() + off, 145 | all_grad.data() + off, 146 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 147 | } 148 | 149 | 150 | 151 | } // namespace device 152 | } // namespace 153 | 154 | void rmsprop_step( 155 | torch::Tensor data, 156 | torch::Tensor rms, 157 | torch::Tensor grad, 158 | torch::Tensor indexer, 159 | float beta, 160 | float lr, 161 | float epsilon, 162 | float minval, 163 | float lr_last) { 164 | 165 | DEVICE_GUARD(data); 166 | CHECK_INPUT(data); 167 | CHECK_INPUT(rms); 168 | CHECK_INPUT(grad); 169 | CHECK_INPUT(indexer); 170 | 171 | if (lr_last < 0.f) lr_last = lr; 172 | 173 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS; 174 | 175 | if (indexer.dim() == 0) { 176 | const size_t Q = data.size(0) * data.size(1); 177 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 178 | device::rmsprop_step_kernel<<>>( 179 | data.packed_accessor64(), 180 | rms.packed_accessor64(), 181 | grad.packed_accessor64(), 182 | beta, 183 | lr, 184 | epsilon, 185 | minval, 186 | lr_last); 187 | } else if (indexer.size(0) == 0) { 188 | // Skip 189 | } else if (indexer.scalar_type() == at::ScalarType::Bool) { 190 | const size_t Q = data.size(0) * data.size(1); 191 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 192 | device::rmsprop_mask_step_kernel<<>>( 193 | data.packed_accessor64(), 194 | rms.packed_accessor64(), 195 | grad.packed_accessor64(), 196 | indexer.data_ptr(), 197 | beta, 198 | lr, 199 | epsilon, 200 | minval, 201 | lr_last); 202 | } else { 203 | const size_t Q = indexer.size(0) * data.size(1); 204 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 205 | device::rmsprop_index_step_kernel<<>>( 206 | data.packed_accessor64(), 207 | rms.packed_accessor64(), 208 | grad.packed_accessor64(), 209 | indexer.packed_accessor32(), 210 | beta, 211 | lr, 212 | epsilon, 213 | minval, 214 | lr_last); 215 | } 216 | 217 | CUDA_CHECK_ERRORS; 218 | } 219 | 220 | void sgd_step( 221 | torch::Tensor data, 222 | torch::Tensor grad, 223 | torch::Tensor indexer, 224 | float lr, 225 | float lr_last) { 226 | 227 | DEVICE_GUARD(data); 228 | CHECK_INPUT(data); 229 | CHECK_INPUT(grad); 230 | CHECK_INPUT(indexer); 231 | 232 | if (lr_last < 0.f) lr_last = lr; 233 | 234 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS; 235 | 236 | if (indexer.dim() == 0) { 237 | const size_t Q = data.size(0) * data.size(1); 238 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 239 | device::sgd_step_kernel<<>>( 240 | data.packed_accessor64(), 241 | grad.packed_accessor64(), 242 | lr, 243 | lr_last); 244 | } else if (indexer.size(0) == 0) { 245 | // Skip 246 | } else if (indexer.scalar_type() == at::ScalarType::Bool) { 247 | const size_t Q = data.size(0) * data.size(1); 248 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 249 | device::sgd_mask_step_kernel<<>>( 250 | data.packed_accessor64(), 251 | grad.packed_accessor64(), 252 | indexer.data_ptr(), 253 | lr, 254 | lr_last); 255 | } else { 256 | const size_t Q = indexer.size(0) * data.size(1); 257 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 258 | device::sgd_index_step_kernel<<>>( 259 | data.packed_accessor64(), 260 | grad.packed_accessor64(), 261 | indexer.packed_accessor32(), 262 | lr, 263 | lr_last); 264 | } 265 | 266 | CUDA_CHECK_ERRORS; 267 | } 268 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/svox2.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | 3 | // This file contains only Python bindings 4 | #include "data_spec.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | using torch::Tensor; 10 | 11 | std::tuple sample_grid(SparseGridSpec &, Tensor, 12 | bool); 13 | void sample_grid_backward(SparseGridSpec &, Tensor, Tensor, Tensor, Tensor, 14 | Tensor, bool); 15 | 16 | // ** NeRF rendering formula (trilerp) 17 | Tensor volume_render_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 18 | Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &, 19 | RenderOptions &); 20 | void volume_render_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 21 | Tensor, Tensor, GridOutputGrads &); 22 | void volume_render_cuvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 23 | Tensor, float, float, Tensor, GridOutputGrads &); 24 | // Expected termination (depth) rendering 25 | torch::Tensor volume_render_expected_term(SparseGridSpec &, RaysSpec &, 26 | RenderOptions &); 27 | // Depth rendering based on sigma-threshold as in Dex-NeRF 28 | torch::Tensor volume_render_sigma_thresh(SparseGridSpec &, RaysSpec &, 29 | RenderOptions &, float); 30 | 31 | // ** NV rendering formula (trilerp) 32 | Tensor volume_render_nvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 33 | void volume_render_nvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 34 | Tensor, Tensor, GridOutputGrads &); 35 | void volume_render_nvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 36 | Tensor, float, float, Tensor, GridOutputGrads &); 37 | 38 | // ** NeRF rendering formula (nearest-neighbor, infinitely many steps) 39 | Tensor volume_render_svox1(SparseGridSpec &, RaysSpec &, RenderOptions &); 40 | void volume_render_svox1_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 41 | Tensor, Tensor, GridOutputGrads &); 42 | void volume_render_svox1_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 43 | Tensor, float, float, Tensor, GridOutputGrads &); 44 | 45 | // Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &, 46 | // RenderOptions &); 47 | // 48 | // void volume_render_cuvol_image_backward(SparseGridSpec &, CameraSpec &, 49 | // RenderOptions &, Tensor, Tensor, 50 | // GridOutputGrads &); 51 | 52 | // Misc 53 | Tensor dilate(Tensor); 54 | void accel_dist_prop(Tensor); 55 | void grid_weight_render(Tensor, CameraSpec &, float, float, bool, Tensor, 56 | Tensor, Tensor); 57 | // void sample_cubemap(Tensor, Tensor, bool, Tensor); 58 | 59 | // Loss 60 | Tensor tv(Tensor, Tensor, int, int, bool, float, bool, float, float); 61 | void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float, 62 | Tensor); 63 | void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool, 64 | float, bool, bool, float, float, Tensor); 65 | void msi_tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, float, float, Tensor); 66 | void lumisphere_tv_grad_sparse(SparseGridSpec &, Tensor, Tensor, Tensor, float, 67 | float, float, float, GridOutputGrads &); 68 | 69 | // Optim 70 | void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float, 71 | float); 72 | void sgd_step(Tensor, Tensor, Tensor, float, float); 73 | 74 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 75 | #define _REG_FUNC(funname) m.def(#funname, &funname) 76 | _REG_FUNC(sample_grid); 77 | _REG_FUNC(sample_grid_backward); 78 | _REG_FUNC(volume_render_cuvol); 79 | _REG_FUNC(volume_render_cuvol_image); 80 | _REG_FUNC(volume_render_cuvol_backward); 81 | _REG_FUNC(volume_render_cuvol_fused); 82 | _REG_FUNC(volume_render_expected_term); 83 | _REG_FUNC(volume_render_sigma_thresh); 84 | 85 | _REG_FUNC(volume_render_nvol); 86 | _REG_FUNC(volume_render_nvol_backward); 87 | _REG_FUNC(volume_render_nvol_fused); 88 | 89 | _REG_FUNC(volume_render_svox1); 90 | _REG_FUNC(volume_render_svox1_backward); 91 | _REG_FUNC(volume_render_svox1_fused); 92 | 93 | // _REG_FUNC(volume_render_cuvol_image); 94 | // _REG_FUNC(volume_render_cuvol_image_backward); 95 | 96 | // Loss 97 | _REG_FUNC(tv); 98 | _REG_FUNC(tv_grad); 99 | _REG_FUNC(tv_grad_sparse); 100 | _REG_FUNC(msi_tv_grad_sparse); 101 | _REG_FUNC(lumisphere_tv_grad_sparse); 102 | 103 | // Misc 104 | _REG_FUNC(dilate); 105 | _REG_FUNC(accel_dist_prop); 106 | _REG_FUNC(grid_weight_render); 107 | // _REG_FUNC(sample_cubemap); 108 | 109 | // Optimizer 110 | _REG_FUNC(rmsprop_step); 111 | _REG_FUNC(sgd_step); 112 | #undef _REG_FUNC 113 | 114 | py::class_(m, "SparseGridSpec") 115 | .def(py::init<>()) 116 | .def_readwrite("density_data", &SparseGridSpec::density_data) 117 | .def_readwrite("sh_data", &SparseGridSpec::sh_data) 118 | .def_readwrite("links", &SparseGridSpec::links) 119 | .def_readwrite("_offset", &SparseGridSpec::_offset) 120 | .def_readwrite("_scaling", &SparseGridSpec::_scaling) 121 | .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) 122 | .def_readwrite("basis_type", &SparseGridSpec::basis_type) 123 | .def_readwrite("basis_data", &SparseGridSpec::basis_data) 124 | .def_readwrite("background_links", &SparseGridSpec::background_links) 125 | .def_readwrite("background_data", &SparseGridSpec::background_data); 126 | 127 | py::class_(m, "CameraSpec") 128 | .def(py::init<>()) 129 | .def_readwrite("c2w", &CameraSpec::c2w) 130 | .def_readwrite("fx", &CameraSpec::fx) 131 | .def_readwrite("fy", &CameraSpec::fy) 132 | .def_readwrite("cx", &CameraSpec::cx) 133 | .def_readwrite("cy", &CameraSpec::cy) 134 | .def_readwrite("width", &CameraSpec::width) 135 | .def_readwrite("height", &CameraSpec::height) 136 | .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx) 137 | .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy); 138 | 139 | py::class_(m, "RaysSpec") 140 | .def(py::init<>()) 141 | .def_readwrite("origins", &RaysSpec::origins) 142 | .def_readwrite("dirs", &RaysSpec::dirs); 143 | 144 | py::class_(m, "RenderOptions") 145 | .def(py::init<>()) 146 | .def_readwrite("background_brightness", 147 | &RenderOptions::background_brightness) 148 | .def_readwrite("step_size", &RenderOptions::step_size) 149 | .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh) 150 | .def_readwrite("stop_thresh", &RenderOptions::stop_thresh) 151 | .def_readwrite("near_clip", &RenderOptions::near_clip) 152 | .def_readwrite("use_spheric_clip", &RenderOptions::use_spheric_clip) 153 | .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque); 154 | // .def_readwrite("randomize", &RenderOptions::randomize) 155 | // .def_readwrite("random_sigma_std", &RenderOptions::random_sigma_std) 156 | // .def_readwrite("random_sigma_std_background", 157 | // &RenderOptions::random_sigma_std_background) 158 | // .def_readwrite("_m1", &RenderOptions::_m1) 159 | // .def_readwrite("_m2", &RenderOptions::_m2) 160 | // .def_readwrite("_m3", &RenderOptions::_m3); 161 | 162 | py::class_(m, "GridOutputGrads") 163 | .def(py::init<>()) 164 | .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out) 165 | .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out) 166 | .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out) 167 | .def_readwrite("grad_background_out", 168 | &GridOutputGrads::grad_background_out) 169 | .def_readwrite("mask_out", &GridOutputGrads::mask_out) 170 | .def_readwrite("mask_background_out", 171 | &GridOutputGrads::mask_background_out); 172 | } 173 | -------------------------------------------------------------------------------- /Svox2/svox2/csrc/svox2_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #include 3 | #include 4 | #include "cuda_util.cuh" 5 | #include "data_spec_packed.cuh" 6 | 7 | namespace { 8 | namespace device { 9 | 10 | __global__ void sample_grid_sh_kernel( 11 | PackedSparseGridSpec grid, 12 | const torch::PackedTensorAccessor32 points, 13 | // Output 14 | torch::PackedTensorAccessor32 out) { 15 | CUDA_GET_THREAD_ID(tid, points.size(0) * grid.sh_data_dim); 16 | const int idx = tid % grid.sh_data_dim; 17 | const int pid = tid / grid.sh_data_dim; 18 | 19 | float point[3] = {points[pid][0], points[pid][1], points[pid][2]}; 20 | transform_coord(point, grid._scaling, grid._offset); 21 | 22 | int32_t l[3]; 23 | #pragma unroll 3 24 | for (int i = 0; i < 3; ++i) { 25 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 26 | l[i] = min((int32_t)point[i], (int32_t)(grid.size[i] - 2)); 27 | point[i] -= l[i]; 28 | } 29 | 30 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 31 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 32 | 33 | #define MAYBE_READ_LINK(u) ((link_ptr[u] >= 0) ? grid.sh_data[ \ 34 | link_ptr[u] * size_t(grid.sh_data_dim) + idx] : 0.f) 35 | 36 | const float ix0y0 = lerp(MAYBE_READ_LINK(0), MAYBE_READ_LINK(1), point[2]); 37 | const float ix0y1 = lerp(MAYBE_READ_LINK(offy), MAYBE_READ_LINK(offy + 1), point[2]); 38 | const float ix0 = lerp(ix0y0, ix0y1, point[1]); 39 | const float ix1y0 = lerp(MAYBE_READ_LINK(offx), MAYBE_READ_LINK(offx + 1), point[2]); 40 | const float ix1y1 = lerp(MAYBE_READ_LINK(offy + offx), 41 | MAYBE_READ_LINK(offy + offx + 1), point[2]); 42 | const float ix1 = lerp(ix1y0, ix1y1, point[1]); 43 | out[pid][idx] = lerp(ix0, ix1, point[0]); 44 | } 45 | #undef MAYBE_READ_LINK 46 | 47 | __global__ void sample_grid_density_kernel( 48 | PackedSparseGridSpec grid, 49 | const torch::PackedTensorAccessor32 points, 50 | // Output 51 | torch::PackedTensorAccessor32 out) { 52 | CUDA_GET_THREAD_ID(tid, points.size(0)); 53 | 54 | float point[3] = {points[tid][0], points[tid][1], points[tid][2]}; 55 | transform_coord(point, grid._scaling, grid._offset); 56 | 57 | int32_t l[3]; 58 | #pragma unroll 3 59 | for (int i = 0; i < 3; ++i) { 60 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 61 | l[i] = min((int32_t)point[i], grid.size[i] - 2); 62 | point[i] -= l[i]; 63 | } 64 | 65 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 66 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 67 | 68 | #define MAYBE_READ_LINK_D(u) ((link_ptr[u] >= 0) ? grid.density_data[link_ptr[u]] : 0.f) 69 | 70 | const float ix0y0 = lerp(MAYBE_READ_LINK_D(0), MAYBE_READ_LINK_D(1), point[2]); 71 | const float ix0y1 = lerp(MAYBE_READ_LINK_D(offy), MAYBE_READ_LINK_D(offy + 1), point[2]); 72 | const float ix0 = lerp(ix0y0, ix0y1, point[1]); 73 | const float ix1y0 = lerp(MAYBE_READ_LINK_D(offx), MAYBE_READ_LINK_D(offx + 1), point[2]); 74 | const float ix1y1 = lerp(MAYBE_READ_LINK_D(offy + offx), 75 | MAYBE_READ_LINK_D(offy + offx + 1), point[2]); 76 | const float ix1 = lerp(ix1y0, ix1y1, point[1]); 77 | out[tid][0] = lerp(ix0, ix1, point[0]); 78 | } 79 | #undef MAYBE_READ_LINK_D 80 | 81 | __global__ void sample_grid_sh_backward_kernel( 82 | PackedSparseGridSpec grid, 83 | const torch::PackedTensorAccessor32 points, 84 | const torch::PackedTensorAccessor32 grad_out, 85 | // Output 86 | torch::PackedTensorAccessor64 grad_data) { 87 | CUDA_GET_THREAD_ID(tid, points.size(0) * grid.sh_data_dim); 88 | const int idx = tid % grid.sh_data_dim; 89 | const int pid = tid / grid.sh_data_dim; 90 | 91 | float point[3] = {points[pid][0], points[pid][1], points[pid][2]}; 92 | transform_coord(point, grid._scaling, grid._offset); 93 | 94 | int32_t l[3]; 95 | #pragma unroll 3 96 | for (int i = 0; i < 3; ++i) { 97 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 98 | l[i] = min((int32_t)point[i], grid.size[i] - 2); 99 | point[i] -= l[i]; 100 | } 101 | 102 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 103 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 104 | 105 | const float go = grad_out[pid][idx]; 106 | 107 | const float xb = point[0], yb = point[1], zb = point[2]; 108 | const float xa = 1.f - point[0], ya = 1.f - point[1], za = 1.f - point[2]; 109 | 110 | #define MAYBE_ADD_GRAD_LINK_PTR(u, content) if (link_ptr[u] >= 0) \ 111 | atomicAdd(&grad_data[link_ptr[u]][idx], content) 112 | 113 | const float xago = xa * go; 114 | float tmp = ya * xago; 115 | MAYBE_ADD_GRAD_LINK_PTR(0, tmp * za); 116 | MAYBE_ADD_GRAD_LINK_PTR(1, tmp * zb); 117 | tmp = yb * xago; 118 | MAYBE_ADD_GRAD_LINK_PTR(offy, tmp * za); 119 | MAYBE_ADD_GRAD_LINK_PTR(offy + 1, tmp * zb); 120 | 121 | const float xbgo = xb * go; 122 | tmp = ya * xbgo; 123 | MAYBE_ADD_GRAD_LINK_PTR(offx, tmp * za); 124 | MAYBE_ADD_GRAD_LINK_PTR(offx + 1, tmp * zb); 125 | tmp = yb * xbgo; 126 | MAYBE_ADD_GRAD_LINK_PTR(offx + offy, tmp * za); 127 | MAYBE_ADD_GRAD_LINK_PTR(offx + offy + 1, tmp * zb); 128 | } 129 | #undef MAYBE_ADD_GRAD_LINK_PTR 130 | 131 | __global__ void sample_grid_density_backward_kernel( 132 | PackedSparseGridSpec grid, 133 | const torch::PackedTensorAccessor32 points, 134 | const torch::PackedTensorAccessor32 grad_out, 135 | // Output 136 | torch::PackedTensorAccessor32 grad_data) { 137 | CUDA_GET_THREAD_ID(tid, points.size(0)); 138 | 139 | float point[3] = {points[tid][0], points[tid][1], points[tid][2]}; 140 | transform_coord(point, grid._scaling, grid._offset); 141 | 142 | int32_t l[3]; 143 | #pragma unroll 3 144 | for (int i = 0; i < 3; ++i) { 145 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 146 | l[i] = min((int32_t)point[i], grid.size[i] - 2); 147 | point[i] -= l[i]; 148 | } 149 | 150 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 151 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 152 | 153 | const float go = grad_out[tid][0]; 154 | 155 | const float xb = point[0], yb = point[1], zb = point[2]; 156 | const float xa = 1.f - point[0], ya = 1.f - point[1], za = 1.f - point[2]; 157 | 158 | #define MAYBE_ADD_GRAD_LINK_PTR_D(u, content) if (link_ptr[u] >= 0) \ 159 | atomicAdd(grad_data[link_ptr[u]].data(), content) 160 | 161 | const float xago = xa * go; 162 | float tmp = ya * xago; 163 | MAYBE_ADD_GRAD_LINK_PTR_D(0, tmp * za); 164 | MAYBE_ADD_GRAD_LINK_PTR_D(1, tmp * zb); 165 | tmp = yb * xago; 166 | MAYBE_ADD_GRAD_LINK_PTR_D(offy, tmp * za); 167 | MAYBE_ADD_GRAD_LINK_PTR_D(offy + 1, tmp * zb); 168 | 169 | const float xbgo = xb * go; 170 | tmp = ya * xbgo; 171 | MAYBE_ADD_GRAD_LINK_PTR_D(offx, tmp * za); 172 | MAYBE_ADD_GRAD_LINK_PTR_D(offx + 1, tmp * zb); 173 | tmp = yb * xbgo; 174 | MAYBE_ADD_GRAD_LINK_PTR_D(offx + offy, tmp * za); 175 | MAYBE_ADD_GRAD_LINK_PTR_D(offx + offy + 1, tmp * zb); 176 | } 177 | } // namespace device 178 | } // namespace 179 | 180 | 181 | std::tuple sample_grid(SparseGridSpec& grid, torch::Tensor points, 182 | bool want_colors) { 183 | DEVICE_GUARD(points); 184 | grid.check(); 185 | CHECK_INPUT(points); 186 | TORCH_CHECK(points.ndimension() == 2); 187 | const auto Q = points.size(0) * grid.sh_data.size(1); 188 | const int cuda_n_threads = std::min(Q, CUDA_MAX_THREADS); 189 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 190 | const int blocks_density = CUDA_N_BLOCKS_NEEDED(points.size(0), cuda_n_threads); 191 | torch::Tensor result_density = torch::empty({points.size(0), 192 | grid.density_data.size(1)}, points.options()); 193 | torch::Tensor result_sh = torch::empty({want_colors ? points.size(0) : 0, 194 | grid.sh_data.size(1)}, points.options()); 195 | 196 | cudaStream_t stream_1, stream_2; 197 | cudaStreamCreate(&stream_1); 198 | cudaStreamCreate(&stream_2); 199 | 200 | device::sample_grid_density_kernel<<>>( 201 | grid, 202 | points.packed_accessor32(), 203 | // Output 204 | result_density.packed_accessor32()); 205 | if (want_colors) { 206 | device::sample_grid_sh_kernel<<>>( 207 | grid, 208 | points.packed_accessor32(), 209 | // Output 210 | result_sh.packed_accessor32()); 211 | } 212 | 213 | cudaStreamSynchronize(stream_1); 214 | cudaStreamSynchronize(stream_2); 215 | CUDA_CHECK_ERRORS; 216 | return std::tuple{result_density, result_sh}; 217 | } 218 | 219 | void sample_grid_backward( 220 | SparseGridSpec& grid, 221 | torch::Tensor points, 222 | torch::Tensor grad_out_density, 223 | torch::Tensor grad_out_sh, 224 | torch::Tensor grad_density_out, 225 | torch::Tensor grad_sh_out, 226 | bool want_colors) { 227 | DEVICE_GUARD(points); 228 | grid.check(); 229 | CHECK_INPUT(points); 230 | CHECK_INPUT(grad_out_density); 231 | CHECK_INPUT(grad_out_sh); 232 | CHECK_INPUT(grad_density_out); 233 | CHECK_INPUT(grad_sh_out); 234 | TORCH_CHECK(points.ndimension() == 2); 235 | TORCH_CHECK(grad_out_density.ndimension() == 2); 236 | TORCH_CHECK(grad_out_sh.ndimension() == 2); 237 | const auto Q = points.size(0) * grid.sh_data.size(1); 238 | 239 | const int cuda_n_threads = std::min(Q, CUDA_MAX_THREADS); 240 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 241 | const int blocks_density = CUDA_N_BLOCKS_NEEDED(points.size(0), cuda_n_threads); 242 | 243 | cudaStream_t stream_1, stream_2; 244 | cudaStreamCreate(&stream_1); 245 | cudaStreamCreate(&stream_2); 246 | 247 | device::sample_grid_density_backward_kernel<<>>( 248 | grid, 249 | points.packed_accessor32(), 250 | grad_out_density.packed_accessor32(), 251 | // Output 252 | grad_density_out.packed_accessor32()); 253 | 254 | if (want_colors) { 255 | device::sample_grid_sh_backward_kernel<<>>( 256 | grid, 257 | points.packed_accessor32(), 258 | grad_out_sh.packed_accessor32(), 259 | // Output 260 | grad_sh_out.packed_accessor64()); 261 | } 262 | 263 | cudaStreamSynchronize(stream_1); 264 | cudaStreamSynchronize(stream_2); 265 | 266 | CUDA_CHECK_ERRORS; 267 | } 268 | -------------------------------------------------------------------------------- /Svox2/svox2/defs.py: -------------------------------------------------------------------------------- 1 | # Basis types (copied from C++ data_spec.hpp) 2 | BASIS_TYPE_SH = 1 3 | BASIS_TYPE_3D_TEXTURE = 4 4 | BASIS_TYPE_MLP = 255 5 | -------------------------------------------------------------------------------- /Svox2/svox2/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1.dev0+sphtexcub.lincolor.fast' 2 | -------------------------------------------------------------------------------- /Svox2/test/prof.py: -------------------------------------------------------------------------------- 1 | # nvprof -f --profile-from-start off --quiet --metrics all --events all -o prof.nvvp python prof.py 2 | # then use nvvp to open prof.nvvp 3 | import svox2 4 | import torch 5 | import numpy as np 6 | from util import Timing 7 | from matplotlib import pyplot as plt 8 | 9 | import torch.cuda.profiler as profiler 10 | import pyprof 11 | 12 | device='cuda:0' 13 | 14 | GRID_FILE = 'lego.npy' 15 | grid = svox2.SparseGrid(reso=256, device='cpu', radius=1.3256) 16 | data = torch.from_numpy(np.load(GRID_FILE)).view(-1, grid.data_dim) 17 | grid.sh_data.data = data[..., 1:] 18 | grid.density_data.data = data[..., :1] 19 | grid = grid.cuda() 20 | # grid.data.data[..., 0] += 0.1 21 | 22 | N_RAYS = 5000 23 | # origins = torch.full((N_RAYS, 3), fill_value=0.0, device=device, dtype=dtype) 24 | origins = torch.zeros((N_RAYS, 3), device=device, dtype=torch.float32) 25 | dirs : torch.Tensor = torch.randn((N_RAYS, 3), device=device, dtype=torch.float32) 26 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 27 | rays = svox2.Rays(origins, dirs) 28 | 29 | grid.requires_grad_(True) 30 | 31 | samps = grid.volume_render(rays, use_kernel=True) 32 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 33 | 34 | pyprof.init() 35 | with torch.autograd.profiler.emit_nvtx(): 36 | profiler.start() 37 | samps = grid.volume_render(rays, use_kernel=True) 38 | s = samps.sum() 39 | s.backward() 40 | profiler.stop() 41 | -------------------------------------------------------------------------------- /Svox2/test/sanity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import svox2 3 | 4 | device = 'cuda:0' 5 | 6 | 7 | torch.random.manual_seed(4000) 8 | g = svox2.SparseGrid(center=[0.0, 0.0, 0.0], 9 | radius=[1.0, 1.0, 1.0], 10 | device=device, 11 | basis_type=svox2.BASIS_TYPE_SH, 12 | background_nlayers=0) 13 | 14 | g.opt.backend = 'nvol' 15 | g.opt.sigma_thresh = 0.0 16 | g.opt.stop_thresh = 0.0 17 | g.opt.background_brightness = 1.0 18 | 19 | g.sh_data.data.normal_() 20 | g.density_data.data[..., 0] = 0.1 21 | g.sh_data.data[..., 0] = 0.5 22 | g.sh_data.data[..., 1:] = torch.randn_like(g.sh_data.data[..., 1:]) * 0.01 23 | 24 | if g.use_background: 25 | g.background_data.data[..., -1] = 1.0 26 | g.background_data.data[..., :-1] = torch.randn_like( 27 | g.background_data.data[..., :-1]) * 0.01 28 | # g.background_data.data[..., :-1] = 0.5 29 | 30 | g.basis_data.data.normal_() 31 | g.basis_data.data *= 10.0 32 | # print('use frustum?', g.use_frustum) 33 | 34 | N_RAYS = 1 35 | 36 | # origins = torch.randn(N_RAYS, 3, device=device) * 3 37 | # dirs = torch.randn(N_RAYS, 3, device=device) 38 | # origins = origins[27513:27514] 39 | # dirs = dirs[27513:27514] 40 | 41 | origins = torch.tensor([[-3.8992738723754883, 4.844727993011475, 4.323856830596924]], device='cuda:0') 42 | dirs = torch.tensor([[1.1424630880355835, -1.2679963111877441, -0.8437137603759766]], device='cuda:0') 43 | dirs = dirs / torch.norm(dirs, dim=-1).unsqueeze(-1) 44 | 45 | rays = svox2.Rays(origins=origins, dirs=dirs) 46 | 47 | rgb = g.volume_render(rays, use_kernel=True) 48 | torch.cuda.synchronize() 49 | rgb_gt = g.volume_render(rays, use_kernel=False) 50 | torch.cuda.synchronize() 51 | 52 | E = torch.abs(rgb - rgb_gt) 53 | err = E.max().detach().item() 54 | print(err) 55 | -------------------------------------------------------------------------------- /Svox2/test/test_render_gradcheck.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import torch.nn.functional as F 4 | from util import Timing 5 | 6 | torch.random.manual_seed(2) 7 | # torch.random.manual_seed(8289) 8 | 9 | device = 'cuda:0' 10 | dtype = torch.float32 11 | grid = svox2.SparseGrid( 12 | reso=128, 13 | center=[0.0, 0.0, 0.0], 14 | radius=[1.0, 1.0, 1.0], 15 | basis_dim=9, 16 | use_z_order=True, 17 | device=device, 18 | background_nlayers=0, 19 | basis_type=svox2.BASIS_TYPE_SH) 20 | grid.opt.backend = 'nvol' 21 | grid.opt.sigma_thresh = 0.0 22 | grid.opt.stop_thresh = 0.0 23 | grid.opt.background_brightness = 1.0 24 | 25 | print(grid.sh_data.shape) 26 | # grid.sh_data.data.normal_() 27 | grid.sh_data.data[..., 0] = 0.5 28 | grid.sh_data.data[..., 1:].normal_(std=0.1) 29 | grid.density_data.data[:] = 100.0 30 | 31 | if grid.use_background: 32 | grid.background_data.data[..., -1] = 0.5 33 | grid.background_data.data[..., :-1] = torch.randn_like( 34 | grid.background_data.data[..., :-1]) * 0.01 35 | 36 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 37 | grid.basis_data.data.normal_() 38 | grid.basis_data.data += 1.0 39 | 40 | ENABLE_TORCH_CHECK = True 41 | # N_RAYS = 5000 #200 * 200 42 | N_RAYS = 200 * 200 43 | origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3 44 | dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 45 | # origins = torch.clip(origins, -0.8, 0.8) 46 | 47 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype) 48 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype) 49 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 50 | 51 | # start = 71 52 | # end = 72 53 | # origins = origins[start:end] 54 | # dirs = dirs[start:end] 55 | # print(origins.tolist(), dirs.tolist()) 56 | 57 | # breakpoint() 58 | rays = svox2.Rays(origins, dirs) 59 | 60 | rgb_gt = torch.zeros((origins.size(0), 3), device=device, dtype=dtype) 61 | 62 | # grid.requires_grad_(True) 63 | 64 | # samps = grid.volume_render(rays, use_kernel=True) 65 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 66 | 67 | with Timing("ours"): 68 | samps = grid.volume_render(rays, use_kernel=True) 69 | s = F.mse_loss(samps, rgb_gt) 70 | 71 | print(s) 72 | print('bkwd..') 73 | with Timing("ours_backward"): 74 | s.backward() 75 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu() 76 | grid_density_grad_s = grid.density_data.grad.clone().cpu() 77 | grid.sh_data.grad = None 78 | grid.density_data.grad = None 79 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 80 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu() 81 | grid.basis_data.grad = None 82 | if grid.use_background: 83 | grid_bg_grad_s = grid.background_data.grad.clone().cpu() 84 | grid.background_data.grad = None 85 | 86 | if ENABLE_TORCH_CHECK: 87 | with Timing("torch"): 88 | sampt = grid.volume_render(rays, use_kernel=False) 89 | s = F.mse_loss(sampt, rgb_gt) 90 | with Timing("torch_backward"): 91 | s.backward() 92 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s) 93 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s) 94 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 95 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu() 96 | if grid.use_background: 97 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s) 98 | 99 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t) 100 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t) 101 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 102 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t) 103 | if grid.use_background: 104 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t) 105 | print('err', torch.abs(samps - sampt).max()) 106 | print('err_sh_grad\n', E.max()) 107 | print(' mean\n', E.mean()) 108 | print('err_density_grad\n', Ed.max()) 109 | print(' mean\n', Ed.mean()) 110 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 111 | print('err_basis_grad\n', Eb.max()) 112 | print(' mean\n', Eb.mean()) 113 | if grid.use_background: 114 | print('err_background_grad\n', Ebg.max()) 115 | print(' mean\n', Ebg.mean()) 116 | print() 117 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max()) 118 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max()) 119 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max()) 120 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max()) 121 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 122 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max()) 123 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max()) 124 | if grid.use_background: 125 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max()) 126 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max()) 127 | -------------------------------------------------------------------------------- /Svox2/test/test_render_timing.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | from util import Timing 4 | 5 | torch.random.manual_seed(0) 6 | 7 | device = 'cuda:0' 8 | dtype = torch.float32 9 | grid = svox2.SparseGrid( 10 | reso=256, 11 | center=[0.0, 0.0, 0.0], 12 | radius=[1.0, 1.0, 1.0], 13 | basis_dim=9, 14 | use_z_order=True, 15 | device=device) 16 | grid.opt.sigma_thresh = 0.0 17 | grid.opt.stop_thresh = 0.0 18 | 19 | grid.sh_data.data.normal_() 20 | grid.density_data.data[:] = 0.1 21 | 22 | N_RAYS = 200 * 200 23 | # origins = torch.full((N_RAYS, 3), fill_value=0.0, device=device, dtype=dtype) 24 | origins = torch.zeros((N_RAYS, 3), device=device, dtype=dtype) 25 | dirs : torch.Tensor = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 26 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 27 | rays = svox2.Rays(origins, dirs) 28 | 29 | grid.requires_grad_(True) 30 | 31 | samps = grid.volume_render(rays, use_kernel=True) 32 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 33 | 34 | with Timing("ours"): 35 | samps = grid.volume_render(rays, use_kernel=True) 36 | s = samps.sum() 37 | with Timing("ours_backward"): 38 | s.backward() 39 | -------------------------------------------------------------------------------- /Svox2/test/test_render_timing_smallbat.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | from util import Timing 4 | 5 | torch.random.manual_seed(0) 6 | 7 | device = 'cuda:0' 8 | dtype = torch.float32 9 | grid = svox2.SparseGrid( 10 | reso=256, 11 | center=[0.0, 0.0, 0.0], 12 | radius=[1.0, 1.0, 1.0], 13 | basis_dim=9, 14 | use_z_order=True, 15 | device=device) 16 | grid.opt.sigma_thresh = 0.0 17 | grid.opt.stop_thresh = 0.0 18 | 19 | grid.sh_data.data.normal_() 20 | grid.density_data.data[:] = 0.1 21 | 22 | N_RAYS = 5000 23 | # origins = torch.full((N_RAYS, 3), fill_value=0.0, device=device, dtype=dtype) 24 | origins = torch.zeros((N_RAYS, 3), device=device, dtype=dtype) 25 | dirs : torch.Tensor = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 26 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 27 | rays = svox2.Rays(origins, dirs) 28 | 29 | grid.requires_grad_(True) 30 | 31 | samps = grid.volume_render(rays, use_kernel=True) 32 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 33 | 34 | with Timing("ours"): 35 | samps = grid.volume_render(rays, use_kernel=True) 36 | s = samps.sum() 37 | with Timing("ours_backward"): 38 | s.backward() 39 | -------------------------------------------------------------------------------- /Svox2/test/test_render_visual.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import numpy as np 4 | from util import Timing 5 | from matplotlib import pyplot as plt 6 | device='cuda:0' 7 | 8 | GRID_FILE = 'lego.npy' 9 | grid = svox2.SparseGrid(reso=256, device='cpu', radius=1.3256) 10 | data = torch.from_numpy(np.load(GRID_FILE)).view(-1, grid.data_dim) 11 | grid.sh_data.data = data[..., 1:] 12 | grid.density_data.data = data[..., :1] 13 | # grid.resample(128, use_z_order=True) 14 | grid = grid.cuda() 15 | 16 | c2w = torch.tensor([ 17 | [ -0.9999999403953552, 0.0, 0.0, 0.0 ], 18 | [ 0.0, -0.7341099977493286, 0.6790305972099304, 2.737260103225708 ], 19 | [ 0.0, 0.6790306568145752, 0.7341098785400391, 2.959291696548462 ], 20 | [ 0.0, 0.0, 0.0, 1.0 ], 21 | ], device=device) 22 | 23 | with torch.no_grad(): 24 | width = height = 800 25 | fx = fy = 1111 26 | origins = c2w[None, :3, 3].expand(height * width, -1).contiguous() 27 | yy, xx = torch.meshgrid( 28 | torch.arange(height, dtype=torch.float64, device=c2w.device), 29 | torch.arange(width, dtype=torch.float64, device=c2w.device), 30 | ) 31 | xx = (xx - width * 0.5) / float(fx) 32 | yy = (yy - height * 0.5) / float(fy) 33 | zz = torch.ones_like(xx) 34 | dirs = torch.stack((xx, -yy, -zz), dim=-1) 35 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 36 | dirs = dirs.reshape(-1, 3) 37 | del xx, yy, zz 38 | dirs = torch.matmul(c2w[None, :3, :3].double(), dirs[..., None])[..., 0].float() 39 | dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True) 40 | 41 | rays = svox2.Rays(origins, dirs) 42 | 43 | for i in range(5): 44 | with Timing("ours"): 45 | im = grid.volume_render(rays, use_kernel=True) 46 | 47 | im = im.reshape(height, width, 3) 48 | im = im.detach().clamp_(0.0, 1.0).cpu() 49 | plt.imshow(im) 50 | plt.show() 51 | -------------------------------------------------------------------------------- /Svox2/test/test_sample.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import numpy as np 4 | from util import Timing 5 | 6 | torch.random.manual_seed(0) 7 | 8 | device = 'cuda:0' 9 | 10 | # GRID_FILE = 'lego.npy' 11 | # grid = svox2.SparseGrid(reso=256, device='cpu', radius=1.3256) 12 | # grid.data.data = torch.from_numpy(np.load(GRID_FILE)).view(-1, grid.data_dim) 13 | # grid = grid.cuda() 14 | 15 | grid = svox2.SparseGrid(reso=256, center=[0.0, 0.0, 0.0], 16 | radius=1.0, device=device) 17 | grid.sh_data.data.normal_(0.0, 1.0) 18 | grid.density_data.data.normal_(0.1, 0.05).clamp_min_(0.0) 19 | # grid.density_data.data[:] = 1.0 20 | # grid = torch.rand((2, 2, 2, 4), device=device, dtype=torch.float32) 21 | 22 | N_POINTS = 5000 * 1024 23 | points = torch.rand(N_POINTS, 3, device=device) * 2 - 1 24 | # points = torch.tensor([[0.49, 0.49, 0.49], [0.9985, 0.4830, 0.4655]], device=device) 25 | # points.clamp_(-0.999, 0.999) 26 | 27 | _ = grid.sample(points) 28 | _ = grid.sample(points, use_kernel=False) 29 | 30 | grid.requires_grad_(True) 31 | 32 | with Timing("ours"): 33 | sigma_c, rgb_c = grid.sample(points) 34 | 35 | s = sigma_c.sum() + rgb_c.sum() 36 | with Timing("our_back"): 37 | s.backward() 38 | gdo = grid.density_data.grad.clone() 39 | gso = grid.sh_data.grad.clone() 40 | grid.density_data.grad = None 41 | grid.sh_data.grad = None 42 | 43 | with Timing("torch"): 44 | sigma_t, rgb_t = grid.sample(points, use_kernel=False) 45 | s = sigma_t.sum() + rgb_t.sum() 46 | with Timing("torch_back"): 47 | s.backward() 48 | gdt = grid.density_data.grad.clone() 49 | gst = grid.sh_data.grad.clone() 50 | 51 | # print('c\n', sampc) 52 | # print('t\n', sampt) 53 | print('err_sigma\n', torch.abs(sigma_t-sigma_c).max()) 54 | print('err_rgb\n', torch.abs(rgb_t-rgb_c).max()) 55 | print('err_grad_sigma\n', torch.abs(gdo-gdt).max()) 56 | print('err_grad_rgb\n', torch.abs(gso-gst).max()) 57 | -------------------------------------------------------------------------------- /Svox2/test/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | 4 | class Timing: 5 | def __init__(self, name): 6 | self.name = name 7 | 8 | def __enter__(self): 9 | self.start = torch.cuda.Event(enable_timing=True) 10 | self.end = torch.cuda.Event(enable_timing=True) 11 | self.start.record() 12 | 13 | def __exit__(self, type, value, traceback): 14 | self.end.record() 15 | torch.cuda.synchronize() 16 | print(self.name, 'elapsed', self.start.elapsed_time(self.end), 'ms') 17 | -------------------------------------------------------------------------------- /data/nerf_datasets/rs_dtu_4/check_same.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('path1', type=str) 6 | parser.add_argument('path2', type=str) 7 | args = parser.parse_args() 8 | 9 | z1 = np.load(args.path1) 10 | z2 = np.load(args.path2) 11 | 12 | assert z1.keys() == z2.keys() 13 | 14 | for k in z1.keys(): 15 | assert k in z2.keys() 16 | err = np.max(np.abs(z1[k] - z2[k])) 17 | assert err < 1e-10 18 | -------------------------------------------------------------------------------- /data/nerf_datasets/rs_dtu_4/proc.py: -------------------------------------------------------------------------------- 1 | def read_list(path): 2 | with open(path, 'r') as f: 3 | ids = [int(x[4:]) for x in f.readlines()] 4 | return ids 5 | 6 | mvsnet_ids = [] 7 | mvsnet_ids.extend(read_list('DTU/mvsnet_train.lst')) 8 | mvsnet_ids.extend(read_list('DTU/mvsnet_val.lst')) 9 | mvsnet_ids.extend(read_list('DTU/mvsnet_test.lst')) 10 | mvsnet_ids = sorted(mvsnet_ids) 11 | 12 | new_val_ids = read_list('DTU/new_val.lst') 13 | print(new_val_ids) 14 | 15 | manual_exclude = [1, 2, 7, 29, 39, 51, 56, 57, 58, 83, 111, 112, 113, 115, 116, 117] 16 | 17 | 18 | remaining = [i for i in range(1, 129) if i in mvsnet_ids and i not in new_val_ids and i not in manual_exclude] 19 | print('Extracted', len(remaining)) 20 | print(remaining) 21 | 22 | txt = '\n'.join(['scan' + str(i) for i in remaining]) 23 | with open('DTU/new_train.lst', 'w') as f: 24 | f.write(txt) 25 | 26 | -------------------------------------------------------------------------------- /data/nerf_datasets/rs_dtu_4/resize_cams.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | import cv2 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--data_dir", 10 | type=str, 11 | default="DTU", 12 | help="Data directory", 13 | ) 14 | args = parser.parse_args() 15 | 16 | cam_paths = [os.path.join(args.data_dir, x, 'cameras.npz') for x in os.listdir(args.data_dir)] 17 | 18 | scale_fact = 4 19 | 20 | for cam_path in tqdm(cam_paths): 21 | if not os.path.exists(cam_path): 22 | continue 23 | z = dict(np.load(cam_path)) 24 | for k in z.keys(): 25 | if k.startswith("camera_mat_inv_"): 26 | pass 27 | elif k.startswith("world_mat_inv_"): 28 | pass 29 | elif k.startswith("camera_mat_"): 30 | z[k][:3, :3] = z[k][:3, :3] * scale_fact 31 | elif k.startswith("world_mat_"): 32 | # K, R, t = cv2.decomposeProjectionMatrix(z[k][:3])[:3] 33 | # print('FROM') 34 | # print(K) 35 | # print(R) 36 | # print(t) 37 | z[k][:2] = z[k][:2] / scale_fact 38 | # K, R, t = cv2.decomposeProjectionMatrix(z[k][:3])[:3] 39 | # print('TO') 40 | # print(K) 41 | # print(R) 42 | # print(t) 43 | 44 | for k in z.keys(): 45 | if k.startswith("camera_mat_inv_"): 46 | noninv = "camera_mat_" + k[k.rindex('_') + 1:] 47 | z[k] = np.linalg.inv(z[noninv]) 48 | elif k.startswith("world_mat_inv_"): 49 | noninv = "world_mat_" + k[k.rindex('_') + 1:] 50 | z[k] = np.linalg.inv(z[noninv]) 51 | np.savez(cam_path, **z) 52 | -------------------------------------------------------------------------------- /data/nerf_datasets/rs_dtu_4/resize_imgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | from concurrent.futures import ProcessPoolExecutor 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--data_dir", 10 | type=str, 11 | default="DTU", 12 | help="Data directory", 13 | ) 14 | parser.add_argument( 15 | "--imsize", 16 | type=int, 17 | default=128, 18 | help="Output image size", 19 | ) 20 | parser.add_argument( 21 | "--num_workers", "-j", 22 | type=int, 23 | default=8, 24 | help="Num processes", 25 | ) 26 | args = parser.parse_args() 27 | 28 | 29 | def process_image(im_path): 30 | im = cv2.imread(im_path) 31 | if im is None: 32 | print('FAIL', im_path) 33 | return 34 | H, W, C = im.shape 35 | if H <= 600: 36 | return 37 | im = cv2.pyrDown(im) 38 | im = cv2.pyrDown(im) 39 | cv2.imwrite(im_path, im) 40 | 41 | 42 | futures = [] 43 | dir_paths = [os.path.join(args.data_dir, x) for x in os.listdir(args.data_dir)] 44 | im_paths = [os.path.join(dirpath, 'image', x) for dirpath in dir_paths if os.path.isdir(dirpath) for x in os.listdir(os.path.join(dirpath, 'image'))] 45 | # mask_paths = [os.path.join(dirpath, 'mask', x) for dirpath in dir_paths if os.path.isdir(dirpath) for x in os.listdir(os.path.join(dirpath, 'mask'))] 46 | # im_paths.extend(mask_paths) 47 | progress = tqdm(total=len(im_paths)) 48 | with ProcessPoolExecutor(max_workers=args.num_workers) as executor: 49 | for im_path in im_paths: 50 | futures.append( 51 | executor.submit( 52 | process_image, 53 | im_path, 54 | ) 55 | ) 56 | for future in futures: 57 | _ = future.result() 58 | progress.update(1) 59 | -------------------------------------------------------------------------------- /nerf_ops.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from matplotlib.pyplot import axis 4 | 5 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "pixel_nerf", "src"))) 6 | # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | 9 | import torch 10 | # import torch.nn.functional as F 11 | import numpy as np 12 | import imageio 13 | import pixel_nerf.src.util as util 14 | from pathlib import Path 15 | 16 | from extra_utils import check_folder 17 | import warnings 18 | from util.util import compute_ssim 19 | # from data import get_split_dataset 20 | from render import NeRFRenderer 21 | from model import make_model 22 | # from scipy.interpolate import CubicSpline 23 | import tqdm 24 | import math 25 | from pyhocon import ConfigFactory 26 | 27 | 28 | def vidoe_distance(frames1, frames2, no_lpips=False,number_of_chunks=1): 29 | """ 30 | a util function to measure the diastance between two video frames ( as torch tensors ``N*H*W*C` ) in psnr, sssim and lpips (optional) 31 | """ 32 | avg_psnr, avg_ssim, avg_lpips = 0.0, 0.0, 0.0 33 | # assert frames1.shape == frames2.shape ," the two videos are not equal in size" 34 | if frames1.shape != frames2.shape: 35 | return {"PSNR": avg_psnr, "SSIM": avg_ssim, "LPIPS": avg_lpips} 36 | 37 | if not no_lpips: 38 | import lpips 39 | lpips_vgg = lpips.LPIPS( 40 | net="vgg", verbose=False).eval().to(frames1.device) 41 | mse = (frames1 - frames2) ** 2 42 | mse_num: float = mse.mean().item() 43 | try: 44 | psnr = -10.0 * math.log10(mse_num) 45 | except: 46 | psnr = 0.0 47 | 48 | avg_psnr += psnr 49 | ssim = compute_ssim(frames1, frames2).mean().item() 50 | avg_ssim += ssim 51 | if not no_lpips: 52 | chunk_frames = int(frames1.shape[0]/number_of_chunks) 53 | for ii in range(number_of_chunks): 54 | lpips_i = lpips_vgg(frames1[ii*chunk_frames:ii*chunk_frames+chunk_frames].permute([0, 3, 1, 2]).contiguous(), 55 | frames2[ii*chunk_frames:ii*chunk_frames+chunk_frames].permute([0, 3, 1, 2]).contiguous(), normalize=True).mean().item() 56 | avg_lpips += lpips_i 57 | passed = chunk_frames * (number_of_chunks-1) 58 | if passed != frames1.shape[0]: 59 | lpips_i = lpips_vgg(frames1[passed::].permute([0, 3, 1, 2]).contiguous(), 60 | frames2[passed::].permute([0, 3, 1, 2]).contiguous(), normalize=True).mean().item() 61 | avg_lpips += lpips_i 62 | 63 | return {"PSNR": avg_psnr, "SSIM": avg_ssim, "LPIPS": avg_lpips} 64 | 65 | 66 | def evaluate_pixel_images(d_dir, eval_frames, traj_type="zoom", srf=None, device=None): 67 | if traj_type != "hard" and traj_type != "test": 68 | gt_vid = os.path.join(d_dir, "SRF", "vox512", "full", "{}_renders.mp4".format(traj_type)) 69 | gt_frames = imageio.mimread(gt_vid) 70 | gt_frames = torch.from_numpy(np.concatenate([x[None,...] for x in gt_frames]))/255.0 71 | else : 72 | _, gts, masks = srf.load_c2ws_images(data_dir=d_dir, device=device, split=traj_type, c_rf_variant=0, randomized_views=False, all_views=True) 73 | gt_frames = torch.Tensor(gts).permute(0, 3, 1, 2)[None, ...] 74 | 75 | eval_frames = torch.from_numpy(np.concatenate([x[None, ...] for x in eval_frames]))/255.0 76 | 77 | metrics = vidoe_distance(gt_frames, eval_frames) 78 | return metrics 79 | 80 | def find_split_and_indx(shape_id,lists_dir): 81 | shape_id = os.path.split(shape_id)[1] 82 | for lbl in ["train","val","test"]: 83 | file1 = open(os.path.join(lists_dir,"snr_{}.txt".format(lbl)), 'r') 84 | Lines = file1.read().splitlines() 85 | if shape_id in Lines: 86 | return lbl, Lines.index(shape_id) 87 | return None , None 88 | 89 | 90 | def visualize_pixel_nerf2(data_dict, batch_indx, render_dir, num_views=200, vizualization_id=0, gif=False, traj_type="zoom", setup=None, device="cuda:0", srf=None): 91 | shape_id = data_dict["labels"][batch_indx] 92 | views = '2' if setup["nb_views"] == 1 else '2 6 10' 93 | fps = 30 if num_views == 200 else int(30*num_views/200.0) 94 | found_split, found_indx = find_split_and_indx(shape_id, setup["data_dir"]) 95 | c2ws = "None" if traj_type not in ["hard", "test"] else srf.load_c2ws_images(data_dir=shape_id, device=device, split=traj_type, c_rf_variant=0, randomized_views=False, all_views=True)[0].tolist() 96 | # print("$$$$$$$$$$$$$$$$$$$",len(c2ws)) 97 | command = "python pixel_nerf/eval/gen_video.py -n sn64 --gpu_id=0 --split {} -P '{}' -D data/nerf_datasets/NMR_Dataset -S {} --conf pixel_nerf/conf/exp/sn64.conf --checkpoints_path pixel_nerf/checkpoints --visual_path {} --radius 0.0 --num_views {} --traj_type {} --vizualization_id {} --new_res {} --fps {} --c2ws '{}'".format( 98 | found_split, views, found_indx, render_dir, num_views, traj_type, vizualization_id, setup["img_res"], fps,str(c2ws)) 99 | # print(command) 100 | os.system(command) 101 | vid_file = os.path.join(render_dir, str(vizualization_id)+".mp4") 102 | frames = imageio.mimread(vid_file) 103 | # print("$$$$$$$$$$$", len(frames), frames[0].shape, frames[0].max()) 104 | return frames 105 | 106 | 107 | def visualize_vision_nerf(data_dict, batch_indx, render_dir, num_views=200, vizualization_id=0, gif=False, traj_type="zoom", setup=None,device="cuda:0",srf=None): 108 | shape_id = data_dict["labels"][batch_indx] 109 | print("$$$$$$$$$$$$$$", vizualization_id," : ", shape_id) 110 | views = '2'# if setup["nb_views"] == 1 else '2 6 10' 111 | fps = 30 if num_views == 200 else int(30*num_views/200.0) 112 | found_split, found_indx = find_split_and_indx(shape_id, setup["data_dir"]) 113 | c2ws = "None" if traj_type not in ["hard", "test"] else np.transpose(srf.load_c2ws_images(data_dir=shape_id, device=device, split=traj_type, c_rf_variant=0, randomized_views=False, all_views=True)[0],(0,1,2)).tolist() 114 | 115 | command = "python vision_nerf/eval_nmr.py --config vision_nerf/configs/render_nmr.txt --use_data_index --data_indices {} --mode {} --new_res {} --fps {} --num_views {} --traj_type {} --outdir {} --vizualization_id {} --pose_index {} --c2ws '{}'".format( 116 | found_indx, found_split, setup["img_res"], fps, num_views, traj_type, render_dir, vizualization_id, views, str(c2ws)) 117 | os.system(command) 118 | vid_file = os.path.join(render_dir, str(vizualization_id),"{}_renders_{}.mp4".format(traj_type,str(vizualization_id))) 119 | frames = imageio.mimread(vid_file) 120 | # print("$$$$$$$$$$$", len(frames), frames[0].shape, frames[0].max()) 121 | return frames 122 | 123 | def visualize_pixel_nerf(data_dict,batch_indx, net, render_dir, num_views=200, vizualization_id=0, gif=False, traj_type="zoom", setup=None, conf=None,device=None): 124 | elevation = -10 125 | elevation2 = 20 126 | radius = 0.0 # 0.85 127 | focal = torch.tensor(482.842712474619, dtype=torch.float32)[None] 128 | lindisp = False 129 | z_near = 1.2 130 | split = "test" 131 | z_far = 4.0 132 | ray_batch_size = 50000 133 | scale = 1.0 134 | fps = 30 135 | num_views = setup["nb_frames"] 136 | c = None # torch.tensor((setup["img_res"]/2, setup["img_res"]/2),dtype=torch.float32).to(device=device) 137 | source = torch.tensor(list(range(setup["nb_views"])), dtype=torch.long) 138 | data_path = data_dict["labels"][batch_indx] 139 | print("Data instance loaded:", data_path) 140 | 141 | images = data_dict["imgs"][batch_indx] # (NV, 3, H, W) 142 | 143 | poses = data_dict["c2ws"][batch_indx] 144 | 145 | 146 | # c = data.get("c") 147 | # if c is not None: 148 | # c = c.to(device=device).unsqueeze(0) 149 | 150 | NV, _, H, W = images.shape 151 | 152 | if scale != 1.0: 153 | Ht = int(H * scale) 154 | Wt = int(W * scale) 155 | if abs(Ht / scale - H) > 1e-10 or abs(Wt / scale - W) > 1e-10: 156 | warnings.warn( 157 | "Inexact scaling, please check {} times ({}, {}) is integral".format( 158 | scale, H, W 159 | ) 160 | ) 161 | H, W = Ht, Wt 162 | 163 | 164 | renderer = NeRFRenderer.from_conf( 165 | conf["renderer"], lindisp=lindisp, eval_batch_size=ray_batch_size, 166 | ).to(device=device) 167 | 168 | render_par = renderer.bind_parallel(net, "0", simple_output=True).eval() 169 | 170 | # Get the distance from camera to origin 171 | # z_near = dset.z_near 172 | # z_far = dset.z_far 173 | 174 | print("Generating rays") 175 | 176 | # dtu_format = hasattr(dset, "sub_format") and dset.sub_format == "dtu" 177 | 178 | print("Using default (360 loop) camera trajectory") 179 | if radius == 0.0: 180 | radius = (z_near + z_far) * 0.5 181 | print("> Using default camera radius", radius) 182 | else: 183 | radius = radius 184 | 185 | # Use 360 pose sequence from NeRF 186 | render_poses = torch.stack( 187 | [ 188 | util.pose_spherical(angle, elevation, radius) 189 | for angle in np.linspace(-180, 180, num_views + 1)[:-1] 190 | ], 191 | 0, 192 | ) # (NV, 4, 4) 193 | 194 | render_rays = util.gen_rays( 195 | render_poses, 196 | W, 197 | H, 198 | focal * scale, 199 | z_near, 200 | z_far, 201 | c=c * scale if c is not None else None, 202 | ).to(device=device) 203 | # (NV, H, W, 8) 204 | 205 | focal = focal.to(device=device) 206 | 207 | # source = torch.tensor(list(map(int, args.source.split())), dtype=torch.long) 208 | 209 | NS = len(source) 210 | print("$$$$$$$$$$$$$", focal) 211 | 212 | random_source = NS == 1 and source[0] == -1 213 | assert not (source >= NV).any() 214 | 215 | if renderer.n_coarse < 64: 216 | # Ensure decent sampling resolution 217 | renderer.n_coarse = 64 218 | renderer.n_fine = 128 219 | 220 | with torch.no_grad(): 221 | print("Encoding source view(s)") 222 | if random_source: 223 | src_view = torch.randint(0, NV, (1,)) 224 | else: 225 | src_view = source 226 | 227 | net.encode( 228 | images[src_view].unsqueeze(0), 229 | poses[src_view].unsqueeze(0).to(device=device), 230 | focal, 231 | c=c, 232 | ) 233 | 234 | print("Rendering", num_views * H * W, "rays") 235 | all_rgb_fine = [] 236 | for rays in tqdm.tqdm(torch.split(render_rays.view(-1, 8), ray_batch_size, dim=0) ): 237 | rgb, _depth = render_par(rays[None]) 238 | all_rgb_fine.append(rgb[0]) 239 | _depth = None 240 | rgb_fine = torch.cat(all_rgb_fine) 241 | # rgb_fine (V*H*W, 3) 242 | 243 | frames = rgb_fine.view(-1, H, W, 3) 244 | 245 | print("Writing video") 246 | vid_name = "{}".format(vizualization_id) 247 | vid_path = os.path.join(render_dir, vid_name + ".mp4") 248 | # viewimg_path = os.path.join(render_dir, args.name, "video" + vid_name + "_view.jpg") 249 | imageio.mimwrite(vid_path, (frames.cpu().numpy() * 255).astype(np.uint8), fps=fps, quality=8 250 | ) 251 | 252 | 253 | return frames 254 | 255 | 256 | def evaluate_pixel_nerf(val_loader, device, srf, setup): 257 | setup["pixel_dir"] = os.path.join(setup["root_dir"], "pixel_nerf") 258 | 259 | # model_path = os.path.join(setup["pixel_dir"], "checkpoints","sn64","pixel_nerf_latest" ) 260 | # conf = ConfigFactory.parse_file(os.path.join(setup["pixel_dir"],"conf","exp","sn64.conf")) 261 | # net = make_model(conf["model"]).to(device=device) 262 | # net.my_load_weights(model_path) 263 | val_ssim, val_psnr, val_lpips = [], [], [] 264 | losses = [] 265 | 266 | # args, conf = util.args.parse_args(extra_args) 267 | # args.resume = True 268 | 269 | 270 | for i, data_dict in enumerate(val_loader): 271 | short_list_cond = i * setup["batch_size"] in range(0, 0+setup["visualizations_nb"]) 272 | if not short_list_cond : 273 | continue 274 | 275 | for ii, d_dir in enumerate(data_dict["labels"]): 276 | render_dir = os.path.join(setup["baseline_dir"],setup["run"], "view"+str(setup["nb_views"]), "vids") 277 | Path(render_dir).mkdir(parents=True, exist_ok=True) 278 | # frames = visualize_pixel_nerf(data_dict, batch_indx=ii, net=net, render_dir=render_dir, num_views=setup["nb_frames"], vizualization_id=setup["batch_size"]*i + ii, gif=setup["gif"], traj_type=setup["traj_type"], setup=setup, conf=conf, device=device) 279 | if setup["run"] == "pixel": 280 | frames = visualize_pixel_nerf2(data_dict, batch_indx=ii, render_dir=render_dir, num_views=setup["nb_frames"], vizualization_id=setup["batch_size"]*i + ii, gif=setup["gif"], traj_type=setup["traj_type"], setup=setup, device=device, srf=srf) 281 | elif setup["run"] == "vision": 282 | frames = visualize_vision_nerf(data_dict, batch_indx=ii, render_dir=render_dir,num_views=setup["nb_frames"], vizualization_id=setup["batch_size"]*i + ii, gif=setup["gif"], traj_type=setup["traj_type"], setup=setup, device=device, srf=srf ) 283 | 284 | pred_metrics = evaluate_pixel_images(d_dir, frames, traj_type=setup["traj_type"], srf=srf, device=device) 285 | 286 | 287 | # if setup["gif"]: 288 | # wandb.log({"renderings/{}".format(setup["batch_size"]*i + ii): wandb.Video(np.transpose(np.concatenate([fr[None, ...] for fr in frames], axis=0), ( 289 | # 0, 3, 1, 2)), fps=2, format="gif"), "epoch": str(epoch)}, commit=False) 290 | # if setup["concat_gt_output"] and not setup["gif"] and setup["visualize_gt"]: 291 | # concat_render_dir = os.path.join( 292 | # setup["output_dir"], "comparisons", str(epoch)) 293 | # os.makedirs(concat_render_dir, exist_ok=True) 294 | # out_vid = os.path.join(render_dir, "{}_renders_{}.mp4".format( 295 | # setup["traj_type"], str(setup["batch_size"]*i + ii))) 296 | # gt_vid = os.path.join(gt_render_dir, "{}_renders_{}.mp4".format( 297 | # setup["traj_type"], str(setup["batch_size"]*i + ii))) 298 | # concat_vid = os.path.join(concat_render_dir, "{}_renders_{}.mp4".format( 299 | # setup["traj_type"], str(setup["batch_size"]*i + ii))) 300 | # concat_horizontal_videos(source_videos_list=[out_vid, gt_vid], output_file=concat_vid) 301 | 302 | val_ssim.append(pred_metrics["SSIM"]) 303 | val_psnr.append(pred_metrics["PSNR"]) 304 | val_lpips.append(pred_metrics["LPIPS"]) 305 | torch.cuda.empty_cache() 306 | 307 | return {"loss": np.mean(losses), "ssim": np.mean(val_ssim), "psnr": np.mean(val_psnr), "lpips": np.mean(val_lpips)} 308 | --------------------------------------------------------------------------------