├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── github_img ├── fastopt.gif ├── fox.png ├── garden.png └── pipeline.png ├── manual_install.sh ├── opt ├── autotune.py ├── calc_metrics.py ├── co3d_tmp │ └── co3d.txt ├── configs │ ├── co3d.json │ ├── custom.json │ ├── custom_alt.json │ ├── llff.json │ ├── llff_hitv.json │ ├── syn.json │ ├── syn_nv.json │ └── tnt.json ├── launch.sh ├── opt.py ├── render_imgs.py ├── render_imgs_circle.py ├── scripts │ ├── colmap2nsvf.py │ ├── create_split.py │ ├── ingp2nsvf.py │ ├── proc_colmap.sh │ ├── proc_record3d.py │ ├── run_colmap.py │ ├── unsplit.py │ ├── vendor │ │ └── read_write_model.py │ └── view_data.py ├── tasks │ ├── eval.json │ ├── eval_ff.json │ ├── eval_real_iconic.json │ ├── eval_tnt.json │ ├── interpablate.json │ ├── ntrainablate.json │ ├── sanity.json │ └── tvearlyonly.json ├── to_svox1.py └── util │ ├── __init__.py │ ├── co3d_dataset.py │ ├── config_util.py │ ├── dataset.py │ ├── dataset_base.py │ ├── llff_dataset.py │ ├── load_llff.py │ ├── nerf_dataset.py │ ├── nsvf_dataset.py │ └── util.py ├── setup.py ├── svox2 ├── __init__.py ├── csrc │ ├── .ccls │ ├── CMakeLists.txt │ ├── include │ │ ├── cubemap_util.cuh │ │ ├── cuda_util.cuh │ │ ├── data_spec.hpp │ │ ├── data_spec_packed.cuh │ │ ├── random_util.cuh │ │ ├── render_util.cuh │ │ └── util.hpp │ ├── loss_kernel.cu │ ├── misc_kernel.cu │ ├── optim_kernel.cu │ ├── render_lerp_kernel_cuvol.cu │ ├── render_lerp_kernel_nvol.cu │ ├── render_svox1_kernel.cu │ ├── svox2.cpp │ └── svox2_kernel.cu ├── defs.py ├── svox2.py ├── utils.py └── version.py └── test ├── prof.py ├── sanity.py ├── test_render_gradcheck.py ├── test_render_timing.py ├── test_render_timing_smallbat.py ├── test_render_visual.py ├── test_sample.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | Session.vim 5 | ckpt/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, the Plenoxels authors 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Plenoxels: Radiance Fields without Neural Networks 2 | 3 | Alex Yu\*, Sara Fridovich-Keil\*, Matthew Tancik, Qinhong Chen, Benjamin Recht, Angjoo Kanazawa 4 | 5 | UC Berkeley 6 | 7 | Website and video: 8 | 9 | arXiv: 10 | 11 | [Featured at Two Minute Papers YouTube](https://youtu.be/yptwRRpPEBM) 2022-01-11 12 | 13 | Despite the name, it's not strictly intended to be a successor of svox 14 | 15 | Citation: 16 | ``` 17 | @inproceedings{yu2022plenoxels, 18 | title={Plenoxels: Radiance Fields without Neural Networks}, 19 | author={Sara Fridovich-Keil and Alex Yu and Matthew Tancik and Qinhong Chen and Benjamin Recht and Angjoo Kanazawa}, 20 | year={2022}, 21 | booktitle={CVPR}, 22 | } 23 | ``` 24 | Note that the joint first-authors decided to swap the order of names between arXiv and CVPR proceedings. 25 | 26 | This contains the official optimization code. 27 | A JAX implementation is also available at . However, note that the JAX version is currently feature-limited, running in about 1 hour per epoch and only supporting bounded scenes (at present). 28 | 29 | ![Fast optimization](https://raw.githubusercontent.com/sxyu/svox2/master/github_img/fastopt.gif) 30 | 31 | ![Overview](https://raw.githubusercontent.com/sxyu/svox2/master/github_img/pipeline.png) 32 | 33 | ### Examples use cases 34 | 35 | Check out PeRFCeption [Jeong, Shin, Lee, et al], which uses Plenoxels with tuned parameters to generate a large 36 | dataset of radiance fields: 37 | https://github.com/POSTECH-CVLab/PeRFception 38 | 39 | Artistic Radiance Fields by Kai Zhang et al 40 | https://github.com/Kai-46/ARF-svox2 41 | 42 | ## Setup 43 | 44 | **Windows is not officially supported, and we have only tested with Linux. Adding support would be welcome.** 45 | 46 | First create the virtualenv; we recommend using conda: 47 | ```sh 48 | conda env create -f environment.yml 49 | conda activate plenoxel 50 | ``` 51 | 52 | Then clone the repo and install the library at the root (svox2), which includes a CUDA extension. 53 | 54 | **If and only if** your CUDA toolkit is older than 11, you will need to install CUB as follows: 55 | `conda install -c bottler nvidiacub`. 56 | Since CUDA 11, CUB is shipped with the toolkit and installing this may lead to build errors. 57 | 58 | To install the main library, simply run 59 | ``` 60 | pip install -e . --verbose 61 | ``` 62 | In the repo root directory. 63 | 64 | ## Getting datasets 65 | 66 | We have backends for NeRF-Blender, LLFF, NSVF, and CO3D dataset formats, and the dataset will be auto-detected. 67 | 68 | Please get the NeRF-synthetic and LLFF datasets from: 69 | 70 | (`nerf_synthetic.zip` and `nerf_llff_data.zip`). 71 | 72 | We provide a processed Tanks and temples dataset (with background) in NSVF format at: 73 | 74 | 75 | Note this data should be identical to that in NeRF++ 76 | 77 | Finally, the real Lego capture can be downloaded from: 78 | https://drive.google.com/file/d/1PG-KllCv4vSRPO7n5lpBjyTjlUyT8Nag/view?usp=sharing 79 | 80 | **Note: we currently do not support the instant-ngp format data (since the project was released before NGP). Using it will trigger the nerf-synthetic (Blender) data loader 81 | due to similarity, but will not train properly. For real data we use the NSVF format.** 82 | 83 | To convert instant-ngp data, please try our script 84 | ``` 85 | cd opt/scripts 86 | python ingp2nsvf.py 87 | ``` 88 | 89 | ## Optimization 90 | 91 | For training a single scene, see `opt/opt.py`. The launch script makes this easier. 92 | 93 | Inside `opt/`, run 94 | `./launch.sh -c ` 95 | 96 | Where `` should be `configs/syn.json` for NeRF-synthetic scenes, 97 | `configs/llff.json` 98 | for forward-facing scenes, and 99 | `configs/tnt.json` for tanks and temples scenes, for example. 100 | 101 | The dataset format will be auto-detected from `data_dir`. 102 | Checkpoints will be in `ckpt/exp_name`. 103 | 104 | **For pretrained checkpoints please see:** https://drive.google.com/drive/folders/1SOEJDw8mot7kf5viUK9XryOAmZGe_vvE?usp=sharing 105 | 106 | ## Evaluation 107 | 108 | Use `opt/render_imgs.py` 109 | 110 | Usage, 111 | (in opt/) 112 | `python render_imgs.py ` 113 | 114 | By default this saves all frames, which is very slow. Add `--no_imsave` to avoid this. 115 | 116 | 117 | ## Rendering a spiral 118 | 119 | Use `opt/render_imgs_circle.py` 120 | 121 | Usage, 122 | (in opt/) 123 | `python render_imgs_circle.py ` 124 | 125 | ## Parallel task executor 126 | 127 | We provide a parallel task executor based on the task manager from PlenOctrees to automatically 128 | schedule many tasks across sets of scenes or hyperparameters. 129 | This is used for evaluation, ablations, and hypertuning 130 | See `opt/autotune.py`. Configs in `opt/tasks/*.json` 131 | 132 | For example, to automatically train and eval all synthetic scenes: 133 | you will need to change `train_root` and `data_root` in `tasks/eval.json`, then run: 134 | ```sh 135 | python autotune.py -g '' tasks/eval.json 136 | ``` 137 | 138 | For forward-facing scenes 139 | ```sh 140 | python autotune.py -g '' tasks/eval_ff.json 141 | ``` 142 | 143 | For Tanks and Temples scenes 144 | ```sh 145 | python autotune.py -g '' tasks/eval_tnt.json 146 | ``` 147 | 148 | ## Using a custom image set (360) 149 | 150 | Please take images all around the object and try to take images at different elevations. 151 | First make sure you have colmap installed. Then 152 | 153 | (in opt/scripts) 154 | `bash proc_colmap.sh --noradial` 155 | 156 | Where `` should be a directory directly containing png/jpg images from a 157 | normal perspective camera. 158 | UPDATE: `--noradial` is recommended since otherwise, the script performs undistortion, which seems to not work well and make results blurry. 159 | Support for the complete OPENCV camera model which has been used by more recent projects would be welcome 160 | https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L477. 161 | For custom datasets we adopt a data format similar to that in NSVF 162 | 163 | 164 | 165 | You should be able to use this dataset directly afterwards. The format will be auto-detected. 166 | 167 | To view the data (and check the scene normalization) use: 168 | `python view_data.py ` 169 | 170 | You will need nerfvis: `pip install nerfvis` 171 | 172 | This should launch a server at localhost:8889 173 | 174 | 175 | Now follow the "Voxel Optimization (aka Training)" section to train: 176 | 177 | `./launch.sh -c configs/custom.json` 178 | 179 | custom.json was used for the real lego bulldozer scene. 180 | You can also try `configs/custom_alt.json` which has some minor differences **especially that near_clip is eliminated**. If the scene's central object is totally messed up, this might be due to the aggressive near clip, and the alt config fixes it. 181 | 182 | You may need to tune the TV and sparsity loss for best results. 183 | 184 | 185 | To render a video, please see the "rendering a spiral" section. 186 | To convert to a svox1-compatible PlenOctree (not perfect quality since interpolation is not implemented) 187 | you can try `to_svox1.py ` 188 | 189 | 190 | Example result with the mip-nerf-360 garden data (using custom_alt config as provided) 191 | ![Garden](https://raw.githubusercontent.com/sxyu/svox2/master/github_img/garden.png) 192 | 193 | Fox data (converted with the script `opt/scripts/ingp2nsvf.py`) 194 | ![Fox](https://raw.githubusercontent.com/sxyu/svox2/master/github_img/fox.png) 195 | 196 | ### Common Capture Tips 197 | 198 | Floaters and poor quality surfaces can be caused by the following reasons 199 | 200 | - Dynamic objects. Dynamic object modelling is not supported in this repo, and if anything moves it will probably lead to floaters 201 | - Specularity. Very shiny surfaces will lead to floaters and/or poor surfaces 202 | - Exposure variations. Please lock the exposure when recording a video if possible 203 | - Lighting variations. Sometimes the clouds move when capturing outdoors.. Try to capture within a short time frame 204 | - Motion blur and DoF blur. Try to move slowly and make sure the object is in focus. For small objects, DoF tends to be a substantial issue 205 | - Image quality. Images may have severe JPEG compression artifacts for example 206 | 207 | ## Potential extensions 208 | 209 | Due to limited time we did not make the follow extensions which should make the quality and speed better. 210 | 211 | - Use exp activation instead of ReLU. May help with the semi-transparent look issue 212 | - Add mip-nerf 360 distortion loss to reduce floaters. PeRFCeption also tuned some parameters to help with the quality 213 | - Exposure modelling 214 | - Use FP16 training. This codebase uses FP32 still. This should improve speed and memory use 215 | - Add a GUI viewer 216 | 217 | ## Random tip: how to make pip install faster for native extensions 218 | 219 | You may notice that this CUDA extension takes forever to install. 220 | A suggestion is using ninja. On Ubuntu, 221 | install it with `sudo apt install ninja-build`. 222 | This will enable parallel compilation and significantly improve iteration speed. 223 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # run: conda env create -f environment.yml 2 | name: plenoxel 3 | channels: 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.8.8 8 | - numpy>=1.16.4,<1.19.0 9 | - pip 10 | - pip: 11 | - imageio 12 | - imageio-ffmpeg 13 | - ipdb 14 | - lpips 15 | - opencv-python>=4.4.0 16 | - Pillow>=7.2.0 17 | - pyyaml>=5.3.1 18 | - tensorboard>=2.4.0 19 | - imageio 20 | - imageio-ffmpeg 21 | - pymcubes 22 | - moviepy 23 | - matplotlib 24 | - scipy>=1.6.0 25 | - pytorch=1.11.0 26 | - torchvision 27 | - cudatoolkit 28 | - tqdm 29 | 30 | -------------------------------------------------------------------------------- /github_img/fastopt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/svox2/ee80e2c4df8f29a407fda5729a494be94ccf9234/github_img/fastopt.gif -------------------------------------------------------------------------------- /github_img/fox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/svox2/ee80e2c4df8f29a407fda5729a494be94ccf9234/github_img/fox.png -------------------------------------------------------------------------------- /github_img/garden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/svox2/ee80e2c4df8f29a407fda5729a494be94ccf9234/github_img/garden.png -------------------------------------------------------------------------------- /github_img/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/svox2/ee80e2c4df8f29a407fda5729a494be94ccf9234/github_img/pipeline.png -------------------------------------------------------------------------------- /manual_install.sh: -------------------------------------------------------------------------------- 1 | cp svox2/svox2.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/svox2.py 2 | cp svox2/utils.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/utils.py 3 | cp svox2/version.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/version.py 4 | cp svox2/defs.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/defs.py 5 | cp svox2/__init__.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/__init__.py 6 | -------------------------------------------------------------------------------- /opt/autotune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import random 7 | from multiprocessing import Process, Queue 8 | import os 9 | from os import path, listdir 10 | import argparse 11 | import json 12 | import subprocess 13 | import sys 14 | from typing import List, Dict 15 | import itertools 16 | from warnings import warn 17 | from datetime import datetime 18 | import numpy as np 19 | from glob import glob 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("task_json", type=str) 23 | parser.add_argument("--gpus", "-g", type=str, required=True, 24 | help="space delimited GPU id list (global id in nvidia-smi, " 25 | "not considering CUDA_VISIBLE_DEVICES)") 26 | parser.add_argument('--eval', action='store_true', default=False, 27 | help='evaluation mode (run the render_imgs script)') 28 | parser.add_argument('--render', action='store_true', default=False, 29 | help='also run render_imgs.py with --render_path to render a rotating trajectory (forward-facing case)') 30 | args = parser.parse_args() 31 | 32 | PSNR_FILE_NAME = 'test_psnr.txt' 33 | 34 | def run_exp(env, eval_mode:bool, enable_render:bool, train_dir, data_dir, config, flags, eval_flags, common_flags): 35 | opt_base_cmd = [ "python", "opt.py", "--tune_mode" ] 36 | 37 | if not eval_mode: 38 | opt_base_cmd += ["--tune_nosave"] 39 | opt_base_cmd += [ 40 | "-t", train_dir, 41 | data_dir 42 | ] 43 | if config != '': 44 | opt_base_cmd += ['-c', config] 45 | log_file_path = path.join(train_dir, 'log') 46 | psnr_file_path = path.join(train_dir, PSNR_FILE_NAME) 47 | ckpt_path = path.join(train_dir, 'ckpt.npz') 48 | if path.isfile(psnr_file_path): 49 | print('! SKIP', train_dir) 50 | return 51 | print('********************************************') 52 | if eval_mode: 53 | print('EVAL MODE') 54 | 55 | if eval_mode and path.isfile(ckpt_path): 56 | print('! SKIP training because ckpt exists', ckpt_path) 57 | opt_ret = "" # Silence 58 | else: 59 | print('! RUN opt.py -t', train_dir) 60 | opt_cmd = ' '.join(opt_base_cmd + flags + common_flags) 61 | print(opt_cmd) 62 | try: 63 | opt_ret = subprocess.check_output(opt_cmd, shell=True, env=env).decode( 64 | sys.stdout.encoding) 65 | except subprocess.CalledProcessError: 66 | print('Error occurred while running OPT for exp', train_dir, 'on', env["CUDA_VISIBLE_DEVICES"]) 67 | return 68 | with open(log_file_path, 'w') as f: 69 | f.write(opt_ret) 70 | 71 | if eval_mode: 72 | eval_base_cmd = [ 73 | "python", "render_imgs.py", 74 | ckpt_path, 75 | data_dir 76 | ] 77 | if config != '': 78 | eval_base_cmd += ['-c', config] 79 | psnr_file_path = path.join(train_dir, 'test_renders', 'psnr.txt') 80 | if not path.exists(psnr_file_path): 81 | eval_cmd = ' '.join(eval_base_cmd + eval_flags + common_flags) 82 | print('! RUN render_imgs.py', ckpt_path) 83 | print(eval_cmd) 84 | try: 85 | eval_ret = subprocess.check_output(eval_cmd, shell=True, env=env).decode( 86 | sys.stdout.encoding) 87 | except subprocess.CalledProcessError: 88 | print('Error occurred while running EVAL for exp', train_dir, 'on', env["CUDA_VISIBLE_DEVICES"]) 89 | return 90 | else: 91 | print('! SKIP eval because psnr.txt exists', psnr_file_path) 92 | 93 | if enable_render: 94 | eval_base_cmd += ['--render_path'] 95 | render_cmd = ' '.join(eval_base_cmd + eval_flags + common_flags) 96 | try: 97 | render_ret = subprocess.check_output(render_cmd, shell=True, env=env).decode( 98 | sys.stdout.encoding) 99 | except subprocess.CalledProcessError: 100 | print('Error occurred while running RENDER for exp', train_dir, 'on', env["CUDA_VISIBLE_DEVICES"]) 101 | return 102 | else: 103 | test_stats = [eval(x.split('eval stats:')[-1].strip()) 104 | for x in opt_ret.split('\n') if 105 | x.startswith('eval stats: ')] 106 | if len(test_stats) == 0: 107 | print('note: invalid config or crash') 108 | final_test_psnr = 0.0 109 | else: 110 | test_psnrs = [stats['psnr'] for stats in test_stats if 'psnr' in stats.keys()] 111 | print('final psnrs', test_psnrs[-5:]) 112 | final_test_psnr = test_psnrs[-1] 113 | with open(psnr_file_path, 'w') as f: 114 | f.write(str(final_test_psnr)) 115 | 116 | def process_main(device, eval_mode:bool, enable_render:bool, queue): 117 | # Set CUDA_VISIBLE_DEVICES programmatically 118 | env = os.environ.copy() 119 | env["CUDA_VISIBLE_DEVICES"] = str(device) 120 | while True: 121 | task = queue.get() 122 | if len(task) == 0: 123 | break 124 | run_exp(env, eval_mode, enable_render, **task) 125 | 126 | # Variable value list generation helpers 127 | def lin(start, stop, num): 128 | return np.linspace(start, stop, num).tolist() 129 | 130 | def randlin(start, stop, num): 131 | lst = np.linspace(start, stop, num + 1)[:-1] 132 | lst += np.random.uniform(low=0.0, high=(lst[1] - lst[0]), size=lst.shape) 133 | return lst.tolist() 134 | 135 | def loglin(start, stop, num): 136 | return np.exp(np.linspace(np.log(start), np.log(stop), num)).tolist() 137 | 138 | def randloglin(start, stop, num): 139 | lst = np.linspace(np.log(start), np.log(stop), num + 1)[:-1] 140 | lst += np.random.uniform(low=0.0, high=(lst[1] - lst[0]), size=lst.shape) 141 | return np.exp(lst).tolist() 142 | # End variable value list generation helpers 143 | 144 | def create_prodvars(variables, noise_stds={}): 145 | """ 146 | Create a dict for each setting of variable values 147 | (product across lists) 148 | """ 149 | 150 | def auto_list(x): 151 | if isinstance(x, list): 152 | return x 153 | elif isinstance(x, dict) or isinstance(x, set): 154 | return [x] 155 | elif isinstance(x, str): 156 | return eval(x) 157 | else: 158 | raise NotImplementedError('variable value must be list of values, or str generator') 159 | 160 | variables = {varname:auto_list(variables[varname]) for varname in variables} 161 | print('variables (prod)', variables) 162 | varnames = list(variables.keys()) 163 | noise_stds = np.array([noise_stds.get(varname, 0.0) for varname in varnames]) 164 | variables = [[(i, val) for val in variables[varname]] for i, varname in enumerate(varnames)] 165 | prodvars = list(itertools.product(*variables)) 166 | noise_vals = np.random.randn(len(prodvars), len(varnames)) * noise_stds 167 | prodvars = [{varnames[i]:((val + n) if n != 0.0 else val) for (i, val), n in zip(sample, noise_vals_samp)} for sample, noise_vals_samp in zip(prodvars, noise_vals)] 168 | return prodvars 169 | 170 | 171 | def recursive_replace(data, variables): 172 | if isinstance(data, str): 173 | return data.format(**variables) 174 | elif isinstance(data, list): 175 | return [recursive_replace(d, variables) for d in data] 176 | elif isinstance(data, dict): 177 | return {k:recursive_replace(data[k], variables) for k in data.keys()} 178 | else: 179 | return data 180 | 181 | 182 | if __name__ == '__main__': 183 | with open(args.task_json, 'r') as f: 184 | tasks_file = json.load(f) 185 | assert isinstance(tasks_file, dict), 'Root of json must be dict' 186 | all_tasks_templ = tasks_file.get('tasks', []) 187 | all_tasks = [] 188 | data_root = path.expanduser(tasks_file['data_root']) # Required 189 | train_root = path.expanduser(tasks_file['train_root']) # Required 190 | base_flags = tasks_file.get('base_flags', []) 191 | base_eval_flags = tasks_file.get('base_eval_flags', []) 192 | base_common_flags = tasks_file.get('base_common_flags', []) 193 | default_config = tasks_file.get('config', '') 194 | 195 | if 'eval' in tasks_file: 196 | args.eval = tasks_file['eval'] 197 | print('Eval mode?', args.eval) 198 | if 'render' in tasks_file: 199 | args.render = tasks_file['render'] 200 | print('Render traj?', args.render) 201 | pqueue = Queue() 202 | 203 | leaderboard_path = path.join(train_root, 'results.txt' if args.eval else 'leaderboard.txt') 204 | print('Leaderboard path:', leaderboard_path) 205 | 206 | variables : Dict = tasks_file.get('variables', {}) 207 | noises : Dict = tasks_file.get('noises', {}) 208 | assert isinstance(variables, dict), 'var must be dict' 209 | 210 | prodvars : List[Dict] = create_prodvars(variables, noises) 211 | del variables 212 | 213 | for task_templ in all_tasks_templ: 214 | for variables in prodvars: 215 | task : Dict = recursive_replace(task_templ, variables) 216 | task['train_dir'] = path.join(train_root, task['train_dir']) # Required 217 | task['data_dir'] = path.join(data_root, task.get('data_dir', '')).rstrip('/') 218 | task['flags'] = task.get('flags', []) + base_flags 219 | task['eval_flags'] = task.get('eval_flags', []) + base_eval_flags 220 | task['common_flags'] = task.get('common_flags', []) + base_common_flags 221 | task['config'] = task.get('config', default_config) 222 | os.makedirs(task['train_dir'], exist_ok=True) 223 | # santity check 224 | assert path.exists(task['train_dir']), task['train_dir'] + ' does not exist' 225 | assert path.exists(task['data_dir']), task['data_dir'] + ' does not exist' 226 | all_tasks.append(task) 227 | task = None 228 | # Shuffle the tasks 229 | if not args.eval: 230 | random.shuffle(all_tasks) 231 | 232 | for task in all_tasks: 233 | pqueue.put(task) 234 | 235 | args.gpus = list(map(int, args.gpus.split())) 236 | print('GPUS:', args.gpus) 237 | 238 | for _ in args.gpus: 239 | pqueue.put({}) 240 | 241 | all_procs = [] 242 | for i, gpu in enumerate(args.gpus): 243 | process = Process(target=process_main, args=(gpu, args.eval, args.render, pqueue)) 244 | process.daemon = True 245 | process.start() 246 | all_procs.append(process) 247 | 248 | for i, gpu in enumerate(args.gpus): 249 | all_procs[i].join() 250 | 251 | if args.eval: 252 | print('Done') 253 | with open(leaderboard_path, 'w') as leaderboard_file: 254 | lines = [f'dir\tPSNR\tSSIM\tLPIPS\nminutes\n'] 255 | all_tasks = sorted(all_tasks, key=lambda task:task['train_dir']) 256 | all_psnr = [] 257 | all_ssim = [] 258 | all_lpips = [] 259 | all_times = [] 260 | for task in all_tasks: 261 | train_dir = task['train_dir'] 262 | psnr_file_path = path.join(train_dir, 'test_renders', 'psnr.txt') 263 | ssim_file_path = path.join(train_dir, 'test_renders', 'ssim.txt') 264 | lpips_file_path = path.join(train_dir, 'test_renders', 'lpips.txt') 265 | time_file_path = path.join(train_dir, 'time_mins.txt') 266 | 267 | if path.isfile(psnr_file_path): 268 | with open(psnr_file_path, 'r') as f: 269 | psnr = float(f.read()) 270 | all_psnr.append(psnr) 271 | psnr_txt = f'{psnr:.10f}' 272 | else: 273 | psnr_txt = 'ERR' 274 | if path.isfile(ssim_file_path): 275 | with open(ssim_file_path, 'r') as f: 276 | ssim = float(f.read()) 277 | all_ssim.append(ssim) 278 | ssim_txt = f'{ssim:.10f}' 279 | else: 280 | ssim_txt = 'ERR' 281 | if path.isfile(lpips_file_path): 282 | with open(lpips_file_path, 'r') as f: 283 | lpips = float(f.read()) 284 | all_lpips.append(lpips) 285 | lpips_txt = f'{lpips:.10f}' 286 | else: 287 | lpips_txt = 'ERR' 288 | if path.isfile(time_file_path): 289 | with open(time_file_path, 'r') as f: 290 | time_mins = float(f.read()) 291 | all_times.append(time_mins) 292 | time_txt = f'{time_mins:.10f}' 293 | else: 294 | time_txt = 'ERR' 295 | line = f'{path.basename(train_dir.rstrip("/"))}\t{psnr_txt}\t{ssim_txt}\t{lpips_txt}\t{time_txt}\n' 296 | lines.append(line) 297 | lines.append('---------\n') 298 | if len(all_psnr): 299 | lines.append('Average PSNR: ' + str(sum(all_psnr) / len(all_psnr)) + '\n') 300 | if len(all_ssim): 301 | lines.append('Average SSIM: ' + str(sum(all_ssim) / len(all_ssim)) + '\n') 302 | if len(all_lpips): 303 | lines.append('Average LPIPS: ' + str(sum(all_lpips) / len(all_lpips)) + '\n') 304 | if len(all_times): 305 | lines.append('Average Time (mins): ' + str(sum(all_times) / len(all_times)) + '\n') 306 | leaderboard_file.writelines(lines) 307 | 308 | else: 309 | with open(leaderboard_path, 'w') as leaderboard_file: 310 | exps = [] 311 | for task in all_tasks: 312 | train_dir = task['train_dir'] 313 | psnr_file_path = path.join(train_dir, PSNR_FILE_NAME) 314 | 315 | with open(psnr_file_path, 'r') as f: 316 | test_psnr = float(f.read()) 317 | print(train_dir, test_psnr) 318 | exps.append((test_psnr, train_dir)) 319 | exps = sorted(exps, key = lambda x: -x[0]) 320 | lines = [f'{psnr:.10f}\t{train_dir}\n' for psnr, train_dir in exps] 321 | leaderboard_file.writelines(lines) 322 | print('Wrote', leaderboard_path) 323 | 324 | -------------------------------------------------------------------------------- /opt/calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Calculate metrics on saved images 2 | 3 | # Usage: python calc_metrics.py 4 | # Where is ckpt_dir/test_renders 5 | # or jaxnerf test renders dir 6 | 7 | from util.dataset import datasets 8 | from util.util import compute_ssim, viridis_cmap 9 | from util import config_util 10 | from os import path 11 | from glob import glob 12 | import imageio 13 | import math 14 | import argparse 15 | import torch 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('render_dir', type=str) 19 | parser.add_argument('--crop', type=float, default=1.0, help='center crop') 20 | config_util.define_common_args(parser) 21 | args = parser.parse_args() 22 | 23 | if path.isfile(args.render_dir): 24 | print('please give the test_renders directory (not checkpoint) in the future') 25 | args.render_dir = path.join(path.dirname(args.render_dir), 'test_renders') 26 | 27 | device = 'cuda:0' 28 | 29 | import lpips 30 | lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device) 31 | 32 | dset = datasets[args.dataset_type](args.data_dir, split="test", 33 | **config_util.build_data_options(args)) 34 | 35 | 36 | im_files = sorted(glob(path.join(args.render_dir, "*.png"))) 37 | im_files = [x for x in im_files if not path.basename(x).startswith('disp_')] # Remove depths 38 | assert len(im_files) == dset.n_images, \ 39 | f'number of images found {len(im_files)} differs from test set images:{dset.n_images}' 40 | 41 | avg_psnr = 0.0 42 | avg_ssim = 0.0 43 | avg_lpips = 0.0 44 | n_images_gen = 0 45 | for i, im_path in enumerate(im_files): 46 | im = torch.from_numpy(imageio.imread(im_path)) 47 | im_gt = dset.gt[i] 48 | if im.shape[1] >= im_gt.shape[1] * 2: 49 | # Assume we have some gt/baselines on the left 50 | im = im[:, -im_gt.shape[1]:] 51 | im = im.float() / 255 52 | if args.crop != 1.0: 53 | del_tb = int(im.shape[0] * (1.0 - args.crop) * 0.5) 54 | del_lr = int(im.shape[1] * (1.0 - args.crop) * 0.5) 55 | im = im[del_tb:-del_tb, del_lr:-del_lr] 56 | im_gt = im_gt[del_tb:-del_tb, del_lr:-del_lr] 57 | 58 | mse = (im - im_gt) ** 2 59 | mse_num : float = mse.mean().item() 60 | psnr = -10.0 * math.log10(mse_num) 61 | ssim = compute_ssim(im_gt, im).item() 62 | lpips_i = lpips_vgg(im_gt.permute([2, 0, 1]).cuda().contiguous(), 63 | im.permute([2, 0, 1]).cuda().contiguous(), 64 | normalize=True).item() 65 | 66 | print(i, 'of', len(im_files), '; PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips_i) 67 | avg_psnr += psnr 68 | avg_ssim += ssim 69 | avg_lpips += lpips_i 70 | n_images_gen += 1 # Just to be sure 71 | 72 | avg_psnr /= n_images_gen 73 | avg_ssim /= n_images_gen 74 | avg_lpips /= n_images_gen 75 | print('AVERAGES') 76 | print('PSNR:', avg_psnr) 77 | print('SSIM:', avg_ssim) 78 | print('LPIPS:', avg_lpips) 79 | postfix = '_cropped' if args.crop != 1.0 else '' 80 | # with open(path.join(args.render_dir, f'psnr{postfix}.txt'), 'w') as f: 81 | # f.write(str(avg_psnr)) 82 | # with open(path.join(args.render_dir, f'ssim{postfix}.txt'), 'w') as f: 83 | # f.write(str(avg_ssim)) 84 | # with open(path.join(args.render_dir, f'lpips{postfix}.txt'), 'w') as f: 85 | # f.write(str(avg_lpips)) 86 | -------------------------------------------------------------------------------- /opt/configs/co3d.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[128, 128, 128], [256, 256, 256], [512, 512, 512], [640, 640, 640]]", 3 | "seq_id": 5075, 4 | "n_iters": 102400, 5 | "background_nlayers": 64, 6 | "background_reso": 1024, 7 | "upsamp_every": 25600, 8 | "near_clip": 0.35, 9 | "lr_sigma": 3e1, 10 | "lr_sh": 1e-2, 11 | "lr_sigma_delay_steps": 0, 12 | "lr_fg_begin_step": 1000, 13 | "thresh_type": "weight", 14 | "weight_thresh": 1.28, 15 | "lambda_tv": 5e-5, 16 | "lambda_tv_sh": 5e-3, 17 | "lambda_tv_background_sigma": 1e-3, 18 | "lambda_tv_background_color": 1e-3, 19 | "lambda_beta": 1e-5, 20 | "lambda_sparsity": 1e-11, 21 | "background_brightness": 0.5, 22 | "tv_early_only": 0, 23 | "tv_decay": 0.5 24 | } 25 | -------------------------------------------------------------------------------- /opt/configs/custom.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[128, 128, 128], [256, 256, 256], [512, 512, 512]]", 3 | "n_iters": 102400, 4 | "background_nlayers": 64, 5 | "background_reso": 1024, 6 | "cam_scale_factor": 0.9, 7 | "upsamp_every": 25600, 8 | "near_clip": 0.35, 9 | "lr_sigma": 3e1, 10 | "lr_sh": 1e-2, 11 | "lr_sigma_delay_steps": 0, 12 | "lr_fg_begin_step": 1000, 13 | "thresh_type": "weight", 14 | "weight_thresh": 1.28, 15 | "lambda_tv": 5e-3, 16 | "lambda_tv_sh": 5e-3, 17 | "lambda_tv_background_sigma": 5e-3, 18 | "lambda_tv_background_color": 5e-3, 19 | "lambda_beta": 1e-5, 20 | "lambda_sparsity": 1e-11, 21 | "background_brightness": 0.5, 22 | "tv_early_only": 0, 23 | "tv_decay": 0.5 24 | } 25 | -------------------------------------------------------------------------------- /opt/configs/custom_alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[128, 128, 128], [256, 256, 256], [512, 512, 512]]", 3 | "n_iters": 102400, 4 | "background_nlayers": 64, 5 | "background_reso": 1024, 6 | "cam_scale_factor": 0.95, 7 | "upsamp_every": 38400, 8 | "lr_sigma": 3e1, 9 | "lr_sh": 1e-2, 10 | "lr_sigma_delay_steps": 35000, 11 | "lr_fg_begin_step": 50, 12 | "thresh_type": "weight", 13 | "weight_thresh": 1.28, 14 | "lambda_tv": 5e-5, 15 | "lambda_tv_sh": 5e-3, 16 | "lambda_tv_background_sigma": 1e-3, 17 | "lambda_tv_background_color": 1e-3, 18 | "lambda_beta": 1e-5, 19 | "lambda_sparsity": 1e-11, 20 | "background_brightness": 1.0, 21 | "tv_early_only": 0 22 | } 23 | -------------------------------------------------------------------------------- /opt/configs/llff.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[256, 256, 128], [512, 512, 128], [1408, 1156, 128]]", 3 | "upsamp_every": 38400, 4 | "lr_sigma": 3e1, 5 | "lr_sh": 1e-2, 6 | "thresh_type": "sigma", 7 | "density_thresh": 5, 8 | "lambda_tv": 5e-4, 9 | "lambda_tv_sh": 5e-3, 10 | "lambda_sparsity": 1e-12, 11 | "background_brightness": 0.5, 12 | "last_sample_opaque": false, 13 | "tv_early_only": 0 14 | } 15 | -------------------------------------------------------------------------------- /opt/configs/llff_hitv.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[256, 256, 128], [512, 512, 128], [1408, 1156, 128]]", 3 | "upsamp_every": 38400, 4 | "lr_sigma": 3e1, 5 | "lr_sh": 1e-2, 6 | "thresh_type": "sigma", 7 | "density_thresh": 5, 8 | "lambda_tv": 5e-3, 9 | "lambda_tv_sh": 5e-2, 10 | "lambda_sparsity": 1e-12, 11 | "background_brightness": 0.5, 12 | "last_sample_opaque": false, 13 | "tv_early_only": 0 14 | } 15 | -------------------------------------------------------------------------------- /opt/configs/syn.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[256, 256, 256], [512, 512, 512]]", 3 | "upsamp_every": 38400, 4 | "lr_sigma": 3e1, 5 | "lr_sh": 1e-2, 6 | "lambda_tv": 1e-5, 7 | "lambda_tv_sh": 1e-3 8 | } 9 | -------------------------------------------------------------------------------- /opt/configs/syn_nv.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[256, 256, 256]]", 3 | "upsamp_every": 38400, 4 | "renderer_backend": "nvol", 5 | "lr_sigma": 3e-1, 6 | "lr_sigma_final": 3e-5, 7 | "lr_sigma_delay_steps": 0, 8 | "lr_sh": 1e-2, 9 | "lambda_tv": 0.0, 10 | "lambda_tv_sh": 0.0 11 | } 12 | -------------------------------------------------------------------------------- /opt/configs/tnt.json: -------------------------------------------------------------------------------- 1 | { 2 | "reso": "[[128, 128, 128], [256, 256, 256], [512, 512, 512], [640, 640, 640]]", 3 | "n_iters": 102400, 4 | "background_nlayers": 64, 5 | "background_reso": 1024, 6 | "upsamp_every": 25600, 7 | "lr_sigma": 3e1, 8 | "lr_sh": 1e-2, 9 | "lr_sigma_delay_steps": 15000, 10 | "thresh_type": "weight", 11 | "weight_thresh": 1.28, 12 | "lambda_tv": 5e-5, 13 | "lambda_tv_sh": 5e-3, 14 | "lambda_tv_background_sigma": 1e-3, 15 | "lambda_tv_background_color": 1e-3, 16 | "lambda_beta": 1e-5, 17 | "lambda_sparsity": 1e-11, 18 | "background_brightness": 1.0, 19 | "tv_early_only": 0 20 | } 21 | -------------------------------------------------------------------------------- /opt/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo Launching experiment $1 4 | echo GPU $2 5 | echo EXTRA ${@:3} 6 | 7 | CKPT_DIR=ckpt/$1 8 | mkdir -p $CKPT_DIR 9 | NOHUP_FILE=$CKPT_DIR/log 10 | echo CKPT $CKPT_DIR 11 | echo LOGFILE $NOHUP_FILE 12 | 13 | CUDA_VISIBLE_DEVICES=$2 nohup python -u opt.py -t $CKPT_DIR ${@:3} > $NOHUP_FILE 2>&1 & 14 | echo DETACH 15 | -------------------------------------------------------------------------------- /opt/render_imgs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | # Eval 3 | 4 | import torch 5 | import svox2 6 | import svox2.utils 7 | import math 8 | import argparse 9 | import numpy as np 10 | import os 11 | from os import path 12 | from util.dataset import datasets 13 | from util.util import Timing, compute_ssim, viridis_cmap 14 | from util import config_util 15 | 16 | import imageio 17 | import cv2 18 | from tqdm import tqdm 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('ckpt', type=str) 21 | 22 | config_util.define_common_args(parser) 23 | 24 | parser.add_argument('--n_eval', '-n', type=int, default=100000, help='images to evaluate (equal interval), at most evals every image') 25 | parser.add_argument('--train', action='store_true', default=False, help='render train set') 26 | parser.add_argument('--render_path', 27 | action='store_true', 28 | default=False, 29 | help="Render path instead of test images (no metrics will be given)") 30 | parser.add_argument('--timing', 31 | action='store_true', 32 | default=False, 33 | help="Run only for timing (do not save images or use LPIPS/SSIM; " 34 | "still computes PSNR to make sure images are being generated)") 35 | parser.add_argument('--no_lpips', 36 | action='store_true', 37 | default=False, 38 | help="Disable LPIPS (faster load)") 39 | parser.add_argument('--no_vid', 40 | action='store_true', 41 | default=False, 42 | help="Disable video generation") 43 | parser.add_argument('--no_imsave', 44 | action='store_true', 45 | default=False, 46 | help="Disable image saving (can still save video; MUCH faster)") 47 | parser.add_argument('--fps', 48 | type=int, 49 | default=30, 50 | help="FPS of video") 51 | 52 | # Camera adjustment 53 | parser.add_argument('--crop', 54 | type=float, 55 | default=1.0, 56 | help="Crop (0, 1], 1.0 = full image") 57 | 58 | # Foreground/background only 59 | parser.add_argument('--nofg', 60 | action='store_true', 61 | default=False, 62 | help="Do not render foreground (if using BG model)") 63 | parser.add_argument('--nobg', 64 | action='store_true', 65 | default=False, 66 | help="Do not render background (if using BG model)") 67 | 68 | # Random debugging features 69 | parser.add_argument('--blackbg', 70 | action='store_true', 71 | default=False, 72 | help="Force a black BG (behind BG model) color; useful for debugging 'clouds'") 73 | parser.add_argument('--ray_len', 74 | action='store_true', 75 | default=False, 76 | help="Render the ray lengths") 77 | 78 | args = parser.parse_args() 79 | config_util.maybe_merge_config_file(args, allow_invalid=True) 80 | device = 'cuda:0' 81 | 82 | if args.timing: 83 | args.no_lpips = True 84 | args.no_vid = True 85 | args.ray_len = False 86 | 87 | if not args.no_lpips: 88 | import lpips 89 | lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device) 90 | if not path.isfile(args.ckpt): 91 | args.ckpt = path.join(args.ckpt, 'ckpt.npz') 92 | 93 | render_dir = path.join(path.dirname(args.ckpt), 94 | 'train_renders' if args.train else 'test_renders') 95 | want_metrics = True 96 | if args.render_path: 97 | assert not args.train 98 | render_dir += '_path' 99 | want_metrics = False 100 | 101 | # Handle various image transforms 102 | if not args.render_path: 103 | # Do not crop if not render_path 104 | args.crop = 1.0 105 | if args.crop != 1.0: 106 | render_dir += f'_crop{args.crop}' 107 | if args.ray_len: 108 | render_dir += f'_raylen' 109 | want_metrics = False 110 | 111 | dset = datasets[args.dataset_type](args.data_dir, split="test_train" if args.train else "test", 112 | **config_util.build_data_options(args)) 113 | 114 | grid = svox2.SparseGrid.load(args.ckpt, device=device) 115 | 116 | if grid.use_background: 117 | if args.nobg: 118 | # grid.background_cubemap.data = grid.background_cubemap.data.cuda() 119 | grid.background_data.data[..., -1] = 0.0 120 | render_dir += '_nobg' 121 | if args.nofg: 122 | grid.density_data.data[:] = 0.0 123 | # grid.sh_data.data[..., 0] = 1.0 / svox2.utils.SH_C0 124 | # grid.sh_data.data[..., 9] = 1.0 / svox2.utils.SH_C0 125 | # grid.sh_data.data[..., 18] = 1.0 / svox2.utils.SH_C0 126 | render_dir += '_nofg' 127 | 128 | # DEBUG 129 | # grid.links.data[grid.links.size(0)//2:] = -1 130 | # render_dir += "_chopx2" 131 | 132 | config_util.setup_render_opts(grid.opt, args) 133 | 134 | if args.blackbg: 135 | print('Forcing black bg') 136 | render_dir += '_blackbg' 137 | grid.opt.background_brightness = 0.0 138 | 139 | print('Writing to', render_dir) 140 | os.makedirs(render_dir, exist_ok=True) 141 | 142 | if not args.no_imsave: 143 | print('Will write out all frames as PNG (this take most of the time)') 144 | 145 | # NOTE: no_grad enables the fast image-level rendering kernel for cuvol backend only 146 | # other backends will manually generate rays per frame (slow) 147 | with torch.no_grad(): 148 | n_images = dset.render_c2w.size(0) if args.render_path else dset.n_images 149 | img_eval_interval = max(n_images // args.n_eval, 1) 150 | avg_psnr = 0.0 151 | avg_ssim = 0.0 152 | avg_lpips = 0.0 153 | n_images_gen = 0 154 | c2ws = dset.render_c2w.to(device=device) if args.render_path else dset.c2w.to(device=device) 155 | # DEBUGGING 156 | # rad = [1.496031746031746, 1.6613756613756614, 1.0] 157 | # half_sz = [grid.links.size(0) // 2, grid.links.size(1) // 2] 158 | # pad_size_x = int(half_sz[0] - half_sz[0] / 1.496031746031746) 159 | # pad_size_y = int(half_sz[1] - half_sz[1] / 1.6613756613756614) 160 | # print(pad_size_x, pad_size_y) 161 | # grid.links[:pad_size_x] = -1 162 | # grid.links[-pad_size_x:] = -1 163 | # grid.links[:, :pad_size_y] = -1 164 | # grid.links[:, -pad_size_y:] = -1 165 | # grid.links[:, :, -8:] = -1 166 | 167 | # LAYER = -16 168 | # grid.links[:, :, :LAYER] = -1 169 | # grid.links[:, :, LAYER+1:] = -1 170 | 171 | frames = [] 172 | # im_gt_all = dset.gt.to(device=device) 173 | 174 | for img_id in tqdm(range(0, n_images, img_eval_interval)): 175 | dset_h, dset_w = dset.get_image_size(img_id) 176 | im_size = dset_h * dset_w 177 | w = dset_w if args.crop == 1.0 else int(dset_w * args.crop) 178 | h = dset_h if args.crop == 1.0 else int(dset_h * args.crop) 179 | 180 | cam = svox2.Camera(c2ws[img_id], 181 | dset.intrins.get('fx', img_id), 182 | dset.intrins.get('fy', img_id), 183 | dset.intrins.get('cx', img_id) + (w - dset_w) * 0.5, 184 | dset.intrins.get('cy', img_id) + (h - dset_h) * 0.5, 185 | w, h, 186 | ndc_coeffs=dset.ndc_coeffs) 187 | im = grid.volume_render_image(cam, use_kernel=True, return_raylen=args.ray_len) 188 | if args.ray_len: 189 | minv, meanv, maxv = im.min().item(), im.mean().item(), im.max().item() 190 | im = viridis_cmap(im.cpu().numpy()) 191 | cv2.putText(im, f"{minv=:.4f} {meanv=:.4f} {maxv=:.4f}", (10, 20), 192 | 0, 0.5, [255, 0, 0]) 193 | im = torch.from_numpy(im).to(device=device) 194 | im.clamp_(0.0, 1.0) 195 | 196 | if not args.render_path: 197 | im_gt = dset.gt[img_id].to(device=device) 198 | mse = (im - im_gt) ** 2 199 | mse_num : float = mse.mean().item() 200 | psnr = -10.0 * math.log10(mse_num) 201 | avg_psnr += psnr 202 | if not args.timing: 203 | ssim = compute_ssim(im_gt, im).item() 204 | avg_ssim += ssim 205 | if not args.no_lpips: 206 | lpips_i = lpips_vgg(im_gt.permute([2, 0, 1]).contiguous(), 207 | im.permute([2, 0, 1]).contiguous(), normalize=True).item() 208 | avg_lpips += lpips_i 209 | print(img_id, 'PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips_i) 210 | else: 211 | print(img_id, 'PSNR', psnr, 'SSIM', ssim) 212 | img_path = path.join(render_dir, f'{img_id:04d}.png'); 213 | im = im.cpu().numpy() 214 | if not args.render_path: 215 | im_gt = dset.gt[img_id].numpy() 216 | im = np.concatenate([im_gt, im], axis=1) 217 | if not args.timing: 218 | im = (im * 255).astype(np.uint8) 219 | if not args.no_imsave: 220 | imageio.imwrite(img_path,im) 221 | if not args.no_vid: 222 | frames.append(im) 223 | im = None 224 | n_images_gen += 1 225 | if want_metrics: 226 | print('AVERAGES') 227 | 228 | avg_psnr /= n_images_gen 229 | with open(path.join(render_dir, 'psnr.txt'), 'w') as f: 230 | f.write(str(avg_psnr)) 231 | print('PSNR:', avg_psnr) 232 | if not args.timing: 233 | avg_ssim /= n_images_gen 234 | print('SSIM:', avg_ssim) 235 | with open(path.join(render_dir, 'ssim.txt'), 'w') as f: 236 | f.write(str(avg_ssim)) 237 | if not args.no_lpips: 238 | avg_lpips /= n_images_gen 239 | print('LPIPS:', avg_lpips) 240 | with open(path.join(render_dir, 'lpips.txt'), 'w') as f: 241 | f.write(str(avg_lpips)) 242 | if not args.no_vid and len(frames): 243 | vid_path = render_dir + '.mp4' 244 | imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) # pip install imageio-ffmpeg 245 | 246 | 247 | -------------------------------------------------------------------------------- /opt/render_imgs_circle.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Alex Yu 2 | # Render 360 circle path 3 | 4 | import torch 5 | import svox2 6 | import svox2.utils 7 | import math 8 | import argparse 9 | import numpy as np 10 | import os 11 | from os import path 12 | from util.dataset import datasets 13 | from util.util import Timing, compute_ssim, viridis_cmap, pose_spherical 14 | from util import config_util 15 | 16 | import imageio 17 | import cv2 18 | from tqdm import tqdm 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('ckpt', type=str) 21 | 22 | config_util.define_common_args(parser) 23 | 24 | parser.add_argument('--n_eval', '-n', type=int, default=100000, help='images to evaluate (equal interval), at most evals every image') 25 | parser.add_argument('--traj_type', 26 | choices=['spiral', 'circle'], 27 | default='spiral', 28 | help="Render a spiral (doubles length, using 2 elevations), or just a cirle") 29 | parser.add_argument('--fps', 30 | type=int, 31 | default=30, 32 | help="FPS of video") 33 | parser.add_argument( 34 | "--width", "-W", type=float, default=None, help="Rendering image width (only if not --traj)" 35 | ) 36 | parser.add_argument( 37 | "--height", "-H", type=float, default=None, help="Rendering image height (only if not --traj)" 38 | ) 39 | parser.add_argument( 40 | "--num_views", "-N", type=int, default=600, 41 | help="Number of frames to render" 42 | ) 43 | 44 | # Path adjustment 45 | parser.add_argument( 46 | "--offset", type=str, default="0,0,0", help="Center point to rotate around (only if not --traj)" 47 | ) 48 | parser.add_argument("--radius", type=float, default=0.85, help="Radius of orbit (only if not --traj)") 49 | parser.add_argument( 50 | "--elevation", 51 | type=float, 52 | default=-45.0, 53 | help="Elevation of orbit in deg, negative is above", 54 | ) 55 | parser.add_argument( 56 | "--elevation2", 57 | type=float, 58 | default=-12.0, 59 | help="Max elevation, only for spiral", 60 | ) 61 | parser.add_argument( 62 | "--vec_up", 63 | type=str, 64 | default=None, 65 | help="up axis for camera views (only if not --traj);" 66 | "3 floats separated by ','; if not given automatically determined", 67 | ) 68 | parser.add_argument( 69 | "--vert_shift", 70 | type=float, 71 | default=0.0, 72 | help="vertical shift by up axis" 73 | ) 74 | 75 | # Camera adjustment 76 | parser.add_argument('--crop', 77 | type=float, 78 | default=1.0, 79 | help="Crop (0, 1], 1.0 = full image") 80 | 81 | # Foreground/background only 82 | parser.add_argument('--nofg', 83 | action='store_true', 84 | default=False, 85 | help="Do not render foreground (if using BG model)") 86 | parser.add_argument('--nobg', 87 | action='store_true', 88 | default=False, 89 | help="Do not render background (if using BG model)") 90 | 91 | # Random debugging features 92 | parser.add_argument('--blackbg', 93 | action='store_true', 94 | default=False, 95 | help="Force a black BG (behind BG model) color; useful for debugging 'clouds'") 96 | 97 | args = parser.parse_args() 98 | config_util.maybe_merge_config_file(args, allow_invalid=True) 99 | device = 'cuda:0' 100 | 101 | 102 | dset = datasets[args.dataset_type](args.data_dir, split="test", 103 | **config_util.build_data_options(args)) 104 | 105 | if args.vec_up is None: 106 | up_rot = dset.c2w[:, :3, :3].cpu().numpy() 107 | ups = np.matmul(up_rot, np.array([0, -1.0, 0])[None, :, None])[..., 0] 108 | args.vec_up = np.mean(ups, axis=0) 109 | args.vec_up /= np.linalg.norm(args.vec_up) 110 | print(' Auto vec_up', args.vec_up) 111 | else: 112 | args.vec_up = np.array(list(map(float, args.vec_up.split(",")))) 113 | 114 | 115 | args.offset = np.array(list(map(float, args.offset.split(",")))) 116 | if args.traj_type == 'spiral': 117 | angles = np.linspace(-180, 180, args.num_views + 1)[:-1] 118 | elevations = np.linspace(args.elevation, args.elevation2, args.num_views) 119 | c2ws = [ 120 | pose_spherical( 121 | angle, 122 | ele, 123 | args.radius, 124 | args.offset, 125 | vec_up=args.vec_up, 126 | ) 127 | for ele, angle in zip(elevations, angles) 128 | ] 129 | c2ws += [ 130 | pose_spherical( 131 | angle, 132 | ele, 133 | args.radius, 134 | args.offset, 135 | vec_up=args.vec_up, 136 | ) 137 | for ele, angle in zip(reversed(elevations), angles) 138 | ] 139 | else : 140 | c2ws = [ 141 | pose_spherical( 142 | angle, 143 | args.elevation, 144 | args.radius, 145 | args.offset, 146 | vec_up=args.vec_up, 147 | ) 148 | for angle in np.linspace(-180, 180, args.num_views + 1)[:-1] 149 | ] 150 | c2ws = np.stack(c2ws, axis=0) 151 | if args.vert_shift != 0.0: 152 | c2ws[:, :3, 3] += np.array(args.vec_up) * args.vert_shift 153 | c2ws = torch.from_numpy(c2ws).to(device=device) 154 | 155 | if not path.isfile(args.ckpt): 156 | args.ckpt = path.join(args.ckpt, 'ckpt.npz') 157 | 158 | render_out_path = path.join(path.dirname(args.ckpt), 'circle_renders') 159 | 160 | # Handle various image transforms 161 | if args.crop != 1.0: 162 | render_out_path += f'_crop{args.crop}' 163 | if args.vert_shift != 0.0: 164 | render_out_path += f'_vshift{args.vert_shift}' 165 | 166 | grid = svox2.SparseGrid.load(args.ckpt, device=device) 167 | print(grid.center, grid.radius) 168 | 169 | # DEBUG 170 | # grid.background_data.data[:, 32:, -1] = 0.0 171 | # render_out_path += '_front' 172 | 173 | if grid.use_background: 174 | if args.nobg: 175 | grid.background_data.data[..., -1] = 0.0 176 | render_out_path += '_nobg' 177 | if args.nofg: 178 | grid.density_data.data[:] = 0.0 179 | # grid.sh_data.data[..., 0] = 1.0 / svox2.utils.SH_C0 180 | # grid.sh_data.data[..., 9] = 1.0 / svox2.utils.SH_C0 181 | # grid.sh_data.data[..., 18] = 1.0 / svox2.utils.SH_C0 182 | render_out_path += '_nofg' 183 | 184 | # # DEBUG 185 | # grid.background_data.data[..., -1] = 100.0 186 | # a1 = torch.linspace(0, 1, grid.background_data.size(0) // 2, dtype=torch.float32, device=device)[:, None] 187 | # a2 = torch.linspace(1, 0, (grid.background_data.size(0) - 1) // 2 + 1, dtype=torch.float32, device=device)[:, None] 188 | # a = torch.cat([a1, a2], dim=0) 189 | # c = torch.stack([a, 1-a, torch.zeros_like(a)], dim=-1) 190 | # grid.background_data.data[..., :-1] = c 191 | # render_out_path += "_gradient" 192 | 193 | config_util.setup_render_opts(grid.opt, args) 194 | 195 | if args.blackbg: 196 | print('Forcing black bg') 197 | render_out_path += '_blackbg' 198 | grid.opt.background_brightness = 0.0 199 | 200 | render_out_path += '.mp4' 201 | print('Writing to', render_out_path) 202 | 203 | # NOTE: no_grad enables the fast image-level rendering kernel for cuvol backend only 204 | # other backends will manually generate rays per frame (slow) 205 | with torch.no_grad(): 206 | n_images = c2ws.size(0) 207 | img_eval_interval = max(n_images // args.n_eval, 1) 208 | avg_psnr = 0.0 209 | avg_ssim = 0.0 210 | avg_lpips = 0.0 211 | n_images_gen = 0 212 | frames = [] 213 | # if args.near_clip >= 0.0: 214 | grid.opt.near_clip = 0.0 #args.near_clip 215 | if args.width is None: 216 | args.width = dset.get_image_size(0)[1] 217 | if args.height is None: 218 | args.height = dset.get_image_size(0)[0] 219 | 220 | for img_id in tqdm(range(0, n_images, img_eval_interval)): 221 | dset_h, dset_w = args.height, args.width 222 | im_size = dset_h * dset_w 223 | w = dset_w if args.crop == 1.0 else int(dset_w * args.crop) 224 | h = dset_h if args.crop == 1.0 else int(dset_h * args.crop) 225 | 226 | cam = svox2.Camera(c2ws[img_id], 227 | dset.intrins.get('fx', 0), 228 | dset.intrins.get('fy', 0), 229 | w * 0.5, 230 | h * 0.5, 231 | w, h, 232 | ndc_coeffs=(-1.0, -1.0)) 233 | torch.cuda.synchronize() 234 | im = grid.volume_render_image(cam, use_kernel=True) 235 | torch.cuda.synchronize() 236 | im.clamp_(0.0, 1.0) 237 | 238 | im = im.cpu().numpy() 239 | im = (im * 255).astype(np.uint8) 240 | frames.append(im) 241 | im = None 242 | n_images_gen += 1 243 | if len(frames): 244 | vid_path = render_out_path 245 | imageio.mimwrite(vid_path, frames, fps=args.fps, macro_block_size=8) # pip install imageio-ffmpeg 246 | 247 | 248 | -------------------------------------------------------------------------------- /opt/scripts/colmap2nsvf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Process COLMAP output into NSVF data format (almost) ready to use with our system. 3 | Usage: /sparse/0 (or replace 0 with some other partial map) 4 | 5 | Add -s to tune the world scale (default 1). You can tune this by 6 | view_colmap_data.py or other tools. 7 | 8 | NOTE: This file should probably be merged into the run_colmap.py 9 | 10 | IMPORTANT: The dataset will not be split into train/test sets as currently required by 11 | our code. You need to append 1_ before test set image names and 0_ before train set image names. 12 | Use python create_split.py to split it automatically if lazy. 13 | 14 | ******* 15 | 16 | The root directory will look like 17 | 18 | COLMAP preconditions: 19 | images/ : all the images (expect this to already exist) 20 | sparse/0 : the sparse map (from COLMAP) 21 | database.db 22 | 23 | Our main outputs: 24 | pose/ : 4x4 pose matrix for each image 25 | intrinsics.txt : 4x4 intrinsics matrix, only 0,0 and 1,1 entries (focal length) matter 26 | 27 | Additionally, 28 | points.npy : Nx3 sparse point cloud of features 29 | feature/ : npz for each frame storing info about features in each image. 30 | Contains fields xys and ids, where xys = position of feature on image, 31 | ids = row of this point in points.npy 32 | 33 | ******* 34 | 35 | If you want to pre-downscale the images, save the corresponding image files 36 | in directory 37 | images_/ 38 | where is something like 2 or 4. Then set data.factor in the config. 39 | You can use mogrify to do this. Or use the script downsample.py, which uses OpenCV and concurrent.futures: 40 | 41 | python downsample.py /images 42 | 43 | Note that if you do not do this and set a factor in the config anyway, 44 | the images will be resized dynamically on load. 45 | 46 | ******* 47 | 48 | The code to parse the COLMAP sparse bin files is from LLFF. 49 | """ 50 | # Copyright 2021 Alex Yu 51 | import os 52 | import os.path as osp 53 | import numpy as np 54 | import struct 55 | import collections 56 | import argparse 57 | import shutil 58 | 59 | CameraModel = collections.namedtuple( 60 | "CameraModel", ["model_id", "model_name", "num_params"] 61 | ) 62 | Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) 63 | BaseImage = collections.namedtuple( 64 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] 65 | ) 66 | Point3D = collections.namedtuple( 67 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] 68 | ) 69 | 70 | 71 | def qvec2rotmat(qvec): 72 | return np.array( 73 | [ 74 | [ 75 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 76 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 77 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 78 | ], 79 | [ 80 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 81 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 82 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 83 | ], 84 | [ 85 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 86 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 87 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 88 | ], 89 | ] 90 | ) 91 | 92 | 93 | class Image(BaseImage): 94 | def qvec2rotmat(self): 95 | return qvec2rotmat(self.qvec) 96 | 97 | 98 | CAMERA_MODELS = { 99 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 100 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 101 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 102 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 103 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 104 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 105 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 106 | CameraModel(model_id=7, model_name="FOV", num_params=5), 107 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 108 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 109 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 110 | } 111 | CAMERA_MODEL_IDS = dict( 112 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] 113 | ) 114 | 115 | 116 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 117 | """Read and unpack the next bytes from a binary file. 118 | :param fid: 119 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 120 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 121 | :param endian_character: Any of {@, =, <, >, !} 122 | :return: Tuple of read and unpacked values. 123 | """ 124 | data = fid.read(num_bytes) 125 | return struct.unpack(endian_character + format_char_sequence, data) 126 | 127 | 128 | def read_colmap_sparse(sparse_path): 129 | cameras = [] 130 | with open(osp.join(sparse_path, "cameras.bin"), "rb") as fid: 131 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 132 | assert num_cameras == 1, "Only supports single camera" 133 | for _ in range(num_cameras): 134 | camera_properties = read_next_bytes( 135 | fid, num_bytes=24, format_char_sequence="iiQQ" 136 | ) 137 | camera_id = camera_properties[0] 138 | model_id = camera_properties[1] 139 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 140 | assert model_name in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"], \ 141 | "Only SIMPLE_PINHOLE/SIMPLE_RADIAL supported" 142 | width = camera_properties[2] 143 | height = camera_properties[3] 144 | num_params = CAMERA_MODEL_IDS[model_id].num_params 145 | params = read_next_bytes( 146 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params 147 | ) 148 | cameras.append( 149 | Camera( 150 | id=camera_id, 151 | model=model_name, 152 | width=width, 153 | height=height, 154 | params=np.array(params), 155 | ) 156 | ) 157 | assert len(cameras) == num_cameras 158 | points3D_idmap = {} 159 | points3D = [] 160 | with open(osp.join(sparse_path, "points3D.bin"), "rb") as fid: 161 | num_points = read_next_bytes(fid, 8, "Q")[0] 162 | for i in range(num_points): 163 | binary_point_line_properties = read_next_bytes( 164 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 165 | ) 166 | point3D_id = binary_point_line_properties[0] 167 | points3D_idmap[point3D_id] = i 168 | xyz = np.array(binary_point_line_properties[1:4]) 169 | rgb = np.array(binary_point_line_properties[4:7]) 170 | error = np.array(binary_point_line_properties[7]) 171 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 172 | 0 173 | ] 174 | track_elems = read_next_bytes( 175 | fid, 176 | num_bytes=8 * track_length, 177 | format_char_sequence="ii" * track_length, 178 | ) 179 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 180 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 181 | points3D.append( 182 | Point3D( 183 | id=point3D_id, 184 | xyz=xyz, 185 | rgb=rgb, 186 | error=error, 187 | image_ids=image_ids, 188 | point2D_idxs=point2D_idxs, 189 | ) 190 | ) 191 | images = [] 192 | with open(osp.join(sparse_path, "images.bin"), "rb") as fid: 193 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 194 | for _ in range(num_reg_images): 195 | binary_image_properties = read_next_bytes( 196 | fid, num_bytes=64, format_char_sequence="idddddddi" 197 | ) 198 | image_id = binary_image_properties[0] 199 | qvec = np.array(binary_image_properties[1:5]) 200 | tvec = np.array(binary_image_properties[5:8]) 201 | camera_id = binary_image_properties[8] 202 | image_name = "" 203 | current_char = read_next_bytes(fid, 1, "c")[0] 204 | while current_char != b"\x00": # look for the ASCII 0 entry 205 | image_name += current_char.decode("utf-8") 206 | current_char = read_next_bytes(fid, 1, "c")[0] 207 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 208 | 0 209 | ] 210 | x_y_id_s = read_next_bytes( 211 | fid, 212 | num_bytes=24 * num_points2D, 213 | format_char_sequence="ddq" * num_points2D, 214 | ) 215 | xys = np.column_stack( 216 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] 217 | ) 218 | point3D_ids = list(map(int, x_y_id_s[2::3])) 219 | point3D_ids = [points3D_idmap[x] for x in point3D_ids if x >= 0] 220 | point3D_ids = np.array(point3D_ids) 221 | images.append( 222 | Image( 223 | id=image_id, 224 | qvec=qvec, 225 | tvec=tvec, 226 | camera_id=camera_id, 227 | name=image_name, 228 | xys=xys, 229 | point3D_ids=point3D_ids, 230 | ) 231 | ) 232 | return cameras, images, points3D 233 | 234 | 235 | def main(): 236 | parser = argparse.ArgumentParser() 237 | parser.add_argument( 238 | "sparse_dir", 239 | type=str, 240 | help="COLMAP output sparse model dir e.g. sparse/0. We expect images to be at sparse_dir/../../images", 241 | ) 242 | parser.add_argument( 243 | "--scale", 244 | "-s", 245 | type=float, 246 | default=1.0, 247 | help="Scale to apply to scene, tune this to improve the autoscaling", 248 | ) 249 | parser.add_argument( 250 | "--gl_cam", 251 | action="store_true", 252 | default=False, 253 | help="Change camera space convention to match NeRF, jaxNeRF " 254 | "(our implementation uses OpenCV convention and does not need this, " 255 | "set data.gl_cam_space = True in the config if you use this option)", 256 | ) 257 | parser.add_argument( 258 | "--overwrite", 259 | action="store_true", 260 | default=False, 261 | help="Overwrite output dirs if exists", 262 | ) 263 | parser.add_argument( 264 | "--overwrite_no_del", 265 | action="store_true", 266 | default=False, 267 | help="Do not delete existing files if overwritting", 268 | ) 269 | parser.add_argument( 270 | "--colmap_suffix", 271 | action="store_true", 272 | default=False, 273 | help="Output to pose_colmap and intrinsics_colmap.txt to retain the gt poses/intrinsics if available", 274 | ) 275 | args = parser.parse_args() 276 | 277 | if args.sparse_dir.endswith("/"): 278 | args.sparse_dir = args.sparse_dir[:-1] 279 | base_dir = osp.dirname(osp.dirname(args.sparse_dir)) 280 | pose_dir = osp.join(base_dir, "pose_colmap" if args.colmap_suffix else "pose") 281 | feat_dir = osp.join(base_dir, "feature") 282 | base_scale_file = osp.join(base_dir, "base_scale.txt") 283 | if osp.exists(base_scale_file): 284 | with open(base_scale_file, 'r') as f: 285 | base_scale = float(f.read()) 286 | print('base_scale', base_scale) 287 | else: 288 | base_scale = 1.0 289 | print('base_scale defaulted to', base_scale) 290 | print("BASE_DIR", base_dir) 291 | print("POSE_DIR", pose_dir) 292 | print("FEATURE_DIR", feat_dir) 293 | print("COLMAP_OUT_DIR", args.sparse_dir) 294 | overwrite = args.overwrite 295 | 296 | def create_or_recreate_dir(dirname): 297 | if osp.isdir(dirname): 298 | import click 299 | 300 | nonlocal overwrite 301 | if overwrite or click.confirm(f"Directory {dirname} exists, overwrite?"): 302 | if not args.overwrite_no_del: 303 | shutil.rmtree(dirname) 304 | overwrite = True 305 | else: 306 | print("Quitting") 307 | import sys 308 | 309 | sys.exit(1) 310 | os.makedirs(dirname, exist_ok=True) 311 | 312 | cameras, imdata, points3D = read_colmap_sparse(args.sparse_dir) 313 | create_or_recreate_dir(pose_dir) 314 | create_or_recreate_dir(feat_dir) 315 | 316 | print("Get intrinsics") 317 | K = np.eye(4) 318 | K[0, 0] = cameras[0].params[0] / base_scale 319 | K[1, 1] = cameras[0].params[0] / base_scale 320 | K[0, 2] = cameras[0].params[1] / base_scale 321 | K[1, 2] = cameras[0].params[2] / base_scale 322 | print("f", K[0, 0], "c", K[0:2, 2]) 323 | np.savetxt(osp.join(base_dir, "intrinsics_colmap.txt" if args.colmap_suffix else "intrinsics.txt"), K) 324 | del K 325 | 326 | print("Get world scaling") 327 | points = np.stack([p.xyz for p in points3D]) 328 | cen = np.median(points, axis=0) 329 | points -= cen 330 | dists = (points ** 2).sum(axis=1) 331 | 332 | # FIXME: Questionable autoscaling. Adopt method from Noah Snavely 333 | meddist = np.median(dists) 334 | points *= 2 * args.scale / meddist 335 | 336 | # Save the sparse point cloud 337 | np.save(osp.join(base_dir, "points.npy"), points) 338 | print(cen, meddist) 339 | 340 | print("Get cameras") 341 | 342 | bottom = np.array([0, 0, 0, 1.0]).reshape([1, 4]) 343 | coord_trans = np.diag([1, -1, -1, 1.0]) 344 | for im in imdata: 345 | R = im.qvec2rotmat() 346 | t = im.tvec.reshape([3, 1]) 347 | xys = im.xys 348 | point3d_ids = im.point3D_ids 349 | # w2c = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 350 | t_world = -R.T @ t 351 | t_world = (t_world - cen[:, None]) * 2 * args.scale / meddist 352 | c2w = np.concatenate([np.concatenate([R.T, t_world], 1), bottom], 0) 353 | 354 | if args.gl_cam: 355 | # Use the alternate camera space convention of jaxNeRF, OpenGL etc 356 | # We use OpenCV convention 357 | c2w = c2w @ coord_trans 358 | 359 | imfile_name = osp.splitext(osp.basename(im.name))[0] 360 | pose_path = osp.join(pose_dir, imfile_name + ".txt") 361 | feat_path = osp.join(feat_dir, imfile_name + ".npz") # NOT USED but maybe nice? 362 | np.savetxt(pose_path, c2w) 363 | np.savez(feat_path, xys=xys, ids=point3d_ids) 364 | print(" Total cameras:", len(imdata)) 365 | print("Done!") 366 | 367 | 368 | if __name__ == "__main__": 369 | main() 370 | -------------------------------------------------------------------------------- /opt/scripts/create_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | Splits dataset using NSVF conventions. 3 | Every eighth image is used as a test image (1_ prefix) and other images are train (0_ prefix) 4 | 5 | Usage: 6 | python create_split.py 7 | data_set_root should contain directories like images/, pose/ 8 | """ 9 | # Copyright 2021 Alex Yu 10 | import os 11 | import os.path as osp 12 | from typing import NamedTuple, List 13 | import argparse 14 | import random 15 | 16 | parser = argparse.ArgumentParser("Automatic dataset splitting") 17 | parser.add_argument('root_dir', type=str, help="COLMAP dataset root dir") 18 | parser.add_argument('--every', type=int, default=16, help="Every x images used for testing") 19 | parser.add_argument('--dry_run', action='store_true', help="Dry run, prints renames without modifying any files") 20 | parser.add_argument('--yes', '-y', action='store_true', help="Answer yes") 21 | parser.add_argument('--random', action='store_true', help="If set, chooses the split randomly rather than at a fixed interval " 22 | "(but number of images in train/test set is same)") 23 | args = parser.parse_args() 24 | 25 | class Dir(NamedTuple): 26 | name: str 27 | valid_exts: List[str] 28 | 29 | def list_filter_dirs(base): 30 | all_dirs = [x for x in os.listdir(base) if osp.isdir(osp.join(base, x))] 31 | image_exts = [".png", ".jpg", ".jpeg", ".gif", ".tif", ".tiff", ".bmp"] 32 | depth_exts = [".exr", ".pfm", ".png", ".npy"] 33 | dirs_prefixes = [Dir(name="pose", valid_exts=[".txt"]), 34 | Dir(name="poses", valid_exts=[".txt"]), 35 | Dir(name="feature", valid_exts=[".npz"]), 36 | Dir(name="rgb", valid_exts=image_exts), 37 | Dir(name="images", valid_exts=image_exts), 38 | Dir(name="image", valid_exts=image_exts), 39 | Dir(name="c2w", valid_exts=image_exts), 40 | Dir(name="depths", valid_exts=depth_exts)] 41 | dirs = [] 42 | dir_idx = 0 43 | for pfx in dirs_prefixes: 44 | for d in all_dirs: 45 | if d.startswith(pfx.name): 46 | if d == "pose": 47 | dir_idx = len(dirs) 48 | dirs.append(Dir(name=osp.join(base, d), valid_exts=pfx.valid_exts)) 49 | return dirs, dir_idx 50 | 51 | dirs, dir_idx = list_filter_dirs(args.root_dir) 52 | 53 | refdir = dirs[dir_idx] 54 | print("going to split", [x.name for x in dirs], "reference", refdir.name) 55 | do_proceed = args.dry_run or args.yes 56 | if not do_proceed: 57 | import click 58 | do_proceed = click.confirm("Continue?", default=True) 59 | if do_proceed: 60 | filedata = {} 61 | base_files = [osp.splitext(x)[0] for x in sorted(os.listdir(refdir.name)) 62 | if osp.splitext(x)[1].lower() in refdir.valid_exts] 63 | if args.random: 64 | print('random enabled') 65 | random.shuffle(base_files) 66 | base_files_map = {x: f"{int(i % args.every == 0)}_" + x for i, x in enumerate(base_files)} 67 | 68 | for dir_obj in dirs: 69 | dirname = dir_obj.name 70 | files = sorted(os.listdir(dirname)) 71 | for filename in files: 72 | full_filename = osp.join(dirname, filename) 73 | if filename.startswith("0_") or filename.startswith("1_"): 74 | continue 75 | if not osp.isfile(full_filename): 76 | continue 77 | base_file, ext = osp.splitext(filename) 78 | if ext.lower() not in dir_obj.valid_exts: 79 | print('SKIP ', full_filename, ' Since it has an unsupported extension') 80 | continue 81 | if base_file not in base_files_map: 82 | print('SKIP ', full_filename, ' Since it does not match any reference file') 83 | continue 84 | new_base_file = base_files_map[base_file] 85 | new_full_filename = osp.join(dirname, new_base_file + ext) 86 | print('rename', full_filename, 'to', new_full_filename) 87 | if not args.dry_run: 88 | os.rename(full_filename, new_full_filename) 89 | if args.dry_run: 90 | print('(dry run complete)') 91 | else: 92 | print('use unsplit.py to undo this operation') 93 | -------------------------------------------------------------------------------- /opt/scripts/ingp2nsvf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert NeRF-iNGP data to NSVF 3 | python ingp2nsvf.py 4 | """ 5 | import os 6 | import shutil 7 | from glob import glob 8 | import json 9 | 10 | import numpy as np 11 | from PIL import Image 12 | import argparse 13 | 14 | def convert(data_dir : str, out_data_dir : str): 15 | """ 16 | Convert Instant-NGP (modified NeRF) data to NSVF 17 | 18 | :param data_dir: the dataset dir (NeRF-NGP format) to convert 19 | :param out_data_dir: output dataset directory NSVF 20 | """ 21 | 22 | images_dir_name = os.path.join(out_data_dir, "images") 23 | pose_dir_name = os.path.join(out_data_dir, "pose") 24 | 25 | os.makedirs(images_dir_name, exist_ok=True) 26 | os.makedirs(pose_dir_name, exist_ok=True) 27 | 28 | def get_subdir(name): 29 | if name.endswith("_train.json"): 30 | return "train" 31 | elif name.endswith("_val.json"): 32 | return "val" 33 | elif name.endswith("_test.json"): 34 | return "test" 35 | return "" 36 | 37 | def get_out_prefix(name): 38 | if name.endswith("_train.json"): 39 | return "0_" 40 | elif name.endswith("_val.json"): 41 | return "1_" 42 | elif name.endswith("_test.json"): 43 | return "2_" 44 | return "" 45 | 46 | jsons = { 47 | x: (get_subdir(x), get_out_prefix(x)) 48 | for x in glob(os.path.join(data_dir, "*.json")) 49 | } 50 | 51 | # OpenGL -> OpenCV 52 | cam_trans = np.diag(np.array([1.0, -1.0, -1.0, 1.0])) 53 | 54 | # fmt: off 55 | world_trans = np.array( 56 | [ 57 | [0.0, -1.0, 0.0, 0.0], 58 | [0.0, 0.0, -1.0, 0.0], 59 | [1.0, 0.0, 0.0, 0.0], 60 | [0.0, 0.0, 0.0, 1.0], 61 | ] 62 | ) 63 | # fmt: on 64 | 65 | assert len(jsons) > 0, f"No jsons found in {data_dir}, can't convert" 66 | cnt = 0 67 | 68 | example_fpath = None 69 | tj = {} 70 | for tj_path, (tj_subdir, tj_out_prefix) in jsons.items(): 71 | with open(tj_path, "r") as f: 72 | tj = json.load(f) 73 | if "frames" not in tj: 74 | print(f"No frames in json {tj_path}, skipping") 75 | continue 76 | 77 | for frame in tj["frames"]: 78 | # Try direct relative path (used in newer NGP datasets) 79 | fpath = os.path.join(data_dir, frame["file_path"]) 80 | if not os.path.isfile(fpath): 81 | # Legacy path (NeRF) 82 | fpath = os.path.join( 83 | data_dir, tj_subdir, os.path.basename(frame["file_path"]) + ".png" 84 | ) 85 | example_fpath = fpath 86 | if not os.path.isfile(fpath): 87 | print("Could not find image:", frame["file_path"], "(this may be ok)") 88 | continue 89 | 90 | ext = os.path.splitext(fpath)[1] 91 | 92 | c2w = np.array(frame["transform_matrix"]) 93 | c2w = world_trans @ c2w @ cam_trans # To OpenCV 94 | 95 | image_fname = tj_out_prefix + f"{cnt:05d}" 96 | 97 | pose_path = os.path.join(pose_dir_name, image_fname + ".txt") 98 | 99 | # Save 4x4 OpenCV C2W pose 100 | np.savetxt(pose_path, c2w) 101 | 102 | # Copy images 103 | new_fpath = os.path.join(images_dir_name, image_fname + ext) 104 | shutil.copyfile(fpath, new_fpath) 105 | cnt += 1 106 | 107 | assert len(tj) > 0, f"No valid jsons found in {data_dir}, can't convert" 108 | 109 | w = tj.get("w") 110 | h = tj.get("h") 111 | 112 | if w is None or h is None: 113 | assert example_fpath is not None 114 | # Pose not available so load a image and get the size 115 | w, h = Image.open(example_fpath).size 116 | 117 | fx = float(0.5 * w / np.tan(0.5 * tj["camera_angle_x"])) 118 | if "camera_angle_y" in tj: 119 | fy = float(0.5 * h / np.tan(0.5 * tj["camera_angle_y"])) 120 | else: 121 | fy = fx 122 | 123 | cx = tj.get("cx", w * 0.5) 124 | cy = tj.get("cy", h * 0.5) 125 | 126 | intrin_mtx = np.array([ 127 | [fx, 0.0, cx, 0.0], 128 | [0.0, fy, cy, 0.0], 129 | [0.0, 0.0, 1.0, 0.0], 130 | [0.0, 0.0, 0.0, 1.0], 131 | ]) 132 | # Write intrinsics 133 | np.savetxt(os.path.join(out_data_dir, "intrinsics.txt"), intrin_mtx) 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("data_dir", type=str, help="NeRF-NGP data directory") 139 | parser.add_argument("out_data_dir", type=str, help="Output NSVF data directory") 140 | args = parser.parse_args() 141 | convert(args.data_dir, args.out_data_dir) 142 | -------------------------------------------------------------------------------- /opt/scripts/proc_colmap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # USAGE: bash proc_colmap.sh 4 | 5 | python run_colmap.py $1 ${@:2} 6 | python colmap2nsvf.py $1/sparse/0 7 | python create_split.py -y $1 8 | -------------------------------------------------------------------------------- /opt/scripts/proc_record3d.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | from os import path 5 | import glob 6 | import numpy as np 7 | import cv2 8 | from tqdm import tqdm 9 | from scipy.spatial.transform import Rotation 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('data_dir', type=str) 13 | parser.add_argument('--every', type=int, default=15) 14 | parser.add_argument('--factor', type=int, default=2, help='downsample') 15 | args = parser.parse_args() 16 | 17 | video_file = glob.glob(args.data_dir + '/*.mp4')[0] 18 | print('Video file:', video_file) 19 | json_meta = path.join(args.data_dir, 'metadata.json') 20 | meta = json.load(open(json_meta, 'r')) 21 | 22 | K_3 = np.array(meta['K']).reshape(3, 3) 23 | K = np.eye(4) 24 | K[:3, :3] = K_3.T / args.factor 25 | output_intrin_file = path.join(args.data_dir, 'intrinsics.txt') 26 | np.savetxt(output_intrin_file, K) 27 | 28 | poses = np.array(meta['poses']) 29 | 30 | t = poses[:, 4:] 31 | q = poses[:, :4] 32 | R = Rotation.from_quat(q).as_matrix() 33 | 34 | # Recenter the poses 35 | center = np.mean(t, axis=0) 36 | print('Scene center', center) 37 | t -= center 38 | 39 | all_poses = np.zeros((q.shape[0], 4, 4)) 40 | all_poses[:, -1, -1] = 1 41 | 42 | Rt = np.concatenate([R, t[:, :, None]], axis=2) 43 | all_poses[:, :3] = Rt 44 | all_poses = all_poses @ np.diag([1, -1, -1, 1]) 45 | video = cv2.VideoCapture(str(video_file)) 46 | print(Rt.shape) 47 | 48 | fps = video.get(cv2.CAP_PROP_FPS) 49 | img_wh = ori_w, ori_h = ( 50 | int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2, 51 | int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)), 52 | ) 53 | 54 | print('image size', img_wh) 55 | pose_dir = path.join(args.data_dir, 'pose') 56 | os.makedirs(pose_dir, exist_ok=True) 57 | 58 | image_dir = path.join(args.data_dir, 'rgb') 59 | os.makedirs(image_dir, exist_ok=True) 60 | video_length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 61 | print('length', video_length) 62 | 63 | idx = 0 64 | for i in tqdm(range(0, video_length, args.every)): 65 | video.set(cv2.CAP_PROP_POS_FRAMES, i) 66 | ret, frame = video.read() 67 | if not ret or frame is None: 68 | print('skip', i) 69 | continue 70 | assert frame.shape[1] == img_wh[0] * 2 71 | assert frame.shape[0] == img_wh[1] 72 | frame = frame[:, img_wh[0]:] 73 | image_path = path.join(image_dir, f"{idx:05d}.png") 74 | pose_path = path.join(pose_dir, f"{idx:05d}.txt") 75 | 76 | if args.factor != 1: 77 | frame = cv2.resize(frame, (img_wh[0] // args.factor, img_wh[1] // args.factor), cv2.INTER_AREA) 78 | 79 | cv2.imwrite(image_path, frame) 80 | np.savetxt(pose_path, all_poses[i]) 81 | idx += 1 82 | -------------------------------------------------------------------------------- /opt/scripts/unsplit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inverse of create_split.py 3 | """ 4 | # Copyright 2021 Alex Yu 5 | import os 6 | import os.path as osp 7 | import click 8 | from typing import NamedTuple, List 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('root_dir', type=str, help="COLMAP dataset root dir") 13 | parser.add_argument('--dry_run', action='store_true', help="Dry run, prints renames without modifying any files") 14 | parser.add_argument('--yes', '-y', action='store_true', help="Answer yes") 15 | args = parser.parse_args() 16 | 17 | class Dir(NamedTuple): 18 | name: str 19 | valid_exts: List[str] 20 | 21 | def list_filter_dirs(base): 22 | all_dirs = [x for x in os.listdir(base) if osp.isdir(osp.join(base, x))] 23 | image_exts = [".png", ".jpg", ".jpeg", ".gif", ".tif", ".tiff", ".bmp"] 24 | depth_exts = [".exr", ".pfm", ".png", ".npy"] 25 | dirs_prefixes = [Dir(name="pose", valid_exts=[".txt"]), 26 | Dir(name="feature", valid_exts=[".npz"]), 27 | Dir(name="rgb", valid_exts=image_exts), 28 | Dir(name="images", valid_exts=image_exts), 29 | Dir(name="depths", valid_exts=depth_exts)] 30 | dirs = [] 31 | dir_idx = 0 32 | for pfx in dirs_prefixes: 33 | for d in all_dirs: 34 | if d.startswith(pfx.name): 35 | if d == "pose": 36 | dir_idx = len(dirs) 37 | dirs.append(Dir(name=osp.join(base, d), valid_exts=pfx.valid_exts)) 38 | return dirs, dir_idx 39 | 40 | dirs, dir_idx = list_filter_dirs(args.root_dir) 41 | 42 | refdir = dirs[dir_idx] 43 | print("going to unsplit", [x.name for x in dirs], "reference", dirs[dir_idx].name) 44 | do_proceed = args.dry_run or args.yes 45 | if not do_proceed: 46 | import click 47 | do_proceed = click.confirm("Continue?", default=True) 48 | if do_proceed: 49 | filedata = {} 50 | base_files = [osp.splitext(x)[0] for x in sorted(os.listdir(refdir.name)) 51 | if osp.splitext(x)[1] in refdir.valid_exts and 52 | (x.startswith('0_') or x.startswith('1_'))] 53 | base_files_map = {x: '_'.join(x.split('_')[1:]) for x in base_files} 54 | 55 | for dir_obj in dirs: 56 | dirname = dir_obj.name 57 | files = sorted(os.listdir(dirname)) 58 | for filename in files: 59 | full_filename = osp.join(dirname, filename) 60 | if not osp.isfile(full_filename): 61 | continue 62 | base_file, ext = osp.splitext(filename) 63 | if ext.lower() not in dir_obj.valid_exts: 64 | print('SKIP ', full_filename, ' Since it has an unsupported extension') 65 | continue 66 | if base_file not in base_files_map: 67 | print('SKIP ', full_filename, ' Since it does not match any reference file') 68 | continue 69 | new_base_file = base_files_map[base_file] 70 | new_full_filename = osp.join(dirname, new_base_file + ext) 71 | print('rename', full_filename, 'to', new_full_filename) 72 | if not args.dry_run: 73 | os.rename(full_filename, new_full_filename) 74 | if args.dry_run: 75 | print('(dry run complete)') 76 | else: 77 | print('use create_split.py to split again') 78 | -------------------------------------------------------------------------------- /opt/tasks/eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "data_root": "/home/sxyu/data/nerf_synthetic", 4 | "train_root": "/home/sxyu/proj/svox2/opt/ckpt_auto/256_to_512_fasttv", 5 | "variables": { 6 | "scene": ["lego", "mic", "ship", "chair", "ficus", "materials", "drums", "hotdog"] 7 | }, 8 | "tasks": [{ 9 | "train_dir": "{scene}", 10 | "data_dir": "{scene}", 11 | "config": "configs/syn.json" 12 | }] 13 | } 14 | -------------------------------------------------------------------------------- /opt/tasks/eval_ff.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "render": true, 4 | "data_root": "/home/sxyu/data/nerf_llff_data", 5 | "train_root": "/home/sxyu/proj/svox2/opt/ckpt_auto/llff_c2f_fasttv_10e_lsopaque_fixtvbord", 6 | "variables": { 7 | "scene": ["fern", "room", "horns", "trex", "flower", "leaves", "orchids", "fortress"] 8 | }, 9 | "tasks": [{ 10 | "train_dir": "{scene}", 11 | "data_dir": "{scene}", 12 | "config": "configs/llff.json" 13 | }] 14 | } 15 | -------------------------------------------------------------------------------- /opt/tasks/eval_real_iconic.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "render": true, 4 | "data_root": "/home/sxyu/data/real_iconic", 5 | "train_root": "/home/sxyu/proj/svox2/opt/ckpt_auto/real_iconic_mass_tv2x", 6 | "variables": { 7 | "scene": "[x for x in listdir('/home/sxyu/data/real_iconic') if path.isfile(path.join('/home/sxyu/data/real_iconic', x, 'poses_bounds.npy'))]" 8 | }, 9 | "tasks": [{ 10 | "train_dir": "{scene}", 11 | "data_dir": "{scene}", 12 | "config": "configs/llff_hitv.json", 13 | "eval_flags": ["--crop", "0.95"] 14 | }] 15 | } 16 | -------------------------------------------------------------------------------- /opt/tasks/eval_tnt.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "data_root": "/home/sxyu/data/TanksAndTempleBG", 4 | "train_root": "/home/sxyu/proj/svox2/opt/ckpt_auto/tnt_equirectlin_fasttv_unify2_grey_delfg", 5 | "variables": { 6 | "scene": ["Train", "M60", "Truck", "Playground"] 7 | }, 8 | "tasks": [{ 9 | "train_dir": "{scene}", 10 | "data_dir": "{scene}", 11 | "config": "configs/tnt.json" 12 | }] 13 | } 14 | -------------------------------------------------------------------------------- /opt/tasks/interpablate.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "data_root": "/home/sxyu/data/nerf_synthetic", 4 | "train_root": "/home/sfk/svox2/opt/ckpt_auto/nearestneighbor", 5 | "variables": { 6 | "scene": ["lego", "mic", "ship", "chair", "ficus", "materials", "drums", "hotdog"], 7 | "reso": ["[[256,256,256]]", "[[128,128,128]]"] 8 | }, 9 | "tasks": [{ 10 | "train_dir": "{reso}_3e0_{scene}", 11 | "data_dir": "{scene}", 12 | "flags": [ 13 | "-B", "svox1", 14 | "--lr_sigma", "3e0", 15 | "--reso", "{reso}" 16 | ] 17 | }] 18 | } 19 | -------------------------------------------------------------------------------- /opt/tasks/ntrainablate.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "data_root": "/home/sxyu/data/nerf_synthetic", 4 | "train_root": "/home/sfk/svox2/opt/ckpt_auto/256_to_512_tvearlyonly_ntrain", 5 | "variables": { 6 | "scene": ["lego", "mic", "ship", "chair", "ficus", "materials", "drums", "hotdog"], 7 | "n_train": [25] 8 | }, 9 | "tasks": [{ 10 | "train_dir": "train_{n_train}_{scene}", 11 | "data_dir": "{scene}", 12 | "flags": [ 13 | "--n_train", "{n_train}" 14 | ] 15 | }] 16 | } 17 | -------------------------------------------------------------------------------- /opt/tasks/sanity.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_root": "/home/sxyu/data/nerf_synthetic/ship", 3 | "train_root": "/home/sxyu/proj/svox2/opt/ckpt_tune/ship_sweep", 4 | "variables": { 5 | "lr_sh_final": "loglin(5e-7, 5e-2, 10)", 6 | "lr_sigma": "loglin(5e0, 2e2, 4)", 7 | "lr_sigma_delay_steps": [25000, 40000, 55000] 8 | }, 9 | "tasks": [{ 10 | "train_dir": "lrcf{lr_sh_final}_lrs{lr_sigma}_del{lr_sigma_delay_steps}", 11 | "flags": [ 12 | "--lr_sh_final", "{lr_sh_final}", 13 | "--lr_sigma", "{lr_sigma}", 14 | "--lr_sigma_delay_steps", "{lr_sigma_delay_steps}" 15 | ] 16 | }] 17 | } 18 | -------------------------------------------------------------------------------- /opt/tasks/tvearlyonly.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval": true, 3 | "data_root": "/home/sxyu/data/nerf_synthetic", 4 | "train_root": "/home/sfk/svox2/opt/ckpt_auto/256_to_512", 5 | "variables": { 6 | "scene": ["lego", "mic", "ship", "chair", "ficus", "materials", "drums", "hotdog"], 7 | "tv_early_only": [0, 1] 8 | }, 9 | "tasks": [{ 10 | "train_dir": "tv_early_only_{tv_early_only}_{scene}", 11 | "data_dir": "{scene}", 12 | "flags": [ 13 | "--tv_early_only", "{tv_early_only}" 14 | ] 15 | }] 16 | } -------------------------------------------------------------------------------- /opt/to_svox1.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import svox 3 | import math 4 | import argparse 5 | from os import path 6 | from tqdm import tqdm 7 | import torch 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('ckpt', type=str) 11 | args = parser.parse_args() 12 | 13 | grid = svox2.SparseGrid.load(args.ckpt) 14 | t = grid.to_svox1() 15 | print(t) 16 | 17 | out_path = path.splitext(args.ckpt)[0] + '_svox1.npz' 18 | print('Saving', out_path) 19 | t.save(out_path) 20 | -------------------------------------------------------------------------------- /opt/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/svox2/ee80e2c4df8f29a407fda5729a494be94ccf9234/opt/util/__init__.py -------------------------------------------------------------------------------- /opt/util/co3d_dataset.py: -------------------------------------------------------------------------------- 1 | # CO3D dataset loader 2 | # https://github.com/facebookresearch/co3d/ 3 | # 4 | # Adapted from basenerf 5 | # Copyright 2021 Alex Yu 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import os 10 | import cv2 11 | from tqdm import tqdm 12 | from os import path 13 | import json 14 | import gzip 15 | 16 | from scipy.spatial.transform import Rotation 17 | from typing import NamedTuple, Optional, List, Union 18 | from .util import Rays, Intrin, similarity_from_cameras 19 | from .dataset_base import DatasetBase 20 | 21 | 22 | class CO3DDataset(DatasetBase): 23 | """ 24 | CO3D Dataset 25 | Preloads all images for an object. 26 | Will create a data index on first load, to make later loads faster. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root, 32 | split, 33 | seq_id : Optional[int] = None, 34 | epoch_size : Optional[int] = None, 35 | permutation: bool = True, 36 | device: Union[str, torch.device] = "cpu", 37 | max_image_dim: int = 800, 38 | max_pose_dist: float = 5.0, 39 | cam_scale_factor: float = 0.95, 40 | hold_every=8, 41 | **kwargs, 42 | ): 43 | """ 44 | :param root: str dataset root directory 45 | :param device: data prefetch device 46 | """ 47 | super().__init__() 48 | os.makedirs('co3d_tmp', exist_ok=True) 49 | index_file = path.join('co3d_tmp', 'co3d_index.npz') 50 | self.split = split 51 | self.permutation = permutation 52 | self.data_dir = root 53 | self.epoch_size = epoch_size 54 | self.max_image_dim = max_image_dim 55 | self.max_pose_dist = max_pose_dist 56 | self.cam_scale_factor = cam_scale_factor 57 | 58 | self.cats = sorted([x for x in os.listdir(root) if path.isdir( 59 | path.join(root, x))]) 60 | self.gt = [] 61 | self.n_images = 0 62 | self.curr_offset = 0 63 | self.next_offset = 0 64 | self.hold_every = hold_every 65 | self.curr_seq_cat = self.curr_seq_name = '' 66 | self.device = device 67 | if path.exists(index_file): 68 | print(' Using cached CO3D index', index_file) 69 | z = np.load(index_file) 70 | self.seq_cats = z.f.seq_cats 71 | self.seq_names = z.f.seq_names 72 | self.seq_offsets = z.f.seq_offsets 73 | self.all_image_size = z.f.image_size # NOTE: w, h 74 | self.image_path = z.f.image_path 75 | self.image_pose = z.f.pose 76 | self.fxy = z.f.fxy 77 | self.cxy = z.f.cxy 78 | else: 79 | print(' Constructing CO3D index (1st run only), this may take a while') 80 | cam_trans = np.diag(np.array([-1, -1, 1, 1], dtype=np.float32)) 81 | frame_data_by_seq = {} 82 | self.seq_cats = [] 83 | self.seq_names = [] 84 | self.seq_offsets = [] 85 | self.image_path = [] 86 | self.all_image_size = [] 87 | self.image_pose = [] 88 | self.fxy = [] 89 | self.cxy = [] 90 | for i, cat in enumerate(self.cats): 91 | print(cat, '- category', i + 1, 'of', len(self.cats)) 92 | cat_dir = path.join(root, cat) 93 | if not path.isdir(cat_dir): 94 | continue 95 | frame_data_path = path.join(cat_dir, 'frame_annotations.jgz') 96 | with gzip.open(frame_data_path, 'r') as f: 97 | all_frames_data = json.load(f) 98 | for frame_data in tqdm(all_frames_data): 99 | seq_name = cat + '//' + frame_data['sequence_name'] 100 | # frame_number = frame_data['frame_number'] 101 | if seq_name not in frame_data_by_seq: 102 | frame_data_by_seq[seq_name] = [] 103 | pose = np.zeros((4, 4)) 104 | image_size_hw = frame_data['image']['size'] # H, W 105 | H, W = image_size_hw 106 | half_wh = np.array([W * 0.5, H * 0.5], dtype=np.float32) 107 | R = np.array(frame_data['viewpoint']['R']) 108 | T = np.array(frame_data['viewpoint']['T']) 109 | fxy = np.array(frame_data['viewpoint']['focal_length']) 110 | cxy = np.array(frame_data['viewpoint']['principal_point']) 111 | focal = fxy * half_wh 112 | prp = -1.0 * (cxy - 1.0) * half_wh 113 | pose[:3, :3] = R 114 | pose[:3, 3:] = -R @ T[..., None] 115 | pose[3, 3] = 1.0 116 | pose = pose @ cam_trans 117 | frame_data_obj = { 118 | 'frame_number':frame_data['frame_number'], 119 | 'image_path':frame_data['image']['path'], 120 | 'image_size':np.array([W, H]), # NOTE: this is w, h 121 | 'pose':pose, 122 | 'fxy':focal, # NOTE: this is x, y 123 | 'cxy':prp, # NOTE: this is x, y 124 | } 125 | frame_data_by_seq[seq_name].append(frame_data_obj) 126 | print(' Sorting by sequence') 127 | for k in frame_data_by_seq: 128 | fd = sorted(frame_data_by_seq[k], 129 | key=lambda x: x['frame_number']) 130 | spl = k.split('//') 131 | self.seq_cats.append(spl[0]) 132 | self.seq_names.append(spl[1]) 133 | self.seq_offsets.append(len(self.image_path)) 134 | self.image_path.extend([x['image_path'] for x in fd]) 135 | self.all_image_size.extend([x['image_size'] for x in fd]) 136 | self.image_pose.extend([x['pose'] for x in fd]) 137 | self.fxy.extend([x['fxy'] for x in fd]) 138 | self.cxy.extend([x['cxy'] for x in fd]) 139 | self.all_image_size = np.stack(self.all_image_size) 140 | self.image_pose = np.stack(self.image_pose) 141 | self.fxy = np.stack(self.fxy) 142 | self.cxy = np.stack(self.cxy) 143 | self.seq_offsets.append(len(self.image_path)) 144 | self.seq_offsets = np.array(self.seq_offsets) 145 | print(' Saving to index') 146 | np.savez(index_file, 147 | seq_cats=self.seq_cats, 148 | seq_names=self.seq_names, 149 | seq_offsets=self.seq_offsets, 150 | image_size=self.all_image_size, 151 | image_path=self.image_path, 152 | pose=self.image_pose, 153 | fxy=self.fxy, 154 | cxy=self.cxy) 155 | self.n_seq = len(self.seq_names) 156 | print( 157 | " Loaded CO3D dataset", 158 | root, 159 | "n_seq", self.n_seq 160 | ) 161 | 162 | if seq_id is not None: 163 | self.load_sequence(seq_id) 164 | 165 | 166 | def load_sequence(self, sequence_id : int): 167 | """ 168 | Load a different CO3D sequence 169 | sequence_id should be at least 0 and at most (n_seq - 1) 170 | see co3d_tmp/co3d.txt for sequence ID -> name mappings 171 | """ 172 | print(' Loading single CO3D sequence:', 173 | self.seq_cats[sequence_id], self.seq_names[sequence_id]) 174 | self.curr_seq_cat = self.seq_cats[sequence_id] 175 | self.curr_seq_name = self.seq_names[sequence_id] 176 | self.curr_offset = self.seq_offsets[sequence_id] 177 | self.next_offset = self.seq_offsets[sequence_id + 1] 178 | self.gt = [] 179 | fxs, fys, cxs, cys = [], [], [], [] 180 | image_sizes = [] 181 | c2ws = [] 182 | ref_c2ws = [] 183 | for i in tqdm(range(self.curr_offset, self.next_offset)): 184 | is_train = i % self.hold_every != 0 185 | ref_c2ws.append(self.image_pose[i]) 186 | if self.split.endswith('train') != is_train: 187 | continue 188 | im = cv2.imread(path.join(self.data_dir, self.image_path[i])) 189 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 190 | im = im[..., :3] 191 | h, w, _ = im.shape 192 | max_hw = max(h, w) 193 | approx_scale = self.max_image_dim / max_hw 194 | 195 | if approx_scale < 1.0: 196 | h2 = int(approx_scale * h) 197 | w2 = int(approx_scale * w) 198 | im = cv2.resize(im, (w2, h2), interpolation=cv2.INTER_AREA) 199 | else: 200 | h2 = h 201 | w2 = w 202 | scale = np.array([w2 / w, h2 / h], dtype=np.float32) 203 | image_sizes.append(np.array([h2, w2])) 204 | cxy = self.cxy[i] * scale 205 | fxy = self.fxy[i] * scale 206 | fxs.append(fxy[0]) 207 | fys.append(fxy[1]) 208 | cxs.append(cxy[0]) 209 | cys.append(cxy[1]) 210 | # grid = data_util.gen_grid(h2, w2, cxy.astype(np.float32), normalize_scale=False) 211 | # grid /= fxy.astype(np.float32) 212 | self.gt.append(torch.from_numpy(im)) 213 | c2ws.append(self.image_pose[i]) 214 | c2w = np.stack(c2ws, axis=0) 215 | ref_c2ws = np.stack(ref_c2ws, axis=0) # For rescaling scene 216 | self.image_size = np.stack(image_sizes) 217 | fxs = torch.tensor(fxs) 218 | fys = torch.tensor(fys) 219 | cxs = torch.tensor(cxs) 220 | cys = torch.tensor(cys) 221 | 222 | # Filter out crazy poses 223 | dists = np.linalg.norm(c2w[:, :3, 3] - np.median(c2w[:, :3, 3], axis=0), axis=-1) 224 | med = np.median(dists) 225 | good_mask = dists < med * self.max_pose_dist 226 | c2w = c2w[good_mask] 227 | self.image_size = self.image_size[good_mask] 228 | good_idx = np.where(good_mask)[0] 229 | self.gt = [self.gt[i] for i in good_idx] 230 | 231 | self.intrins_full = Intrin(fxs[good_mask], fys[good_mask], 232 | cxs[good_mask], cys[good_mask]) 233 | 234 | # Normalize 235 | # c2w[:, :3, 3] -= np.mean(c2w[:, :3, 3], axis=0) 236 | # dists = np.linalg.norm(c2w[:, :3, 3], axis=-1) 237 | # c2w[:, :3, 3] *= self.cam_scale_factor / np.median(dists) 238 | 239 | T, sscale = similarity_from_cameras(ref_c2ws) 240 | c2w = T @ c2w 241 | c2w[:, :3, 3] *= self.cam_scale_factor * sscale 242 | 243 | self.c2w = torch.from_numpy(c2w).float() 244 | self.cam_n_rays = self.image_size[:, 0] * self.image_size[:, 1] 245 | self.n_images = len(self.gt) 246 | self.image_size_full = self.image_size 247 | 248 | if self.split == "train": 249 | self.gen_rays(factor=1) 250 | else: 251 | # Rays are not needed for testing 252 | self.intrins : Intrin = self.intrins_full 253 | 254 | 255 | def gen_rays(self, factor=1): 256 | print(" Generating rays, scaling factor", factor) 257 | # Generate rays 258 | self.factor = factor 259 | self.image_size = self.image_size_full // factor 260 | true_factor = self.image_size_full[:, 0] / self.image_size[:, 0] 261 | self.intrins = self.intrins_full.scale(1.0 / true_factor) 262 | 263 | all_origins = [] 264 | all_dirs = [] 265 | all_gts = [] 266 | for i in tqdm(range(self.n_images)): 267 | yy, xx = torch.meshgrid( 268 | torch.arange(self.image_size[i, 0], dtype=torch.float32) + 0.5, 269 | torch.arange(self.image_size[i, 1], dtype=torch.float32) + 0.5, 270 | ) 271 | xx = (xx - self.intrins.get('cx', i)) / self.intrins.get('fx', i) 272 | yy = (yy - self.intrins.get('cy', i)) / self.intrins.get('fy', i) 273 | zz = torch.ones_like(xx) 274 | dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention 275 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 276 | dirs = dirs.reshape(-1, 3, 1) 277 | del xx, yy, zz 278 | dirs = (self.c2w[i, None, :3, :3] @ dirs)[..., 0] 279 | 280 | if factor != 1: 281 | gt = F.interpolate( 282 | self.gt[i].permute([2, 0, 1])[None], size=(self.image_size[i, 0], 283 | self.image_size[i, 1]), 284 | mode="area" 285 | )[0].permute([1, 2, 0]) 286 | gt = gt.reshape(-1, 3) 287 | else: 288 | gt = self.gt[i].reshape(-1, 3) 289 | origins = self.c2w[i, None, :3, 3].expand(self.image_size[i, 0] * 290 | self.image_size[i, 1], -1).contiguous() 291 | all_origins.append(origins) 292 | all_dirs.append(dirs) 293 | all_gts.append(gt) 294 | origins = all_origins 295 | dirs = all_dirs 296 | gt = all_gts 297 | 298 | if self.split == "train": 299 | origins = torch.cat([o.view(-1, 3) for o in origins], dim=0) 300 | dirs = torch.cat([o.view(-1, 3) for o in dirs], dim=0) 301 | gt = torch.cat([o.reshape(-1, 3) for o in gt], dim=0) 302 | 303 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) 304 | self.rays = self.rays_init 305 | -------------------------------------------------------------------------------- /opt/util/config_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from util.dataset import datasets 4 | import json 5 | 6 | 7 | def define_common_args(parser : argparse.ArgumentParser): 8 | parser.add_argument('data_dir', type=str) 9 | 10 | parser.add_argument('--config', '-c', 11 | type=str, 12 | default=None, 13 | help="Config yaml file (will override args)") 14 | 15 | group = parser.add_argument_group("Data loading") 16 | group.add_argument('--dataset_type', 17 | choices=list(datasets.keys()) + ["auto"], 18 | default="auto", 19 | help="Dataset type (specify type or use auto)") 20 | group.add_argument('--scene_scale', 21 | type=float, 22 | default=None, 23 | help="Global scene scaling (or use dataset default)") 24 | group.add_argument('--scale', 25 | type=float, 26 | default=None, 27 | help="Image scale, e.g. 0.5 for half resolution (or use dataset default)") 28 | group.add_argument('--seq_id', 29 | type=int, 30 | default=1000, 31 | help="Sequence ID (for CO3D only)") 32 | group.add_argument('--epoch_size', 33 | type=int, 34 | default=12800, 35 | help="Pseudo-epoch size in term of batches (to be consistent across datasets)") 36 | group.add_argument('--white_bkgd', 37 | type=bool, 38 | default=True, 39 | help="Whether to use white background (ignored in some datasets)") 40 | group.add_argument('--llffhold', 41 | type=int, 42 | default=8, 43 | help="LLFF holdout every") 44 | group.add_argument('--normalize_by_bbox', 45 | type=bool, 46 | default=False, 47 | help="Normalize by bounding box in bbox.txt, if available (NSVF dataset only); precedes normalize_by_camera") 48 | group.add_argument('--data_bbox_scale', 49 | type=float, 50 | default=1.2, 51 | help="Data bbox scaling (NSVF dataset only)") 52 | group.add_argument('--cam_scale_factor', 53 | type=float, 54 | default=0.95, 55 | help="Camera autoscale factor (NSVF/CO3D dataset only)") 56 | group.add_argument('--normalize_by_camera', 57 | type=bool, 58 | default=True, 59 | help="Normalize using cameras, assuming a 360 capture (NSVF dataset only); only used if not normalize_by_bbox") 60 | group.add_argument('--perm', action='store_true', default=False, 61 | help='sample by permutation of rays (true epoch) instead of ' 62 | 'uniformly random rays') 63 | 64 | group = parser.add_argument_group("Render options") 65 | group.add_argument('--step_size', 66 | type=float, 67 | default=0.5, 68 | help="Render step size (in voxel size units)") 69 | group.add_argument('--sigma_thresh', 70 | type=float, 71 | default=1e-8, 72 | help="Skips voxels with sigma < this") 73 | group.add_argument('--stop_thresh', 74 | type=float, 75 | default=1e-7, 76 | help="Ray march stopping threshold") 77 | group.add_argument('--background_brightness', 78 | type=float, 79 | default=1.0, 80 | help="Brightness of the infinite background") 81 | group.add_argument('--renderer_backend', '-B', 82 | choices=['cuvol', 'svox1', 'nvol'], 83 | default='cuvol', 84 | help="Renderer backend") 85 | group.add_argument('--random_sigma_std', 86 | type=float, 87 | default=0.0, 88 | help="Random Gaussian std to add to density values (only if enable_random)") 89 | group.add_argument('--random_sigma_std_background', 90 | type=float, 91 | default=0.0, 92 | help="Random Gaussian std to add to density values for BG (only if enable_random)") 93 | group.add_argument('--near_clip', 94 | type=float, 95 | default=0.00, 96 | help="Near clip distance (in world space distance units, only for FG)") 97 | group.add_argument('--use_spheric_clip', 98 | action='store_true', 99 | default=False, 100 | help="Use spheric ray clipping instead of voxel grid AABB " 101 | "(only for FG; changes near_clip to mean 1-near_intersection_radius; " 102 | "far intersection is always at radius 1)") 103 | group.add_argument('--enable_random', 104 | action='store_true', 105 | default=False, 106 | help="Random Gaussian std to add to density values") 107 | group.add_argument('--last_sample_opaque', 108 | action='store_true', 109 | default=False, 110 | help="Last sample has +1e9 density (used for LLFF)") 111 | 112 | 113 | def build_data_options(args): 114 | """ 115 | Arguments to pass as kwargs to the dataset constructor 116 | """ 117 | return { 118 | 'dataset_type': args.dataset_type, 119 | 'seq_id': args.seq_id, 120 | 'epoch_size': args.epoch_size * args.__dict__.get('batch_size', 5000), 121 | 'scene_scale': args.scene_scale, 122 | 'scale': args.scale, 123 | 'white_bkgd': args.white_bkgd, 124 | 'hold_every': args.llffhold, 125 | 'normalize_by_bbox': args.normalize_by_bbox, 126 | 'data_bbox_scale': args.data_bbox_scale, 127 | 'cam_scale_factor': args.cam_scale_factor, 128 | 'normalize_by_camera': args.normalize_by_camera, 129 | 'permutation': args.perm 130 | } 131 | 132 | def maybe_merge_config_file(args, allow_invalid=False): 133 | """ 134 | Load json config file if specified and merge the arguments 135 | """ 136 | if args.config is not None: 137 | with open(args.config, "r") as config_file: 138 | configs = json.load(config_file) 139 | invalid_args = list(set(configs.keys()) - set(dir(args))) 140 | if invalid_args and not allow_invalid: 141 | raise ValueError(f"Invalid args {invalid_args} in {args.config}.") 142 | args.__dict__.update(configs) 143 | 144 | def setup_render_opts(opt, args): 145 | """ 146 | Pass render arguments to the SparseGrid renderer options 147 | """ 148 | opt.step_size = args.step_size 149 | opt.sigma_thresh = args.sigma_thresh 150 | opt.stop_thresh = args.stop_thresh 151 | opt.background_brightness = args.background_brightness 152 | opt.backend = args.renderer_backend 153 | opt.random_sigma_std = args.random_sigma_std 154 | opt.random_sigma_std_background = args.random_sigma_std_background 155 | opt.last_sample_opaque = args.last_sample_opaque 156 | opt.near_clip = args.near_clip 157 | opt.use_spheric_clip = args.use_spheric_clip 158 | -------------------------------------------------------------------------------- /opt/util/dataset.py: -------------------------------------------------------------------------------- 1 | from .nerf_dataset import NeRFDataset 2 | from .llff_dataset import LLFFDataset 3 | from .nsvf_dataset import NSVFDataset 4 | from .co3d_dataset import CO3DDataset 5 | from os import path 6 | 7 | def auto_dataset(root : str, *args, **kwargs): 8 | if path.isfile(path.join(root, 'apple', 'eval_batches_multisequence.json')): 9 | print("Detected CO3D dataset") 10 | return CO3DDataset(root, *args, **kwargs) 11 | elif path.isfile(path.join(root, 'poses_bounds.npy')): 12 | print("Detected LLFF dataset") 13 | return LLFFDataset(root, *args, **kwargs) 14 | elif path.isfile(path.join(root, 'transforms.json')) or \ 15 | path.isfile(path.join(root, 'transforms_train.json')): 16 | print("Detected NeRF (Blender) dataset") 17 | return NeRFDataset(root, *args, **kwargs) 18 | else: 19 | print("Defaulting to extended NSVF dataset") 20 | return NSVFDataset(root, *args, **kwargs) 21 | 22 | datasets = { 23 | 'nerf': NeRFDataset, 24 | 'llff': LLFFDataset, 25 | 'nsvf': NSVFDataset, 26 | 'co3d': CO3DDataset, 27 | 'auto': auto_dataset 28 | } 29 | -------------------------------------------------------------------------------- /opt/util/dataset_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Union, Optional, List 4 | from .util import select_or_shuffle_rays, Rays, Intrin 5 | 6 | class DatasetBase: 7 | split: str 8 | permutation: bool 9 | epoch_size: Optional[int] 10 | n_images: int 11 | h_full: int 12 | w_full: int 13 | intrins_full: Intrin 14 | c2w: torch.Tensor # C2W OpenCV poses 15 | gt: Union[torch.Tensor, List[torch.Tensor]] # RGB images 16 | device : Union[str, torch.device] 17 | 18 | def __init__(self): 19 | self.ndc_coeffs = (-1, -1) 20 | self.use_sphere_bound = False 21 | self.should_use_background = True # a hint 22 | self.use_sphere_bound = True 23 | self.scene_center = [0.0, 0.0, 0.0] 24 | self.scene_radius = [1.0, 1.0, 1.0] 25 | self.permutation = False 26 | 27 | def shuffle_rays(self): 28 | """ 29 | Shuffle all rays 30 | """ 31 | if self.split == "train": 32 | del self.rays 33 | self.rays = select_or_shuffle_rays(self.rays_init, self.permutation, 34 | self.epoch_size, self.device) 35 | 36 | def gen_rays(self, factor=1): 37 | print(" Generating rays, scaling factor", factor) 38 | # Generate rays 39 | self.factor = factor 40 | self.h = self.h_full // factor 41 | self.w = self.w_full // factor 42 | true_factor = self.h_full / self.h 43 | self.intrins = self.intrins_full.scale(1.0 / true_factor) 44 | yy, xx = torch.meshgrid( 45 | torch.arange(self.h, dtype=torch.float32) + 0.5, 46 | torch.arange(self.w, dtype=torch.float32) + 0.5, 47 | ) 48 | xx = (xx - self.intrins.cx) / self.intrins.fx 49 | yy = (yy - self.intrins.cy) / self.intrins.fy 50 | zz = torch.ones_like(xx) 51 | dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention 52 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 53 | dirs = dirs.reshape(1, -1, 3, 1) 54 | del xx, yy, zz 55 | dirs = (self.c2w[:, None, :3, :3] @ dirs)[..., 0] 56 | 57 | if factor != 1: 58 | gt = F.interpolate( 59 | self.gt.permute([0, 3, 1, 2]), size=(self.h, self.w), mode="area" 60 | ).permute([0, 2, 3, 1]) 61 | gt = gt.reshape(self.n_images, -1, 3) 62 | else: 63 | gt = self.gt.reshape(self.n_images, -1, 3) 64 | origins = self.c2w[:, None, :3, 3].expand(-1, self.h * self.w, -1).contiguous() 65 | if self.split == "train": 66 | origins = origins.view(-1, 3) 67 | dirs = dirs.view(-1, 3) 68 | gt = gt.reshape(-1, 3) 69 | 70 | self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt) 71 | self.rays = self.rays_init 72 | 73 | def get_image_size(self, i : int): 74 | # H, W 75 | if hasattr(self, 'image_size'): 76 | return tuple(self.image_size[i]) 77 | else: 78 | return self.h, self.w 79 | -------------------------------------------------------------------------------- /opt/util/load_llff.py: -------------------------------------------------------------------------------- 1 | # Originally from LLFF 2 | # https://github.com/Fyusion/LLFF 3 | # With minor modifications from NeX 4 | # https://github.com/nex-mpi/nex-code 5 | 6 | import numpy as np 7 | import os 8 | import imageio 9 | 10 | def get_image_size(path : str): 11 | """ 12 | Get image size without loading it 13 | """ 14 | from PIL import Image 15 | im = Image.open(path) 16 | return im.size[1], im.size[0] # H, W 17 | 18 | def _minify(basedir, factors=[], resolutions=[]): 19 | needtoload = False 20 | for r in factors: 21 | imgdir = os.path.join(basedir, "images_{}".format(r)) 22 | if not os.path.exists(imgdir): 23 | needtoload = True 24 | for r in resolutions: 25 | imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0])) 26 | if not os.path.exists(imgdir): 27 | needtoload = True 28 | if not needtoload: 29 | return 30 | 31 | from shutil import copy 32 | from subprocess import check_output 33 | 34 | imgdir = os.path.join(basedir, "images") 35 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 36 | imgs = [ 37 | f 38 | for f in imgs 39 | if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]]) 40 | ] 41 | imgdir_orig = imgdir 42 | 43 | wd = os.getcwd() 44 | 45 | for r in factors + resolutions: 46 | if isinstance(r, int): 47 | name = "images_{}".format(r) 48 | resizearg = "{}%".format(100.0 / r) 49 | else: 50 | name = "images_{}x{}".format(r[1], r[0]) 51 | resizearg = "{}x{}".format(r[1], r[0]) 52 | imgdir = os.path.join(basedir, name) 53 | if os.path.exists(imgdir): 54 | continue 55 | 56 | print("Minifying", r, basedir) 57 | 58 | os.makedirs(imgdir) 59 | check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True) 60 | 61 | ext = imgs[0].split(".")[-1] 62 | args = " ".join( 63 | ["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)] 64 | ) 65 | print(args) 66 | os.chdir(imgdir) 67 | check_output(args, shell=True) 68 | os.chdir(wd) 69 | 70 | if ext != "png": 71 | check_output("rm {}/*.{}".format(imgdir, ext), shell=True) 72 | print("Removed duplicates") 73 | print("Done") 74 | 75 | 76 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 77 | poses_arr = np.load(os.path.join(basedir, "poses_bounds.npy")) 78 | shape = 5 79 | 80 | # poss llff arr [3, 5, images] [R | T | intrinsic] 81 | # intrinsic same for all images 82 | if os.path.isfile(os.path.join(basedir, "hwf_cxcy.npy")): 83 | shape = 4 84 | # h, w, fx, fy, cx, cy 85 | intrinsic_arr = np.load(os.path.join(basedir, "hwf_cxcy.npy")) 86 | 87 | poses = poses_arr[:, :-2].reshape([-1, 3, shape]).transpose([1, 2, 0]) 88 | bds = poses_arr[:, -2:].transpose([1, 0]) 89 | 90 | if not os.path.isfile(os.path.join(basedir, "hwf_cxcy.npy")): 91 | intrinsic_arr = poses[:, 4, 0] 92 | poses = poses[:, :4, :] 93 | 94 | img0 = [ 95 | os.path.join(basedir, "images", f) 96 | for f in sorted(os.listdir(os.path.join(basedir, "images"))) 97 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 98 | ][0] 99 | sh = get_image_size(img0) 100 | 101 | sfx = "" 102 | if factor is not None: 103 | sfx = "_{}".format(factor) 104 | _minify(basedir, factors=[factor]) 105 | factor = factor 106 | elif height is not None: 107 | factor = sh[0] / float(height) 108 | width = int(sh[1] / factor) 109 | _minify(basedir, resolutions=[[height, width]]) 110 | sfx = "_{}x{}".format(width, height) 111 | elif width is not None: 112 | factor = sh[1] / float(width) 113 | height = int(sh[0] / factor) 114 | _minify(basedir, resolutions=[[height, width]]) 115 | sfx = "_{}x{}".format(width, height) 116 | else: 117 | factor = 1 118 | 119 | imgdir = os.path.join(basedir, "images" + sfx) 120 | if not os.path.exists(imgdir): 121 | print(imgdir, "does not exist, returning") 122 | return 123 | 124 | imgfiles = [ 125 | os.path.join(imgdir, f) 126 | for f in sorted(os.listdir(imgdir)) 127 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 128 | ] 129 | if poses.shape[-1] != len(imgfiles): 130 | print( 131 | "Mismatch between imgs {} and poses {} !!!!".format( 132 | len(imgfiles), poses.shape[-1] 133 | ) 134 | ) 135 | return 136 | 137 | if not load_imgs: 138 | return poses, bds, intrinsic_arr 139 | 140 | def imread(f): 141 | if f.endswith("png"): 142 | return imageio.imread(f, ignoregamma=True) 143 | else: 144 | return imageio.imread(f) 145 | 146 | imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles] 147 | imgs = np.stack(imgs, -1) 148 | 149 | print("Loaded image data", imgs.shape, poses[:, -1, 0]) 150 | return poses, bds, imgs, intrinsic_arr 151 | 152 | 153 | def normalize(x): 154 | return x / np.linalg.norm(x) 155 | 156 | 157 | def viewmatrix(z, up, pos): 158 | vec2 = normalize(z) 159 | vec1_avg = up 160 | vec0 = normalize(np.cross(vec1_avg, vec2)) 161 | vec1 = normalize(np.cross(vec2, vec0)) 162 | m = np.stack([vec0, vec1, vec2, pos], 1) 163 | return m 164 | 165 | 166 | def ptstocam(pts, c2w): 167 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] 168 | return tt 169 | 170 | 171 | def poses_avg(poses): 172 | # poses [images, 3, 4] not [images, 3, 5] 173 | # hwf = poses[0, :3, -1:] 174 | 175 | center = poses[:, :3, 3].mean(0) 176 | vec2 = normalize(poses[:, :3, 2].sum(0)) 177 | up = poses[:, :3, 1].sum(0) 178 | c2w = np.concatenate([viewmatrix(vec2, up, center)], 1) 179 | 180 | return c2w 181 | 182 | 183 | def render_path_axis(c2w, up, ax, rad, focal, N): 184 | render_poses = [] 185 | center = c2w[:, 3] 186 | hwf = c2w[:, 4:5] 187 | v = c2w[:, ax] * rad 188 | for t in np.linspace(-1.0, 1.0, N + 1)[:-1]: 189 | c = center + t * v 190 | z = normalize(c - (center - focal * c2w[:, 2])) 191 | # render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 192 | render_poses.append(viewmatrix(z, up, c)) 193 | return render_poses 194 | 195 | 196 | def render_path_spiral(c2w, up, rads, focal, zrate, rots, N): 197 | render_poses = [] 198 | rads = np.array(list(rads) + [1.0]) 199 | # hwf = c2w[:,4:5] 200 | 201 | for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 202 | c = np.dot( 203 | c2w[:3, :4], 204 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) 205 | * rads, 206 | ) 207 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 208 | # render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 209 | render_poses.append(viewmatrix(z, up, c)) 210 | return render_poses 211 | 212 | 213 | def recenter_poses(poses): 214 | # poses [images, 3, 4] 215 | poses_ = poses + 0 216 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) 217 | c2w = poses_avg(poses) 218 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 219 | 220 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 221 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 222 | 223 | poses = np.linalg.inv(c2w) @ poses 224 | poses_[:, :3, :4] = poses[:, :3, :4] 225 | poses = poses_ 226 | return poses 227 | 228 | 229 | def spherify_poses(poses, bds): 230 | p34_to_44 = lambda p: np.concatenate( 231 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 232 | ) 233 | 234 | rays_d = poses[:, :3, 2:3] 235 | rays_o = poses[:, :3, 3:4] 236 | 237 | def min_line_dist(rays_o, rays_d): 238 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 239 | b_i = -A_i @ rays_o 240 | pt_mindist = np.squeeze( 241 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0) 242 | ) 243 | return pt_mindist 244 | 245 | pt_mindist = min_line_dist(rays_o, rays_d) 246 | 247 | center = pt_mindist 248 | up = (poses[:, :3, 3] - center).mean(0) 249 | 250 | vec0 = normalize(up) 251 | vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0)) 252 | vec2 = normalize(np.cross(vec0, vec1)) 253 | pos = center 254 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 255 | 256 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 257 | 258 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 259 | 260 | sc = 1.0 / rad 261 | poses_reset[:, :3, 3] *= sc 262 | bds *= sc 263 | rad *= sc 264 | 265 | centroid = np.mean(poses_reset[:, :3, 3], 0) 266 | zh = centroid[2] 267 | radcircle = np.sqrt(rad ** 2 - zh ** 2) 268 | new_poses = [] 269 | 270 | for th in np.linspace(0.0, 2.0 * np.pi, 120): 271 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 272 | up = np.array([0, 0, -1.0]) 273 | 274 | vec2 = normalize(camorigin) 275 | vec0 = normalize(np.cross(vec2, up)) 276 | vec1 = normalize(np.cross(vec2, vec0)) 277 | pos = camorigin 278 | p = np.stack([vec0, vec1, vec2, pos], 1) 279 | 280 | new_poses.append(p) 281 | 282 | new_poses = np.stack(new_poses, 0) 283 | 284 | new_poses = np.concatenate( 285 | [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1 286 | ) 287 | poses_reset = np.concatenate( 288 | [ 289 | poses_reset[:, :3, :4], 290 | np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape), 291 | ], 292 | -1, 293 | ) 294 | 295 | return poses_reset, new_poses, bds 296 | 297 | 298 | def load_llff_data( 299 | basedir, 300 | factor=None, 301 | recenter=True, 302 | bd_factor=0.75, 303 | spherify=False, 304 | # path_zflat=False, 305 | split_train_val=8, 306 | render_style="", 307 | ): 308 | 309 | # poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 310 | poses, bds, intrinsic = _load_data( 311 | basedir, factor=factor, load_imgs=False 312 | ) # factor=8 downsamples original imgs by 8x 313 | 314 | print("Loaded LLFF data", basedir, bds.min(), bds.max()) 315 | 316 | # Correct rotation matrix ordering and move variable dim to axis 0 317 | # poses [R | T] [3, 4, images] 318 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 319 | # poses [3, 4, images] --> [images, 3, 4] 320 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 321 | 322 | # imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 323 | # images = imgs 324 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 325 | 326 | # Rescale if bd_factor is provided 327 | sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor) 328 | poses[:, :3, 3] *= sc 329 | bds *= sc 330 | 331 | if recenter: 332 | poses = recenter_poses(poses) 333 | 334 | if spherify: 335 | poses, render_poses, bds = spherify_poses(poses, bds) 336 | else: 337 | c2w = poses_avg(poses) 338 | print("recentered", c2w.shape) 339 | 340 | ## Get spiral 341 | # Get average pose 342 | up = normalize(poses[:, :3, 1].sum(0)) 343 | 344 | close_depth, inf_depth = -1, -1 345 | # Find a reasonable "focus depth" for this dataset 346 | # if os.path.exists(os.path.join(basedir, "planes_spiral.txt")): 347 | # with open(os.path.join(basedir, "planes_spiral.txt"), "r") as fi: 348 | # data = [float(x) for x in fi.readline().split(" ")] 349 | # dmin, dmax = data[:2] 350 | # close_depth = dmin * 0.9 351 | # inf_depth = dmax * 5.0 352 | # elif os.path.exists(os.path.join(basedir, "planes.txt")): 353 | # with open(os.path.join(basedir, "planes.txt"), "r") as fi: 354 | # data = [float(x) for x in fi.readline().split(" ")] 355 | # if len(data) == 3: 356 | # dmin, dmax, invz = data 357 | # elif len(data) == 4: 358 | # dmin, dmax, invz, _ = data 359 | # close_depth = dmin * 0.9 360 | # inf_depth = dmax * 5.0 361 | 362 | prev_close, prev_inf = close_depth, inf_depth 363 | if close_depth < 0 or inf_depth < 0 or render_style == "llff": 364 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 365 | 366 | if render_style == "shiny": 367 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 368 | if close_depth < prev_close: 369 | close_depth = prev_close 370 | if inf_depth > prev_inf: 371 | inf_depth = prev_inf 372 | 373 | dt = 0.75 374 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 375 | focal = mean_dz 376 | 377 | # Get radii for spiral path 378 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 379 | rads = np.percentile(np.abs(tt), 90, 0) 380 | c2w_path = c2w 381 | N_views = 120 382 | N_rots = 2 383 | # if path_zflat: 384 | # # zloc = np.percentile(tt, 10, 0)[2] 385 | # zloc = -close_depth * 0.1 386 | # c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2] 387 | # rads[2] = 0.0 388 | # N_rots = 1 389 | # N_views /= 2 390 | 391 | render_poses = render_path_spiral( 392 | c2w_path, up, rads, focal, zrate=0.5, rots=N_rots, N=N_views 393 | ) 394 | 395 | render_poses = np.array(render_poses).astype(np.float32) 396 | # reference_view_id should stay in train set only 397 | validation_ids = np.arange(poses.shape[0]) 398 | validation_ids[::split_train_val] = -1 399 | validation_ids = validation_ids < 0 400 | train_ids = np.logical_not(validation_ids) 401 | train_poses = poses[train_ids] 402 | train_bds = bds[train_ids] 403 | c2w = poses_avg(train_poses) 404 | 405 | dists = np.sum(np.square(c2w[:3, 3] - train_poses[:, :3, 3]), -1) 406 | reference_view_id = np.argmin(dists) 407 | reference_depth = train_bds[reference_view_id] 408 | print(reference_depth) 409 | 410 | return ( 411 | reference_depth, 412 | reference_view_id, 413 | render_poses, 414 | poses, 415 | intrinsic 416 | ) 417 | -------------------------------------------------------------------------------- /opt/util/nerf_dataset.py: -------------------------------------------------------------------------------- 1 | # Standard NeRF Blender dataset loader 2 | from .util import Rays, Intrin, select_or_shuffle_rays 3 | from .dataset_base import DatasetBase 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import NamedTuple, Optional, Union 7 | from os import path 8 | import imageio 9 | from tqdm import tqdm 10 | import cv2 11 | import json 12 | import numpy as np 13 | 14 | 15 | class NeRFDataset(DatasetBase): 16 | """ 17 | NeRF dataset loader 18 | 19 | WARNING: this is only intended for use with NeRF Blender data!!!! 20 | """ 21 | 22 | focal: float 23 | c2w: torch.Tensor # (n_images, 4, 4) 24 | gt: torch.Tensor # (n_images, h, w, 3) 25 | h: int 26 | w: int 27 | n_images: int 28 | rays: Optional[Rays] 29 | split: str 30 | 31 | def __init__( 32 | self, 33 | root, 34 | split, 35 | epoch_size : Optional[int] = None, 36 | device: Union[str, torch.device] = "cpu", 37 | scene_scale: Optional[float] = None, 38 | factor: int = 1, 39 | scale : Optional[float] = None, 40 | permutation: bool = True, 41 | white_bkgd: bool = True, 42 | n_images = None, 43 | **kwargs 44 | ): 45 | super().__init__() 46 | assert path.isdir(root), f"'{root}' is not a directory" 47 | 48 | if scene_scale is None: 49 | scene_scale = 2/3 50 | if scale is None: 51 | scale = 1.0 52 | self.device = device 53 | self.permutation = permutation 54 | self.epoch_size = epoch_size 55 | all_c2w = [] 56 | all_gt = [] 57 | 58 | split_name = split if split != "test_train" else "train" 59 | data_path = path.join(root, split_name) 60 | data_json = path.join(root, "transforms_" + split_name + ".json") 61 | 62 | print("LOAD DATA", data_path) 63 | print("WARNING: This data loader is ONLY intended for use with NeRF-synthetic Blender data!!!!") 64 | print("If you want to try running this code on Instant-NGP data please use scripts/ingp2nsvf.py") 65 | 66 | j = json.load(open(data_json, "r")) 67 | 68 | # OpenGL -> OpenCV 69 | cam_trans = torch.diag(torch.tensor([1, -1, -1, 1], dtype=torch.float32)) 70 | 71 | for frame in tqdm(j["frames"]): 72 | fpath = path.join(data_path, path.basename(frame["file_path"]) + ".png") 73 | c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) 74 | c2w = c2w @ cam_trans # To OpenCV 75 | 76 | im_gt = imageio.imread(fpath) 77 | if scale < 1.0: 78 | full_size = list(im_gt.shape[:2]) 79 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 80 | im_gt = cv2.resize(im_gt, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA) 81 | 82 | all_c2w.append(c2w) 83 | all_gt.append(torch.from_numpy(im_gt)) 84 | focal = float( 85 | 0.5 * all_gt[0].shape[1] / np.tan(0.5 * j["camera_angle_x"]) 86 | ) 87 | self.c2w = torch.stack(all_c2w) 88 | self.c2w[:, :3, 3] *= scene_scale 89 | 90 | self.gt = torch.stack(all_gt).float() / 255.0 91 | if self.gt.size(-1) == 4: 92 | if white_bkgd: 93 | # Apply alpha channel 94 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 95 | else: 96 | self.gt = self.gt[..., :3] 97 | 98 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 99 | # Choose a subset of training images 100 | if n_images is not None: 101 | if n_images > self.n_images: 102 | print(f'using {self.n_images} available training views instead of the requested {n_images}.') 103 | n_images = self.n_images 104 | self.n_images = n_images 105 | self.gt = self.gt[0:n_images,...] 106 | self.c2w = self.c2w[0:n_images,...] 107 | 108 | self.intrins_full : Intrin = Intrin(focal, focal, 109 | self.w_full * 0.5, 110 | self.h_full * 0.5) 111 | 112 | self.split = split 113 | self.scene_scale = scene_scale 114 | if self.split == "train": 115 | self.gen_rays(factor=factor) 116 | else: 117 | # Rays are not needed for testing 118 | self.h, self.w = self.h_full, self.w_full 119 | self.intrins : Intrin = self.intrins_full 120 | 121 | self.should_use_background = False # Give warning 122 | 123 | -------------------------------------------------------------------------------- /opt/util/nsvf_dataset.py: -------------------------------------------------------------------------------- 1 | # Extended NSVF-format dataset loader 2 | # This is a more sane format vs the NeRF formats 3 | 4 | from .util import Rays, Intrin, similarity_from_cameras 5 | from .dataset_base import DatasetBase 6 | import torch 7 | import torch.nn.functional as F 8 | from typing import NamedTuple, Optional, Union 9 | from os import path 10 | import os 11 | import cv2 12 | import imageio 13 | from tqdm import tqdm 14 | import json 15 | import numpy as np 16 | from warnings import warn 17 | 18 | 19 | class NSVFDataset(DatasetBase): 20 | """ 21 | Extended NSVF dataset loader 22 | """ 23 | 24 | focal: float 25 | c2w: torch.Tensor # (n_images, 4, 4) 26 | gt: torch.Tensor # (n_images, h, w, 3) 27 | h: int 28 | w: int 29 | n_images: int 30 | rays: Optional[Rays] 31 | split: str 32 | 33 | def __init__( 34 | self, 35 | root, 36 | split, 37 | epoch_size : Optional[int] = None, 38 | device: Union[str, torch.device] = "cpu", 39 | scene_scale: Optional[float] = None, # Scene scaling 40 | factor: int = 1, # Image scaling (on ray gen; use gen_rays(factor) to dynamically change scale) 41 | scale : Optional[float] = 1.0, # Image scaling (on load) 42 | permutation: bool = True, 43 | white_bkgd: bool = True, 44 | normalize_by_bbox: bool = False, 45 | data_bbox_scale : float = 1.1, # Only used if normalize_by_bbox 46 | cam_scale_factor : float = 0.95, 47 | normalize_by_camera: bool = True, 48 | **kwargs 49 | ): 50 | super().__init__() 51 | assert path.isdir(root), f"'{root}' is not a directory" 52 | 53 | if scene_scale is None: 54 | scene_scale = 1.0 55 | if scale is None: 56 | scale = 1.0 57 | 58 | self.device = device 59 | self.permutation = permutation 60 | self.epoch_size = epoch_size 61 | all_c2w = [] 62 | all_gt = [] 63 | 64 | split_name = split if split != "test_train" else "train" 65 | 66 | print("LOAD NSVF DATA", root, 'split', split) 67 | 68 | self.split = split 69 | 70 | def sort_key(x): 71 | if len(x) > 2 and x[1] == "_": 72 | return x[2:] 73 | return x 74 | def look_for_dir(cands, required=True): 75 | for cand in cands: 76 | if path.isdir(path.join(root, cand)): 77 | return cand 78 | if required: 79 | assert False, "None of " + str(cands) + " found in data directory" 80 | return "" 81 | 82 | img_dir_name = look_for_dir(["images", "image", "rgb"]) 83 | pose_dir_name = look_for_dir(["poses", "pose"]) 84 | # intrin_dir_name = look_for_dir(["intrin"], required=False) 85 | orig_img_files = sorted(os.listdir(path.join(root, img_dir_name)), key=sort_key) 86 | 87 | # Select subset of files 88 | if self.split == "train" or self.split == "test_train": 89 | img_files = [x for x in orig_img_files if x.startswith("0_")] 90 | elif self.split == "val": 91 | img_files = [x for x in orig_img_files if x.startswith("1_")] 92 | elif self.split == "test": 93 | test_img_files = [x for x in orig_img_files if x.startswith("2_")] 94 | if len(test_img_files) == 0: 95 | test_img_files = [x for x in orig_img_files if x.startswith("1_")] 96 | img_files = test_img_files 97 | else: 98 | img_files = orig_img_files 99 | 100 | if len(img_files) == 0: 101 | if self.split == "train": 102 | img_files = [x for i, x in enumerate(orig_img_files) if i % 16 != 0] 103 | else: 104 | img_files = orig_img_files[::16] 105 | 106 | assert len(img_files) > 0, "No matching images in directory: " + path.join(root, img_dir_name) 107 | self.img_files = img_files 108 | 109 | dynamic_resize = scale < 1 110 | self.use_integral_scaling = False 111 | scaled_img_dir = '' 112 | if dynamic_resize and abs((1.0 / scale) - round(1.0 / scale)) < 1e-9: 113 | resized_dir = img_dir_name + "_" + str(round(1.0 / scale)) 114 | if path.exists(path.join(root, resized_dir)): 115 | img_dir_name = resized_dir 116 | dynamic_resize = False 117 | print("> Pre-resized images from", img_dir_name) 118 | if dynamic_resize: 119 | print("> WARNING: Dynamically resizing images") 120 | 121 | full_size = [0, 0] 122 | rsz_h = rsz_w = 0 123 | 124 | for img_fname in tqdm(img_files): 125 | img_path = path.join(root, img_dir_name, img_fname) 126 | image = imageio.imread(img_path) 127 | pose_fname = path.splitext(img_fname)[0] + ".txt" 128 | pose_path = path.join(root, pose_dir_name, pose_fname) 129 | # intrin_path = path.join(root, intrin_dir_name, pose_fname) 130 | 131 | cam_mtx = np.loadtxt(pose_path).reshape(-1, 4) 132 | if len(cam_mtx) == 3: 133 | bottom = np.array([[0.0, 0.0, 0.0, 1.0]]) 134 | cam_mtx = np.concatenate([cam_mtx, bottom], axis=0) 135 | all_c2w.append(torch.from_numpy(cam_mtx)) # C2W (4, 4) OpenCV 136 | full_size = list(image.shape[:2]) 137 | rsz_h, rsz_w = [round(hw * scale) for hw in full_size] 138 | if dynamic_resize: 139 | image = cv2.resize(image, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA) 140 | 141 | all_gt.append(torch.from_numpy(image)) 142 | 143 | 144 | self.c2w_f64 = torch.stack(all_c2w) 145 | 146 | print('NORMALIZE BY?', 'bbox' if normalize_by_bbox else 'camera' if normalize_by_camera else 'manual') 147 | if normalize_by_bbox: 148 | # Not used, but could be helpful 149 | bbox_path = path.join(root, "bbox.txt") 150 | if path.exists(bbox_path): 151 | bbox_data = np.loadtxt(bbox_path) 152 | center = (bbox_data[:3] + bbox_data[3:6]) * 0.5 153 | radius = (bbox_data[3:6] - bbox_data[:3]) * 0.5 * data_bbox_scale 154 | 155 | # Recenter 156 | self.c2w_f64[:, :3, 3] -= center 157 | # Rescale 158 | scene_scale = 1.0 / radius.max() 159 | else: 160 | warn('normalize_by_bbox=True but bbox.txt was not available') 161 | elif normalize_by_camera: 162 | norm_pose_files = sorted(os.listdir(path.join(root, pose_dir_name)), key=sort_key) 163 | norm_poses = np.stack([np.loadtxt(path.join(root, pose_dir_name, x)).reshape(-1, 4) 164 | for x in norm_pose_files], axis=0) 165 | 166 | # Select subset of files 167 | T, sscale = similarity_from_cameras(norm_poses) 168 | 169 | self.c2w_f64 = torch.from_numpy(T) @ self.c2w_f64 170 | scene_scale = cam_scale_factor * sscale 171 | 172 | # center = np.mean(norm_poses[:, :3, 3], axis=0) 173 | # radius = np.median(np.linalg.norm(norm_poses[:, :3, 3] - center, axis=-1)) 174 | # self.c2w_f64[:, :3, 3] -= center 175 | # scene_scale = cam_scale_factor / radius 176 | # print('good', self.c2w_f64[:2], scene_scale) 177 | 178 | print('scene_scale', scene_scale) 179 | self.c2w_f64[:, :3, 3] *= scene_scale 180 | self.c2w = self.c2w_f64.float() 181 | 182 | self.gt = torch.stack(all_gt).double() / 255.0 183 | if self.gt.size(-1) == 4: 184 | if white_bkgd: 185 | # Apply alpha channel 186 | self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:]) 187 | else: 188 | self.gt = self.gt[..., :3] 189 | self.gt = self.gt.float() 190 | 191 | assert full_size[0] > 0 and full_size[1] > 0, "Empty images" 192 | self.n_images, self.h_full, self.w_full, _ = self.gt.shape 193 | 194 | intrin_path = path.join(root, "intrinsics.txt") 195 | assert path.exists(intrin_path), "intrinsics unavailable" 196 | try: 197 | K: np.ndarray = np.loadtxt(intrin_path) 198 | fx = K[0, 0] 199 | fy = K[1, 1] 200 | cx = K[0, 2] 201 | cy = K[1, 2] 202 | except: 203 | # Weird format sometimes in NSVF data 204 | with open(intrin_path, "r") as f: 205 | spl = f.readline().split() 206 | fx = fy = float(spl[0]) 207 | cx = float(spl[1]) 208 | cy = float(spl[2]) 209 | if scale < 1.0: 210 | scale_w = rsz_w / full_size[1] 211 | scale_h = rsz_h / full_size[0] 212 | fx *= scale_w 213 | cx *= scale_w 214 | fy *= scale_h 215 | cy *= scale_h 216 | 217 | self.intrins_full : Intrin = Intrin(fx, fy, cx, cy) 218 | print(' intrinsics (loaded reso)', self.intrins_full) 219 | 220 | self.scene_scale = scene_scale 221 | if self.split == "train": 222 | self.gen_rays(factor=factor) 223 | else: 224 | # Rays are not needed for testing 225 | self.h, self.w = self.h_full, self.w_full 226 | self.intrins : Intrin = self.intrins_full 227 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | import os.path as osp 4 | import warnings 5 | 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 9 | 10 | __version__ = None 11 | exec(open('svox2/version.py', 'r').read()) 12 | 13 | CUDA_FLAGS = [] 14 | INSTALL_REQUIREMENTS = [] 15 | include_dirs = [osp.join(ROOT_DIR, "svox2", "csrc", "include")] 16 | 17 | # From PyTorch3D 18 | cub_home = os.environ.get("CUB_HOME", None) 19 | if cub_home is None: 20 | prefix = os.environ.get("CONDA_PREFIX", None) 21 | if prefix is not None and os.path.isdir(prefix + "/include/cub"): 22 | cub_home = prefix + "/include" 23 | 24 | if cub_home is None: 25 | warnings.warn( 26 | "The environment variable `CUB_HOME` was not found." 27 | "Installation will fail if your system CUDA toolkit version is less than 11." 28 | "NVIDIA CUB can be downloaded " 29 | "from `https://github.com/NVIDIA/cub/releases`. You can unpack " 30 | "it to a location of your choice and set the environment variable " 31 | "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." 32 | ) 33 | else: 34 | include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " ")) 35 | 36 | try: 37 | ext_modules = [ 38 | CUDAExtension('svox2.csrc', [ 39 | 'svox2/csrc/svox2.cpp', 40 | 'svox2/csrc/svox2_kernel.cu', 41 | 'svox2/csrc/render_lerp_kernel_cuvol.cu', 42 | 'svox2/csrc/render_lerp_kernel_nvol.cu', 43 | 'svox2/csrc/render_svox1_kernel.cu', 44 | 'svox2/csrc/misc_kernel.cu', 45 | 'svox2/csrc/loss_kernel.cu', 46 | 'svox2/csrc/optim_kernel.cu', 47 | ], include_dirs=include_dirs, 48 | optional=False), 49 | ] 50 | except: 51 | import warnings 52 | warnings.warn("Failed to build CUDA extension") 53 | ext_modules = [] 54 | 55 | setup( 56 | name='svox2', 57 | version=__version__, 58 | author='Alex Yu', 59 | author_email='alexyu99126@gmail.com', 60 | description='PyTorch sparse voxel volume extension, including custom CUDA kernels', 61 | long_description='PyTorch sparse voxel volume extension, including custom CUDA kernels', 62 | ext_modules=ext_modules, 63 | setup_requires=['pybind11>=2.5.0'], 64 | packages=['svox2', 'svox2.csrc'], 65 | cmdclass={'build_ext': BuildExtension}, 66 | zip_safe=False, 67 | ) 68 | -------------------------------------------------------------------------------- /svox2/__init__.py: -------------------------------------------------------------------------------- 1 | from .defs import * 2 | from .svox2 import SparseGrid, Camera, Rays, RenderOptions 3 | from .version import __version__ 4 | -------------------------------------------------------------------------------- /svox2/csrc/.ccls: -------------------------------------------------------------------------------- 1 | %compile_commands.json 2 | %cu -x cuda 3 | %cu --cuda-gpu-arch=sm_61 4 | %cu --cuda-path=/usr/local/cuda-11.2 5 | -------------------------------------------------------------------------------- /svox2/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PlenOctree Authors. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions are met: 5 | # 6 | # 1. Redistributions of source code must retain the above copyright notice, 7 | # this list of conditions and the following disclaimer. 8 | # 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 13 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 15 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 16 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 17 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 18 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 19 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 20 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 21 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 22 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 23 | # POSSIBILITY OF SUCH DAMAGE. 24 | 25 | # NOTE: This CMakeLists is for development purposes only 26 | # (To check CUDA compile errors) 27 | # It is NOT necessary to use this for installation. Just use pip install . 28 | cmake_minimum_required( VERSION 3.3 ) 29 | 30 | if(NOT CMAKE_BUILD_TYPE) 31 | set(CMAKE_BUILD_TYPE Release) 32 | endif() 33 | if (POLICY CMP0048) 34 | cmake_policy(SET CMP0048 NEW) 35 | endif (POLICY CMP0048) 36 | if (POLICY CMP0069) 37 | cmake_policy(SET CMP0069 NEW) 38 | endif (POLICY CMP0069) 39 | if (POLICY CMP0072) 40 | cmake_policy(SET CMP0072 NEW) 41 | endif (POLICY CMP0072) 42 | 43 | project( svox2 ) 44 | 45 | set(CMAKE_CXX_STANDARD 14) 46 | enable_language(CUDA) 47 | message(STATUS "CUDA enabled") 48 | set( CMAKE_CUDA_STANDARD 14 ) 49 | set( CMAKE_CUDA_STANDARD_REQUIRED ON) 50 | set( CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -Xcudafe \"--display_error_number --diag_suppress=3057 --diag_suppress=3058 --diag_suppress=3059 --diag_suppress=3060\" -lineinfo -arch=sm_75 ") 51 | # -Xptxas=\"-v\" 52 | 53 | set( INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" ) 54 | 55 | if( MSVC ) 56 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd") 57 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /GLT /Ox") 58 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler=\"/MT\"" ) 59 | endif() 60 | 61 | file(GLOB SOURCES 62 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp 63 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cu) 64 | 65 | find_package(pybind11 REQUIRED) 66 | find_package(Torch REQUIRED) 67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 68 | 69 | include_directories (${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 70 | 71 | pybind11_add_module(svox2-test SHARED ${SOURCES}) 72 | target_link_libraries(svox2-test PRIVATE "${TORCH_LIBRARIES}") 73 | target_include_directories(svox2-test PRIVATE "${INCLUDE_DIR}") 74 | 75 | if (MSVC) 76 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") 77 | add_custom_command(TARGET svox2-test 78 | POST_BUILD 79 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 80 | ${TORCH_DLLS} 81 | $) 82 | endif (MSVC) 83 | -------------------------------------------------------------------------------- /svox2/csrc/include/cubemap_util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuda_util.cuh" 3 | #include 4 | #include 5 | 6 | #define _AXIS(x) (x>>1) 7 | #define _ORI(x) (x&1) 8 | #define _FACE(axis, ori) uint8_t((axis << 1) | ori) 9 | 10 | namespace { 11 | namespace device { 12 | 13 | struct CubemapCoord { 14 | uint8_t face; 15 | float uv[2]; 16 | }; 17 | 18 | struct CubemapLocation { 19 | uint8_t face; 20 | int16_t uv[2]; 21 | }; 22 | 23 | struct CubemapBilerpQuery { 24 | CubemapLocation ptr[2][2]; 25 | float duv[2]; 26 | }; 27 | 28 | __device__ __inline__ void 29 | invert_cubemap(int u, int v, float r, 30 | int reso, 31 | float* __restrict__ out) { 32 | const float u_norm = (u + 0.5f) / reso * 2 - 1; 33 | const float v_norm = (v + 0.5f) / reso * 2 - 1; 34 | // EAC 35 | const float tx = tanf((M_PI / 4) * u_norm); 36 | const float ty = tanf((M_PI / 4) * v_norm); 37 | const float common = r * rnorm3df(1.f, tx, ty); 38 | out[0] = tx * common; 39 | out[1] = ty * common; 40 | out[2] = common; 41 | } 42 | 43 | __device__ __inline__ void 44 | invert_cubemap_traditional(int u, int v, float r, 45 | int reso, 46 | float* __restrict__ out) { 47 | const float u_norm = (u + 0.5f) / reso * 2 - 1; 48 | const float v_norm = (v + 0.5f) / reso * 2 - 1; 49 | const float common = r * rnorm3df(1.f, u_norm, v_norm); 50 | out[0] = u_norm * common; 51 | out[1] = v_norm * common; 52 | out[2] = common; 53 | } 54 | 55 | __device__ __host__ __inline__ CubemapCoord 56 | dir_to_cubemap_coord(const float* __restrict__ xyz_o, 57 | int face_reso, 58 | bool eac = true) { 59 | float maxv; 60 | int ax; 61 | float xyz[3] = {xyz_o[0], xyz_o[1], xyz_o[2]}; 62 | if (fabsf(xyz[0]) >= fabsf(xyz[1]) && fabsf(xyz[0]) >= fabsf(xyz[2])) { 63 | ax = 0; maxv = xyz[0]; 64 | } else if (fabsf(xyz[1]) >= fabsf(xyz[2])) { 65 | ax = 1; maxv = xyz[1]; 66 | } else { 67 | ax = 2; maxv = xyz[2]; 68 | } 69 | const float recip = 1.f / fabsf(maxv); 70 | xyz[0] *= recip; 71 | xyz[1] *= recip; 72 | xyz[2] *= recip; 73 | 74 | if (eac) { 75 | #pragma unroll 3 76 | for (int i = 0; i < 3; ++i) { 77 | xyz[i] = atanf(xyz[i]) * (4 * M_1_PI); 78 | } 79 | } 80 | 81 | CubemapCoord idx; 82 | idx.uv[0] = ((xyz[(ax ^ 1) & 1] + 1) * face_reso - 1) * 0.5; 83 | idx.uv[1] = ((xyz[(ax ^ 2) & 2] + 1) * face_reso - 1) * 0.5; 84 | const int ori = xyz[ax] >= 0; 85 | idx.face = _FACE(ax, ori); 86 | 87 | return idx; 88 | } 89 | 90 | __device__ __host__ __inline__ CubemapBilerpQuery 91 | cubemap_build_query( 92 | const CubemapCoord& idx, 93 | int face_reso) { 94 | const int uv_idx[2] ={ (int)floorf(idx.uv[0]), (int)floorf(idx.uv[1]) }; 95 | 96 | bool m[2][2]; 97 | m[0][0] = uv_idx[0] < 0; 98 | m[0][1] = uv_idx[0] > face_reso - 2; 99 | m[1][0] = uv_idx[1] < 0; 100 | m[1][1] = uv_idx[1] > face_reso - 2; 101 | 102 | const int face = idx.face; 103 | const int ax = _AXIS(face); 104 | const int ori = _ORI(face); 105 | // if ax is one of {0, 1, 2}, this trick gets the 2 106 | // of {0, 1, 2} other than ax 107 | const int uvd[2] = {((ax ^ 1) & 1), ((ax ^ 2) & 2)}; 108 | int uv_ori[2]; 109 | 110 | CubemapBilerpQuery result; 111 | result.duv[0] = idx.uv[0] - uv_idx[0]; 112 | result.duv[1] = idx.uv[1] - uv_idx[1]; 113 | 114 | #pragma unroll 2 115 | for (uv_ori[0] = 0; uv_ori[0] < 2; ++uv_ori[0]) { 116 | #pragma unroll 2 117 | for (uv_ori[1] = 0; uv_ori[1] < 2; ++uv_ori[1]) { 118 | CubemapLocation& nidx = result.ptr[uv_ori[0]][uv_ori[1]]; 119 | nidx.face = face; 120 | nidx.uv[0] = uv_idx[0] + uv_ori[0]; 121 | nidx.uv[1] = uv_idx[1] + uv_ori[1]; 122 | 123 | const bool mu = m[0][uv_ori[0]]; 124 | const bool mv = m[1][uv_ori[1]]; 125 | 126 | int edge_idx = -1; 127 | if (mu) { 128 | // Crosses edge in u-axis 129 | if (mv) { 130 | // FIXME: deal with corners properly, right now 131 | // just clamps, resulting in a little artifact 132 | // at each cube corner 133 | nidx.uv[0] = min(max(nidx.uv[0], 0), face_reso - 1); 134 | nidx.uv[1] = min(max(nidx.uv[1], 0), face_reso - 1); 135 | } else { 136 | edge_idx = 0; 137 | } 138 | } else if (mv) { 139 | // Crosses edge in v-axis 140 | edge_idx = 1; 141 | } 142 | if (~edge_idx) { 143 | const int nax = uvd[edge_idx]; 144 | const int16_t other_coord = nidx.uv[1 - edge_idx]; 145 | 146 | // Determine directions in the new face 147 | const int nud = (nax ^ 1) & 1; 148 | // const int nvd = (nax ^ 2) & 2; 149 | 150 | if (nud == ax) { 151 | nidx.uv[0] = ori ? (face_reso - 1) : 0; 152 | nidx.uv[1] = other_coord; 153 | } else { 154 | nidx.uv[0] = other_coord; 155 | nidx.uv[1] = ori ? (face_reso - 1) : 0; 156 | } 157 | 158 | nidx.face = _FACE(nax, uv_ori[edge_idx]); 159 | } 160 | // Interior point: nothing needs to be done 161 | 162 | } 163 | } 164 | 165 | return result; 166 | } 167 | 168 | __device__ __host__ __inline__ float 169 | cubemap_sample( 170 | const float* __restrict__ cubemap, // (6, face_reso, face_reso, n_channels) 171 | const CubemapBilerpQuery& query, 172 | int face_reso, 173 | int n_channels, 174 | int chnl_id) { 175 | 176 | // NOTE: assuming address will fit in int32 177 | const int stride1 = face_reso * n_channels; 178 | const int stride0 = face_reso * stride1; 179 | const CubemapLocation& p00 = query.ptr[0][0]; 180 | const float v00 = cubemap[p00.face * stride0 + p00.uv[0] * stride1 + p00.uv[1] * n_channels + chnl_id]; 181 | const CubemapLocation& p01 = query.ptr[0][1]; 182 | const float v01 = cubemap[p01.face * stride0 + p01.uv[0] * stride1 + p01.uv[1] * n_channels + chnl_id]; 183 | const CubemapLocation& p10 = query.ptr[1][0]; 184 | const float v10 = cubemap[p10.face * stride0 + p10.uv[0] * stride1 + p10.uv[1] * n_channels + chnl_id]; 185 | const CubemapLocation& p11 = query.ptr[1][1]; 186 | const float v11 = cubemap[p11.face * stride0 + p11.uv[0] * stride1 + p11.uv[1] * n_channels + chnl_id]; 187 | 188 | const float val0 = lerp(v00, v01, query.duv[1]); 189 | const float val1 = lerp(v10, v11, query.duv[1]); 190 | 191 | return lerp(val0, val1, query.duv[0]); 192 | } 193 | 194 | __device__ __inline__ void 195 | cubemap_sample_backward( 196 | float* __restrict__ cubemap_grad, // (6, face_reso, face_reso, n_channels) 197 | const CubemapBilerpQuery& query, 198 | int face_reso, 199 | int n_channels, 200 | float grad_out, 201 | int chnl_id, 202 | bool* __restrict__ mask_out = nullptr) { 203 | 204 | // NOTE: assuming address will fit in int32 205 | const float bu = query.duv[0], bv = query.duv[1]; 206 | const float au = 1.f - bu, av = 1.f - bv; 207 | 208 | #define _ADD_CUBEVERT(i, j, val) { \ 209 | const CubemapLocation& p00 = query.ptr[i][j]; \ 210 | const int idx = (p00.face * face_reso + p00.uv[0]) * face_reso + p00.uv[1]; \ 211 | float* __restrict__ v00 = &cubemap_grad[idx * n_channels + chnl_id]; \ 212 | atomicAdd(v00, val); \ 213 | if (mask_out != nullptr) { \ 214 | mask_out[idx] = true; \ 215 | } \ 216 | } 217 | 218 | _ADD_CUBEVERT(0, 0, au * av * grad_out); 219 | _ADD_CUBEVERT(0, 1, au * bv * grad_out); 220 | _ADD_CUBEVERT(1, 0, bu * av * grad_out); 221 | _ADD_CUBEVERT(1, 1, bu * bv * grad_out); 222 | #undef _ADD_CUBEVERT 223 | 224 | } 225 | 226 | __device__ __host__ __inline__ float 227 | multi_cubemap_sample( 228 | const float* __restrict__ cubemap1, // (6, face_reso, face_reso, n_channels) 229 | const float* __restrict__ cubemap2, // (6, face_reso, face_reso, n_channels) 230 | const CubemapBilerpQuery& query, 231 | float interp_wt, 232 | int face_reso, 233 | int n_channels, 234 | int chnl_id) { 235 | const float val1 = cubemap_sample(cubemap1, 236 | query, 237 | face_reso, 238 | n_channels, 239 | chnl_id); 240 | const float val2 = cubemap_sample(cubemap2, 241 | query, 242 | face_reso, 243 | n_channels, 244 | chnl_id); 245 | return lerp(val1, val2, interp_wt); 246 | } 247 | 248 | __device__ __inline__ void 249 | multi_cubemap_sample_backward( 250 | float* __restrict__ cubemap_grad1, // (6, face_reso, face_reso, n_channels) 251 | float* __restrict__ cubemap_grad2, // (6, face_reso, face_reso, n_channels) 252 | const CubemapBilerpQuery& query, 253 | float interp_wt, 254 | int face_reso, 255 | int n_channels, 256 | float grad_out, 257 | int chnl_id, 258 | bool* __restrict__ mask_out1 = nullptr, 259 | bool* __restrict__ mask_out2 = nullptr) { 260 | if (cubemap_grad1 == nullptr) return; 261 | cubemap_sample_backward(cubemap_grad1, 262 | query, 263 | face_reso, 264 | n_channels, 265 | grad_out * (1.f - interp_wt), 266 | chnl_id, 267 | mask_out1); 268 | cubemap_sample_backward(cubemap_grad2, 269 | query, 270 | face_reso, 271 | n_channels, 272 | grad_out * interp_wt, 273 | chnl_id, 274 | mask_out1 == nullptr ? nullptr : mask_out2); 275 | } 276 | 277 | 278 | } // namespace device 279 | } // namespace 280 | -------------------------------------------------------------------------------- /svox2/csrc/include/cuda_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "util.hpp" 8 | 9 | 10 | #define DEVICE_GUARD(_ten) \ 11 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); 12 | 13 | #define CUDA_GET_THREAD_ID(tid, Q) const int tid = blockIdx.x * blockDim.x + threadIdx.x; \ 14 | if (tid >= Q) return 15 | #define CUDA_GET_THREAD_ID_U64(tid, Q) const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \ 16 | if (tid >= Q) return 17 | #define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1) 18 | #define CUDA_CHECK_ERRORS \ 19 | cudaError_t err = cudaGetLastError(); \ 20 | if (err != cudaSuccess) \ 21 | printf("Error in svox2.%s : %s\n", __FUNCTION__, cudaGetErrorString(err)) 22 | 23 | #define CUDA_MAX_THREADS at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock 24 | 25 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 26 | #else 27 | __device__ inline double atomicAdd(double* address, double val){ 28 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 29 | unsigned long long int old = *address_as_ull, assumed; 30 | do { 31 | assumed = old; 32 | old = atomicCAS(address_as_ull, assumed, 33 | __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | #endif 38 | 39 | __device__ inline void atomicMax(float* result, float value){ 40 | unsigned* result_as_u = (unsigned*)result; 41 | unsigned old = *result_as_u, assumed; 42 | do { 43 | assumed = old; 44 | old = atomicCAS(result_as_u, assumed, 45 | __float_as_int(fmaxf(value, __int_as_float(assumed)))); 46 | } while (old != assumed); 47 | return; 48 | } 49 | 50 | __device__ inline void atomicMax(double* result, double value){ 51 | unsigned long long int* result_as_ull = (unsigned long long int*)result; 52 | unsigned long long int old = *result_as_ull, assumed; 53 | do { 54 | assumed = old; 55 | old = atomicCAS(result_as_ull, assumed, 56 | __double_as_longlong(fmaxf(value, __longlong_as_double(assumed)))); 57 | } while (old != assumed); 58 | return; 59 | } 60 | 61 | __device__ __inline__ void transform_coord(float* __restrict__ point, 62 | const float* __restrict__ scaling, 63 | const float* __restrict__ offset) { 64 | point[0] = fmaf(point[0], scaling[0], offset[0]); // a*b + c 65 | point[1] = fmaf(point[1], scaling[1], offset[1]); // a*b + c 66 | point[2] = fmaf(point[2], scaling[2], offset[2]); // a*b + c 67 | } 68 | 69 | // Linear interp 70 | // Subtract and fused multiply-add 71 | // (1-w) a + w b 72 | template 73 | __host__ __device__ __inline__ T lerp(T a, T b, T w) { 74 | return fmaf(w, b - a, a); 75 | } 76 | 77 | __device__ __inline__ static float _norm( 78 | const float* __restrict__ dir) { 79 | // return sqrtf(dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2]); 80 | return norm3df(dir[0], dir[1], dir[2]); 81 | } 82 | 83 | __device__ __inline__ static float _rnorm( 84 | const float* __restrict__ dir) { 85 | // return 1.f / _norm(dir); 86 | return rnorm3df(dir[0], dir[1], dir[2]); 87 | } 88 | 89 | __host__ __device__ __inline__ static void xsuby3d( 90 | float* __restrict__ x, 91 | const float* __restrict__ y) { 92 | x[0] -= y[0]; 93 | x[1] -= y[1]; 94 | x[2] -= y[2]; 95 | } 96 | 97 | __host__ __device__ __inline__ static float _dot( 98 | const float* __restrict__ x, 99 | const float* __restrict__ y) { 100 | return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]; 101 | } 102 | 103 | __host__ __device__ __inline__ static void _cross( 104 | const float* __restrict__ a, 105 | const float* __restrict__ b, 106 | float* __restrict__ out) { 107 | out[0] = a[1] * b[2] - a[2] * b[1]; 108 | out[1] = a[2] * b[0] - a[0] * b[2]; 109 | out[2] = a[0] * b[1] - a[1] * b[0]; 110 | } 111 | 112 | __device__ __inline__ static float _dist_ray_to_origin( 113 | const float* __restrict__ origin, 114 | const float* __restrict__ dir) { 115 | // dir must be unit vector 116 | float tmp[3]; 117 | _cross(origin, dir, tmp); 118 | return _norm(tmp); 119 | } 120 | 121 | #define int_div2_ceil(x) ((((x) - 1) >> 1) + 1) 122 | 123 | __host__ __inline__ cudaError_t cuda_assert( 124 | const cudaError_t code, const char* const file, 125 | const int line, const bool abort) { 126 | if (code != cudaSuccess) { 127 | fprintf(stderr, "cuda_assert: %s %s %s %d\n", cudaGetErrorName(code) ,cudaGetErrorString(code), 128 | file, line); 129 | 130 | if (abort) { 131 | cudaDeviceReset(); 132 | exit(code); 133 | } 134 | } 135 | 136 | return code; 137 | } 138 | 139 | #define cuda(...) cuda_assert((cuda##__VA_ARGS__), __FILE__, __LINE__, true); 140 | 141 | -------------------------------------------------------------------------------- /svox2/csrc/include/data_spec.hpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include "util.hpp" 4 | #include 5 | 6 | using torch::Tensor; 7 | 8 | enum BasisType { 9 | // For svox 1 compatibility 10 | // BASIS_TYPE_RGBA = 0 11 | BASIS_TYPE_SH = 1, 12 | // BASIS_TYPE_SG = 2 13 | // BASIS_TYPE_ASG = 3 14 | BASIS_TYPE_3D_TEXTURE = 4, 15 | BASIS_TYPE_MLP = 255, 16 | }; 17 | 18 | struct SparseGridSpec { 19 | Tensor density_data; 20 | Tensor sh_data; 21 | Tensor links; 22 | Tensor _offset; 23 | Tensor _scaling; 24 | 25 | Tensor background_links; 26 | Tensor background_data; 27 | 28 | int basis_dim; 29 | uint8_t basis_type; 30 | Tensor basis_data; 31 | 32 | inline void check() { 33 | CHECK_INPUT(density_data); 34 | CHECK_INPUT(sh_data); 35 | CHECK_INPUT(links); 36 | if (background_links.defined()) { 37 | CHECK_INPUT(background_links); 38 | CHECK_INPUT(background_data); 39 | TORCH_CHECK(background_links.ndimension() == 40 | 2); // (H, W) -> [N] \cup {-1} 41 | TORCH_CHECK(background_data.ndimension() == 3); // (N, D, C) -> R 42 | } 43 | if (basis_data.defined()) { 44 | CHECK_INPUT(basis_data); 45 | } 46 | CHECK_CPU_INPUT(_offset); 47 | CHECK_CPU_INPUT(_scaling); 48 | TORCH_CHECK(density_data.ndimension() == 2); 49 | TORCH_CHECK(sh_data.ndimension() == 2); 50 | TORCH_CHECK(links.ndimension() == 3); 51 | } 52 | }; 53 | 54 | struct GridOutputGrads { 55 | torch::Tensor grad_density_out; 56 | torch::Tensor grad_sh_out; 57 | torch::Tensor grad_basis_out; 58 | torch::Tensor grad_background_out; 59 | 60 | torch::Tensor mask_out; 61 | torch::Tensor mask_background_out; 62 | inline void check() { 63 | if (grad_density_out.defined()) { 64 | CHECK_INPUT(grad_density_out); 65 | } 66 | if (grad_sh_out.defined()) { 67 | CHECK_INPUT(grad_sh_out); 68 | } 69 | if (grad_basis_out.defined()) { 70 | CHECK_INPUT(grad_basis_out); 71 | } 72 | if (grad_background_out.defined()) { 73 | CHECK_INPUT(grad_background_out); 74 | } 75 | if (mask_out.defined() && mask_out.size(0) > 0) { 76 | CHECK_INPUT(mask_out); 77 | } 78 | if (mask_background_out.defined() && mask_background_out.size(0) > 0) { 79 | CHECK_INPUT(mask_background_out); 80 | } 81 | } 82 | }; 83 | 84 | struct CameraSpec { 85 | torch::Tensor c2w; 86 | float fx; 87 | float fy; 88 | float cx; 89 | float cy; 90 | int width; 91 | int height; 92 | 93 | float ndc_coeffx; 94 | float ndc_coeffy; 95 | 96 | inline void check() { 97 | CHECK_INPUT(c2w); 98 | TORCH_CHECK(c2w.is_floating_point()); 99 | TORCH_CHECK(c2w.ndimension() == 2); 100 | TORCH_CHECK(c2w.size(1) == 4); 101 | } 102 | }; 103 | 104 | struct RaysSpec { 105 | Tensor origins; 106 | Tensor dirs; 107 | inline void check() { 108 | CHECK_INPUT(origins); 109 | CHECK_INPUT(dirs); 110 | TORCH_CHECK(origins.is_floating_point()); 111 | TORCH_CHECK(dirs.is_floating_point()); 112 | } 113 | }; 114 | 115 | struct RenderOptions { 116 | float background_brightness; 117 | // float step_epsilon; 118 | float step_size; 119 | float sigma_thresh; 120 | float stop_thresh; 121 | 122 | float near_clip; 123 | bool use_spheric_clip; 124 | 125 | bool last_sample_opaque; 126 | 127 | // bool randomize; 128 | // float random_sigma_std; 129 | // float random_sigma_std_background; 130 | // 32-bit RNG state masks 131 | // uint32_t _m1, _m2, _m3; 132 | 133 | // int msi_start_layer = 0; 134 | // int msi_end_layer = 66; 135 | }; 136 | -------------------------------------------------------------------------------- /svox2/csrc/include/data_spec_packed.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include "data_spec.hpp" 5 | #include "cuda_util.cuh" 6 | #include "random_util.cuh" 7 | 8 | namespace { 9 | namespace device { 10 | 11 | struct PackedSparseGridSpec { 12 | PackedSparseGridSpec(SparseGridSpec& spec) 13 | : 14 | density_data(spec.density_data.data_ptr()), 15 | sh_data(spec.sh_data.data_ptr()), 16 | links(spec.links.data_ptr()), 17 | basis_type(spec.basis_type), 18 | basis_data(spec.basis_data.defined() ? spec.basis_data.data_ptr() : nullptr), 19 | background_links(spec.background_links.defined() ? 20 | spec.background_links.data_ptr() : 21 | nullptr), 22 | background_data(spec.background_data.defined() ? 23 | spec.background_data.data_ptr() : 24 | nullptr), 25 | size{(int)spec.links.size(0), 26 | (int)spec.links.size(1), 27 | (int)spec.links.size(2)}, 28 | stride_x{(int)spec.links.stride(0)}, 29 | background_reso{ 30 | spec.background_links.defined() ? (int)spec.background_links.size(1) : 0, 31 | }, 32 | background_nlayers{ 33 | spec.background_data.defined() ? (int)spec.background_data.size(1) : 0 34 | }, 35 | basis_dim(spec.basis_dim), 36 | sh_data_dim((int)spec.sh_data.size(1)), 37 | basis_reso(spec.basis_data.defined() ? spec.basis_data.size(0) : 0), 38 | _offset{spec._offset.data_ptr()[0], 39 | spec._offset.data_ptr()[1], 40 | spec._offset.data_ptr()[2]}, 41 | _scaling{spec._scaling.data_ptr()[0], 42 | spec._scaling.data_ptr()[1], 43 | spec._scaling.data_ptr()[2]} { 44 | } 45 | 46 | float* __restrict__ density_data; 47 | float* __restrict__ sh_data; 48 | const int32_t* __restrict__ links; 49 | 50 | const uint8_t basis_type; 51 | float* __restrict__ basis_data; 52 | 53 | const int32_t* __restrict__ background_links; 54 | float* __restrict__ background_data; 55 | 56 | const int size[3], stride_x; 57 | const int background_reso, background_nlayers; 58 | 59 | const int basis_dim, sh_data_dim, basis_reso; 60 | const float _offset[3]; 61 | const float _scaling[3]; 62 | }; 63 | 64 | struct PackedGridOutputGrads { 65 | PackedGridOutputGrads(GridOutputGrads& grads) : 66 | grad_density_out(grads.grad_density_out.defined() ? grads.grad_density_out.data_ptr() : nullptr), 67 | grad_sh_out(grads.grad_sh_out.defined() ? grads.grad_sh_out.data_ptr() : nullptr), 68 | grad_basis_out(grads.grad_basis_out.defined() ? grads.grad_basis_out.data_ptr() : nullptr), 69 | grad_background_out(grads.grad_background_out.defined() ? grads.grad_background_out.data_ptr() : nullptr), 70 | mask_out((grads.mask_out.defined() && grads.mask_out.size(0) > 0) ? grads.mask_out.data_ptr() : nullptr), 71 | mask_background_out((grads.mask_background_out.defined() && grads.mask_background_out.size(0) > 0) ? grads.mask_background_out.data_ptr() : nullptr) 72 | {} 73 | float* __restrict__ grad_density_out; 74 | float* __restrict__ grad_sh_out; 75 | float* __restrict__ grad_basis_out; 76 | float* __restrict__ grad_background_out; 77 | 78 | bool* __restrict__ mask_out; 79 | bool* __restrict__ mask_background_out; 80 | }; 81 | 82 | struct PackedCameraSpec { 83 | PackedCameraSpec(CameraSpec& cam) : 84 | c2w(cam.c2w.packed_accessor32()), 85 | fx(cam.fx), fy(cam.fy), 86 | cx(cam.cx), cy(cam.cy), 87 | width(cam.width), height(cam.height), 88 | ndc_coeffx(cam.ndc_coeffx), ndc_coeffy(cam.ndc_coeffy) {} 89 | const torch::PackedTensorAccessor32 90 | c2w; 91 | float fx; 92 | float fy; 93 | float cx; 94 | float cy; 95 | int width; 96 | int height; 97 | 98 | float ndc_coeffx; 99 | float ndc_coeffy; 100 | }; 101 | 102 | struct PackedRaysSpec { 103 | const torch::PackedTensorAccessor32 origins; 104 | const torch::PackedTensorAccessor32 dirs; 105 | PackedRaysSpec(RaysSpec& spec) : 106 | origins(spec.origins.packed_accessor32()), 107 | dirs(spec.dirs.packed_accessor32()) 108 | { } 109 | }; 110 | 111 | struct SingleRaySpec { 112 | SingleRaySpec() = default; 113 | __device__ SingleRaySpec(const float* __restrict__ origin, const float* __restrict__ dir) 114 | : origin{origin[0], origin[1], origin[2]}, 115 | dir{dir[0], dir[1], dir[2]} {} 116 | __device__ void set(const float* __restrict__ origin, const float* __restrict__ dir) { 117 | #pragma unroll 3 118 | for (int i = 0; i < 3; ++i) { 119 | this->origin[i] = origin[i]; 120 | this->dir[i] = dir[i]; 121 | } 122 | } 123 | 124 | float origin[3]; 125 | float dir[3]; 126 | float tmin, tmax, world_step; 127 | 128 | float pos[3]; 129 | int32_t l[3]; 130 | RandomEngine32 rng; 131 | }; 132 | 133 | } // namespace device 134 | } // namespace 135 | -------------------------------------------------------------------------------- /svox2/csrc/include/random_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | 6 | // A custom xorshift random generator 7 | // Maybe replace with some CUDA internal stuff? 8 | struct RandomEngine32 { 9 | uint32_t x, y, z; 10 | 11 | // Inclusive both 12 | __host__ __device__ 13 | uint32_t randint(uint32_t lo, uint32_t hi) { 14 | if (hi <= lo) return lo; 15 | uint32_t z = (*this)(); 16 | return z % (hi - lo + 1) + lo; 17 | } 18 | 19 | __host__ __device__ 20 | void rand2(float* out1, float* out2) { 21 | const uint32_t z = (*this)(); 22 | const uint32_t fmax = (1 << 16); 23 | const uint32_t z1 = z >> 16; 24 | const uint32_t z2 = z & (fmax - 1); 25 | const float ifmax = 1.f / fmax; 26 | 27 | *out1 = z1 * ifmax; 28 | *out2 = z2 * ifmax; 29 | } 30 | 31 | __host__ __device__ 32 | float rand() { 33 | uint32_t z = (*this)(); 34 | return float(z) / (1LL << 32); 35 | } 36 | 37 | 38 | __host__ __device__ 39 | void randn2(float* out1, float* out2) { 40 | rand2(out1, out2); 41 | // Box-Muller transform 42 | const float srlog = sqrtf(-2 * logf(*out1 + 1e-32f)); 43 | *out2 *= 2 * M_PI; 44 | *out1 = srlog * cosf(*out2); 45 | *out2 = srlog * sinf(*out2); 46 | } 47 | 48 | __host__ __device__ 49 | float randn() { 50 | float x, y; 51 | rand2(&x, &y); 52 | // Box-Muller transform 53 | return sqrtf(-2 * logf(x + 1e-32f))* cosf(2 * M_PI * y); 54 | } 55 | 56 | __host__ __device__ 57 | uint32_t operator()() { 58 | uint32_t t; 59 | x ^= x << 16; 60 | x ^= x >> 5; 61 | x ^= x << 1; 62 | t = x; 63 | x = y; 64 | y = z; 65 | z = t ^ x ^ y; 66 | return z; 67 | } 68 | }; 69 | -------------------------------------------------------------------------------- /svox2/csrc/include/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // Changed from x.type().is_cuda() due to deprecation 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") 5 | #define CHECK_CONTIGUOUS(x) \ 6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) \ 8 | CHECK_CUDA(x); \ 9 | CHECK_CONTIGUOUS(x) 10 | #define CHECK_CPU_INPUT(x) \ 11 | CHECK_CPU(x); \ 12 | CHECK_CONTIGUOUS(x) 13 | 14 | #if defined(__CUDACC__) 15 | // #define _EXP(x) expf(x) // SLOW EXP 16 | #define _EXP(x) __expf(x) // FAST EXP 17 | #define _SIGMOID(x) (1 / (1 + _EXP(-(x)))) 18 | 19 | #else 20 | 21 | #define _EXP(x) expf(x) 22 | #define _SIGMOID(x) (1 / (1 + expf(-(x)))) 23 | #endif 24 | #define _SQR(x) ((x) * (x)) 25 | -------------------------------------------------------------------------------- /svox2/csrc/optim_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | // Optimizer-related kernels 3 | 4 | #include 5 | #include "cuda_util.cuh" 6 | 7 | namespace { 8 | 9 | const int RMSPROP_STEP_CUDA_THREADS = 256; 10 | const int MIN_BLOCKS_PER_SM = 4; 11 | 12 | namespace device { 13 | 14 | // RMSPROP 15 | __inline__ __device__ void rmsprop_once( 16 | float* __restrict__ ptr_data, 17 | float* __restrict__ ptr_rms, 18 | float* __restrict__ ptr_grad, 19 | const float beta, const float lr, const float epsilon, float minval) { 20 | float rms = *ptr_rms; 21 | rms = rms == 0.f ? _SQR(*ptr_grad) : lerp(_SQR(*ptr_grad), rms, beta); 22 | *ptr_rms = rms; 23 | *ptr_data = fmaxf(*ptr_data - lr * (*ptr_grad) / (sqrtf(rms) + epsilon), minval); 24 | *ptr_grad = 0.f; 25 | } 26 | 27 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 28 | __global__ void rmsprop_step_kernel( 29 | torch::PackedTensorAccessor64 all_data, 30 | torch::PackedTensorAccessor64 all_rms, 31 | torch::PackedTensorAccessor64 all_grad, 32 | float beta, 33 | float lr, 34 | float epsilon, 35 | float minval, 36 | float lr_last) { 37 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 38 | int32_t chnl = tid % all_data.size(1); 39 | rmsprop_once(all_data.data() + tid, 40 | all_rms.data() + tid, 41 | all_grad.data() + tid, 42 | beta, 43 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 44 | epsilon, 45 | minval); 46 | } 47 | 48 | 49 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 50 | __global__ void rmsprop_mask_step_kernel( 51 | torch::PackedTensorAccessor64 all_data, 52 | torch::PackedTensorAccessor64 all_rms, 53 | torch::PackedTensorAccessor64 all_grad, 54 | const bool* __restrict__ mask, 55 | float beta, 56 | float lr, 57 | float epsilon, 58 | float minval, 59 | float lr_last) { 60 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 61 | if (mask[tid / all_data.size(1)] == false) return; 62 | int32_t chnl = tid % all_data.size(1); 63 | rmsprop_once(all_data.data() + tid, 64 | all_rms.data() + tid, 65 | all_grad.data() + tid, 66 | beta, 67 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 68 | epsilon, 69 | minval); 70 | } 71 | 72 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 73 | __global__ void rmsprop_index_step_kernel( 74 | torch::PackedTensorAccessor64 all_data, 75 | torch::PackedTensorAccessor64 all_rms, 76 | torch::PackedTensorAccessor64 all_grad, 77 | torch::PackedTensorAccessor32 indices, 78 | float beta, 79 | float lr, 80 | float epsilon, 81 | float minval, 82 | float lr_last) { 83 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); 84 | int32_t i = indices[tid / all_data.size(1)]; 85 | int32_t chnl = tid % all_data.size(1); 86 | size_t off = i * all_data.size(1) + chnl; 87 | rmsprop_once(all_data.data() + off, all_rms.data() + off, 88 | all_grad.data() + off, 89 | beta, 90 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 91 | epsilon, 92 | minval); 93 | } 94 | 95 | 96 | // SGD 97 | __inline__ __device__ void sgd_once( 98 | float* __restrict__ ptr_data, 99 | float* __restrict__ ptr_grad, 100 | const float lr) { 101 | *ptr_data -= lr * (*ptr_grad); 102 | *ptr_grad = 0.f; 103 | } 104 | 105 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 106 | __global__ void sgd_step_kernel( 107 | torch::PackedTensorAccessor64 all_data, 108 | torch::PackedTensorAccessor64 all_grad, 109 | float lr, 110 | float lr_last) { 111 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 112 | int32_t chnl = tid % all_data.size(1); 113 | sgd_once(all_data.data() + tid, 114 | all_grad.data() + tid, 115 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 116 | } 117 | 118 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 119 | __global__ void sgd_mask_step_kernel( 120 | torch::PackedTensorAccessor64 all_data, 121 | torch::PackedTensorAccessor64 all_grad, 122 | const bool* __restrict__ mask, 123 | float lr, 124 | float lr_last) { 125 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 126 | if (mask[tid / all_data.size(1)] == false) return; 127 | int32_t chnl = tid % all_data.size(1); 128 | sgd_once(all_data.data() + tid, 129 | all_grad.data() + tid, 130 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 131 | } 132 | 133 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 134 | __global__ void sgd_index_step_kernel( 135 | torch::PackedTensorAccessor64 all_data, 136 | torch::PackedTensorAccessor64 all_grad, 137 | torch::PackedTensorAccessor32 indices, 138 | float lr, 139 | float lr_last) { 140 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); 141 | int32_t i = indices[tid / all_data.size(1)]; 142 | int32_t chnl = tid % all_data.size(1); 143 | size_t off = i * all_data.size(1) + chnl; 144 | sgd_once(all_data.data() + off, 145 | all_grad.data() + off, 146 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 147 | } 148 | 149 | 150 | 151 | } // namespace device 152 | } // namespace 153 | 154 | void rmsprop_step( 155 | torch::Tensor data, 156 | torch::Tensor rms, 157 | torch::Tensor grad, 158 | torch::Tensor indexer, 159 | float beta, 160 | float lr, 161 | float epsilon, 162 | float minval, 163 | float lr_last) { 164 | 165 | DEVICE_GUARD(data); 166 | CHECK_INPUT(data); 167 | CHECK_INPUT(rms); 168 | CHECK_INPUT(grad); 169 | CHECK_INPUT(indexer); 170 | 171 | if (lr_last < 0.f) lr_last = lr; 172 | 173 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS; 174 | 175 | if (indexer.dim() == 0) { 176 | const size_t Q = data.size(0) * data.size(1); 177 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 178 | device::rmsprop_step_kernel<<>>( 179 | data.packed_accessor64(), 180 | rms.packed_accessor64(), 181 | grad.packed_accessor64(), 182 | beta, 183 | lr, 184 | epsilon, 185 | minval, 186 | lr_last); 187 | } else if (indexer.size(0) == 0) { 188 | // Skip 189 | } else if (indexer.scalar_type() == at::ScalarType::Bool) { 190 | const size_t Q = data.size(0) * data.size(1); 191 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 192 | device::rmsprop_mask_step_kernel<<>>( 193 | data.packed_accessor64(), 194 | rms.packed_accessor64(), 195 | grad.packed_accessor64(), 196 | indexer.data_ptr(), 197 | beta, 198 | lr, 199 | epsilon, 200 | minval, 201 | lr_last); 202 | } else { 203 | const size_t Q = indexer.size(0) * data.size(1); 204 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 205 | device::rmsprop_index_step_kernel<<>>( 206 | data.packed_accessor64(), 207 | rms.packed_accessor64(), 208 | grad.packed_accessor64(), 209 | indexer.packed_accessor32(), 210 | beta, 211 | lr, 212 | epsilon, 213 | minval, 214 | lr_last); 215 | } 216 | 217 | CUDA_CHECK_ERRORS; 218 | } 219 | 220 | void sgd_step( 221 | torch::Tensor data, 222 | torch::Tensor grad, 223 | torch::Tensor indexer, 224 | float lr, 225 | float lr_last) { 226 | 227 | DEVICE_GUARD(data); 228 | CHECK_INPUT(data); 229 | CHECK_INPUT(grad); 230 | CHECK_INPUT(indexer); 231 | 232 | if (lr_last < 0.f) lr_last = lr; 233 | 234 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS; 235 | 236 | if (indexer.dim() == 0) { 237 | const size_t Q = data.size(0) * data.size(1); 238 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 239 | device::sgd_step_kernel<<>>( 240 | data.packed_accessor64(), 241 | grad.packed_accessor64(), 242 | lr, 243 | lr_last); 244 | } else if (indexer.size(0) == 0) { 245 | // Skip 246 | } else if (indexer.scalar_type() == at::ScalarType::Bool) { 247 | const size_t Q = data.size(0) * data.size(1); 248 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 249 | device::sgd_mask_step_kernel<<>>( 250 | data.packed_accessor64(), 251 | grad.packed_accessor64(), 252 | indexer.data_ptr(), 253 | lr, 254 | lr_last); 255 | } else { 256 | const size_t Q = indexer.size(0) * data.size(1); 257 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 258 | device::sgd_index_step_kernel<<>>( 259 | data.packed_accessor64(), 260 | grad.packed_accessor64(), 261 | indexer.packed_accessor32(), 262 | lr, 263 | lr_last); 264 | } 265 | 266 | CUDA_CHECK_ERRORS; 267 | } 268 | -------------------------------------------------------------------------------- /svox2/csrc/svox2.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | 3 | // This file contains only Python bindings 4 | #include "data_spec.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | using torch::Tensor; 10 | 11 | std::tuple sample_grid(SparseGridSpec &, Tensor, 12 | bool); 13 | void sample_grid_backward(SparseGridSpec &, Tensor, Tensor, Tensor, Tensor, 14 | Tensor, bool); 15 | 16 | // ** NeRF rendering formula (trilerp) 17 | Tensor volume_render_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 18 | Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &, 19 | RenderOptions &); 20 | void volume_render_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 21 | Tensor, Tensor, GridOutputGrads &); 22 | void volume_render_cuvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 23 | Tensor, float, float, Tensor, GridOutputGrads &); 24 | // Expected termination (depth) rendering 25 | torch::Tensor volume_render_expected_term(SparseGridSpec &, RaysSpec &, 26 | RenderOptions &); 27 | // Depth rendering based on sigma-threshold as in Dex-NeRF 28 | torch::Tensor volume_render_sigma_thresh(SparseGridSpec &, RaysSpec &, 29 | RenderOptions &, float); 30 | 31 | // ** NV rendering formula (trilerp) 32 | Tensor volume_render_nvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 33 | void volume_render_nvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 34 | Tensor, Tensor, GridOutputGrads &); 35 | void volume_render_nvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 36 | Tensor, float, float, Tensor, GridOutputGrads &); 37 | 38 | // ** NeRF rendering formula (nearest-neighbor, infinitely many steps) 39 | Tensor volume_render_svox1(SparseGridSpec &, RaysSpec &, RenderOptions &); 40 | void volume_render_svox1_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 41 | Tensor, Tensor, GridOutputGrads &); 42 | void volume_render_svox1_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 43 | Tensor, float, float, Tensor, GridOutputGrads &); 44 | 45 | // Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &, 46 | // RenderOptions &); 47 | // 48 | // void volume_render_cuvol_image_backward(SparseGridSpec &, CameraSpec &, 49 | // RenderOptions &, Tensor, Tensor, 50 | // GridOutputGrads &); 51 | 52 | // Misc 53 | Tensor dilate(Tensor); 54 | void accel_dist_prop(Tensor); 55 | void grid_weight_render(Tensor, CameraSpec &, float, float, bool, Tensor, 56 | Tensor, Tensor); 57 | // void sample_cubemap(Tensor, Tensor, bool, Tensor); 58 | 59 | // Loss 60 | Tensor tv(Tensor, Tensor, int, int, bool, float, bool, float, float); 61 | void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float, 62 | Tensor); 63 | void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool, 64 | float, bool, bool, float, float, Tensor); 65 | void msi_tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, float, float, Tensor); 66 | void lumisphere_tv_grad_sparse(SparseGridSpec &, Tensor, Tensor, Tensor, float, 67 | float, float, float, GridOutputGrads &); 68 | 69 | // Optim 70 | void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float, 71 | float); 72 | void sgd_step(Tensor, Tensor, Tensor, float, float); 73 | 74 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 75 | #define _REG_FUNC(funname) m.def(#funname, &funname) 76 | _REG_FUNC(sample_grid); 77 | _REG_FUNC(sample_grid_backward); 78 | _REG_FUNC(volume_render_cuvol); 79 | _REG_FUNC(volume_render_cuvol_image); 80 | _REG_FUNC(volume_render_cuvol_backward); 81 | _REG_FUNC(volume_render_cuvol_fused); 82 | _REG_FUNC(volume_render_expected_term); 83 | _REG_FUNC(volume_render_sigma_thresh); 84 | 85 | _REG_FUNC(volume_render_nvol); 86 | _REG_FUNC(volume_render_nvol_backward); 87 | _REG_FUNC(volume_render_nvol_fused); 88 | 89 | _REG_FUNC(volume_render_svox1); 90 | _REG_FUNC(volume_render_svox1_backward); 91 | _REG_FUNC(volume_render_svox1_fused); 92 | 93 | // _REG_FUNC(volume_render_cuvol_image); 94 | // _REG_FUNC(volume_render_cuvol_image_backward); 95 | 96 | // Loss 97 | _REG_FUNC(tv); 98 | _REG_FUNC(tv_grad); 99 | _REG_FUNC(tv_grad_sparse); 100 | _REG_FUNC(msi_tv_grad_sparse); 101 | _REG_FUNC(lumisphere_tv_grad_sparse); 102 | 103 | // Misc 104 | _REG_FUNC(dilate); 105 | _REG_FUNC(accel_dist_prop); 106 | _REG_FUNC(grid_weight_render); 107 | // _REG_FUNC(sample_cubemap); 108 | 109 | // Optimizer 110 | _REG_FUNC(rmsprop_step); 111 | _REG_FUNC(sgd_step); 112 | #undef _REG_FUNC 113 | 114 | py::class_(m, "SparseGridSpec") 115 | .def(py::init<>()) 116 | .def_readwrite("density_data", &SparseGridSpec::density_data) 117 | .def_readwrite("sh_data", &SparseGridSpec::sh_data) 118 | .def_readwrite("links", &SparseGridSpec::links) 119 | .def_readwrite("_offset", &SparseGridSpec::_offset) 120 | .def_readwrite("_scaling", &SparseGridSpec::_scaling) 121 | .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) 122 | .def_readwrite("basis_type", &SparseGridSpec::basis_type) 123 | .def_readwrite("basis_data", &SparseGridSpec::basis_data) 124 | .def_readwrite("background_links", &SparseGridSpec::background_links) 125 | .def_readwrite("background_data", &SparseGridSpec::background_data); 126 | 127 | py::class_(m, "CameraSpec") 128 | .def(py::init<>()) 129 | .def_readwrite("c2w", &CameraSpec::c2w) 130 | .def_readwrite("fx", &CameraSpec::fx) 131 | .def_readwrite("fy", &CameraSpec::fy) 132 | .def_readwrite("cx", &CameraSpec::cx) 133 | .def_readwrite("cy", &CameraSpec::cy) 134 | .def_readwrite("width", &CameraSpec::width) 135 | .def_readwrite("height", &CameraSpec::height) 136 | .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx) 137 | .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy); 138 | 139 | py::class_(m, "RaysSpec") 140 | .def(py::init<>()) 141 | .def_readwrite("origins", &RaysSpec::origins) 142 | .def_readwrite("dirs", &RaysSpec::dirs); 143 | 144 | py::class_(m, "RenderOptions") 145 | .def(py::init<>()) 146 | .def_readwrite("background_brightness", 147 | &RenderOptions::background_brightness) 148 | .def_readwrite("step_size", &RenderOptions::step_size) 149 | .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh) 150 | .def_readwrite("stop_thresh", &RenderOptions::stop_thresh) 151 | .def_readwrite("near_clip", &RenderOptions::near_clip) 152 | .def_readwrite("use_spheric_clip", &RenderOptions::use_spheric_clip) 153 | .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque); 154 | // .def_readwrite("randomize", &RenderOptions::randomize) 155 | // .def_readwrite("random_sigma_std", &RenderOptions::random_sigma_std) 156 | // .def_readwrite("random_sigma_std_background", 157 | // &RenderOptions::random_sigma_std_background) 158 | // .def_readwrite("_m1", &RenderOptions::_m1) 159 | // .def_readwrite("_m2", &RenderOptions::_m2) 160 | // .def_readwrite("_m3", &RenderOptions::_m3); 161 | 162 | py::class_(m, "GridOutputGrads") 163 | .def(py::init<>()) 164 | .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out) 165 | .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out) 166 | .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out) 167 | .def_readwrite("grad_background_out", 168 | &GridOutputGrads::grad_background_out) 169 | .def_readwrite("mask_out", &GridOutputGrads::mask_out) 170 | .def_readwrite("mask_background_out", 171 | &GridOutputGrads::mask_background_out); 172 | } 173 | -------------------------------------------------------------------------------- /svox2/csrc/svox2_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #include 3 | #include 4 | #include "cuda_util.cuh" 5 | #include "data_spec_packed.cuh" 6 | 7 | namespace { 8 | namespace device { 9 | 10 | __global__ void sample_grid_sh_kernel( 11 | PackedSparseGridSpec grid, 12 | const torch::PackedTensorAccessor32 points, 13 | // Output 14 | torch::PackedTensorAccessor32 out) { 15 | CUDA_GET_THREAD_ID(tid, points.size(0) * grid.sh_data_dim); 16 | const int idx = tid % grid.sh_data_dim; 17 | const int pid = tid / grid.sh_data_dim; 18 | 19 | float point[3] = {points[pid][0], points[pid][1], points[pid][2]}; 20 | transform_coord(point, grid._scaling, grid._offset); 21 | 22 | int32_t l[3]; 23 | #pragma unroll 3 24 | for (int i = 0; i < 3; ++i) { 25 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 26 | l[i] = min((int32_t)point[i], (int32_t)(grid.size[i] - 2)); 27 | point[i] -= l[i]; 28 | } 29 | 30 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 31 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 32 | 33 | #define MAYBE_READ_LINK(u) ((link_ptr[u] >= 0) ? grid.sh_data[ \ 34 | link_ptr[u] * size_t(grid.sh_data_dim) + idx] : 0.f) 35 | 36 | const float ix0y0 = lerp(MAYBE_READ_LINK(0), MAYBE_READ_LINK(1), point[2]); 37 | const float ix0y1 = lerp(MAYBE_READ_LINK(offy), MAYBE_READ_LINK(offy + 1), point[2]); 38 | const float ix0 = lerp(ix0y0, ix0y1, point[1]); 39 | const float ix1y0 = lerp(MAYBE_READ_LINK(offx), MAYBE_READ_LINK(offx + 1), point[2]); 40 | const float ix1y1 = lerp(MAYBE_READ_LINK(offy + offx), 41 | MAYBE_READ_LINK(offy + offx + 1), point[2]); 42 | const float ix1 = lerp(ix1y0, ix1y1, point[1]); 43 | out[pid][idx] = lerp(ix0, ix1, point[0]); 44 | } 45 | #undef MAYBE_READ_LINK 46 | 47 | __global__ void sample_grid_density_kernel( 48 | PackedSparseGridSpec grid, 49 | const torch::PackedTensorAccessor32 points, 50 | // Output 51 | torch::PackedTensorAccessor32 out) { 52 | CUDA_GET_THREAD_ID(tid, points.size(0)); 53 | 54 | float point[3] = {points[tid][0], points[tid][1], points[tid][2]}; 55 | transform_coord(point, grid._scaling, grid._offset); 56 | 57 | int32_t l[3]; 58 | #pragma unroll 3 59 | for (int i = 0; i < 3; ++i) { 60 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 61 | l[i] = min((int32_t)point[i], grid.size[i] - 2); 62 | point[i] -= l[i]; 63 | } 64 | 65 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 66 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 67 | 68 | #define MAYBE_READ_LINK_D(u) ((link_ptr[u] >= 0) ? grid.density_data[link_ptr[u]] : 0.f) 69 | 70 | const float ix0y0 = lerp(MAYBE_READ_LINK_D(0), MAYBE_READ_LINK_D(1), point[2]); 71 | const float ix0y1 = lerp(MAYBE_READ_LINK_D(offy), MAYBE_READ_LINK_D(offy + 1), point[2]); 72 | const float ix0 = lerp(ix0y0, ix0y1, point[1]); 73 | const float ix1y0 = lerp(MAYBE_READ_LINK_D(offx), MAYBE_READ_LINK_D(offx + 1), point[2]); 74 | const float ix1y1 = lerp(MAYBE_READ_LINK_D(offy + offx), 75 | MAYBE_READ_LINK_D(offy + offx + 1), point[2]); 76 | const float ix1 = lerp(ix1y0, ix1y1, point[1]); 77 | out[tid][0] = lerp(ix0, ix1, point[0]); 78 | } 79 | #undef MAYBE_READ_LINK_D 80 | 81 | __global__ void sample_grid_sh_backward_kernel( 82 | PackedSparseGridSpec grid, 83 | const torch::PackedTensorAccessor32 points, 84 | const torch::PackedTensorAccessor32 grad_out, 85 | // Output 86 | torch::PackedTensorAccessor64 grad_data) { 87 | CUDA_GET_THREAD_ID(tid, points.size(0) * grid.sh_data_dim); 88 | const int idx = tid % grid.sh_data_dim; 89 | const int pid = tid / grid.sh_data_dim; 90 | 91 | float point[3] = {points[pid][0], points[pid][1], points[pid][2]}; 92 | transform_coord(point, grid._scaling, grid._offset); 93 | 94 | int32_t l[3]; 95 | #pragma unroll 3 96 | for (int i = 0; i < 3; ++i) { 97 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 98 | l[i] = min((int32_t)point[i], grid.size[i] - 2); 99 | point[i] -= l[i]; 100 | } 101 | 102 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 103 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 104 | 105 | const float go = grad_out[pid][idx]; 106 | 107 | const float xb = point[0], yb = point[1], zb = point[2]; 108 | const float xa = 1.f - point[0], ya = 1.f - point[1], za = 1.f - point[2]; 109 | 110 | #define MAYBE_ADD_GRAD_LINK_PTR(u, content) if (link_ptr[u] >= 0) \ 111 | atomicAdd(&grad_data[link_ptr[u]][idx], content) 112 | 113 | const float xago = xa * go; 114 | float tmp = ya * xago; 115 | MAYBE_ADD_GRAD_LINK_PTR(0, tmp * za); 116 | MAYBE_ADD_GRAD_LINK_PTR(1, tmp * zb); 117 | tmp = yb * xago; 118 | MAYBE_ADD_GRAD_LINK_PTR(offy, tmp * za); 119 | MAYBE_ADD_GRAD_LINK_PTR(offy + 1, tmp * zb); 120 | 121 | const float xbgo = xb * go; 122 | tmp = ya * xbgo; 123 | MAYBE_ADD_GRAD_LINK_PTR(offx, tmp * za); 124 | MAYBE_ADD_GRAD_LINK_PTR(offx + 1, tmp * zb); 125 | tmp = yb * xbgo; 126 | MAYBE_ADD_GRAD_LINK_PTR(offx + offy, tmp * za); 127 | MAYBE_ADD_GRAD_LINK_PTR(offx + offy + 1, tmp * zb); 128 | } 129 | #undef MAYBE_ADD_GRAD_LINK_PTR 130 | 131 | __global__ void sample_grid_density_backward_kernel( 132 | PackedSparseGridSpec grid, 133 | const torch::PackedTensorAccessor32 points, 134 | const torch::PackedTensorAccessor32 grad_out, 135 | // Output 136 | torch::PackedTensorAccessor32 grad_data) { 137 | CUDA_GET_THREAD_ID(tid, points.size(0)); 138 | 139 | float point[3] = {points[tid][0], points[tid][1], points[tid][2]}; 140 | transform_coord(point, grid._scaling, grid._offset); 141 | 142 | int32_t l[3]; 143 | #pragma unroll 3 144 | for (int i = 0; i < 3; ++i) { 145 | point[i] = fminf(fmaxf(point[i], 0.f), grid.size[i] - 1.f); 146 | l[i] = min((int32_t)point[i], grid.size[i] - 2); 147 | point[i] -= l[i]; 148 | } 149 | 150 | const int offy = grid.size[2], offx = grid.size[1] * grid.size[2]; 151 | const int32_t* __restrict__ link_ptr = &grid.links[l[0] * offx + l[1] * offy + l[2]]; 152 | 153 | const float go = grad_out[tid][0]; 154 | 155 | const float xb = point[0], yb = point[1], zb = point[2]; 156 | const float xa = 1.f - point[0], ya = 1.f - point[1], za = 1.f - point[2]; 157 | 158 | #define MAYBE_ADD_GRAD_LINK_PTR_D(u, content) if (link_ptr[u] >= 0) \ 159 | atomicAdd(grad_data[link_ptr[u]].data(), content) 160 | 161 | const float xago = xa * go; 162 | float tmp = ya * xago; 163 | MAYBE_ADD_GRAD_LINK_PTR_D(0, tmp * za); 164 | MAYBE_ADD_GRAD_LINK_PTR_D(1, tmp * zb); 165 | tmp = yb * xago; 166 | MAYBE_ADD_GRAD_LINK_PTR_D(offy, tmp * za); 167 | MAYBE_ADD_GRAD_LINK_PTR_D(offy + 1, tmp * zb); 168 | 169 | const float xbgo = xb * go; 170 | tmp = ya * xbgo; 171 | MAYBE_ADD_GRAD_LINK_PTR_D(offx, tmp * za); 172 | MAYBE_ADD_GRAD_LINK_PTR_D(offx + 1, tmp * zb); 173 | tmp = yb * xbgo; 174 | MAYBE_ADD_GRAD_LINK_PTR_D(offx + offy, tmp * za); 175 | MAYBE_ADD_GRAD_LINK_PTR_D(offx + offy + 1, tmp * zb); 176 | } 177 | } // namespace device 178 | } // namespace 179 | 180 | 181 | std::tuple sample_grid(SparseGridSpec& grid, torch::Tensor points, 182 | bool want_colors) { 183 | DEVICE_GUARD(points); 184 | grid.check(); 185 | CHECK_INPUT(points); 186 | TORCH_CHECK(points.ndimension() == 2); 187 | const auto Q = points.size(0) * grid.sh_data.size(1); 188 | const int cuda_n_threads = std::min(Q, CUDA_MAX_THREADS); 189 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 190 | const int blocks_density = CUDA_N_BLOCKS_NEEDED(points.size(0), cuda_n_threads); 191 | torch::Tensor result_density = torch::empty({points.size(0), 192 | grid.density_data.size(1)}, points.options()); 193 | torch::Tensor result_sh = torch::empty({want_colors ? points.size(0) : 0, 194 | grid.sh_data.size(1)}, points.options()); 195 | 196 | cudaStream_t stream_1, stream_2; 197 | cudaStreamCreate(&stream_1); 198 | cudaStreamCreate(&stream_2); 199 | 200 | device::sample_grid_density_kernel<<>>( 201 | grid, 202 | points.packed_accessor32(), 203 | // Output 204 | result_density.packed_accessor32()); 205 | if (want_colors) { 206 | device::sample_grid_sh_kernel<<>>( 207 | grid, 208 | points.packed_accessor32(), 209 | // Output 210 | result_sh.packed_accessor32()); 211 | } 212 | 213 | cudaStreamSynchronize(stream_1); 214 | cudaStreamSynchronize(stream_2); 215 | CUDA_CHECK_ERRORS; 216 | return std::tuple{result_density, result_sh}; 217 | } 218 | 219 | void sample_grid_backward( 220 | SparseGridSpec& grid, 221 | torch::Tensor points, 222 | torch::Tensor grad_out_density, 223 | torch::Tensor grad_out_sh, 224 | torch::Tensor grad_density_out, 225 | torch::Tensor grad_sh_out, 226 | bool want_colors) { 227 | DEVICE_GUARD(points); 228 | grid.check(); 229 | CHECK_INPUT(points); 230 | CHECK_INPUT(grad_out_density); 231 | CHECK_INPUT(grad_out_sh); 232 | CHECK_INPUT(grad_density_out); 233 | CHECK_INPUT(grad_sh_out); 234 | TORCH_CHECK(points.ndimension() == 2); 235 | TORCH_CHECK(grad_out_density.ndimension() == 2); 236 | TORCH_CHECK(grad_out_sh.ndimension() == 2); 237 | const auto Q = points.size(0) * grid.sh_data.size(1); 238 | 239 | const int cuda_n_threads = std::min(Q, CUDA_MAX_THREADS); 240 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 241 | const int blocks_density = CUDA_N_BLOCKS_NEEDED(points.size(0), cuda_n_threads); 242 | 243 | cudaStream_t stream_1, stream_2; 244 | cudaStreamCreate(&stream_1); 245 | cudaStreamCreate(&stream_2); 246 | 247 | device::sample_grid_density_backward_kernel<<>>( 248 | grid, 249 | points.packed_accessor32(), 250 | grad_out_density.packed_accessor32(), 251 | // Output 252 | grad_density_out.packed_accessor32()); 253 | 254 | if (want_colors) { 255 | device::sample_grid_sh_backward_kernel<<>>( 256 | grid, 257 | points.packed_accessor32(), 258 | grad_out_sh.packed_accessor32(), 259 | // Output 260 | grad_sh_out.packed_accessor64()); 261 | } 262 | 263 | cudaStreamSynchronize(stream_1); 264 | cudaStreamSynchronize(stream_2); 265 | 266 | CUDA_CHECK_ERRORS; 267 | } 268 | -------------------------------------------------------------------------------- /svox2/defs.py: -------------------------------------------------------------------------------- 1 | # Basis types (copied from C++ data_spec.hpp) 2 | BASIS_TYPE_SH = 1 3 | BASIS_TYPE_3D_TEXTURE = 4 4 | BASIS_TYPE_MLP = 255 5 | -------------------------------------------------------------------------------- /svox2/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1.dev0+sphtexcub.lincolor.fast' 2 | -------------------------------------------------------------------------------- /test/prof.py: -------------------------------------------------------------------------------- 1 | # nvprof -f --profile-from-start off --quiet --metrics all --events all -o prof.nvvp python prof.py 2 | # then use nvvp to open prof.nvvp 3 | import svox2 4 | import torch 5 | import numpy as np 6 | from util import Timing 7 | from matplotlib import pyplot as plt 8 | 9 | import torch.cuda.profiler as profiler 10 | import pyprof 11 | 12 | device='cuda:0' 13 | 14 | GRID_FILE = 'lego.npy' 15 | grid = svox2.SparseGrid(reso=256, device='cpu', radius=1.3256) 16 | data = torch.from_numpy(np.load(GRID_FILE)).view(-1, grid.data_dim) 17 | grid.sh_data.data = data[..., 1:] 18 | grid.density_data.data = data[..., :1] 19 | grid = grid.cuda() 20 | # grid.data.data[..., 0] += 0.1 21 | 22 | N_RAYS = 5000 23 | # origins = torch.full((N_RAYS, 3), fill_value=0.0, device=device, dtype=dtype) 24 | origins = torch.zeros((N_RAYS, 3), device=device, dtype=torch.float32) 25 | dirs : torch.Tensor = torch.randn((N_RAYS, 3), device=device, dtype=torch.float32) 26 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 27 | rays = svox2.Rays(origins, dirs) 28 | 29 | grid.requires_grad_(True) 30 | 31 | samps = grid.volume_render(rays, use_kernel=True) 32 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 33 | 34 | pyprof.init() 35 | with torch.autograd.profiler.emit_nvtx(): 36 | profiler.start() 37 | samps = grid.volume_render(rays, use_kernel=True) 38 | s = samps.sum() 39 | s.backward() 40 | profiler.stop() 41 | -------------------------------------------------------------------------------- /test/sanity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import svox2 3 | 4 | device = 'cuda:0' 5 | 6 | 7 | torch.random.manual_seed(4000) 8 | g = svox2.SparseGrid(center=[0.0, 0.0, 0.0], 9 | radius=[1.0, 1.0, 1.0], 10 | device=device, 11 | basis_type=svox2.BASIS_TYPE_SH, 12 | background_nlayers=0) 13 | 14 | g.opt.backend = 'nvol' 15 | g.opt.sigma_thresh = 0.0 16 | g.opt.stop_thresh = 0.0 17 | g.opt.background_brightness = 1.0 18 | 19 | g.sh_data.data.normal_() 20 | g.density_data.data[..., 0] = 0.1 21 | g.sh_data.data[..., 0] = 0.5 22 | g.sh_data.data[..., 1:] = torch.randn_like(g.sh_data.data[..., 1:]) * 0.01 23 | 24 | if g.use_background: 25 | g.background_data.data[..., -1] = 1.0 26 | g.background_data.data[..., :-1] = torch.randn_like( 27 | g.background_data.data[..., :-1]) * 0.01 28 | # g.background_data.data[..., :-1] = 0.5 29 | 30 | g.basis_data.data.normal_() 31 | g.basis_data.data *= 10.0 32 | # print('use frustum?', g.use_frustum) 33 | 34 | N_RAYS = 1 35 | 36 | # origins = torch.randn(N_RAYS, 3, device=device) * 3 37 | # dirs = torch.randn(N_RAYS, 3, device=device) 38 | # origins = origins[27513:27514] 39 | # dirs = dirs[27513:27514] 40 | 41 | origins = torch.tensor([[-3.8992738723754883, 4.844727993011475, 4.323856830596924]], device='cuda:0') 42 | dirs = torch.tensor([[1.1424630880355835, -1.2679963111877441, -0.8437137603759766]], device='cuda:0') 43 | dirs = dirs / torch.norm(dirs, dim=-1).unsqueeze(-1) 44 | 45 | rays = svox2.Rays(origins=origins, dirs=dirs) 46 | 47 | rgb = g.volume_render(rays, use_kernel=True) 48 | torch.cuda.synchronize() 49 | rgb_gt = g.volume_render(rays, use_kernel=False) 50 | torch.cuda.synchronize() 51 | 52 | E = torch.abs(rgb - rgb_gt) 53 | err = E.max().detach().item() 54 | print(err) 55 | -------------------------------------------------------------------------------- /test/test_render_gradcheck.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import torch.nn.functional as F 4 | from util import Timing 5 | 6 | torch.random.manual_seed(2) 7 | # torch.random.manual_seed(8289) 8 | 9 | device = 'cuda:0' 10 | dtype = torch.float32 11 | grid = svox2.SparseGrid( 12 | reso=128, 13 | center=[0.0, 0.0, 0.0], 14 | radius=[1.0, 1.0, 1.0], 15 | basis_dim=9, 16 | use_z_order=True, 17 | device=device, 18 | background_nlayers=0, 19 | basis_type=svox2.BASIS_TYPE_SH) 20 | grid.opt.backend = 'nvol' 21 | grid.opt.sigma_thresh = 0.0 22 | grid.opt.stop_thresh = 0.0 23 | grid.opt.background_brightness = 1.0 24 | 25 | print(grid.sh_data.shape) 26 | # grid.sh_data.data.normal_() 27 | grid.sh_data.data[..., 0] = 0.5 28 | grid.sh_data.data[..., 1:].normal_(std=0.1) 29 | grid.density_data.data[:] = 100.0 30 | 31 | if grid.use_background: 32 | grid.background_data.data[..., -1] = 0.5 33 | grid.background_data.data[..., :-1] = torch.randn_like( 34 | grid.background_data.data[..., :-1]) * 0.01 35 | 36 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 37 | grid.basis_data.data.normal_() 38 | grid.basis_data.data += 1.0 39 | 40 | ENABLE_TORCH_CHECK = True 41 | # N_RAYS = 5000 #200 * 200 42 | N_RAYS = 200 * 200 43 | origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3 44 | dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 45 | # origins = torch.clip(origins, -0.8, 0.8) 46 | 47 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype) 48 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype) 49 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 50 | 51 | # start = 71 52 | # end = 72 53 | # origins = origins[start:end] 54 | # dirs = dirs[start:end] 55 | # print(origins.tolist(), dirs.tolist()) 56 | 57 | # breakpoint() 58 | rays = svox2.Rays(origins, dirs) 59 | 60 | rgb_gt = torch.zeros((origins.size(0), 3), device=device, dtype=dtype) 61 | 62 | # grid.requires_grad_(True) 63 | 64 | # samps = grid.volume_render(rays, use_kernel=True) 65 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 66 | 67 | with Timing("ours"): 68 | samps = grid.volume_render(rays, use_kernel=True) 69 | s = F.mse_loss(samps, rgb_gt) 70 | 71 | print(s) 72 | print('bkwd..') 73 | with Timing("ours_backward"): 74 | s.backward() 75 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu() 76 | grid_density_grad_s = grid.density_data.grad.clone().cpu() 77 | grid.sh_data.grad = None 78 | grid.density_data.grad = None 79 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 80 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu() 81 | grid.basis_data.grad = None 82 | if grid.use_background: 83 | grid_bg_grad_s = grid.background_data.grad.clone().cpu() 84 | grid.background_data.grad = None 85 | 86 | if ENABLE_TORCH_CHECK: 87 | with Timing("torch"): 88 | sampt = grid.volume_render(rays, use_kernel=False) 89 | s = F.mse_loss(sampt, rgb_gt) 90 | with Timing("torch_backward"): 91 | s.backward() 92 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s) 93 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s) 94 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 95 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu() 96 | if grid.use_background: 97 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s) 98 | 99 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t) 100 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t) 101 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 102 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t) 103 | if grid.use_background: 104 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t) 105 | print('err', torch.abs(samps - sampt).max()) 106 | print('err_sh_grad\n', E.max()) 107 | print(' mean\n', E.mean()) 108 | print('err_density_grad\n', Ed.max()) 109 | print(' mean\n', Ed.mean()) 110 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 111 | print('err_basis_grad\n', Eb.max()) 112 | print(' mean\n', Eb.mean()) 113 | if grid.use_background: 114 | print('err_background_grad\n', Ebg.max()) 115 | print(' mean\n', Ebg.mean()) 116 | print() 117 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max()) 118 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max()) 119 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max()) 120 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max()) 121 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 122 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max()) 123 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max()) 124 | if grid.use_background: 125 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max()) 126 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max()) 127 | -------------------------------------------------------------------------------- /test/test_render_timing.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | from util import Timing 4 | 5 | torch.random.manual_seed(0) 6 | 7 | device = 'cuda:0' 8 | dtype = torch.float32 9 | grid = svox2.SparseGrid( 10 | reso=256, 11 | center=[0.0, 0.0, 0.0], 12 | radius=[1.0, 1.0, 1.0], 13 | basis_dim=9, 14 | use_z_order=True, 15 | device=device) 16 | grid.opt.sigma_thresh = 0.0 17 | grid.opt.stop_thresh = 0.0 18 | 19 | grid.sh_data.data.normal_() 20 | grid.density_data.data[:] = 0.1 21 | 22 | N_RAYS = 200 * 200 23 | # origins = torch.full((N_RAYS, 3), fill_value=0.0, device=device, dtype=dtype) 24 | origins = torch.zeros((N_RAYS, 3), device=device, dtype=dtype) 25 | dirs : torch.Tensor = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 26 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 27 | rays = svox2.Rays(origins, dirs) 28 | 29 | grid.requires_grad_(True) 30 | 31 | samps = grid.volume_render(rays, use_kernel=True) 32 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 33 | 34 | with Timing("ours"): 35 | samps = grid.volume_render(rays, use_kernel=True) 36 | s = samps.sum() 37 | with Timing("ours_backward"): 38 | s.backward() 39 | -------------------------------------------------------------------------------- /test/test_render_timing_smallbat.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | from util import Timing 4 | 5 | torch.random.manual_seed(0) 6 | 7 | device = 'cuda:0' 8 | dtype = torch.float32 9 | grid = svox2.SparseGrid( 10 | reso=256, 11 | center=[0.0, 0.0, 0.0], 12 | radius=[1.0, 1.0, 1.0], 13 | basis_dim=9, 14 | use_z_order=True, 15 | device=device) 16 | grid.opt.sigma_thresh = 0.0 17 | grid.opt.stop_thresh = 0.0 18 | 19 | grid.sh_data.data.normal_() 20 | grid.density_data.data[:] = 0.1 21 | 22 | N_RAYS = 5000 23 | # origins = torch.full((N_RAYS, 3), fill_value=0.0, device=device, dtype=dtype) 24 | origins = torch.zeros((N_RAYS, 3), device=device, dtype=dtype) 25 | dirs : torch.Tensor = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 26 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 27 | rays = svox2.Rays(origins, dirs) 28 | 29 | grid.requires_grad_(True) 30 | 31 | samps = grid.volume_render(rays, use_kernel=True) 32 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 33 | 34 | with Timing("ours"): 35 | samps = grid.volume_render(rays, use_kernel=True) 36 | s = samps.sum() 37 | with Timing("ours_backward"): 38 | s.backward() 39 | -------------------------------------------------------------------------------- /test/test_render_visual.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import numpy as np 4 | from util import Timing 5 | from matplotlib import pyplot as plt 6 | device='cuda:0' 7 | 8 | GRID_FILE = 'lego.npy' 9 | grid = svox2.SparseGrid(reso=256, device='cpu', radius=1.3256) 10 | data = torch.from_numpy(np.load(GRID_FILE)).view(-1, grid.data_dim) 11 | grid.sh_data.data = data[..., 1:] 12 | grid.density_data.data = data[..., :1] 13 | # grid.resample(128, use_z_order=True) 14 | grid = grid.cuda() 15 | 16 | c2w = torch.tensor([ 17 | [ -0.9999999403953552, 0.0, 0.0, 0.0 ], 18 | [ 0.0, -0.7341099977493286, 0.6790305972099304, 2.737260103225708 ], 19 | [ 0.0, 0.6790306568145752, 0.7341098785400391, 2.959291696548462 ], 20 | [ 0.0, 0.0, 0.0, 1.0 ], 21 | ], device=device) 22 | 23 | with torch.no_grad(): 24 | width = height = 800 25 | fx = fy = 1111 26 | origins = c2w[None, :3, 3].expand(height * width, -1).contiguous() 27 | yy, xx = torch.meshgrid( 28 | torch.arange(height, dtype=torch.float64, device=c2w.device), 29 | torch.arange(width, dtype=torch.float64, device=c2w.device), 30 | ) 31 | xx = (xx - width * 0.5) / float(fx) 32 | yy = (yy - height * 0.5) / float(fy) 33 | zz = torch.ones_like(xx) 34 | dirs = torch.stack((xx, -yy, -zz), dim=-1) 35 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 36 | dirs = dirs.reshape(-1, 3) 37 | del xx, yy, zz 38 | dirs = torch.matmul(c2w[None, :3, :3].double(), dirs[..., None])[..., 0].float() 39 | dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True) 40 | 41 | rays = svox2.Rays(origins, dirs) 42 | 43 | for i in range(5): 44 | with Timing("ours"): 45 | im = grid.volume_render(rays, use_kernel=True) 46 | 47 | im = im.reshape(height, width, 3) 48 | im = im.detach().clamp_(0.0, 1.0).cpu() 49 | plt.imshow(im) 50 | plt.show() 51 | -------------------------------------------------------------------------------- /test/test_sample.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import numpy as np 4 | from util import Timing 5 | 6 | torch.random.manual_seed(0) 7 | 8 | device = 'cuda:0' 9 | 10 | # GRID_FILE = 'lego.npy' 11 | # grid = svox2.SparseGrid(reso=256, device='cpu', radius=1.3256) 12 | # grid.data.data = torch.from_numpy(np.load(GRID_FILE)).view(-1, grid.data_dim) 13 | # grid = grid.cuda() 14 | 15 | grid = svox2.SparseGrid(reso=256, center=[0.0, 0.0, 0.0], 16 | radius=1.0, device=device) 17 | grid.sh_data.data.normal_(0.0, 1.0) 18 | grid.density_data.data.normal_(0.1, 0.05).clamp_min_(0.0) 19 | # grid.density_data.data[:] = 1.0 20 | # grid = torch.rand((2, 2, 2, 4), device=device, dtype=torch.float32) 21 | 22 | N_POINTS = 5000 * 1024 23 | points = torch.rand(N_POINTS, 3, device=device) * 2 - 1 24 | # points = torch.tensor([[0.49, 0.49, 0.49], [0.9985, 0.4830, 0.4655]], device=device) 25 | # points.clamp_(-0.999, 0.999) 26 | 27 | _ = grid.sample(points) 28 | _ = grid.sample(points, use_kernel=False) 29 | 30 | grid.requires_grad_(True) 31 | 32 | with Timing("ours"): 33 | sigma_c, rgb_c = grid.sample(points) 34 | 35 | s = sigma_c.sum() + rgb_c.sum() 36 | with Timing("our_back"): 37 | s.backward() 38 | gdo = grid.density_data.grad.clone() 39 | gso = grid.sh_data.grad.clone() 40 | grid.density_data.grad = None 41 | grid.sh_data.grad = None 42 | 43 | with Timing("torch"): 44 | sigma_t, rgb_t = grid.sample(points, use_kernel=False) 45 | s = sigma_t.sum() + rgb_t.sum() 46 | with Timing("torch_back"): 47 | s.backward() 48 | gdt = grid.density_data.grad.clone() 49 | gst = grid.sh_data.grad.clone() 50 | 51 | # print('c\n', sampc) 52 | # print('t\n', sampt) 53 | print('err_sigma\n', torch.abs(sigma_t-sigma_c).max()) 54 | print('err_rgb\n', torch.abs(rgb_t-rgb_c).max()) 55 | print('err_grad_sigma\n', torch.abs(gdo-gdt).max()) 56 | print('err_grad_rgb\n', torch.abs(gso-gst).max()) 57 | -------------------------------------------------------------------------------- /test/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | 4 | class Timing: 5 | def __init__(self, name): 6 | self.name = name 7 | 8 | def __enter__(self): 9 | self.start = torch.cuda.Event(enable_timing=True) 10 | self.end = torch.cuda.Event(enable_timing=True) 11 | self.start.record() 12 | 13 | def __exit__(self, type, value, traceback): 14 | self.end.record() 15 | torch.cuda.synchronize() 16 | print(self.name, 'elapsed', self.start.elapsed_time(self.end), 'ms') 17 | --------------------------------------------------------------------------------