├── .gitignore ├── LICENSE.txt ├── README.md ├── environment.yml ├── nerf_sh ├── __init__.py ├── config │ ├── blender.yaml │ ├── misc │ │ ├── og_nerf.yaml │ │ ├── proj.yaml │ │ └── sg.yaml │ └── tt.yaml ├── eval.py ├── gen_mesh.py ├── gen_video.py ├── nerf │ ├── __init__.py │ ├── datasets.py │ ├── model_utils.py │ ├── models.py │ ├── sg.py │ ├── sh.py │ └── utils.py ├── parse_timing.py └── train.py ├── octree ├── compression.py ├── config │ ├── syn_sg25.json │ ├── syn_sh16.json │ └── tt_sh25.json ├── evaluation.py ├── extraction.py ├── nerf │ ├── __init__.py │ ├── datasets.py │ ├── model_utils.py │ ├── models.py │ ├── sh_proj.py │ └── utils.py ├── optimization.py └── task_manager.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright The PlenOctree Authors 2021 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, 7 | this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 15 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 16 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 17 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 18 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 19 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 20 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 21 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 22 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 23 | POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PlenOctrees Official Repo: NeRF-SH training and conversion 2 | 3 | This repository contains code to train NeRF-SH and 4 | to extract the PlenOctree, constituting part of the code release for: 5 | 6 | PlenOctrees for Real Time Rendering of Neural Radiance Fields
7 | Alex Yu, Ruilong Li, Matthew Tancik, Hao Li, Ren Ng, Angjoo Kanazawa 8 | 9 | https://alexyu.net/plenoctrees 10 | 11 | ``` 12 | @inproceedings{yu2021plenoctrees, 13 | title={{PlenOctrees} for Real-time Rendering of Neural Radiance Fields}, 14 | author={Alex Yu and Ruilong Li and Matthew Tancik and Hao Li and Ren Ng and Angjoo Kanazawa}, 15 | year={2021}, 16 | booktitle={ICCV}, 17 | } 18 | ``` 19 | 20 | Please see the following repository for our C++ PlenOctrees volume renderer: 21 | 22 | 23 | ## Setup 24 | 25 | Please use conda for a replicable environment. 26 | ``` 27 | conda env create -f environment.yml 28 | conda activate plenoctree 29 | pip install --upgrade pip 30 | ``` 31 | 32 | Or you can install the dependencies manually by: 33 | ``` 34 | conda install pytorch torchvision cudatoolkit=11.0 -c pytorch 35 | conda install tqdm 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | [Optional] Install GPU and TPU support for Jax. This is useful for NeRF-SH training. 40 | Remember to **change cuda110 to your CUDA version**, e.g. cuda102 for CUDA 10.2. 41 | ``` 42 | pip install --upgrade jax jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html 43 | ``` 44 | 45 | ## NeRF-SH Training 46 | 47 | We release our trained NeRF-SH models as well as converted plenoctrees at 48 | [Google Drive](https://drive.google.com/drive/folders/1J0lRiDn_wOiLVpCraf6jM7vvCwDr9Dmx?usp=sharing). 49 | You can also use the following commands to reproduce the NeRF-SH models. 50 | 51 | Training and evaluation on the **NeRF-Synthetic dataset** ([Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)): 52 | ``` 53 | export DATA_ROOT=./data/NeRF/nerf_synthetic/ 54 | export CKPT_ROOT=./data/Plenoctree/checkpoints/syn_sh16/ 55 | export SCENE=chair 56 | export CONFIG_FILE=nerf_sh/config/blender 57 | 58 | python -m nerf_sh.train \ 59 | --train_dir $CKPT_ROOT/$SCENE/ \ 60 | --config $CONFIG_FILE \ 61 | --data_dir $DATA_ROOT/$SCENE/ 62 | 63 | python -m nerf_sh.eval \ 64 | --chunk 4096 \ 65 | --train_dir $CKPT_ROOT/$SCENE/ \ 66 | --config $CONFIG_FILE \ 67 | --data_dir $DATA_ROOT/$SCENE/ 68 | ``` 69 | Note for `SCENE=mic`, we adopt a warmup learning rate schedule (`--lr_delay_steps 50000 --lr_delay_mult 0.01`) to avoid unstable initialization. 70 | 71 | 72 | Training and evaluation on **TanksAndTemple dataset** 73 | ([Download Link](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)) from the [NSVF](https://github.com/facebookresearch/NSVF) paper: 74 | ``` 75 | export DATA_ROOT=./data/TanksAndTemple/ 76 | export CKPT_ROOT=./data/Plenoctree/checkpoints/tt_sh25/ 77 | export SCENE=Barn 78 | export CONFIG_FILE=nerf_sh/config/tt 79 | 80 | python -m nerf_sh.train \ 81 | --train_dir $CKPT_ROOT/$SCENE/ \ 82 | --config $CONFIG_FILE \ 83 | --data_dir $DATA_ROOT/$SCENE/ 84 | 85 | python -m nerf_sh.eval \ 86 | --chunk 4096 \ 87 | --train_dir $CKPT_ROOT/$SCENE/ \ 88 | --config $CONFIG_FILE \ 89 | --data_dir $DATA_ROOT/$SCENE/ 90 | ``` 91 | 92 | ## PlenOctrees Conversion and Optimization 93 | 94 | Before converting the NeRF-SH models into plenoctrees, you should already have the 95 | NeRF-SH models trained/downloaded and placed at `./data/Plenoctree/checkpoints/{syn_sh16, tt_sh25}/`. 96 | Also make sure you have the training data placed at 97 | `./data/NeRF/nerf_synthetic` and/or `./data/TanksAndTemple`. 98 | 99 | To reproduce our results in the paper, you can simplly run: 100 | ``` 101 | # NeRF-Synthetic dataset 102 | python -m octree.task_manager octree/config/syn_sh16.json --gpus="0 1 2 3" 103 | 104 | # TanksAndTemple dataset 105 | python -m octree.task_manager octree/config/tt_sh25.json --gpus="0 1 2 3" 106 | ``` 107 | The above command will parallel all scenes in the dataset across the gpus you set. The json files 108 | contain dedicated hyper-parameters towards better performance (PSNR, SSIM, LPIPS). So in this setting, a 24GB GPU is 109 | needed for each scene and in averange the process takes about 15 minutes to finish. The converted plenoctree 110 | will be saved to `./data/Plenoctree/checkpoints/{syn_sh16, tt_sh25}/$SCENE/octrees/`. 111 | 112 | 113 | Below is a more straight-forward script for demonstration purpose: 114 | ``` 115 | export DATA_ROOT=./data/NeRF/nerf_synthetic/ 116 | export CKPT_ROOT=./data/Plenoctree/checkpoints/syn_sh16 117 | export SCENE=chair 118 | export CONFIG_FILE=nerf_sh/config/blender 119 | 120 | python -m octree.extraction \ 121 | --train_dir $CKPT_ROOT/$SCENE/ --is_jaxnerf_ckpt \ 122 | --config $CONFIG_FILE \ 123 | --data_dir $DATA_ROOT/$SCENE/ \ 124 | --output $CKPT_ROOT/$SCENE/octrees/tree.npz 125 | 126 | python -m octree.optimization \ 127 | --input $CKPT_ROOT/$SCENE/tree.npz \ 128 | --config $CONFIG_FILE \ 129 | --data_dir $DATA_ROOT/$SCENE/ \ 130 | --output $CKPT_ROOT/$SCENE/octrees/tree_opt.npz 131 | 132 | python -m octree.evaluation \ 133 | --input $CKPT_ROOT/$SCENE/octrees/tree_opt.npz \ 134 | --config $CONFIG_FILE \ 135 | --data_dir $DATA_ROOT/$SCENE/ 136 | 137 | # [Optional] Only used for in-browser viewing. 138 | python -m octree.compression \ 139 | $CKPT_ROOT/$SCENE/octrees/tree_opt.npz \ 140 | --out_dir $CKPT_ROOT/$SCENE/ \ 141 | --overwrite 142 | ``` 143 | 144 | ## MISC 145 | 146 | ### Project Vanilla NeRF to PlenOctree 147 | 148 | A vanilla trained NeRF can also be converted to a plenoctree for fast inference. To mimic the 149 | view-independency propertity as in a NeRF-SH model, we project the vanilla NeRF model to SH basis functions 150 | by sampling view directions for every points in the space. Though this makes converting vanilla NeRF to 151 | a plenoctree possible, the projection process inevitability loses the quality of the model, even with a large amount 152 | of sampling view directions (which takes hours to finish). So we recommend to just directly train a NeRF-SH model end-to-end. 153 | 154 | Below is a example of projecting a trained vanilla NeRF model from 155 | [JaxNeRF repo](https://github.com/google-research/google-research/tree/master/jaxnerf) 156 | ([Download Link](http://storage.googleapis.com/gresearch/jaxnerf/jaxnerf_pretrained_models.zip)) to a plenoctree. 157 | After extraction, you can optimize & evaluate & compress the plenoctree just like usual: 158 | ``` 159 | export DATA_ROOT=./data/NeRF/nerf_synthetic/ 160 | export CKPT_ROOT=./data/JaxNeRF/jaxnerf_models/blender/ 161 | export SCENE=drums 162 | export CONFIG_FILE=nerf_sh/config/misc/proj 163 | 164 | python -m octree.extraction \ 165 | --train_dir $CKPT_ROOT/$SCENE/ --is_jaxnerf_ckpt \ 166 | --config $CONFIG_FILE \ 167 | --data_dir $DATA_ROOT/$SCENE/ \ 168 | --output $CKPT_ROOT/$SCENE/octrees/tree.npz \ 169 | --projection_samples 100 \ 170 | --radius 1.3 171 | ``` 172 | Note `--projection_samples` controls how many sampling view directions are used. More sampling view directions give better 173 | projection quality but takes longer time to finish. For example, for the `drums` scene 174 | in the NeRF-Synthetic dataset, `100 / 10000` sampling view directions takes about `2 mins / 2 hours` to finish the plenoctree extraction. 175 | It produce *raw* plenoctrees with `PSNR=22.49 / 23.84` (before optimization). Note that extraction from a NeRF-SH model produce 176 | a *raw* plenoctree with `PSNR=25.01`. 177 | 178 | ### List of possible improvements 179 | 180 | In the interst reproducibility, the parameters used in the paper are also used here. 181 | For future work we recommend trying the changes in mip-NeRF 182 | for improved stability and quality: 183 | 184 | - Centered pixels (+ 0.5 on x, y) when generating rays 185 | - Use shifted SoftPlus instead of ReLU for density (including for octree optimization) 186 | - Pad the RGB sigmoid output (avoid low gradient region near 0/1 color) 187 | - Multi-scale training from mip-NeRF 188 | 189 | 190 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # run: conda env create -f environment.yml 2 | name: plenoctree 3 | channels: 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.8.8 8 | - numpy>=1.16.4,<1.19.0 9 | - pip 10 | - pip: 11 | - dotmap 12 | - imageio 13 | - imageio-ffmpeg 14 | - ipdb 15 | - pretrainedmodels 16 | - lpips 17 | - jax==0.2.9 18 | - jaxlib>=0.1.57 19 | - flax>=0.3.1 20 | - opencv-python>=4.4.0 21 | - Pillow>=7.2.0 22 | - pyyaml>=5.3.1 23 | - tensorboard>=2.4.0 24 | - tensorflow>=2.3.1 25 | - imageio 26 | - imageio-ffmpeg 27 | - pymcubes 28 | - svox>=0.2.28 29 | - scipy>=1.6.0 30 | - torch>=1.7.0,<=1.7.1 31 | - tqdm 32 | 33 | -------------------------------------------------------------------------------- /nerf_sh/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /nerf_sh/config/blender.yaml: -------------------------------------------------------------------------------- 1 | dataset: blender 2 | image_batching: false 3 | factor: 0 4 | num_coarse_samples: 64 5 | num_fine_samples: 128 6 | use_viewdirs: false 7 | white_bkgd: true 8 | batch_size: 1024 9 | sh_deg: 3 10 | randomized: true 11 | max_steps: 2000000 12 | -------------------------------------------------------------------------------- /nerf_sh/config/misc/og_nerf.yaml: -------------------------------------------------------------------------------- 1 | dataset: blender 2 | image_batching: false 3 | factor: 0 4 | num_coarse_samples: 64 5 | num_fine_samples: 128 6 | use_viewdirs: true 7 | white_bkgd: true 8 | batch_size: 1024 9 | randomized: true 10 | sparsity_weight: 0.0 11 | max_steps: 2000000 12 | -------------------------------------------------------------------------------- /nerf_sh/config/misc/proj.yaml: -------------------------------------------------------------------------------- 1 | dataset: blender 2 | image_batching: false 3 | factor: 0 4 | num_coarse_samples: 64 5 | num_fine_samples: 128 6 | use_viewdirs: true 7 | white_bkgd: true 8 | batch_size: 1024 9 | sh_deg: 4 10 | randomized: true 11 | max_steps: 2000000 12 | -------------------------------------------------------------------------------- /nerf_sh/config/misc/sg.yaml: -------------------------------------------------------------------------------- 1 | dataset: blender 2 | image_batching: false 3 | factor: 0 4 | num_coarse_samples: 64 5 | num_fine_samples: 128 6 | use_viewdirs: false 7 | white_bkgd: true 8 | batch_size: 1024 9 | sg_dim: 25 10 | randomized: true 11 | max_steps: 2000000 12 | -------------------------------------------------------------------------------- /nerf_sh/config/tt.yaml: -------------------------------------------------------------------------------- 1 | # Tanks and Temples dataset 2 | dataset: nsvf 3 | image_batching: false 4 | factor: 0 5 | num_coarse_samples: 64 6 | num_fine_samples: 128 7 | use_viewdirs: false 8 | white_bkgd: true # No alpha channel 9 | batch_size: 1024 10 | randomized: true 11 | sh_deg: 4 12 | max_steps: 2000000 13 | near: 0.0 14 | far: 4.0 15 | sparsity_radius: 5.0 16 | sparsity_length: 0.2 17 | -------------------------------------------------------------------------------- /nerf_sh/eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Evaluation script for Nerf.""" 19 | 20 | import os 21 | # Get rid of ugly TF logs 22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 23 | 24 | import functools 25 | from os import path 26 | 27 | from absl import app 28 | from absl import flags 29 | import flax 30 | from flax.metrics import tensorboard 31 | from flax.training import checkpoints 32 | import jax 33 | from jax import random 34 | import numpy as np 35 | 36 | from nerf_sh.nerf import datasets 37 | from nerf_sh.nerf import models 38 | from nerf_sh.nerf import utils 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | utils.define_flags() 43 | 44 | 45 | def main(unused_argv): 46 | rng = random.PRNGKey(20200823) 47 | rng, key = random.split(rng) 48 | 49 | utils.update_flags(FLAGS) 50 | utils.check_flags(FLAGS) 51 | 52 | dataset = datasets.get_dataset("test", FLAGS) 53 | model, state = models.get_model_state(key, FLAGS, restore=False) 54 | 55 | # Rendering is forced to be deterministic even if training was randomized, as 56 | # this eliminates "speckle" artifacts. 57 | render_pfn = utils.get_render_pfn(model, randomized=False) 58 | 59 | # Compiling to the CPU because it's faster and more accurate. 60 | ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.0), backend="cpu") 61 | 62 | last_step = 0 63 | out_dir = path.join( 64 | FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds" 65 | ) 66 | if not FLAGS.eval_once: 67 | summary_writer = tensorboard.SummaryWriter(path.join(FLAGS.train_dir, "eval")) 68 | while True: 69 | print('Loading model') 70 | state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) 71 | step = int(state.optimizer.state.step) 72 | if step <= last_step: 73 | continue 74 | if FLAGS.save_output and (not utils.isdir(out_dir)): 75 | utils.makedirs(out_dir) 76 | psnrs = [] 77 | ssims = [] 78 | if not FLAGS.eval_once: 79 | showcase_index = np.random.randint(0, dataset.size) 80 | for idx in range(dataset.size): 81 | print(f"Evaluating {idx+1}/{dataset.size}") 82 | batch = next(dataset) 83 | if idx % FLAGS.approx_eval_skip != 0: 84 | continue 85 | pred_color, pred_disp, pred_acc = utils.render_image( 86 | functools.partial(render_pfn, state.optimizer.target), 87 | batch["rays"], 88 | rng, 89 | FLAGS.dataset == "llff", 90 | chunk=FLAGS.chunk, 91 | ) 92 | if jax.host_id() != 0: # Only record via host 0. 93 | continue 94 | if not FLAGS.eval_once and idx == showcase_index: 95 | showcase_color = pred_color 96 | showcase_disp = pred_disp 97 | showcase_acc = pred_acc 98 | if not FLAGS.render_path: 99 | showcase_gt = batch["pixels"] 100 | # if not FLAGS.render_path: 101 | # psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean()) 102 | # ssim = ssim_fn(pred_color, batch["pixels"]) 103 | # print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") 104 | # psnrs.append(float(psnr)) 105 | # ssims.append(float(ssim)) 106 | if FLAGS.save_output: 107 | utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) 108 | utils.save_img( 109 | pred_disp[Ellipsis, 0], 110 | path.join(out_dir, "disp_{:03d}.png".format(idx)), 111 | ) 112 | if (not FLAGS.eval_once) and (jax.host_id() == 0): 113 | summary_writer.image("pred_color", showcase_color, step) 114 | summary_writer.image("pred_disp", showcase_disp, step) 115 | summary_writer.image("pred_acc", showcase_acc, step) 116 | # if not FLAGS.render_path: 117 | # summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step) 118 | # summary_writer.scalar("ssim", np.mean(np.array(ssims)), step) 119 | # summary_writer.image("target", showcase_gt, step) 120 | # if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): 121 | # with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: 122 | # f.write("{}".format(np.mean(np.array(psnrs)))) 123 | # with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: 124 | # f.write("{}".format(np.mean(np.array(ssims)))) 125 | # with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: 126 | # f.write(" ".join([str(v) for v in psnrs])) 127 | # with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: 128 | # f.write(" ".join([str(v) for v in ssims])) 129 | if FLAGS.eval_once: 130 | break 131 | if int(step) >= FLAGS.max_steps: 132 | break 133 | last_step = step 134 | 135 | 136 | if __name__ == "__main__": 137 | app.run(main) 138 | -------------------------------------------------------------------------------- /nerf_sh/gen_mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | import os 24 | # Get rid of ugly TF logs 25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 26 | 27 | from absl import app 28 | from absl import flags 29 | import mcubes 30 | 31 | import jax 32 | from jax import config 33 | from jax import random 34 | import jax.numpy as jnp 35 | import numpy as np 36 | import flax 37 | from flax.training import checkpoints 38 | 39 | import functools 40 | 41 | from nerf_sh.nerf import models 42 | from nerf_sh.nerf import utils 43 | from nerf_sh.nerf.utils import host0_print as h0print 44 | 45 | FLAGS = flags.FLAGS 46 | 47 | utils.define_flags() 48 | 49 | flags.DEFINE_string( 50 | "reso", 51 | "300 300 300", 52 | "Marching cube resolution in each dimension: x y z", 53 | ) 54 | flags.DEFINE_string( 55 | "c1", 56 | "-2 -2 -2", 57 | "Marching cubes bounds lower corner 1 in x y z OR single number", 58 | ) 59 | flags.DEFINE_string( 60 | "c2", 61 | "2 2 2", 62 | "Marching cubes bounds upper corner in x y z OR single number", 63 | ) 64 | flags.DEFINE_float( 65 | "iso", 6.0, "Marching cubes isosurface" 66 | ) 67 | flags.DEFINE_bool( 68 | "coarse", 69 | False, 70 | "Force use corase network (else depends on renderer n_fine in conf)", 71 | ) 72 | flags.DEFINE_integer( 73 | "point_chunk", 74 | 720720, 75 | "Chunk (batch) size of points for evaluation. NOTE: --chunk will be ignored", 76 | ) 77 | # TODO: implement color 78 | # flags.DEFINE_bool( 79 | # "color", 80 | # False, 81 | # "Generate colored mesh." 82 | # ) 83 | 84 | 85 | config.parse_flags_with_absl() 86 | 87 | 88 | def marching_cubes( 89 | fn, 90 | c1, 91 | c2, 92 | reso, 93 | isosurface, 94 | chunk, 95 | ): 96 | """ 97 | Run marching cubes on network. Uses PyMCubes. 98 | Args: 99 | fn main NeRF type network 100 | c1: list corner 1 of marching cube bounds x,y,z 101 | c2: list corner 2 of marching cube bounds x,y,z (all > c1) 102 | reso: list resolutions of marching cubes x,y,z 103 | isosurface: float sigma-isosurface of marching cubes 104 | """ 105 | grid = np.vstack( 106 | np.meshgrid( 107 | *(np.linspace(lo, hi, sz, dtype=np.float32) 108 | for lo, hi, sz in zip(c1, c2, reso)), 109 | indexing="ij" 110 | ) 111 | ).reshape(3, -1).T 112 | 113 | h0print("* Evaluating sigma @", grid.shape[0], "points") 114 | rgbs, sigmas = utils.eval_points( 115 | fn, 116 | grid, 117 | chunk, 118 | ) 119 | sigmas = sigmas.reshape(*reso) 120 | del rgbs 121 | 122 | if jax.host_id() == 0: 123 | print("* Running marching cubes") 124 | vertices, triangles = mcubes.marching_cubes(sigmas, isosurface) 125 | # Scale 126 | c1, c2 = np.array(c1), np.array(c2) 127 | vertices *= (c2 - c1) / np.array(reso) 128 | 129 | return vertices + c1, triangles 130 | return None, None 131 | 132 | 133 | def save_obj(vertices, triangles, path, vert_rgb=None): 134 | """ 135 | Save OBJ file, optionally with vertex colors. 136 | This version is faster than PyMCubes and supports color. 137 | Taken from PIFu. 138 | :param vertices (N, 3) 139 | :param triangles (N, 3) 140 | :param vert_rgb (N, 3) rgb 141 | """ 142 | file = open(path, "w") 143 | if vert_rgb is None: 144 | # No color 145 | for v in vertices: 146 | file.write("v %.4f %.4f %.4f\n" % (v[0], v[1], v[2])) 147 | else: 148 | # Color 149 | for idx, v in enumerate(vertices): 150 | c = vert_rgb[idx] 151 | file.write( 152 | "v %.4f %.4f %.4f %.4f %.4f %.4f\n" 153 | % (v[0], v[1], v[2], c[0], c[1], c[2]) 154 | ) 155 | for f in triangles: 156 | f_plus = f + 1 157 | file.write("f %d %d %d\n" % (f_plus[0], f_plus[1], f_plus[2])) 158 | file.close() 159 | 160 | 161 | def main(unused_argv): 162 | rng = random.PRNGKey(20200823) 163 | 164 | utils.update_flags(FLAGS) 165 | utils.check_flags(FLAGS, require_data=False) 166 | 167 | reso = list(map(int, FLAGS.reso.split())) 168 | if len(reso) == 1: 169 | reso *= 3 170 | c1 = list(map(float, FLAGS.c1.split())) 171 | if len(c1) == 1: 172 | c1 *= 3 173 | c2 = list(map(float, FLAGS.c2.split())) 174 | if len(c2) == 1: 175 | c2 *= 3 176 | 177 | rng, key = random.split(rng) 178 | 179 | h0print('* Creating model') 180 | model, state = models.get_model_state(key, FLAGS) 181 | h0print('* Eval reso', FLAGS.reso, 'coarse?', FLAGS.coarse) 182 | 183 | eval_points_pfn = utils.get_eval_points_pfn(model, raw_rgb=True, 184 | coarse=FLAGS.coarse) 185 | 186 | verts, faces = marching_cubes( 187 | functools.partial(eval_points_pfn, state.optimizer.target), 188 | c1=c1, c2=c2, reso=reso, isosurface=FLAGS.iso, chunk=FLAGS.point_chunk 189 | ) 190 | 191 | if jax.host_id() == 0: 192 | mesh_path = os.path.join(FLAGS.train_dir, 'mesh.obj') 193 | print(' Saving to', mesh_path) 194 | save_obj(verts, faces, mesh_path) 195 | 196 | 197 | if __name__ == "__main__": 198 | app.run(main) 199 | -------------------------------------------------------------------------------- /nerf_sh/gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """ 24 | Simple video generation script (for 360 Blender scene only) 25 | """ 26 | import os 27 | # Get rid of ugly TF logs 28 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 29 | 30 | from absl import app 31 | from absl import flags 32 | 33 | import jax 34 | from jax import config 35 | from jax import random 36 | import jax.numpy as jnp 37 | import numpy as np 38 | import flax 39 | from flax.training import checkpoints 40 | 41 | import functools 42 | 43 | from nerf_sh.nerf import models 44 | from nerf_sh.nerf import utils 45 | from nerf_sh.nerf.utils import host0_print as h0print 46 | 47 | import imageio 48 | 49 | FLAGS = flags.FLAGS 50 | 51 | utils.define_flags() 52 | 53 | flags.DEFINE_float( 54 | "elevation", 55 | -30.0, 56 | "Elevation angle (negative is above)", 57 | ) 58 | flags.DEFINE_integer( 59 | "num_views", 60 | 40, 61 | "The number of views to generate.", 62 | ) 63 | flags.DEFINE_integer( 64 | "height", 65 | 800, 66 | "The size of images to generate.", 67 | ) 68 | flags.DEFINE_integer( 69 | "width", 70 | 800, 71 | "The size of images to generate.", 72 | ) 73 | flags.DEFINE_float( 74 | "camera_angle_x", 75 | 0.7, 76 | "The camera angle in rad in x direction (used to get focal length).", 77 | short_name='A', 78 | ) 79 | flags.DEFINE_string( 80 | "intrin", 81 | None, 82 | "Intrinsics file. If set, overrides camera_angle_x", 83 | ) 84 | flags.DEFINE_float( 85 | "radius", 86 | 4.0, 87 | "Radius to origin of camera path.", 88 | ) 89 | flags.DEFINE_integer( 90 | "fps", 91 | 20, 92 | "FPS of generated video", 93 | ) 94 | flags.DEFINE_integer( 95 | "up_axis", 96 | 1, 97 | "up axis for camera views; 1-6: Z up/Z down/Y up/Y down/X up/X down; " + 98 | "same effect as pressing number keys in volrend", 99 | ) 100 | flags.DEFINE_string( 101 | "write_poses", 102 | None, 103 | "Specify to write poses to given file (4N x 4), does not write poses else", 104 | ) 105 | 106 | config.parse_flags_with_absl() 107 | 108 | def main(unused_argv): 109 | rng = random.PRNGKey(20200823) 110 | 111 | utils.update_flags(FLAGS) 112 | utils.check_flags(FLAGS, require_data=False) 113 | 114 | rng, key = random.split(rng) 115 | 116 | h0print('* Generating poses') 117 | render_poses = np.stack( 118 | [ 119 | utils.pose_spherical(angle, FLAGS.elevation, FLAGS.radius, FLAGS.up_axis - 1) 120 | for angle in np.linspace(-180, 180, FLAGS.num_views + 1)[:-1] 121 | ], 122 | 0, 123 | ) # (NV, 4, 4) 124 | 125 | if FLAGS.write_poses: 126 | np.savetxt(FLAGS.write_poses, render_poses.reshape(-1, 4)) 127 | print('Saved poses to', FLAGS.write_poses) 128 | 129 | h0print('* Generating rays') 130 | focal = 0.5 * FLAGS.width / np.tan(0.5 * FLAGS.camera_angle_x) 131 | 132 | if FLAGS.intrin is not None: 133 | print('Load focal length from intrin file') 134 | K : np.ndarray = np.loadtxt(FLAGS.intrin) 135 | focal = (K[0, 0] + K[1, 1]) * 0.5 136 | 137 | rays = utils.generate_rays(FLAGS.width, FLAGS.height, focal, render_poses) 138 | 139 | h0print('* Creating model') 140 | model, state = models.get_model_state(key, FLAGS) 141 | render_pfn = utils.get_render_pfn(model, randomized=False) 142 | 143 | h0print('* Rendering') 144 | 145 | vid_name = "e{:03}".format(int(-FLAGS.elevation * 10)) 146 | video_dir = os.path.join(FLAGS.train_dir, 'video', vid_name) 147 | frames_dir = os.path.join(video_dir, 'frames') 148 | h0print(' Saving to', video_dir) 149 | utils.makedirs(frames_dir) 150 | 151 | frames = [] 152 | for i in range(FLAGS.num_views): 153 | h0print(f'** View {i+1}/{FLAGS.num_views} = {i / FLAGS.num_views * 100}%') 154 | pred_color, pred_disp, pred_acc = utils.render_image( 155 | functools.partial(render_pfn, state.optimizer.target), 156 | utils.to_device(utils.namedtuple_map(lambda x: x[i], rays)), 157 | rng, 158 | FLAGS.dataset == "llff", 159 | chunk=FLAGS.chunk, 160 | ) 161 | if jax.host_id() == 0: 162 | utils.save_img(pred_color, os.path.join(frames_dir, f'{i:04}.png')) 163 | frames.append(np.array(pred_color)) 164 | 165 | if jax.host_id() == 0: 166 | frames = np.stack(frames) 167 | vid_path = os.path.join(video_dir, "video.mp4") 168 | print('* Writing video', vid_path) 169 | imageio.mimwrite( 170 | vid_path, (np.clip(frames, 0.0, 1.0) * 255).astype(np.uint8), 171 | fps=FLAGS.fps, quality=8 172 | ) 173 | print('* Done') 174 | 175 | if __name__ == "__main__": 176 | app.run(main) 177 | -------------------------------------------------------------------------------- /nerf_sh/nerf/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /nerf_sh/nerf/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Helper functions/classes for model definition.""" 19 | 20 | import functools 21 | from typing import Any, Callable 22 | 23 | from flax import linen as nn 24 | import jax 25 | from jax import lax 26 | from jax import random 27 | import jax.numpy as jnp 28 | 29 | 30 | class MLP(nn.Module): 31 | """A simple MLP.""" 32 | 33 | net_depth: int = 8 # The depth of the first part of MLP. 34 | net_width: int = 256 # The width of the first part of MLP. 35 | net_depth_condition: int = 1 # The depth of the second part of MLP. 36 | net_width_condition: int = 128 # The width of the second part of MLP. 37 | net_activation: Callable[Ellipsis, Any] = nn.relu # The activation function. 38 | skip_layer: int = 4 # The layer to add skip layers to. 39 | num_rgb_channels: int = 3 # The number of RGB channels. 40 | num_sigma_channels: int = 1 # The number of sigma channels. 41 | 42 | @nn.compact 43 | def __call__(self, x, condition=None): 44 | """Evaluate the MLP. 45 | 46 | Args: 47 | x: jnp.ndarray(float32), [batch, num_samples, feature], points. 48 | condition: jnp.ndarray(float32), [batch, feature], if not None, this 49 | variable will be part of the input to the second part of the MLP 50 | concatenated with the output vector of the first part of the MLP. If 51 | None, only the first part of the MLP will be used with input x. In the 52 | original paper, this variable is the view direction. 53 | 54 | Returns: 55 | raw_rgb: jnp.ndarray(float32), with a shape of 56 | [batch, num_samples, num_rgb_channels]. 57 | raw_sigma: jnp.ndarray(float32), with a shape of 58 | [batch, num_samples, num_sigma_channels]. 59 | """ 60 | feature_dim = x.shape[-1] 61 | num_samples = x.shape[1] 62 | x = x.reshape([-1, feature_dim]) 63 | dense_layer = functools.partial( 64 | nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform() 65 | ) 66 | inputs = x 67 | for i in range(self.net_depth): 68 | x = dense_layer(self.net_width)(x) 69 | x = self.net_activation(x) 70 | if i % self.skip_layer == 0 and i > 0: 71 | x = jnp.concatenate([x, inputs], axis=-1) 72 | raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape( 73 | [-1, num_samples, self.num_sigma_channels] 74 | ) 75 | 76 | if condition is not None: 77 | # Output of the first part of MLP. 78 | bottleneck = dense_layer(self.net_width)(x) 79 | # Broadcast condition from [batch, feature] to 80 | # [batch, num_samples, feature] since all the samples along the same ray 81 | # have the same viewdir. 82 | condition = jnp.tile(condition[:, None, :], (1, num_samples, 1)) 83 | # Collapse the [batch, num_samples, feature] tensor to 84 | # [batch * num_samples, feature] so that it can be fed into nn.Dense. 85 | condition = condition.reshape([-1, condition.shape[-1]]) 86 | x = jnp.concatenate([bottleneck, condition], axis=-1) 87 | # Here use 1 extra layer to align with the original nerf model. 88 | for i in range(self.net_depth_condition): 89 | x = dense_layer(self.net_width_condition)(x) 90 | x = self.net_activation(x) 91 | raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape( 92 | [-1, num_samples, self.num_rgb_channels] 93 | ) 94 | return raw_rgb, raw_sigma 95 | 96 | 97 | def cast_rays(z_vals, origins, directions): 98 | return ( 99 | origins[Ellipsis, None, :] 100 | + z_vals[Ellipsis, None] * directions[Ellipsis, None, :] 101 | ) 102 | 103 | 104 | def sample_along_rays( 105 | key, origins, directions, num_samples, near, far, randomized, lindisp 106 | ): 107 | """Stratified sampling along the rays. 108 | 109 | Args: 110 | key: jnp.ndarray, random generator key. 111 | origins: jnp.ndarray(float32), [batch_size, 3], ray origins. 112 | directions: jnp.ndarray(float32), [batch_size, 3], ray directions. 113 | num_samples: int. 114 | near: float, near clip. 115 | far: float, far clip. 116 | randomized: bool, use randomized stratified sampling. 117 | lindisp: bool, sampling linearly in disparity rather than depth. 118 | 119 | Returns: 120 | z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values. 121 | points: jnp.ndarray, [batch_size, num_samples, 3], sampled points. 122 | """ 123 | batch_size = origins.shape[0] 124 | 125 | t_vals = jnp.linspace(0.0, 1.0, num_samples) 126 | if lindisp: 127 | z_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals) 128 | else: 129 | z_vals = near * (1.0 - t_vals) + far * t_vals 130 | 131 | if randomized: 132 | mids = 0.5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1]) 133 | upper = jnp.concatenate([mids, z_vals[Ellipsis, -1:]], -1) 134 | lower = jnp.concatenate([z_vals[Ellipsis, :1], mids], -1) 135 | t_rand = random.uniform(key, [batch_size, num_samples]) 136 | z_vals = lower + (upper - lower) * t_rand 137 | else: 138 | # Broadcast z_vals to make the returned shape consistent. 139 | z_vals = jnp.broadcast_to(z_vals[None, Ellipsis], [batch_size, num_samples]) 140 | 141 | coords = cast_rays(z_vals, origins, directions) 142 | return z_vals, coords 143 | 144 | 145 | def posenc(x, min_deg, max_deg, legacy_posenc_order=False): 146 | """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1]. 147 | 148 | Instead of computing [sin(x), cos(x)], we use the trig identity 149 | cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). 150 | 151 | Args: 152 | x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi]. 153 | min_deg: int, the minimum (inclusive) degree of the encoding. 154 | max_deg: int, the maximum (exclusive) degree of the encoding. 155 | legacy_posenc_order: bool, keep the same ordering as the original tf code. 156 | 157 | Returns: 158 | encoded: jnp.ndarray, encoded variables. 159 | """ 160 | if min_deg == max_deg: 161 | return x 162 | scales = jnp.array([2 ** i for i in range(min_deg, max_deg)]) 163 | if legacy_posenc_order: 164 | xb = x[Ellipsis, None, :] * scales[:, None] 165 | four_feat = jnp.reshape( 166 | jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)), list(x.shape[:-1]) + [-1] 167 | ) 168 | else: 169 | xb = jnp.reshape( 170 | (x[Ellipsis, None, :] * scales[:, None]), list(x.shape[:-1]) + [-1] 171 | ) 172 | four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1)) 173 | return jnp.concatenate([x] + [four_feat], axis=-1) 174 | 175 | 176 | def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd): 177 | """Volumetric Rendering Function. 178 | 179 | Args: 180 | rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3] 181 | sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1]. 182 | z_vals: jnp.ndarray(float32), [batch_size, num_samples]. 183 | dirs: jnp.ndarray(float32), [batch_size, 3]. 184 | white_bkgd: bool. 185 | 186 | Returns: 187 | comp_rgb: jnp.ndarray(float32), [batch_size, 3]. 188 | disp: jnp.ndarray(float32), [batch_size]. 189 | acc: jnp.ndarray(float32), [batch_size]. 190 | weights: jnp.ndarray(float32), [batch_size, num_samples] 191 | """ 192 | eps = 1e-10 193 | dists = jnp.concatenate( 194 | [ 195 | z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1], 196 | jnp.broadcast_to([1e10], z_vals[Ellipsis, :1].shape), 197 | ], 198 | -1, 199 | ) 200 | dists = dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1) 201 | # Note that we're quietly turning sigma from [..., 0] to [...]. 202 | alpha = 1.0 - jnp.exp(-sigma[Ellipsis, 0] * dists) 203 | accum_prod = jnp.concatenate( 204 | [ 205 | jnp.ones_like(alpha[Ellipsis, :1], alpha.dtype), 206 | jnp.cumprod(1.0 - alpha[Ellipsis, :-1] + eps, axis=-1), 207 | ], 208 | axis=-1, 209 | ) 210 | weights = alpha * accum_prod 211 | 212 | comp_rgb = (weights[Ellipsis, None] * rgb).sum(axis=-2) 213 | depth = (weights * z_vals).sum(axis=-1) 214 | acc = weights.sum(axis=-1) # Alpha 215 | # Equivalent to (but slightly more efficient and stable than): 216 | # disp = 1 / max(eps, where(acc > eps, depth / acc, 0)) 217 | inv_eps = 1 / eps 218 | disp = acc / depth 219 | disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps) 220 | if white_bkgd: 221 | comp_rgb = comp_rgb + (1.0 - acc[Ellipsis, None]) 222 | return comp_rgb, disp, acc, weights 223 | 224 | 225 | def piecewise_constant_pdf(key, bins, weights, num_samples, randomized): 226 | """Piecewise-Constant PDF sampling. 227 | 228 | Args: 229 | key: jnp.ndarray(float32), [2,], random number generator. 230 | bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. 231 | weights: jnp.ndarray(float32), [batch_size, num_bins]. 232 | num_samples: int, the number of samples. 233 | randomized: bool, use randomized samples. 234 | 235 | Returns: 236 | z_samples: jnp.ndarray(float32), [batch_size, num_samples]. 237 | """ 238 | # Pad each weight vector (only if necessary) to bring its sum to `eps`. This 239 | # avoids NaNs when the input is zeros or small, but has no effect otherwise. 240 | eps = 1e-5 241 | weight_sum = jnp.sum(weights, axis=-1, keepdims=True) 242 | padding = jnp.maximum(0, eps - weight_sum) 243 | weights += padding / weights.shape[-1] 244 | weight_sum += padding 245 | 246 | # Compute the PDF and CDF for each weight vector, while ensuring that the CDF 247 | # starts with exactly 0 and ends with exactly 1. 248 | pdf = weights / weight_sum 249 | cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1)) 250 | cdf = jnp.concatenate( 251 | [ 252 | jnp.zeros(list(cdf.shape[:-1]) + [1]), 253 | cdf, 254 | jnp.ones(list(cdf.shape[:-1]) + [1]), 255 | ], 256 | axis=-1, 257 | ) 258 | 259 | # Draw uniform samples. 260 | if randomized: 261 | # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1. 262 | u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples]) 263 | else: 264 | # Match the behavior of random.uniform() by spanning [0, 1-eps]. 265 | u = jnp.linspace(0.0, 1.0 - jnp.finfo("float32").eps, num_samples) 266 | u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples]) 267 | 268 | # Identify the location in `cdf` that corresponds to a random sample. 269 | # The final `True` index in `mask` will be the start of the sampled interval. 270 | mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None] 271 | 272 | def find_interval(x): 273 | # Grab the value where `mask` switches from True to False, and vice versa. 274 | # This approach takes advantage of the fact that `x` is sorted. 275 | x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2) 276 | x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2) 277 | return x0, x1 278 | 279 | bins_g0, bins_g1 = find_interval(bins) 280 | cdf_g0, cdf_g1 = find_interval(cdf) 281 | 282 | t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) 283 | samples = bins_g0 + t * (bins_g1 - bins_g0) 284 | 285 | # Prevent gradient from backprop-ing through `samples`. 286 | return lax.stop_gradient(samples) 287 | 288 | 289 | def sample_pdf( 290 | key, bins, weights, origins, directions, z_vals, num_samples, randomized 291 | ): 292 | """Hierarchical sampling. 293 | 294 | Args: 295 | key: jnp.ndarray(float32), [2,], random number generator. 296 | bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. 297 | weights: jnp.ndarray(float32), [batch_size, num_bins]. 298 | origins: jnp.ndarray(float32), [batch_size, 3], ray origins. 299 | directions: jnp.ndarray(float32), [batch_size, 3], ray directions. 300 | z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples]. 301 | num_samples: int, the number of samples. 302 | randomized: bool, use randomized samples. 303 | 304 | Returns: 305 | z_vals: jnp.ndarray(float32), 306 | [batch_size, num_coarse_samples + num_fine_samples]. 307 | points: jnp.ndarray(float32), 308 | [batch_size, num_coarse_samples + num_fine_samples, 3]. 309 | """ 310 | z_samples = piecewise_constant_pdf(key, bins, weights, num_samples, randomized) 311 | # Compute united z_vals and sample points 312 | z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) 313 | coords = cast_rays(z_vals, origins, directions) 314 | return z_vals, coords 315 | 316 | 317 | def add_gaussian_noise(key, raw, noise_std, randomized): 318 | """Adds gaussian noise to `raw`, which can used to regularize it. 319 | 320 | Args: 321 | key: jnp.ndarray(float32), [2,], random number generator. 322 | raw: jnp.ndarray(float32), arbitrary shape. 323 | noise_std: float, The standard deviation of the noise to be added. 324 | randomized: bool, add noise if randomized is True. 325 | 326 | Returns: 327 | raw + noise: jnp.ndarray(float32), with the same shape as `raw`. 328 | """ 329 | if (noise_std is not None) and randomized: 330 | return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std 331 | else: 332 | return raw 333 | -------------------------------------------------------------------------------- /nerf_sh/nerf/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Different model implementation plus a general port for all the models.""" 19 | from typing import Any, Callable 20 | import flax 21 | from flax import linen as nn 22 | from jax import random 23 | import jax.numpy as jnp 24 | 25 | from nerf_sh.nerf import model_utils 26 | from nerf_sh.nerf import utils 27 | from nerf_sh.nerf import sh 28 | from nerf_sh.nerf import sg 29 | 30 | 31 | def get_model(key, args): 32 | """A helper function that wraps around a 'model zoo'.""" 33 | model_dict = { 34 | "nerf": construct_nerf, 35 | } 36 | return model_dict[args.model](key, args) 37 | 38 | def get_model_state(key, args, restore=True): 39 | """ 40 | Helper for loading model with get_model & creating optimizer & 41 | optionally restoring checkpoint to reduce boilerplate 42 | """ 43 | model, variables = get_model(key, args) 44 | optimizer = flax.optim.Adam(args.lr_init).create(variables) 45 | state = utils.TrainState(optimizer=optimizer) 46 | if restore: 47 | from flax.training import checkpoints 48 | state = checkpoints.restore_checkpoint(args.train_dir, state) 49 | return model, state 50 | 51 | 52 | class NerfModel(nn.Module): 53 | """Nerf NN Model with both coarse and fine MLPs.""" 54 | 55 | num_coarse_samples: int # The number of samples for the coarse nerf. 56 | num_fine_samples: int # The number of samples for the fine nerf. 57 | use_viewdirs: bool # If True, use viewdirs as an input. 58 | sh_deg: int # If != -1, use spherical harmonics output up to given degree 59 | sg_dim: int # If != -1, use spherical gaussians output of given dimension 60 | near: float # The distance to the near plane 61 | far: float # The distance to the far plane 62 | noise_std: float # The std dev of noise added to raw sigma. 63 | net_depth: int # The depth of the first part of MLP. 64 | net_width: int # The width of the first part of MLP. 65 | net_depth_condition: int # The depth of the second part of MLP. 66 | net_width_condition: int # The width of the second part of MLP. 67 | net_activation: Callable[Ellipsis, Any] # MLP activation 68 | skip_layer: int # How often to add skip connections. 69 | num_rgb_channels: int # The number of RGB channels. 70 | num_sigma_channels: int # The number of density channels. 71 | white_bkgd: bool # If True, use a white background. 72 | min_deg_point: int # The minimum degree of positional encoding for positions. 73 | max_deg_point: int # The maximum degree of positional encoding for positions. 74 | deg_view: int # The degree of positional encoding for viewdirs. 75 | lindisp: bool # If True, sample linearly in disparity rather than in depth. 76 | rgb_activation: Callable[Ellipsis, Any] # Output RGB activation. 77 | sigma_activation: Callable[Ellipsis, Any] # Output sigma activation. 78 | legacy_posenc_order: bool # Keep the same ordering as the original tf code. 79 | 80 | def setup(self): 81 | # Construct the "coarse" MLP. Weird name is for 82 | # compatibility with 'compact' version 83 | self.MLP_0 = model_utils.MLP( 84 | net_depth=self.net_depth, 85 | net_width=self.net_width, 86 | net_depth_condition=self.net_depth_condition, 87 | net_width_condition=self.net_width_condition, 88 | net_activation=self.net_activation, 89 | skip_layer=self.skip_layer, 90 | num_rgb_channels=self.num_rgb_channels, 91 | num_sigma_channels=self.num_sigma_channels, 92 | ) 93 | 94 | # Construct the "fine" MLP. 95 | self.MLP_1 = model_utils.MLP( 96 | net_depth=self.net_depth, 97 | net_width=self.net_width, 98 | net_depth_condition=self.net_depth_condition, 99 | net_width_condition=self.net_width_condition, 100 | net_activation=self.net_activation, 101 | skip_layer=self.skip_layer, 102 | num_rgb_channels=self.num_rgb_channels, 103 | num_sigma_channels=self.num_sigma_channels, 104 | ) 105 | 106 | # Construct global learnable variables for spherical gaussians. 107 | if self.sg_dim > 0: 108 | key1, key2 = random.split(random.PRNGKey(0), 2) 109 | self.sg_lambda = self.variable( 110 | "params", "sg_lambda", 111 | lambda x: jnp.ones([x], jnp.float32), self.sg_dim) 112 | self.sg_mu_spher = self.variable( 113 | "params", "sg_mu_spher", 114 | lambda x: jnp.concatenate([ 115 | random.uniform(key1, [x, 1]) * jnp.pi, # theta 116 | random.uniform(key2, [x, 1]) * jnp.pi * 2, # phi 117 | ], axis=-1), self.sg_dim) 118 | 119 | def _quick_init(self): 120 | points = jnp.zeros((1, 1, 3), dtype=jnp.float32) 121 | points_enc = model_utils.posenc( 122 | points, 123 | self.min_deg_point, 124 | self.max_deg_point, 125 | self.legacy_posenc_order, 126 | ) 127 | if self.use_viewdirs: 128 | viewdirs = jnp.zeros((1, 1, 3), dtype=jnp.float32) 129 | viewdirs_enc = model_utils.posenc( 130 | viewdirs, 131 | 0, 132 | self.deg_view, 133 | self.legacy_posenc_order, 134 | ) 135 | self.MLP_0(points_enc, viewdirs_enc) 136 | if self.num_fine_samples > 0: 137 | self.MLP_1(points_enc, viewdirs_enc) 138 | else: 139 | self.MLP_0(points_enc) 140 | if self.num_fine_samples > 0: 141 | self.MLP_1(points_enc) 142 | 143 | def eval_points_raw(self, points, viewdirs=None, coarse=False): 144 | """ 145 | Evaluate at points, returing rgb and sigma. 146 | If sh_deg >= 0 / sg_dim > 0 then this will return 147 | spherical harmonic / spherical gaussians / anisotropic spherical gaussians 148 | coeffs for RGB. Please see eval_points for alternate 149 | version which always returns RGB. 150 | Args: 151 | points: jnp.ndarray [B, 3] 152 | viewdirs: jnp.ndarray [B, 3] 153 | coarse: if true, uses coarse MLP 154 | Returns: 155 | raw_rgb: jnp.ndarray [B, 3 * (sh_deg + 1)**2 or 3 or 3 * sg_dim] 156 | raw_sigma: jnp.ndarray [B, 1] 157 | """ 158 | points = points[None] 159 | points_enc = model_utils.posenc( 160 | points, 161 | self.min_deg_point, 162 | self.max_deg_point, 163 | self.legacy_posenc_order, 164 | ) 165 | if self.num_fine_samples > 0 and not coarse: 166 | mlp = self.MLP_1 167 | else: 168 | mlp = self.MLP_0 169 | if self.use_viewdirs: 170 | assert viewdirs is not None 171 | viewdirs = viewdirs[None] 172 | viewdirs_enc = model_utils.posenc( 173 | viewdirs, 174 | 0, 175 | self.deg_view, 176 | self.legacy_posenc_order, 177 | ) 178 | raw_rgb, raw_sigma = mlp(points_enc, viewdirs_enc) 179 | else: 180 | raw_rgb, raw_sigma = mlp(points_enc) 181 | return raw_rgb[0], raw_sigma[0] 182 | 183 | def eval_points(self, points, viewdirs=None, coarse=False): 184 | """ 185 | Evaluate at points, converting spherical harmonics rgb to 186 | rgb via viewdirs if applicable. Exists since jax does not allow 187 | size to depend on input. 188 | Args: 189 | points: jnp.ndarray [B, 3] 190 | viewdirs: jnp.ndarray [B, 3] 191 | coarse: if true, uses coarse MLP 192 | Returns: 193 | rgb: jnp.ndarray [B, 3] 194 | sigma: jnp.ndarray [B, 1] 195 | """ 196 | raw_rgb, raw_sigma = self.eval_points_raw(points, viewdirs, coarse) 197 | if self.sh_deg >= 0: 198 | assert viewdirs is not None 199 | # (256, 64, 48) (256, 3) 200 | raw_rgb = sh.eval_sh(self.sh_deg, raw_rgb.reshape( 201 | *raw_rgb.shape[:-1], 202 | -1, 203 | (self.sh_deg + 1) ** 2), viewdirs[:, None]) 204 | elif self.sg_dim > 0: 205 | assert viewdirs is not None 206 | sg_lambda = self.sg_lambda.value 207 | sg_mu_spher = self.sg_mu_spher.value 208 | sg_coeffs = raw_rgb.reshape(*raw_rgb.shape[:-1], -1, self.sg_dim) 209 | raw_rgb = sg.eval_sg( 210 | sg_lambda, sg_mu_spher, sg_coeffs, viewdirs[:, None]) 211 | 212 | rgb = self.rgb_activation(raw_rgb) 213 | sigma = self.sigma_activation(raw_sigma) 214 | return rgb, sigma 215 | 216 | def __call__(self, rng_0, rng_1, rays, randomized): 217 | """Nerf Model. 218 | 219 | Args: 220 | rng_0: jnp.ndarray, random number generator for coarse model sampling. 221 | rng_1: jnp.ndarray, random number generator for fine model sampling. 222 | rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. 223 | randomized: bool, use randomized stratified sampling. 224 | 225 | Returns: 226 | ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)] 227 | """ 228 | # Stratified sampling along rays 229 | key, rng_0 = random.split(rng_0) 230 | z_vals, samples = model_utils.sample_along_rays( 231 | key, 232 | rays.origins, 233 | rays.directions, 234 | self.num_coarse_samples, 235 | self.near, 236 | self.far, 237 | randomized, 238 | self.lindisp, 239 | ) 240 | samples_enc = model_utils.posenc( 241 | samples, 242 | self.min_deg_point, 243 | self.max_deg_point, 244 | self.legacy_posenc_order, 245 | ) 246 | 247 | # Point attribute predictions 248 | if self.use_viewdirs: 249 | viewdirs_enc = model_utils.posenc( 250 | rays.viewdirs, 251 | 0, 252 | self.deg_view, 253 | self.legacy_posenc_order, 254 | ) 255 | raw_rgb, raw_sigma = self.MLP_0(samples_enc, viewdirs_enc) 256 | else: 257 | raw_rgb, raw_sigma = self.MLP_0(samples_enc) 258 | # Add noises to regularize the density predictions if needed 259 | key, rng_0 = random.split(rng_0) 260 | raw_sigma = model_utils.add_gaussian_noise( 261 | key, 262 | raw_sigma, 263 | self.noise_std, 264 | randomized, 265 | ) 266 | 267 | if self.sh_deg >= 0: 268 | # (256, 64, 48) (256, 3) 269 | raw_rgb = sh.eval_sh(self.sh_deg, raw_rgb.reshape( 270 | *raw_rgb.shape[:-1], 271 | -1, 272 | (self.sh_deg + 1) ** 2), rays.viewdirs[:, None]) 273 | elif self.sg_dim > 0: 274 | sg_lambda = self.sg_lambda.value 275 | sg_mu_spher = self.sg_mu_spher.value 276 | sg_coeffs = raw_rgb.reshape(*raw_rgb.shape[:-1], -1, self.sg_dim) 277 | raw_rgb = sg.eval_sg( 278 | sg_lambda, sg_mu_spher, sg_coeffs, rays.viewdirs[:, None]) 279 | 280 | rgb = self.rgb_activation(raw_rgb) 281 | sigma = self.sigma_activation(raw_sigma) 282 | 283 | # Volumetric rendering. 284 | comp_rgb, disp, acc, weights = model_utils.volumetric_rendering( 285 | rgb, 286 | sigma, 287 | z_vals, 288 | rays.directions, 289 | white_bkgd=self.white_bkgd, 290 | ) 291 | ret = [ 292 | (comp_rgb, disp, acc), 293 | ] 294 | # Hierarchical sampling based on coarse predictions 295 | if self.num_fine_samples > 0: 296 | z_vals_mid = 0.5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1]) 297 | key, rng_1 = random.split(rng_1) 298 | z_vals, samples = model_utils.sample_pdf( 299 | key, 300 | z_vals_mid, 301 | weights[Ellipsis, 1:-1], 302 | rays.origins, 303 | rays.directions, 304 | z_vals, 305 | self.num_fine_samples, 306 | randomized, 307 | ) 308 | samples_enc = model_utils.posenc( 309 | samples, 310 | self.min_deg_point, 311 | self.max_deg_point, 312 | self.legacy_posenc_order, 313 | ) 314 | 315 | if self.use_viewdirs: 316 | raw_rgb, raw_sigma = self.MLP_1(samples_enc, viewdirs_enc) 317 | else: 318 | raw_rgb, raw_sigma = self.MLP_1(samples_enc) 319 | key, rng_1 = random.split(rng_1) 320 | raw_sigma = model_utils.add_gaussian_noise( 321 | key, 322 | raw_sigma, 323 | self.noise_std, 324 | randomized, 325 | ) 326 | if self.sh_deg >= 0: 327 | raw_rgb = sh.eval_sh(self.sh_deg, raw_rgb.reshape( 328 | *raw_rgb.shape[:-1], 329 | -1, 330 | (self.sh_deg + 1) ** 2), rays.viewdirs[:, None]) 331 | elif self.sg_dim > 0: 332 | sg_lambda = self.sg_lambda.value 333 | sg_mu_spher = self.sg_mu_spher.value 334 | sg_coeffs = raw_rgb.reshape(*raw_rgb.shape[:-1], -1, self.sg_dim) 335 | raw_rgb = sg.eval_sg( 336 | sg_lambda, sg_mu_spher, sg_coeffs, rays.viewdirs[:, None]) 337 | 338 | rgb = self.rgb_activation(raw_rgb) 339 | sigma = self.sigma_activation(raw_sigma) 340 | comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( 341 | rgb, 342 | sigma, 343 | z_vals, 344 | rays.directions, 345 | white_bkgd=self.white_bkgd, 346 | ) 347 | ret.append((comp_rgb, disp, acc)) 348 | return ret 349 | 350 | 351 | def construct_nerf(key, args): 352 | """Construct a Neural Radiance Field. 353 | 354 | Args: 355 | key: jnp.ndarray. Random number generator. 356 | args: FLAGS class. Hyperparameters of nerf. 357 | 358 | Returns: 359 | model: nn.Model. Nerf model with parameters. 360 | state: flax.Module.state. Nerf model state for stateful parameters. 361 | """ 362 | net_activation = getattr(nn, str(args.net_activation)) 363 | rgb_activation = getattr(nn, str(args.rgb_activation)) 364 | sigma_activation = getattr(nn, str(args.sigma_activation)) 365 | 366 | # Assert that rgb_activation always produces outputs in [0, 1], and 367 | # sigma_activation always produce non-negative outputs. 368 | x = jnp.exp(jnp.linspace(-90, 90, 1024)) 369 | x = jnp.concatenate([-x[::-1], x], 0) 370 | 371 | rgb = rgb_activation(x) 372 | if jnp.any(rgb < 0) or jnp.any(rgb > 1): 373 | raise NotImplementedError( 374 | "Choice of rgb_activation `{}` produces colors outside of [0, 1]".format( 375 | args.rgb_activation 376 | ) 377 | ) 378 | 379 | sigma = sigma_activation(x) 380 | if jnp.any(sigma < 0): 381 | raise NotImplementedError( 382 | "Choice of sigma_activation `{}` produces negative densities".format( 383 | args.sigma_activation 384 | ) 385 | ) 386 | num_rgb_channels = args.num_rgb_channels 387 | # TODO cleanup assert 388 | if args.sh_deg >= 0: 389 | assert not args.use_viewdirs and args.sg_dim == -1, ( 390 | "You can only use up to one of: SH, SG or use_viewdirs.") 391 | num_rgb_channels *= (args.sh_deg + 1) ** 2 392 | elif args.sg_dim > 0: 393 | assert not args.use_viewdirs and args.sh_deg == -1, ( 394 | "You can only use up to one of: SH, SG or use_viewdirs.") 395 | num_rgb_channels *= args.sg_dim 396 | 397 | model = NerfModel( 398 | min_deg_point=args.min_deg_point, 399 | max_deg_point=args.max_deg_point, 400 | deg_view=args.deg_view, 401 | num_coarse_samples=args.num_coarse_samples, 402 | num_fine_samples=args.num_fine_samples, 403 | use_viewdirs=args.use_viewdirs, 404 | sh_deg=args.sh_deg, 405 | sg_dim=args.sg_dim, 406 | near=args.near, 407 | far=args.far, 408 | noise_std=args.noise_std, 409 | white_bkgd=args.white_bkgd, 410 | net_depth=args.net_depth, 411 | net_width=args.net_width, 412 | net_depth_condition=args.net_depth_condition, 413 | net_width_condition=args.net_width_condition, 414 | skip_layer=args.skip_layer, 415 | num_rgb_channels=num_rgb_channels, 416 | num_sigma_channels=args.num_sigma_channels, 417 | lindisp=args.lindisp, 418 | net_activation=net_activation, 419 | rgb_activation=rgb_activation, 420 | sigma_activation=sigma_activation, 421 | legacy_posenc_order=args.legacy_posenc_order, 422 | ) 423 | key1, key = random.split(key) 424 | init_variables = model.init( 425 | key1, 426 | method=model._quick_init, 427 | ) 428 | return model, init_variables 429 | -------------------------------------------------------------------------------- /nerf_sh/nerf/sg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | def spher2cart(r, theta, phi): 28 | """Convert spherical coordinates into Cartesian coordinates.""" 29 | x = r * jnp.sin(theta) * jnp.cos(phi) 30 | y = r * jnp.sin(theta) * jnp.sin(phi) 31 | z = r * jnp.cos(theta) 32 | return jnp.stack([x, y, z], axis=-1) 33 | 34 | 35 | def eval_sg(sg_lambda, sg_mu, sg_coeffs, dirs): 36 | """ 37 | Evaluate spherical gaussians at unit directions 38 | using learnable SG basis. 39 | Works with jnp. 40 | ... Can be 0 or more batch dimensions. 41 | N is the number of SG basis we use. 42 | 43 | Output = \sigma_{i}{coeffs_i * \exp ^ {lambda_i * (\dot(mu_i, dirs) - 1)}} 44 | 45 | Args: 46 | sg_lambda: The sharpness of the SG lobes. [N] or [..., N] 47 | sg_mu: The directions of the SG lobes. [N, 3 or 2] or [..., N, 3 or 2] 48 | sg_coeffs: The coefficients of the SG (lob amplitude). [..., C, N] 49 | dirs: unit directions [..., 3] 50 | 51 | Returns: 52 | [..., C] 53 | """ 54 | sg_lambda = jax.nn.softplus(sg_lambda) # force lambda > 0 55 | # spherical coordinates -> Cartesian coordinates 56 | if sg_mu.shape[-1] == 2: 57 | theta, phi = sg_mu[..., 0], sg_mu[..., 1] 58 | sg_mu = spher2cart(1.0, theta, phi) # [..., N, 3] 59 | product = jnp.einsum( 60 | "...ij,...j->...i", sg_mu, dirs) # [..., N] 61 | basis = jnp.exp(jnp.einsum( 62 | "...i,...i->...i", sg_lambda, product - 1)) # [..., N] 63 | output = jnp.einsum( 64 | "...ki,...i->...k", sg_coeffs, basis) # [..., C] 65 | output /= sg_lambda.shape[-1] 66 | return output 67 | 68 | 69 | def euler2mat(angle): 70 | """Convert euler angles to rotation matrix. 71 | 72 | Args: 73 | angle: rotation angle along 3 axis (in radians). [..., 3] 74 | Returns: 75 | Rotation matrix corresponding to the euler angles. [..., 3, 3] 76 | """ 77 | x, y, z = angle[..., 0], angle[..., 1], angle[..., 2] 78 | cosz = jnp.cos(z) 79 | sinz = jnp.sin(z) 80 | cosy = jnp.cos(y) 81 | siny = jnp.sin(y) 82 | cosx = jnp.cos(x) 83 | sinx = jnp.sin(x) 84 | zeros = jnp.zeros_like(z) 85 | ones = jnp.ones_like(z) 86 | zmat = jnp.stack([jnp.stack([cosz, -sinz, zeros], axis=-1), 87 | jnp.stack([sinz, cosz, zeros], axis=-1), 88 | jnp.stack([zeros, zeros, ones], axis=-1)], axis=-1) 89 | ymat = jnp.stack([jnp.stack([ cosy, zeros, siny], axis=-1), 90 | jnp.stack([zeros, ones, zeros], axis=-1), 91 | jnp.stack([-siny, zeros, cosy], axis=-1)], axis=-1) 92 | xmat = jnp.stack([jnp.stack([ ones, zeros, zeros], axis=-1), 93 | jnp.stack([zeros, cosx, -sinx], axis=-1), 94 | jnp.stack([zeros, sinx, cosx], axis=-1)], axis=-1) 95 | rotMat = jnp.einsum("...ij,...jk,...kq->...iq", xmat, ymat, zmat) 96 | return rotMat 97 | -------------------------------------------------------------------------------- /nerf_sh/nerf/sh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | C0 = 0.28209479177387814 25 | C1 = 0.4886025119029199 26 | C2 = [ 27 | 1.0925484305920792, 28 | -1.0925484305920792, 29 | 0.31539156525252005, 30 | -1.0925484305920792, 31 | 0.5462742152960396 32 | ] 33 | C3 = [ 34 | -0.5900435899266435, 35 | 2.890611442640554, 36 | -0.4570457994644658, 37 | 0.3731763325901154, 38 | -0.4570457994644658, 39 | 1.445305721320277, 40 | -0.5900435899266435 41 | ] 42 | C4 = [ 43 | 2.5033429417967046, 44 | -1.7701307697799304, 45 | 0.9461746957575601, 46 | -0.6690465435572892, 47 | 0.10578554691520431, 48 | -0.6690465435572892, 49 | 0.47308734787878004, 50 | -1.7701307697799304, 51 | 0.6258357354491761, 52 | ] 53 | 54 | def eval_sh(deg, sh, dirs): 55 | """ 56 | Evaluate spherical harmonics at unit directions 57 | using hardcoded SH polynomials. 58 | Works with torch/np/jnp. 59 | ... Can be 0 or more batch dimensions. 60 | 61 | Args: 62 | deg: int SH deg. Currently, 0-3 supported 63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 64 | dirs: jnp.ndarray unit directions [..., 3] 65 | 66 | Returns: 67 | [..., C] 68 | """ 69 | assert deg <= 4 and deg >= 0 70 | assert (deg + 1) ** 2 == sh.shape[-1] 71 | C = sh.shape[-2] 72 | 73 | result = C0 * sh[..., 0] 74 | if deg > 0: 75 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 76 | result = (result - 77 | C1 * y * sh[..., 1] + 78 | C1 * z * sh[..., 2] - 79 | C1 * x * sh[..., 3]) 80 | if deg > 1: 81 | xx, yy, zz = x * x, y * y, z * z 82 | xy, yz, xz = x * y, y * z, x * z 83 | result = (result + 84 | C2[0] * xy * sh[..., 4] + 85 | C2[1] * yz * sh[..., 5] + 86 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 87 | C2[3] * xz * sh[..., 7] + 88 | C2[4] * (xx - yy) * sh[..., 8]) 89 | 90 | if deg > 2: 91 | result = (result + 92 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 93 | C3[1] * xy * z * sh[..., 10] + 94 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 95 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 96 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 97 | C3[5] * z * (xx - yy) * sh[..., 14] + 98 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 99 | if deg > 3: 100 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 101 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 102 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 103 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 104 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 105 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 106 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 107 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 108 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 109 | return result 110 | -------------------------------------------------------------------------------- /nerf_sh/parse_timing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """ 24 | A utility to parse timings files, which are saved automatically during training 25 | (checkpoint_dir/timings.txt). 26 | If you do not stop & restart training, this allows for measuring the training time. 27 | """ 28 | import argparse 29 | from datetime import datetime 30 | 31 | parser = argparse.ArgumentParser(); 32 | parser.add_argument("file", type=str); 33 | parser.add_argument("--times", "-t", type=int, default=[], nargs='+'); 34 | args = parser.parse_args(); 35 | 36 | f = open(args.file, 'r') 37 | lines = f.readlines() 38 | lines = [line.strip() for line in lines] 39 | lines = [line.split() for line in lines if len(line)] 40 | lines = {int(line[0]) : datetime.fromisoformat(line[1]) for line in lines} 41 | 42 | if not args.times: 43 | print(list(lines.keys())) 44 | 45 | for t in args.times: 46 | print((lines[t] - lines[0]).total_seconds() / 3600) 47 | -------------------------------------------------------------------------------- /nerf_sh/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Training script for Nerf.""" 19 | 20 | import os 21 | # Get rid of ugly TF logs 22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 23 | 24 | import functools 25 | import gc 26 | import time 27 | from absl import app 28 | from absl import flags 29 | import flax 30 | import flax.linen as nn 31 | from flax.metrics import tensorboard 32 | from flax.training import checkpoints 33 | import jax 34 | from jax import config 35 | from jax import random 36 | import jax.numpy as jnp 37 | import numpy as np 38 | 39 | 40 | from nerf_sh.nerf import datasets 41 | from nerf_sh.nerf import models 42 | from nerf_sh.nerf import utils 43 | from nerf_sh.nerf.utils import host0_print as h0print 44 | 45 | FLAGS = flags.FLAGS 46 | 47 | utils.define_flags() 48 | config.parse_flags_with_absl() 49 | 50 | 51 | def train_step(model, rng, state, batch, lr): 52 | """One optimization step. 53 | 54 | Args: 55 | model: The linen model. 56 | rng: jnp.ndarray, random number generator. 57 | state: utils.TrainState, state of the model/optimizer. 58 | batch: dict, a mini-batch of data for training. 59 | lr: float, real-time learning rate. 60 | 61 | Returns: 62 | new_state: utils.TrainState, new training state. 63 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. 64 | rng: jnp.ndarray, updated random number generator. 65 | """ 66 | rng, key_0, key_1, key_2 = random.split(rng, 4) 67 | 68 | def loss_fn(variables): 69 | rays = batch["rays"] 70 | ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized) 71 | if len(ret) not in (1, 2): 72 | raise ValueError( 73 | "ret should contain either 1 set of output (coarse only), or 2 sets" 74 | "of output (coarse as ret[0] and fine as ret[1])." 75 | ) 76 | 77 | if FLAGS.sparsity_weight > 0.0: 78 | rng, key = random.split(key_2) 79 | sp_points = random.uniform(key, (FLAGS.sparsity_npoints, 3), minval=-FLAGS.sparsity_radius, maxval=FLAGS.sparsity_radius) 80 | sp_rgb, sp_sigma = model.apply(variables, sp_points, method=model.eval_points_raw) 81 | del sp_rgb 82 | sp_sigma = nn.relu(sp_sigma) 83 | loss_sp = FLAGS.sparsity_weight * (1.0 - jnp.exp(- FLAGS.sparsity_length * sp_sigma).mean()) 84 | else: 85 | loss_sp = 0.0 86 | 87 | # The main prediction is always at the end of the ret list. 88 | rgb, unused_disp, unused_acc = ret[-1] 89 | loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean() 90 | psnr = utils.compute_psnr(loss) 91 | if len(ret) > 1: 92 | # If there are both coarse and fine predictions, we compute the loss for 93 | # the coarse prediction (ret[0]) as well. 94 | rgb_c, unused_disp_c, unused_acc_c = ret[0] 95 | loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean() 96 | psnr_c = utils.compute_psnr(loss_c) 97 | else: 98 | loss_c = 0.0 99 | psnr_c = 0.0 100 | 101 | def tree_sum_fn(fn): 102 | return jax.tree_util.tree_reduce( 103 | lambda x, y: x + fn(y), variables, initializer=0 104 | ) 105 | 106 | weight_l2 = tree_sum_fn(lambda z: jnp.sum(z ** 2)) / tree_sum_fn( 107 | lambda z: jnp.prod(jnp.array(z.shape)) 108 | ) 109 | 110 | stats = utils.Stats( 111 | loss=loss, psnr=psnr, loss_c=loss_c, loss_sp=loss_sp, 112 | psnr_c=psnr_c, weight_l2=weight_l2 113 | ) 114 | return loss + loss_c + loss_sp + FLAGS.weight_decay_mult * weight_l2, stats 115 | 116 | (_, stats), grad = jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target) 117 | grad = jax.lax.pmean(grad, axis_name="batch") 118 | stats = jax.lax.pmean(stats, axis_name="batch") 119 | new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr) 120 | new_state = state.replace(optimizer=new_optimizer) 121 | return new_state, stats, rng 122 | 123 | 124 | def main(unused_argv): 125 | rng = random.PRNGKey(20200823) 126 | # Shift the numpy random seed by host_id() to shuffle data loaded by different 127 | # hosts. 128 | np.random.seed(20201473 + jax.host_id()) 129 | rng, key = random.split(rng) 130 | 131 | utils.update_flags(FLAGS) 132 | utils.check_flags(FLAGS, require_batch_size_div=True) 133 | 134 | utils.makedirs(FLAGS.train_dir) 135 | render_dir = os.path.join(FLAGS.train_dir, 'render') 136 | utils.makedirs(render_dir) 137 | 138 | # TEMP 139 | timings_file = open(os.path.join(FLAGS.train_dir, 'timings.txt'), 'a') 140 | from datetime import datetime 141 | def write_ts_now(step): 142 | timings_file.write(f"{step} {datetime.now().isoformat()}\n") 143 | timings_file.flush() 144 | write_ts_now(0) 145 | 146 | h0print('* Load train data') 147 | dataset = datasets.get_dataset("train", FLAGS) 148 | h0print('* Load test data') 149 | test_dataset = datasets.get_dataset("test", FLAGS) 150 | 151 | h0print('* Load model') 152 | model, state = models.get_model_state(key, FLAGS) 153 | 154 | learning_rate_fn = functools.partial( 155 | utils.learning_rate_decay, 156 | lr_init=FLAGS.lr_init, 157 | lr_final=FLAGS.lr_final, 158 | max_steps=FLAGS.max_steps, 159 | lr_delay_steps=FLAGS.lr_delay_steps, 160 | lr_delay_mult=FLAGS.lr_delay_mult, 161 | ) 162 | 163 | train_pstep = jax.pmap( 164 | functools.partial(train_step, model), 165 | axis_name="batch", 166 | in_axes=(0, 0, 0, None), 167 | donate_argnums=(2,), 168 | ) 169 | 170 | render_pfn = utils.get_render_pfn(model, randomized=FLAGS.randomized) 171 | 172 | # Compiling to the CPU because it's faster and more accurate. 173 | ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.0), backend="cpu") 174 | 175 | # Resume training a the step of the last checkpoint. 176 | init_step = state.optimizer.state.step + 1 177 | state = flax.jax_utils.replicate(state) 178 | 179 | if jax.host_id() == 0: 180 | summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) 181 | 182 | h0print('* Prefetch') 183 | # Prefetch_buffer_size = 3 x batch_size 184 | pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) 185 | n_local_deices = jax.local_device_count() 186 | rng = rng + jax.host_id() # Make random seed separate across hosts. 187 | keys = random.split(rng, n_local_deices) # For pmapping RNG keys. 188 | gc.disable() # Disable automatic garbage collection for efficiency. 189 | stats_trace = [] 190 | 191 | 192 | reset_timer = True 193 | for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): 194 | if reset_timer: 195 | t_loop_start = time.time() 196 | reset_timer = False 197 | lr = learning_rate_fn(step) 198 | state, stats, keys = train_pstep(keys, state, batch, lr) 199 | if jax.host_id() == 0: 200 | stats_trace.append(stats) 201 | if step % FLAGS.gc_every == 0: 202 | gc.collect() 203 | 204 | # Log training summaries. This is put behind a host_id check because in 205 | # multi-host evaluation, all hosts need to run inference even though we 206 | # only use host 0 to record results. 207 | if jax.host_id() == 0: 208 | if step % FLAGS.print_every == 0: 209 | summary_writer.scalar("train_loss", stats.loss[0], step) 210 | summary_writer.scalar("train_psnr", stats.psnr[0], step) 211 | summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) 212 | summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) 213 | if FLAGS.sparsity_weight > 0.0: 214 | summary_writer.scalar("train_sparse_loss", stats.loss_sp[0], step) 215 | summary_writer.scalar("weight_l2", stats.weight_l2[0], step) 216 | avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) 217 | avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) 218 | stats_trace = [] 219 | summary_writer.scalar("train_avg_loss", avg_loss, step) 220 | summary_writer.scalar("train_avg_psnr", avg_psnr, step) 221 | summary_writer.scalar("learning_rate", lr, step) 222 | steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) 223 | reset_timer = True 224 | rays_per_sec = FLAGS.batch_size * steps_per_sec 225 | summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) 226 | summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) 227 | precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 228 | print( 229 | ("{:" + "{:d}".format(precision) + "d}").format(step) 230 | + f"/{FLAGS.max_steps:d}: " 231 | + f"i_loss={stats.loss[0]:0.4f}, " 232 | + f"avg_loss={avg_loss:0.4f}, " 233 | + f"weight_l2={stats.weight_l2[0]:0.2e}, " 234 | + f"lr={lr:0.2e}, " 235 | + f"{rays_per_sec:0.0f} rays/sec" 236 | ) 237 | if step % FLAGS.save_every == 0: 238 | print('* Saving') 239 | state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) 240 | checkpoints.save_checkpoint( 241 | FLAGS.train_dir, state_to_save, int(step), keep=200 242 | ) 243 | 244 | # Test-set evaluation. 245 | if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: 246 | # We reuse the same random number generator from the optimization step 247 | # here on purpose so that the visualization matches what happened in 248 | # training. 249 | h0print('\n* Rendering') 250 | t_eval_start = time.time() 251 | eval_variables = jax.device_get( 252 | jax.tree_map(lambda x: x[0], state) 253 | ).optimizer.target 254 | test_case = next(test_dataset) 255 | pred_color, pred_disp, pred_acc = utils.render_image( 256 | functools.partial(render_pfn, eval_variables), 257 | test_case["rays"], 258 | keys[0], 259 | FLAGS.dataset == "llff", 260 | chunk=FLAGS.chunk, 261 | ) 262 | 263 | # Log eval summaries on host 0. 264 | if jax.host_id() == 0: 265 | write_ts_now(step) 266 | psnr = utils.compute_psnr( 267 | ((pred_color - test_case["pixels"]) ** 2).mean() 268 | ) 269 | ssim = ssim_fn(pred_color, test_case["pixels"]) 270 | eval_time = time.time() - t_eval_start 271 | num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1])) 272 | rays_per_sec = num_rays / eval_time 273 | summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) 274 | print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec") 275 | summary_writer.scalar("test_psnr", psnr, step) 276 | summary_writer.scalar("test_ssim", ssim, step) 277 | # print(pred_color.shape, pred_disp.shape, pred_acc.shape, 278 | # test_case["pixels"].shape) 279 | # print(pred_color.dtype, pred_disp.dtype, pred_acc.dtype, 280 | # test_case["pixels"].dtype) 281 | # print(pred_color.min(), pred_color.max(), 282 | # pred_disp.min(), pred_disp.max(), 283 | # pred_acc.min(), pred_acc.max(), 284 | # test_case['pixels'].min(), test_case['pixels'].max()) 285 | # 0 1. 0.0 1.0 0.90906805 1.0000007 0.0 1.0 286 | 287 | # (800, 800, 3) (800, 800, 1) (800, 800, 1) (800, 800, 3) 288 | # float32 float32 float32 float32 289 | 290 | vis_list= [test_case["pixels"], 291 | pred_color, 292 | np.repeat(pred_disp, 3, axis=-1), 293 | np.repeat(pred_acc, 3, axis=-1)] 294 | out_path = os.path.join(render_dir, '{:010}.png'.format(step)) 295 | utils.save_img(np.hstack(vis_list), out_path) 296 | print(' Rendering saved to ', out_path) 297 | 298 | # I am saving rendering to disk instead of Tensorboard 299 | # Since Tensorboard begins to load very slowly when it has many images 300 | 301 | # summary_writer.image("test_pred_color", pred_color, step) 302 | # summary_writer.image("test_pred_disp", pred_disp, step) 303 | # summary_writer.image("test_pred_acc", pred_acc, step) 304 | # summary_writer.image("test_target", test_case["pixels"], step) 305 | 306 | if FLAGS.max_steps % FLAGS.save_every != 0: 307 | state = jax.device_get(jax.tree_map(lambda x: x[0], state)) 308 | checkpoints.save_checkpoint( 309 | FLAGS.train_dir, state, int(FLAGS.max_steps), keep=200 310 | ) 311 | 312 | 313 | if __name__ == "__main__": 314 | app.run(main) 315 | -------------------------------------------------------------------------------- /octree/compression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """Compress a plenoctree. 24 | 25 | Including quantization using median cut algorithm. 26 | 27 | Usage: 28 | python compression.py x.npz [y.npz ...] 29 | """ 30 | import sys 31 | import numpy as np 32 | import os.path as osp 33 | import torch 34 | from svox.helpers import _get_c_extension 35 | from tqdm import tqdm 36 | import os 37 | import argparse 38 | 39 | @torch.no_grad() 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('input', type=str, nargs='+', default=None, help='Input npz(s)') 43 | parser.add_argument('--noquant', action='store_true', 44 | help='Disable quantization') 45 | parser.add_argument('--bits', type=int, default=16, 46 | help='Quantization bits (order)') 47 | parser.add_argument('--out_dir', type=str, default='min_alt', 48 | help='Where to write compressed npz') 49 | parser.add_argument('--overwrite', action='store_true', 50 | help='Overwrite existing compressed npz') 51 | parser.add_argument('--weighted', action='store_true', 52 | help='Use weighted median cut (seems quite useless)') 53 | parser.add_argument('--sigma_thresh', type=float, default=2.0, 54 | help='Kill voxels under this sigma') 55 | parser.add_argument('--retain', type=int, default=0, 56 | help='Do not compress first x SH coeffs, needed for some scenes to keep ok quality') 57 | 58 | args = parser.parse_args() 59 | 60 | _C = _get_c_extension() 61 | os.makedirs(args.out_dir, exist_ok=True) 62 | 63 | if args.noquant: 64 | print('Quantization disabled, only applying deflate') 65 | else: 66 | print('Quantization enabled') 67 | 68 | for fname in args.input: 69 | fname_c = osp.join(args.out_dir, osp.basename(fname)) 70 | print('Compressing', fname, 'to', fname_c) 71 | if not args.overwrite and osp.exists(fname_c): 72 | print(' > skip') 73 | continue 74 | 75 | z = np.load(fname) 76 | 77 | if not args.noquant: 78 | if 'quant_colors' in z.files: 79 | print(' > skip since source already compressed') 80 | continue 81 | z = dict(z) 82 | del z['parent_depth'] 83 | del z['geom_resize_fact'] 84 | del z['n_free'] 85 | del z['n_internal'] 86 | del z['depth_limit'] 87 | 88 | if not args.noquant: 89 | data = torch.from_numpy(z['data']) 90 | sigma = data[..., -1].reshape(-1) 91 | snz = sigma > args.sigma_thresh 92 | sigma[~snz] = 0.0 93 | 94 | data = data[..., :-1] 95 | N = data.size(1) 96 | basis_dim = data.size(-1) // 3 97 | 98 | data = data.reshape(-1, 3, basis_dim).float()[snz].unbind(-1) 99 | if args.retain: 100 | retained = data[:args.retain] 101 | data = data[args.retain:] 102 | else: 103 | retained = None 104 | 105 | all_quant_colors = [] 106 | all_quant_maps = [] 107 | 108 | if args.weighted: 109 | weights = 1.0 - np.exp(-0.01 * sigma.float(float32)) 110 | else: 111 | weights = torch.empty((0,)) 112 | 113 | for i, d in tqdm(enumerate(data), total=len(data)): 114 | colors, color_id_map = _C.quantize_median_cut(d.contiguous(), 115 | weights, 116 | args.bits) 117 | color_id_map_full = np.zeros((snz.shape[0],), dtype=np.uint16) 118 | color_id_map_full[snz] = color_id_map 119 | 120 | all_quant_colors.append(colors.numpy().astype(np.float16)) 121 | all_quant_maps.append(color_id_map_full.reshape(-1, N, N, N).astype(np.uint16)) 122 | quant_map = np.stack(all_quant_maps, axis=0) 123 | quant_colors = np.stack(all_quant_colors, axis=0) 124 | del all_quant_maps 125 | del all_quant_colors 126 | z['quant_colors'] = quant_colors 127 | z['quant_map'] = quant_map 128 | z['sigma'] = sigma.reshape(-1, N, N, N) 129 | if args.retain: 130 | all_retained = [] 131 | for i in range(args.retain): 132 | retained_wz = np.zeros((snz.shape[0], 3), dtype=np.float16) 133 | retained_wz[snz] = retained[i] 134 | all_retained.append(retained_wz.reshape(-1, N, N, N, 3)) 135 | all_retained = np.stack(all_retained, axis=0) 136 | del retained 137 | z['data_retained'] = all_retained 138 | del z['data'] 139 | np.savez_compressed(fname_c, **z) 140 | print(' > Size', osp.getsize(fname) // (1024 * 1024), 'MB ->', 141 | osp.getsize(fname_c) // (1024 * 1024), 'MB') 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /octree/config/syn_sg25.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_root": "./data/NeRF/nerf_synthetic/", 3 | "train_root": "./data/Plenoctree/checkpoints/syn_sg25/", 4 | "scenes": ["chair", "drums", "ficus", "hotdog", "lego", "ship"], 5 | "scene_tasks": [{ 6 | "octree_name": "", 7 | "train_dir": "{%}", 8 | "data_dir": "{%}", 9 | "config": "nerf_sh/config/misc/sg", 10 | "extr_flags": [ 11 | "--autoscale", 12 | "--scale_alpha_thresh", "0.1", 13 | "--radius", "1.4", 14 | "--samples_per_cell", "256", 15 | "--no_early_stop", 16 | "--renderer_step_size", "1e-5"], 17 | "opt_flags": [ 18 | "--num_epochs", "80", 19 | "--sgd", 20 | "--lr", "1e9", 21 | "--no_early_stop", 22 | "--renderer_step_size", "1e-5"], 23 | "eval_flags": [ 24 | "--renderer_step_size", "1e-5"] 25 | }], 26 | "tasks": [{ 27 | "octree_name": "", 28 | "train_dir": "materials", 29 | "data_dir": "materials", 30 | "config": "nerf_sh/config/misc/sg", 31 | "extr_flags": [ 32 | "--autoscale", 33 | "--bbox_scale", "1.2", 34 | "--scale_alpha_thresh", "0.1", 35 | "--radius", "1.4", 36 | "--samples_per_cell", "256", 37 | "--no_early_stop", 38 | "--renderer_step_size", "1e-5"], 39 | "opt_flags": [ 40 | "--num_epochs", "80", 41 | "--sgd", 42 | "--lr", "1e9", 43 | "--no_early_stop", 44 | "--renderer_step_size", "1e-5"], 45 | "eval_flags": [ 46 | "--renderer_step_size", "1e-5"] 47 | },{ 48 | "octree_name": "", 49 | "train_dir": "mic", 50 | "data_dir": "mic", 51 | "config": "nerf_sh/config/misc/sg", 52 | "extr_flags": [ 53 | "--autoscale", 54 | "--scale_alpha_thresh", "0.1", 55 | "--radius", "1.6", 56 | "--samples_per_cell", "256", 57 | "--no_early_stop", 58 | "--renderer_step_size", "1e-5"], 59 | "opt_flags": [ 60 | "--num_epochs", "80", 61 | "--sgd", 62 | "--lr", "1e9", 63 | "--no_early_stop", 64 | "--renderer_step_size", "1e-5"], 65 | "eval_flags": [ 66 | "--renderer_step_size", "1e-5"] 67 | }] 68 | } 69 | -------------------------------------------------------------------------------- /octree/config/syn_sh16.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_root": "./data/NeRF/nerf_synthetic/", 3 | "train_root": "./data/Plenoctree/checkpoints/syn_sh16/", 4 | "scenes": ["chair", "drums", "ficus", "hotdog", "lego", "ship"], 5 | "scene_tasks": [{ 6 | "octree_name": "", 7 | "train_dir": "{%}", 8 | "data_dir": "{%}", 9 | "config": "nerf_sh/config/blender", 10 | "extr_flags": [ 11 | "--autoscale", 12 | "--scale_alpha_thresh", "0.1", 13 | "--radius", "1.4", 14 | "--samples_per_cell", "256", 15 | "--no_early_stop", 16 | "--renderer_step_size", "1e-5"], 17 | "opt_flags": [ 18 | "--num_epochs", "80", 19 | "--sgd", 20 | "--lr", "1e7", 21 | "--no_early_stop", 22 | "--renderer_step_size", "1e-5"], 23 | "eval_flags": [ 24 | "--renderer_step_size", "1e-5"] 25 | }], 26 | "tasks": [{ 27 | "octree_name": "", 28 | "train_dir": "materials", 29 | "data_dir": "materials", 30 | "config": "nerf_sh/config/blender", 31 | "extr_flags": [ 32 | "--autoscale", 33 | "--bbox_scale", "1.1", 34 | "--scale_alpha_thresh", "0.1", 35 | "--radius", "1.4", 36 | "--samples_per_cell", "256", 37 | "--no_early_stop", 38 | "--renderer_step_size", "1e-5"], 39 | "opt_flags": [ 40 | "--num_epochs", "80", 41 | "--sgd", 42 | "--lr", "1e7", 43 | "--no_early_stop", 44 | "--renderer_step_size", "1e-5"], 45 | "eval_flags": [ 46 | "--renderer_step_size", "1e-5"] 47 | },{ 48 | "octree_name": "", 49 | "train_dir": "mic", 50 | "data_dir": "mic", 51 | "config": "nerf_sh/config/blender", 52 | "extr_flags": [ 53 | "--autoscale", 54 | "--scale_alpha_thresh", "0.1", 55 | "--radius", "1.6", 56 | "--samples_per_cell", "256", 57 | "--no_early_stop", 58 | "--renderer_step_size", "1e-5"], 59 | "opt_flags": [ 60 | "--num_epochs", "80", 61 | "--sgd", 62 | "--lr", "1e7", 63 | "--no_early_stop", 64 | "--renderer_step_size", "1e-5"], 65 | "eval_flags": [ 66 | "--renderer_step_size", "1e-5"] 67 | }] 68 | } 69 | -------------------------------------------------------------------------------- /octree/config/tt_sh25.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_root": "./data/TanksAndTemple", 3 | "train_root": "./data/Plenoctree/checkpoints/tt_sh25/", 4 | "scenes": ["Barn", "Caterpillar", "Family", "Truck"], 5 | "scene_tasks": [{ 6 | "octree_name": "", 7 | "train_dir": "{%}", 8 | "data_dir": "{%}", 9 | "config": "nerf_sh/config/tt", 10 | "extr_flags": [ 11 | "--autoscale", 12 | "--scale_alpha_thresh", "0.1", 13 | "--bbox_from_data", 14 | "--data_bbox_scale", "1.2", 15 | "--bbox_scale", "1.0", 16 | "--samples_per_cell", "256", 17 | "--chunk", "8192", 18 | "--no_early_stop", 19 | "--renderer_step_size", "1e-5" 20 | ], 21 | "opt_flags": [ 22 | "--num_epochs", "40", 23 | "--sgd", 24 | "--lr", "1.5e6", 25 | "--renderer_step_size", "1e-5", 26 | "--split_train", 27 | "--split_holdout_prop", "0.1" 28 | ], 29 | "eval_flags": [ 30 | "--renderer_step_size", "1e-5" 31 | ] 32 | }], 33 | "tasks": [{ 34 | "octree_name": "", 35 | "train_dir": "Ignatius", 36 | "data_dir": "Ignatius", 37 | "config": "nerf_sh/config/tt", 38 | "extr_flags": [ 39 | "--autoscale", 40 | "--scale_alpha_thresh", "0.1", 41 | "--bbox_from_data", 42 | "--data_bbox_scale", "1.2", 43 | "--bbox_scale", "1.25", 44 | "--samples_per_cell", "256", 45 | "--chunk", "8192", 46 | "--no_early_stop", 47 | "--renderer_step_size", "1e-5" 48 | ], 49 | "opt_flags": [ 50 | "--num_epochs", "40", 51 | "--sgd", 52 | "--lr", "1.5e6", 53 | "--renderer_step_size", "1e-5", 54 | "--split_train", 55 | "--split_holdout_prop", "0.1" 56 | ], 57 | "eval_flags": [ 58 | "--renderer_step_size", "1e-5" 59 | ] 60 | }] 61 | } 62 | -------------------------------------------------------------------------------- /octree/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """Evluate a plenoctree on test set. 24 | 25 | Usage: 26 | 27 | export DATA_ROOT=./data/NeRF/nerf_synthetic/ 28 | export CKPT_ROOT=./data/PlenOctree/checkpoints/syn_sh16 29 | export SCENE=chair 30 | export CONFIG_FILE=nerf_sh/config/blender 31 | 32 | python -m octree.evaluation \ 33 | --input $CKPT_ROOT/$SCENE/octrees/tree_opt.npz \ 34 | --config $CONFIG_FILE \ 35 | --data_dir $DATA_ROOT/$SCENE/ 36 | """ 37 | import torch 38 | import numpy as np 39 | import os 40 | from absl import app 41 | from absl import flags 42 | from tqdm import tqdm 43 | import imageio 44 | 45 | from octree.nerf import models 46 | from octree.nerf import utils 47 | from octree.nerf import datasets 48 | 49 | import svox 50 | 51 | FLAGS = flags.FLAGS 52 | 53 | utils.define_flags() 54 | 55 | flags.DEFINE_string( 56 | "input", 57 | "./tree_opt.npz", 58 | "Input octree npz from optimization.py", 59 | ) 60 | flags.DEFINE_string( 61 | "write_vid", 62 | None, 63 | "If specified, writes rendered video to given path (*.mp4)", 64 | ) 65 | flags.DEFINE_string( 66 | "write_images", 67 | None, 68 | "If specified, writes images to given path (*.png)", 69 | ) 70 | 71 | device = "cuda" if torch.cuda.is_available() else "cpu" 72 | 73 | @torch.no_grad() 74 | def main(unused_argv): 75 | utils.set_random_seed(20200823) 76 | utils.update_flags(FLAGS) 77 | 78 | dataset = datasets.get_dataset("test", FLAGS) 79 | 80 | print('N3Tree load', FLAGS.input) 81 | t = svox.N3Tree.load(FLAGS.input, map_location=device) 82 | 83 | avg_psnr, avg_ssim, avg_lpips, out_frames = utils.eval_octree(t, dataset, FLAGS, 84 | want_lpips=True, 85 | want_frames=FLAGS.write_vid is not None or FLAGS.write_images is not None) 86 | print('Average PSNR', avg_psnr, 'SSIM', avg_ssim, 'LPIPS', avg_lpips) 87 | 88 | if FLAGS.write_vid is not None and len(out_frames): 89 | print('Writing to', FLAGS.write_vid) 90 | imageio.mimwrite(FLAGS.write_vid, out_frames) 91 | 92 | if FLAGS.write_images is not None and len(out_frames): 93 | print('Writing to', FLAGS.write_images) 94 | os.makedirs(FLAGS.write_images, exist_ok=True) 95 | for idx, frame in tqdm(enumerate(out_frames)): 96 | imageio.imwrite(os.path.join(FLAGS.write_images, f"{idx:03d}.png"), frame) 97 | 98 | if __name__ == "__main__": 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /octree/extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """Extract a plenoctree from a trained NeRF-SH model. 24 | 25 | Usage: 26 | 27 | export DATA_ROOT=./data/NeRF/nerf_synthetic/ 28 | export CKPT_ROOT=./data/PlenOctree/checkpoints/syn_sh16 29 | export SCENE=chair 30 | export CONFIG_FILE=nerf_sh/config/blender 31 | 32 | python -m octree.extraction \ 33 | --train_dir $CKPT_ROOT/$SCENE/ --is_jaxnerf_ckpt \ 34 | --config $CONFIG_FILE \ 35 | --data_dir $DATA_ROOT/$SCENE/ \ 36 | --output $CKPT_ROOT/$SCENE/octrees/tree.npz 37 | """ 38 | import os 39 | # Get rid of ugly TF logs 40 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 41 | 42 | import torch 43 | import torch.nn.functional as F 44 | import numpy as np 45 | import os.path as osp 46 | 47 | from absl import app 48 | from absl import flags 49 | 50 | from octree.nerf import models 51 | from octree.nerf import utils 52 | from octree.nerf import datasets 53 | from octree.nerf import sh_proj 54 | 55 | from svox import N3Tree 56 | from svox import NDCConfig, VolumeRenderer 57 | from svox.helpers import _get_c_extension 58 | from tqdm import tqdm 59 | 60 | _C = _get_c_extension() 61 | 62 | FLAGS = flags.FLAGS 63 | 64 | utils.define_flags() 65 | 66 | flags.DEFINE_string( 67 | "output", 68 | "./tree.npz", 69 | "Output file", 70 | ) 71 | flags.DEFINE_string( 72 | "center", 73 | "0 0 0", 74 | "Center of volume in x y z OR single number", 75 | ) 76 | flags.DEFINE_string( 77 | "radius", 78 | "1.5", 79 | "1/2 side length of volume", 80 | ) 81 | flags.DEFINE_float( 82 | "alpha_thresh", 83 | 0.01, 84 | "Alpha threshold to keep a voxel in initial sigma thresholding", 85 | ) 86 | flags.DEFINE_float( 87 | "max_refine_prop", 88 | 0.5, 89 | "Max proportion of cells to refine", 90 | ) 91 | flags.DEFINE_float( 92 | "z_min", 93 | None, 94 | "Discard z axis points below this value, for NDC use", 95 | ) 96 | flags.DEFINE_float( 97 | "z_max", 98 | None, 99 | "Discard z axis points above this value, for NDC use", 100 | ) 101 | flags.DEFINE_integer( 102 | "tree_branch_n", 103 | 2, 104 | "Tree branch factor (2=octree)", 105 | ) 106 | flags.DEFINE_integer( 107 | "init_grid_depth", 108 | 8, 109 | "Initial evaluation grid (2^{x+1} voxel grid)", 110 | ) 111 | flags.DEFINE_integer( 112 | "samples_per_cell", 113 | 8, 114 | "Samples per cell in step 2 (3D antialiasing)", 115 | short_name='S', 116 | ) 117 | flags.DEFINE_bool( 118 | "is_jaxnerf_ckpt", 119 | False, 120 | "Whether the ckpt is from jaxnerf or not.", 121 | ) 122 | flags.DEFINE_enum( 123 | "masking_mode", 124 | "weight", 125 | ["sigma", "weight"], 126 | "How to calculate mask when building the octree", 127 | ) 128 | flags.DEFINE_float( 129 | "weight_thresh", 130 | 0.001, 131 | "Weight threshold to keep a voxel", 132 | ) 133 | flags.DEFINE_integer( 134 | "projection_samples", 135 | 10000, 136 | "Number of rays to sample for SH projection.", 137 | ) 138 | 139 | # Load bbox from dataset 140 | flags.DEFINE_bool( 141 | "bbox_from_data", 142 | False, 143 | "Use bounding box from dataset if possible", 144 | ) 145 | flags.DEFINE_float( 146 | "data_bbox_scale", 147 | 1.0, 148 | "Scaling factor to apply to the bounding box from dataset (before autoscale), " + 149 | "if bbox_from_data is used", 150 | ) 151 | flags.DEFINE_bool( 152 | "autoscale", 153 | False, 154 | "Automatic scaling, after bbox_from_data", 155 | ) 156 | flags.DEFINE_bool( 157 | "bbox_cube", 158 | False, 159 | "Force bbox to be a cube", 160 | ) 161 | flags.DEFINE_float( 162 | "bbox_scale", 163 | 1.0, 164 | "Scaling factor to apply to the bounding box at the end (after load, autoscale)", 165 | ) 166 | flags.DEFINE_float( 167 | "scale_alpha_thresh", 168 | 0.01, 169 | "Alpha threshold to keep a voxel in initial sigma thresholding for autoscale", 170 | ) 171 | # For integrated eval (to avoid slow load) 172 | flags.DEFINE_bool( 173 | "eval", 174 | True, 175 | "Evaluate after building the octree", 176 | ) 177 | 178 | device = "cuda" if torch.cuda.is_available() else "cpu" 179 | 180 | 181 | def calculate_grid_weights(dataset, sigmas, reso, invradius, offset): 182 | w, h, focal = dataset.w, dataset.h, dataset.focal 183 | 184 | opts = _C.RenderOptions() 185 | opts.step_size = FLAGS.renderer_step_size 186 | opts.sigma_thresh = 0.0 187 | if 'llff' in FLAGS.config and (not FLAGS.spherify): 188 | ndc_config = NDCConfig(width=w, height=h, focal=focal) 189 | opts.ndc_width = ndc_config.width 190 | opts.ndc_height = ndc_config.height 191 | opts.ndc_focal = ndc_config.focal 192 | else: 193 | opts.ndc_width = -1 194 | 195 | cam = _C.CameraSpec() 196 | cam.fx = focal 197 | cam.fy = focal 198 | cam.width = w 199 | cam.height = h 200 | 201 | grid_data = sigmas.reshape((reso, reso, reso)) 202 | maximum_weight = torch.zeros_like(grid_data) 203 | for idx in tqdm(range(dataset.size)): 204 | cam.c2w = torch.from_numpy(dataset.camtoworlds[idx]).float().to(sigmas.device) 205 | grid_weight, grid_hit = _C.grid_weight_render( 206 | grid_data, 207 | cam, 208 | opts, 209 | offset, 210 | invradius, 211 | ) 212 | maximum_weight = torch.max(maximum_weight, grid_weight) 213 | 214 | return maximum_weight 215 | 216 | 217 | def project_nerf_to_sh(nerf, sh_deg, points): 218 | """ 219 | Args: 220 | points: [N, 3] 221 | Returns: 222 | coeffs for rgb. [N, C * (sh_deg + 1)**2] 223 | """ 224 | nerf.use_viewdirs = True 225 | 226 | def _sperical_func(viewdirs): 227 | # points: [num_points, 3] 228 | # viewdirs: [num_rays, 3] 229 | # raw_rgb: [num_points, num_rays, 3] 230 | # sigma: [num_points] 231 | raw_rgb, sigma = nerf.eval_points_raw(points, viewdirs, cross_broadcast=True) 232 | return raw_rgb, sigma 233 | 234 | coeffs, sigma = sh_proj.ProjectFunctionNeRF( 235 | order=sh_deg, 236 | sperical_func=_sperical_func, 237 | batch_size=points.shape[0], 238 | sample_count=FLAGS.projection_samples, 239 | device=points.device) 240 | 241 | return coeffs.reshape([points.shape[0], -1]), sigma 242 | 243 | 244 | def auto_scale(args, center, radius, nerf): 245 | print('* Step 0: Auto scale') 246 | reso = 2 ** args.init_grid_depth 247 | 248 | radius = torch.tensor(radius, dtype=torch.float32) 249 | center = torch.tensor(center, dtype=torch.float32) 250 | scale = 0.5 / radius 251 | offset = 0.5 * (1.0 - center / radius) 252 | 253 | arr = (torch.arange(0, reso, dtype=torch.float32) + 0.5) / reso 254 | xx = (arr - offset[0]) / scale[0] 255 | yy = (arr - offset[1]) / scale[1] 256 | zz = (arr - offset[2]) / scale[2] 257 | if args.z_min is not None: 258 | zz = zz[zz >= args.z_min] 259 | if args.z_max is not None: 260 | zz = zz[zz <= args.z_max] 261 | 262 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T 263 | 264 | out_chunks = [] 265 | for i in tqdm(range(0, grid.shape[0], args.chunk)): 266 | grid_chunk = grid[i:i+args.chunk].cuda() 267 | if nerf.use_viewdirs: 268 | fake_viewdirs = torch.zeros([grid_chunk.shape[0], 3], device=grid_chunk.device) 269 | else: 270 | fake_viewdirs = None 271 | rgb, sigma = nerf.eval_points_raw(grid_chunk, fake_viewdirs) 272 | del grid_chunk 273 | out_chunks.append(sigma.squeeze(-1)) 274 | sigmas = torch.cat(out_chunks, 0) 275 | del out_chunks 276 | 277 | approx_delta = 2.0 / reso 278 | sigma_thresh = -np.log(1.0 - args.scale_alpha_thresh) / approx_delta 279 | mask = sigmas >= sigma_thresh 280 | 281 | grid = grid[mask] 282 | del mask 283 | 284 | lc = grid.min(dim=0)[0] - 0.5 / reso 285 | uc = grid.max(dim=0)[0] + 0.5 / reso 286 | return ((lc + uc) * 0.5).tolist(), ((uc - lc) * 0.5).tolist() 287 | 288 | def step1(args, tree, nerf, dataset): 289 | print('* Step 1: Grid eval') 290 | reso = 2 ** (args.init_grid_depth + 1) 291 | offset = tree.offset.cpu() 292 | scale = tree.invradius.cpu() 293 | 294 | arr = (torch.arange(0, reso, dtype=torch.float32) + 0.5) / reso 295 | xx = (arr - offset[0]) / scale[0] 296 | yy = (arr - offset[1]) / scale[1] 297 | zz = (arr - offset[2]) / scale[2] 298 | if args.z_min is not None: 299 | zz = zz[zz >= args.z_min] 300 | if args.z_max is not None: 301 | zz = zz[zz <= args.z_max] 302 | 303 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T 304 | print('init grid', grid.shape) 305 | 306 | approx_delta = 2.0 / reso 307 | sigma_thresh = -np.log(1.0 - args.alpha_thresh) / approx_delta 308 | 309 | out_chunks = [] 310 | for i in tqdm(range(0, grid.shape[0], args.chunk)): 311 | grid_chunk = grid[i:i+args.chunk].cuda() 312 | if nerf.use_viewdirs: 313 | fake_viewdirs = torch.zeros([grid_chunk.shape[0], 3], device=grid_chunk.device) 314 | else: 315 | fake_viewdirs = None 316 | rgb, sigma = nerf.eval_points_raw(grid_chunk, fake_viewdirs) 317 | del grid_chunk 318 | out_chunks.append(sigma.squeeze(-1)) 319 | sigmas = torch.cat(out_chunks, 0) 320 | del out_chunks 321 | 322 | if FLAGS.masking_mode == "sigma": 323 | mask = sigmas >= sigma_thresh 324 | elif FLAGS.masking_mode == "weight": 325 | print ("* Calculating grid weights") 326 | grid_weights = calculate_grid_weights(dataset, 327 | sigmas, reso, tree.invradius, tree.offset) 328 | mask = grid_weights.reshape(-1) >= FLAGS.weight_thresh 329 | del grid_weights 330 | else: 331 | raise ValueError 332 | del sigmas 333 | 334 | grid = grid[mask] 335 | del mask 336 | print(grid.shape, grid.min(), grid.max()) 337 | grid = grid.cuda() 338 | 339 | torch.cuda.empty_cache() 340 | print(' Building octree') 341 | for i in range(args.init_grid_depth - 1): 342 | tree[grid].refine() 343 | refine_chunk = 2000000 344 | if grid.shape[0] <= refine_chunk: 345 | tree[grid].refine() 346 | else: 347 | # Do last layer separately 348 | grid = grid.cpu() 349 | for j in tqdm(range(0, grid.shape[0], refine_chunk)): 350 | tree[grid[j:j+refine_chunk].cuda()].refine() 351 | print(tree) 352 | 353 | assert tree.max_depth == args.init_grid_depth 354 | 355 | def step2(args, tree, nerf): 356 | print('* Step 2: AA', args.samples_per_cell) 357 | 358 | leaf_mask = tree.depths.cpu() == tree.max_depth 359 | leaf_ind = torch.where(leaf_mask)[0] 360 | del leaf_mask 361 | 362 | if args.use_viewdirs: 363 | chunk_size = args.chunk // (args.samples_per_cell * args.projection_samples // 10) 364 | else: 365 | chunk_size = args.chunk // (args.samples_per_cell) 366 | 367 | for i in tqdm(range(0, leaf_ind.size(0), chunk_size)): 368 | chunk_inds = leaf_ind[i:i+chunk_size] 369 | points = tree[chunk_inds].sample(args.samples_per_cell) # (n_cells, n_samples, 3) 370 | points = points.view(-1, 3) 371 | 372 | if not args.use_viewdirs: # trained NeRF-SH/SG model returns rgb as coeffs 373 | rgb, sigma = nerf.eval_points_raw(points) 374 | else: # vanilla NeRF model returns rgb, so we project them into coeffs (only SH supported) 375 | rgb, sigma = project_nerf_to_sh(nerf, args.sh_deg, points) 376 | 377 | if tree.data_format.format == tree.data_format.RGBA: 378 | rgb = rgb.reshape(-1, args.samples_per_cell, tree.data_dim - 1); 379 | sigma = sigma.reshape(-1, args.samples_per_cell, 1); 380 | sigma_avg = sigma.mean(dim=1) 381 | 382 | reso = 2 ** (args.init_grid_depth + 1) 383 | approx_delta = 2.0 / reso 384 | alpha = 1.0 - torch.exp(-approx_delta * sigma) 385 | msum = alpha.sum(dim=1) 386 | rgb_avg = (rgb * alpha).sum(dim=1) / msum 387 | rgb_avg[msum[..., 0] < 1e-3] = 0 388 | rgba = torch.cat([rgb_avg, sigma_avg], dim=-1) 389 | del rgb, sigma 390 | else: 391 | rgba = torch.cat([rgb, sigma], dim=-1) 392 | del rgb, sigma 393 | rgba = rgba.reshape(-1, args.samples_per_cell, tree.data_dim).mean(dim=1) 394 | tree[chunk_inds] = rgba 395 | 396 | def euler2mat(angle): 397 | """Convert euler angles to rotation matrix. 398 | 399 | Args: 400 | angle: rotation angle along 3 axis (in radians). [..., 3] 401 | Returns: 402 | Rotation matrix corresponding to the euler angles. [..., 3, 3] 403 | """ 404 | x, y, z = angle[..., 0], angle[..., 1], angle[..., 2] 405 | cosz = torch.cos(z) 406 | sinz = torch.sin(z) 407 | cosy = torch.cos(y) 408 | siny = torch.sin(y) 409 | cosx = torch.cos(x) 410 | sinx = torch.sin(x) 411 | zeros = torch.zeros_like(z) 412 | ones = torch.ones_like(z) 413 | zmat = torch.stack([torch.stack([cosz, -sinz, zeros], dim=-1), 414 | torch.stack([sinz, cosz, zeros], dim=-1), 415 | torch.stack([zeros, zeros, ones], dim=-1)], dim=-1) 416 | ymat = torch.stack([torch.stack([ cosy, zeros, siny], dim=-1), 417 | torch.stack([zeros, ones, zeros], dim=-1), 418 | torch.stack([-siny, zeros, cosy], dim=-1)], dim=-1) 419 | xmat = torch.stack([torch.stack([ ones, zeros, zeros], dim=-1), 420 | torch.stack([zeros, cosx, -sinx], dim=-1), 421 | torch.stack([zeros, sinx, cosx], dim=-1)], dim=-1) 422 | rotMat = torch.einsum("...ij,...jk,...kq->...iq", xmat, ymat, zmat) 423 | return rotMat 424 | 425 | @torch.no_grad() 426 | def main(unused_argv): 427 | utils.set_random_seed(20200823) 428 | utils.update_flags(FLAGS) 429 | 430 | print('* Loading NeRF') 431 | nerf = models.get_model_state(FLAGS, device=device, restore=True) 432 | nerf.eval() 433 | 434 | data_format = None 435 | extra_data = None 436 | if FLAGS.sg_dim > 0: 437 | data_format = f'SG{FLAGS.sg_dim}' 438 | assert FLAGS.sg_global 439 | extra_data = torch.cat(( 440 | F.softplus(nerf.sg_lambda[:, None]), 441 | sh_proj.spher2cart(nerf.sg_mu_spher[:, 0], nerf.sg_mu_spher[:, 1]) 442 | ), dim=-1) 443 | elif FLAGS.sh_deg > 0: 444 | data_format = f'SH{(FLAGS.sh_deg + 1) ** 2}' 445 | if data_format is not None: 446 | print('Detected format:', data_format) 447 | 448 | base_dir = osp.dirname(FLAGS.output) 449 | if base_dir: 450 | os.makedirs(base_dir, exist_ok=True) 451 | 452 | assert FLAGS.data_dir # Dataset is required now 453 | dataset = datasets.get_dataset("train", FLAGS) 454 | 455 | if FLAGS.bbox_from_data: 456 | assert dataset.bbox is not None # Dataset must be NSVF 457 | center = (dataset.bbox[:3] + dataset.bbox[3:6]) * 0.5 458 | radius = (dataset.bbox[3:6] - dataset.bbox[:3]) * 0.5 * FLAGS.data_bbox_scale 459 | print('Bounding box from data: c', center, 'r', radius) 460 | else: 461 | center = list(map(float, FLAGS.center.split())) 462 | if len(center) == 1: 463 | center *= 3 464 | radius = list(map(float, FLAGS.radius.split())) 465 | if len(radius) == 1: 466 | radius *= 3 467 | 468 | if FLAGS.autoscale: 469 | center, radius = auto_scale(FLAGS, center, radius, nerf) 470 | print('Autoscale result center', center, 'radius', radius) 471 | 472 | radius = [r * FLAGS.bbox_scale for r in radius] 473 | if FLAGS.bbox_cube: 474 | radius = [max(radius)] * 3 475 | 476 | num_rgb_channels = FLAGS.num_rgb_channels 477 | if FLAGS.sh_deg >= 0: 478 | assert FLAGS.sg_dim == -1, ( 479 | "You can only use up to one of: SH or SG") 480 | num_rgb_channels *= (FLAGS.sh_deg + 1) ** 2 481 | elif FLAGS.sg_dim > 0: 482 | assert FLAGS.sh_deg == -1, ( 483 | "You can only use up to one of: SH or SG") 484 | num_rgb_channels *= FLAGS.sg_dim 485 | data_dim = 1 + num_rgb_channels # alpha + rgb 486 | print('data dim is', data_dim) 487 | 488 | print('* Creating model') 489 | tree = N3Tree(N=FLAGS.tree_branch_n, 490 | data_dim=data_dim, 491 | init_refine=0, 492 | init_reserve=500000, 493 | geom_resize_fact=1.0, 494 | depth_limit=FLAGS.init_grid_depth, 495 | radius=radius, 496 | center=center, 497 | data_format=data_format, 498 | extra_data=extra_data, 499 | map_location=device) 500 | 501 | step1(FLAGS, tree, nerf, dataset) 502 | step2(FLAGS, tree, nerf) 503 | tree[:, -1:].relu_() 504 | tree.shrink_to_fit() 505 | print(tree) 506 | 507 | del dataset.images 508 | print('* Saving', FLAGS.output) 509 | tree.save(FLAGS.output, compress=False) # Faster saving 510 | 511 | if FLAGS.eval: 512 | dataset = datasets.get_dataset("test", FLAGS) 513 | print('* Evaluation (before fine tune)') 514 | avg_psnr, avg_ssim, avg_lpips, out_frames = utils.eval_octree(tree, 515 | dataset, FLAGS, want_lpips=True) 516 | print('Average PSNR', avg_psnr, 'SSIM', avg_ssim, 'LPIPS', avg_lpips) 517 | 518 | 519 | if __name__ == "__main__": 520 | app.run(main) 521 | -------------------------------------------------------------------------------- /octree/nerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxyu/plenoctree/f0b82631199a1aa7dc9ce263c08980a3a7504014/octree/nerf/__init__.py -------------------------------------------------------------------------------- /octree/nerf/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Different datasets implementation plus a general port for all the datasets.""" 19 | INTERNAL = False # pylint: disable=g-statement-before-imports 20 | import json 21 | import os 22 | from os import path 23 | 24 | if not INTERNAL: 25 | import cv2 # pylint: disable=g-import-not-at-top 26 | import numpy as np 27 | from PIL import Image 28 | from tqdm import tqdm 29 | 30 | from octree.nerf import utils 31 | 32 | 33 | def get_dataset(split, args): 34 | return dataset_dict[args.dataset](split, args) 35 | 36 | 37 | def convert_to_ndc(origins, directions, focal, w, h, near=1.0): 38 | """Convert a set of rays to NDC coordinates.""" 39 | # Shift ray origins to near plane 40 | t = -(near + origins[Ellipsis, 2]) / directions[Ellipsis, 2] 41 | origins = origins + t[Ellipsis, None] * directions 42 | 43 | dx, dy, dz = tuple(np.moveaxis(directions, -1, 0)) 44 | ox, oy, oz = tuple(np.moveaxis(origins, -1, 0)) 45 | 46 | # Projection 47 | o0 = -((2 * focal) / w) * (ox / oz) 48 | o1 = -((2 * focal) / h) * (oy / oz) 49 | o2 = 1 + 2 * near / oz 50 | 51 | d0 = -((2 * focal) / w) * (dx / dz - ox / oz) 52 | d1 = -((2 * focal) / h) * (dy / dz - oy / oz) 53 | d2 = -2 * near / oz 54 | 55 | origins = np.stack([o0, o1, o2], -1) 56 | directions = np.stack([d0, d1, d2], -1) 57 | return origins, directions 58 | 59 | 60 | class Dataset(): 61 | """Dataset Base Class.""" 62 | 63 | def __init__(self, split, args, prefetch=True): 64 | super(Dataset, self).__init__() 65 | self.split = split 66 | self._general_init(args) 67 | 68 | @property 69 | def size(self): 70 | return self.n_examples 71 | 72 | def _general_init(self, args): 73 | bbox_path = path.join(args.data_dir, 'bbox.txt') 74 | if os.path.isfile(bbox_path): 75 | self.bbox = np.loadtxt(bbox_path)[:-1] 76 | else: 77 | self.bbox = None 78 | self._load_renderings(args) 79 | 80 | 81 | class Blender(Dataset): 82 | """Blender Dataset.""" 83 | 84 | def _load_renderings(self, args): 85 | """Load images from disk.""" 86 | if args.render_path: 87 | raise ValueError("render_path cannot be used for the blender dataset.") 88 | with utils.open_file( 89 | path.join(args.data_dir, "transforms_{}.json".format(self.split)), "r" 90 | ) as fp: 91 | meta = json.load(fp) 92 | images = [] 93 | cams = [] 94 | print(' Load Blender', args.data_dir, 'split', self.split) 95 | for i in tqdm(range(len(meta["frames"]))): 96 | frame = meta["frames"][i] 97 | fname = os.path.join(args.data_dir, frame["file_path"] + ".png") 98 | with utils.open_file(fname, "rb") as imgin: 99 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0 100 | if args.factor == 2: 101 | [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] 102 | image = cv2.resize( 103 | image, (halfres_w, halfres_h), interpolation=cv2.INTER_AREA 104 | ) 105 | elif args.factor > 0: 106 | raise ValueError( 107 | "Blender dataset only supports factor=0 or 2, {} " 108 | "set.".format(args.factor) 109 | ) 110 | cams.append(frame["transform_matrix"]) 111 | if args.white_bkgd: 112 | mask = image[..., -1:] 113 | image = image[..., :3] * mask + (1.0 - mask) 114 | else: 115 | image = image[..., :3] 116 | images.append(image) 117 | self.images = np.stack(images, axis=0) 118 | self.h, self.w = self.images.shape[1:3] 119 | self.resolution = self.h * self.w 120 | self.camtoworlds = np.stack(cams, axis=0).astype(np.float32) 121 | camera_angle_x = float(meta["camera_angle_x"]) 122 | self.focal = 0.5 * self.w / np.tan(0.5 * camera_angle_x) 123 | self.n_examples = self.images.shape[0] 124 | 125 | 126 | class LLFF(Dataset): 127 | """LLFF Dataset.""" 128 | 129 | def _load_renderings(self, args): 130 | """Load images from disk.""" 131 | args.data_dir = path.expanduser(args.data_dir) 132 | print(' Load LLFF', args.data_dir, 'split', self.split) 133 | # Load images. 134 | imgdir_suffix = "" 135 | if args.factor > 0: 136 | imgdir_suffix = "_{}".format(args.factor) 137 | factor = args.factor 138 | else: 139 | factor = 1 140 | imgdir = path.join(args.data_dir, "images" + imgdir_suffix) 141 | if not utils.file_exists(imgdir): 142 | raise ValueError("Image folder {} doesn't exist.".format(imgdir)) 143 | imgfiles = [ 144 | path.join(imgdir, f) 145 | for f in sorted(utils.listdir(imgdir)) 146 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") 147 | ] 148 | images = [] 149 | for imgfile in imgfiles: 150 | with utils.open_file(imgfile, "rb") as imgin: 151 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0 152 | images.append(image) 153 | images = np.stack(images, axis=-1) 154 | 155 | # Load poses and bds. 156 | with utils.open_file(path.join(args.data_dir, "poses_bounds.npy"), "rb") as fp: 157 | poses_arr = np.load(fp) 158 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 159 | bds = poses_arr[:, -2:].transpose([1, 0]) 160 | if poses.shape[-1] != images.shape[-1]: 161 | raise RuntimeError( 162 | "Mismatch between imgs {} and poses {}".format( 163 | images.shape[-1], poses.shape[-1] 164 | ) 165 | ) 166 | 167 | # Update poses according to downsampling. 168 | poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) 169 | poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor 170 | 171 | # Correct rotation matrix ordering and move variable dim to axis 0. 172 | poses = np.concatenate( 173 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1 174 | ) 175 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 176 | images = np.moveaxis(images, -1, 0) 177 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 178 | 179 | # Rescale according to a default bd factor. 180 | scale = 1.0 / (bds.min() * 0.75) 181 | poses[:, :3, 3] *= scale 182 | bds *= scale 183 | 184 | # Recenter poses. 185 | poses = self._recenter_poses(poses) 186 | 187 | # Generate a spiral/spherical ray path for rendering videos. 188 | if args.spherify: 189 | poses = self._generate_spherical_poses(poses, bds) 190 | self.spherify = True 191 | else: 192 | self.spherify = False 193 | if not args.spherify and self.split == "test": 194 | self._generate_spiral_poses(poses, bds) 195 | 196 | # Select the split. 197 | i_test = np.arange(images.shape[0])[:: args.llffhold] 198 | i_train = np.array( 199 | [i for i in np.arange(int(images.shape[0])) if i not in i_test] 200 | ) 201 | if self.split == "train": 202 | indices = i_train 203 | else: 204 | indices = i_test 205 | images = images[indices] 206 | poses = poses[indices] 207 | 208 | self.images = images 209 | self.camtoworlds = poses[:, :3, :4] 210 | self.focal = poses[0, -1, -1] 211 | self.h, self.w = images.shape[1:3] 212 | self.resolution = self.h * self.w 213 | if args.render_path: 214 | self.n_examples = self.render_poses.shape[0] 215 | else: 216 | self.n_examples = images.shape[0] 217 | 218 | def _generate_rays(self): 219 | """Generate normalized device coordinate rays for llff.""" 220 | if self.split == "test": 221 | n_render_poses = self.render_poses.shape[0] 222 | self.camtoworlds = np.concatenate( 223 | [self.render_poses, self.camtoworlds], axis=0 224 | ) 225 | 226 | super()._generate_rays() 227 | 228 | if not self.spherify: 229 | ndc_origins, ndc_directions = convert_to_ndc( 230 | self.rays.origins, self.rays.directions, self.focal, self.w, self.h 231 | ) 232 | self.rays = utils.Rays( 233 | origins=ndc_origins, 234 | directions=ndc_directions, 235 | viewdirs=self.rays.viewdirs, 236 | ) 237 | 238 | # Split poses from the dataset and generated poses 239 | if self.split == "test": 240 | self.camtoworlds = self.camtoworlds[n_render_poses:] 241 | split = [np.split(r, [n_render_poses], 0) for r in self.rays] 242 | split0, split1 = zip(*split) 243 | self.render_rays = utils.Rays(*split0) 244 | self.rays = utils.Rays(*split1) 245 | 246 | def _recenter_poses(self, poses): 247 | """Recenter poses according to the original NeRF code.""" 248 | poses_ = poses.copy() 249 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) 250 | c2w = self._poses_avg(poses) 251 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 252 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 253 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 254 | poses = np.linalg.inv(c2w) @ poses 255 | poses_[:, :3, :4] = poses[:, :3, :4] 256 | poses = poses_ 257 | return poses 258 | 259 | def _poses_avg(self, poses): 260 | """Average poses according to the original NeRF code.""" 261 | hwf = poses[0, :3, -1:] 262 | center = poses[:, :3, 3].mean(0) 263 | vec2 = self._normalize(poses[:, :3, 2].sum(0)) 264 | up = poses[:, :3, 1].sum(0) 265 | c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1) 266 | return c2w 267 | 268 | def _viewmatrix(self, z, up, pos): 269 | """Construct lookat view matrix.""" 270 | vec2 = self._normalize(z) 271 | vec1_avg = up 272 | vec0 = self._normalize(np.cross(vec1_avg, vec2)) 273 | vec1 = self._normalize(np.cross(vec2, vec0)) 274 | m = np.stack([vec0, vec1, vec2, pos], 1) 275 | return m 276 | 277 | def _normalize(self, x): 278 | """Normalization helper function.""" 279 | return x / np.linalg.norm(x) 280 | 281 | def _generate_spiral_poses(self, poses, bds): 282 | """Generate a spiral path for rendering.""" 283 | c2w = self._poses_avg(poses) 284 | # Get average pose. 285 | up = self._normalize(poses[:, :3, 1].sum(0)) 286 | # Find a reasonable "focus depth" for this dataset. 287 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0 288 | dt = 0.75 289 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 290 | focal = mean_dz 291 | # Get radii for spiral path. 292 | tt = poses[:, :3, 3] 293 | rads = np.percentile(np.abs(tt), 90, 0) 294 | c2w_path = c2w 295 | n_views = 120 296 | n_rots = 2 297 | # Generate poses for spiral path. 298 | render_poses = [] 299 | rads = np.array(list(rads) + [1.0]) 300 | hwf = c2w_path[:, 4:5] 301 | zrate = 0.5 302 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_views + 1)[:-1]: 303 | c = np.dot( 304 | c2w[:3, :4], 305 | ( 306 | np.array( 307 | [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0] 308 | ) 309 | * rads 310 | ), 311 | ) 312 | z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) 313 | render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1)) 314 | self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4] 315 | 316 | def _generate_spherical_poses(self, poses, bds): 317 | """Generate a 360 degree spherical path for rendering.""" 318 | # pylint: disable=g-long-lambda 319 | p34_to_44 = lambda p: np.concatenate( 320 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 321 | ) 322 | rays_d = poses[:, :3, 2:3] 323 | rays_o = poses[:, :3, 3:4] 324 | 325 | def min_line_dist(rays_o, rays_d): 326 | a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 327 | b_i = -a_i @ rays_o 328 | pt_mindist = np.squeeze( 329 | -np.linalg.inv((np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) 330 | @ (b_i).mean(0) 331 | ) 332 | return pt_mindist 333 | 334 | pt_mindist = min_line_dist(rays_o, rays_d) 335 | center = pt_mindist 336 | up = (poses[:, :3, 3] - center).mean(0) 337 | vec0 = self._normalize(up) 338 | vec1 = self._normalize(np.cross([0.1, 0.2, 0.3], vec0)) 339 | vec2 = self._normalize(np.cross(vec0, vec1)) 340 | pos = center 341 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 342 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 343 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 344 | sc = 1.0 / rad 345 | poses_reset[:, :3, 3] *= sc 346 | bds *= sc 347 | rad *= sc 348 | centroid = np.mean(poses_reset[:, :3, 3], 0) 349 | zh = centroid[2] 350 | radcircle = np.sqrt(rad ** 2 - zh ** 2) 351 | new_poses = [] 352 | 353 | for th in np.linspace(0.0, 2.0 * np.pi, 120): 354 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 355 | up = np.array([0, 0, -1.0]) 356 | vec2 = self._normalize(camorigin) 357 | vec0 = self._normalize(np.cross(vec2, up)) 358 | vec1 = self._normalize(np.cross(vec2, vec0)) 359 | pos = camorigin 360 | p = np.stack([vec0, vec1, vec2, pos], 1) 361 | new_poses.append(p) 362 | 363 | new_poses = np.stack(new_poses, 0) 364 | new_poses = np.concatenate( 365 | [ 366 | new_poses, 367 | np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape), 368 | ], 369 | -1, 370 | ) 371 | poses_reset = np.concatenate( 372 | [ 373 | poses_reset[:, :3, :4], 374 | np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape), 375 | ], 376 | -1, 377 | ) 378 | if self.split == "test": 379 | self.render_poses = new_poses[:, :3, :4] 380 | return poses_reset 381 | 382 | 383 | class NSVF(Dataset): 384 | """NSVF Generic Dataset.""" 385 | 386 | def _load_renderings(self, args): 387 | """Load images from disk.""" 388 | if args.render_path: 389 | raise ValueError("render_path cannot be used for the NSVF dataset.") 390 | args.data_dir = path.expanduser(args.data_dir) 391 | K : np.ndarray = np.loadtxt(path.join(args.data_dir, "intrinsics.txt")) 392 | pose_files = sorted(os.listdir(path.join(args.data_dir, 'pose'))) 393 | img_files = sorted(os.listdir(path.join(args.data_dir, 'rgb'))) 394 | 395 | if self.split == 'train': 396 | pose_files = [x for x in pose_files if x.startswith('0_')] 397 | img_files = [x for x in img_files if x.startswith('0_')] 398 | elif self.split == 'val': 399 | pose_files = [x for x in pose_files if x.startswith('1_')] 400 | img_files = [x for x in img_files if x.startswith('1_')] 401 | elif self.split == 'test': 402 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 403 | test_img_files = [x for x in img_files if x.startswith('2_')] 404 | if len(test_pose_files) == 0: 405 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 406 | test_img_files = [x for x in img_files if x.startswith('1_')] 407 | pose_files = test_pose_files 408 | img_files = test_img_files 409 | 410 | images = [] 411 | cams = [] 412 | 413 | cam_trans = np.diag(np.array([1, -1, -1, 1], dtype=np.float32)) 414 | 415 | assert len(img_files) == len(pose_files) 416 | print(' Load NSVF', args.data_dir, 'split', self.split, 'num_images', len(img_files)) 417 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), total=len(img_files)): 418 | img_fname = path.join(args.data_dir, 'rgb', img_fname) 419 | with utils.open_file(img_fname, "rb") as imgin: 420 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0 421 | cam_mtx = np.loadtxt(path.join(args.data_dir, 'pose', pose_fname)) @ cam_trans 422 | cams.append(cam_mtx) # C2W 423 | if image.shape[-1] == 4: 424 | # Alpha channel available 425 | if args.white_bkgd: 426 | mask = image[..., -1:] 427 | image = image[..., :3] * mask + (1.0 - mask) 428 | else: 429 | image = image[..., :3] 430 | if args.factor > 1: 431 | [rsz_h, rsz_w] = [hw // args.factor for hw in image.shape[:2]] 432 | image = cv2.resize( 433 | image, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA 434 | ) 435 | 436 | images.append(image) 437 | self.images = np.stack(images, axis=0) 438 | self.n_examples, self.h, self.w = self.images.shape[:3] 439 | self.resolution = self.h * self.w 440 | self.camtoworlds = np.stack(cams, axis=0).astype(np.float32) 441 | # We assume fx and fy are same 442 | self.focal = (K[0, 0] + K[1, 1]) * 0.5 443 | if args.factor > 1: 444 | self.focal /= args.factor 445 | 446 | 447 | dataset_dict = { 448 | "blender": Blender, 449 | "llff": LLFF, 450 | "nsvf": NSVF, 451 | } 452 | -------------------------------------------------------------------------------- /octree/nerf/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Helper functions/classes for model definition.""" 19 | 20 | import functools 21 | from typing import Any, Callable 22 | import math 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | 28 | def dense_layer(in_features, out_features): 29 | layer = nn.Linear(in_features, out_features) 30 | # The initialization matters! 31 | nn.init.xavier_uniform_(layer.weight) 32 | nn.init.zeros_(layer.bias) 33 | return layer 34 | 35 | 36 | class MLP(nn.Module): 37 | """A simple MLP.""" 38 | 39 | def __init__( 40 | self, 41 | net_depth: int = 8, # The depth of the first part of MLP. 42 | net_width: int = 256, # The width of the first part of MLP. 43 | net_depth_condition: int = 1, # The depth of the second part of MLP. 44 | net_width_condition: int = 128, # The width of the second part of MLP. 45 | net_activation: Callable[Ellipsis, Any] = nn.ReLU(), # The activation function. 46 | skip_layer: int = 4, # The layer to add skip layers to. 47 | num_rgb_channels: int = 3, # The number of RGB channels. 48 | num_sigma_channels: int = 1, # The number of sigma channels. 49 | input_dim: int = 63, # The number of input tensor channels. 50 | condition_dim: int = 27, # The number of conditional tensor channels. 51 | ): 52 | super(MLP, self).__init__() 53 | self.net_depth = net_depth 54 | self.net_width = net_width 55 | self.net_depth_condition = net_depth_condition 56 | self.net_width_condition = net_width_condition 57 | self.net_activation = net_activation 58 | self.skip_layer = skip_layer 59 | self.num_rgb_channels = num_rgb_channels 60 | self.num_sigma_channels = num_sigma_channels 61 | self.input_dim = input_dim 62 | self.condition_dim = condition_dim 63 | 64 | self.input_layers = nn.ModuleList() 65 | in_features = self.input_dim 66 | for i in range(self.net_depth): 67 | self.input_layers.append( 68 | dense_layer(in_features, self.net_width) 69 | ) 70 | if i % self.skip_layer == 0 and i > 0: 71 | in_features = self.net_width + self.input_dim 72 | else: 73 | in_features = self.net_width 74 | self.sigma_layer = dense_layer(in_features, self.num_sigma_channels) 75 | 76 | if self.condition_dim > 0: 77 | self.bottleneck_layer = dense_layer(in_features, self.net_width) 78 | self.condition_layers = nn.ModuleList() 79 | in_features = self.net_width + self.condition_dim 80 | for i in range(self.net_depth_condition): 81 | self.condition_layers.append( 82 | dense_layer(in_features, self.net_width_condition) 83 | ) 84 | in_features = self.net_width_condition 85 | self.rgb_layer = dense_layer(in_features, self.num_rgb_channels) 86 | 87 | def forward(self, x, condition=None, cross_broadcast=False): 88 | """Evaluate the MLP. 89 | 90 | Args: 91 | x: torch.tensor(float32), [batch, num_samples, feature], points. 92 | condition: torch.tensor(float32), 93 | [batch, feature] or [batch, num_samples, feature] or [batch, num_rays, feature], 94 | if not None, this variable will be part of the input to the second part of the MLP 95 | concatenated with the output vector of the first part of the MLP. If 96 | None, only the first part of the MLP will be used with input x. In the 97 | original paper, this variable is the view direction. Note when the shape of this 98 | tensor is [batch, num_rays, feature], where `num_rays` != `num_samples`, this 99 | function will cross broadcast all rays with all samples. And the `cross_broadcast` 100 | option must be set to `True`. 101 | cross_broadcast: if true, cross broadcast the x tensor and the condition 102 | tensor. 103 | 104 | Returns: 105 | raw_rgb: torch.tensor(float32), with a shape of 106 | [batch, num_samples, num_rgb_channels]. If `cross_broadcast` is true, the return 107 | shape would be [batch, num_samples, num_rays, num_rgb_channels]. 108 | raw_sigma: torch.tensor(float32), with a shape of 109 | [batch, num_samples, num_sigma_channels]. If `cross_broadcast` is true, the return 110 | shape woudl be [batch, num_samples, num_rays, num_sigma_channels]. 111 | """ 112 | batch_size = x.shape[0] 113 | feature_dim = x.shape[-1] 114 | num_samples = x.shape[1] 115 | x = x.view([-1, feature_dim]) 116 | inputs = x 117 | for i in range(self.net_depth): 118 | x = self.input_layers[i](x) 119 | x = self.net_activation(x) 120 | if i % self.skip_layer == 0 and i > 0: 121 | x = torch.cat([x, inputs], dim=-1) 122 | raw_sigma = self.sigma_layer(x).view( 123 | [-1, num_samples, self.num_sigma_channels] 124 | ) 125 | 126 | if condition is not None: 127 | # Output of the first part of MLP. 128 | bottleneck = self.bottleneck_layer(x) 129 | # Broadcast condition from [batch, feature] to 130 | # [batch, num_samples, feature] since all the samples along the same ray 131 | # have the same viewdir. 132 | if len(condition.shape) == 2 and (not cross_broadcast): 133 | condition = condition[:, None, :].repeat(1, num_samples, 1) 134 | # Broadcast samples from [batch, num_samples, feature] 135 | # and condition from [batch, num_rays, feature] to 136 | # [batch, num_samples, num_rays, feature] since in this case each point 137 | # is passed by all the rays. This option is used for projecting an 138 | # trained vanilla NeRF to NeRF-SH. 139 | if cross_broadcast: 140 | condition = condition.view([batch_size, -1, condition.shape[-1]]) 141 | num_rays = condition.shape[1] 142 | condition = condition[:, None, :, :].repeat(1, num_samples, 1, 1) 143 | bottleneck = bottleneck.view([batch_size, -1, bottleneck.shape[-1]]) 144 | bottleneck = bottleneck[:, :, None, :].repeat(1, 1, num_rays, 1) 145 | # Collapse the [batch, num_samples, (num_rays,) feature] tensor to 146 | # [batch * num_samples (* num_rays), feature] so that it can be fed into nn.Dense. 147 | x = torch.cat([ 148 | bottleneck.view([-1, bottleneck.shape[-1]]), 149 | condition.view([-1, condition.shape[-1]])], dim=-1) 150 | # Here use 1 extra layer to align with the original nerf model. 151 | for i in range(self.net_depth_condition): 152 | x = self.condition_layers[i](x) 153 | x = self.net_activation(x) 154 | raw_rgb = self.rgb_layer(x).view( 155 | [batch_size, num_samples, self.num_rgb_channels] if not cross_broadcast else \ 156 | [batch_size, num_samples, num_rays, self.num_rgb_channels] 157 | ) 158 | return raw_rgb, raw_sigma 159 | 160 | 161 | def posenc(x, min_deg, max_deg, legacy_posenc_order=False): 162 | """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1]. 163 | 164 | Instead of computing [sin(x), cos(x)], we use the trig identity 165 | cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). 166 | 167 | Args: 168 | x: torch.tensor, variables to be encoded. Note that x should be in [-pi, pi]. 169 | min_deg: int, the minimum (inclusive) degree of the encoding. 170 | max_deg: int, the maximum (exclusive) degree of the encoding. 171 | legacy_posenc_order: bool, keep the same ordering as the original tf code. 172 | 173 | Returns: 174 | encoded: torch.tensor, encoded variables. 175 | """ 176 | if min_deg == max_deg: 177 | return x 178 | scales = torch.tensor([2 ** i for i in range(min_deg, max_deg)], 179 | dtype=x.dtype, device=x.device) 180 | if legacy_posenc_order: 181 | xb = x[Ellipsis, None, :] * scales[:, None] 182 | four_feat = torch.reshape( 183 | torch.sin(torch.stack([xb, xb + 0.5 * math.pi], -2)), list(x.shape[:-1]) + [-1] 184 | ) 185 | else: 186 | xb = torch.reshape( 187 | (x[Ellipsis, None, :] * scales[:, None]), list(x.shape[:-1]) + [-1] 188 | ) 189 | four_feat = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1)) 190 | return torch.cat([x] + [four_feat], dim=-1) 191 | 192 | -------------------------------------------------------------------------------- /octree/nerf/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Different model implementation plus a general port for all the models.""" 19 | import os, glob 20 | import inspect 21 | from typing import Any, Callable 22 | import math 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | from octree.nerf import model_utils 28 | 29 | 30 | def get_model(args): 31 | """A helper function that wraps around a 'model zoo'.""" 32 | model_dict = { 33 | "nerf": construct_nerf, 34 | } 35 | return model_dict[args.model](args) 36 | 37 | 38 | def get_model_state(args, device="cpu", restore=True): 39 | """ 40 | Helper for loading model with get_model & creating optimizer & 41 | optionally restoring checkpoint to reduce boilerplate 42 | """ 43 | model = get_model(args).to(device) 44 | if restore: 45 | if args.is_jaxnerf_ckpt: 46 | model = restore_model_state_from_jaxnerf(args, model) 47 | else: 48 | model = restore_model_state(args, model) 49 | return model 50 | 51 | 52 | def restore_model_state(args, model): 53 | """ 54 | Helper for restoring checkpoint. 55 | """ 56 | ckpt_paths = sorted( 57 | glob.glob(os.path.join(args.train_dir, "*.ckpt"))) 58 | if len(ckpt_paths) > 0: 59 | ckpt_path = ckpt_paths[-1] 60 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) 61 | model.load_state_dict(ckpt["model"]) 62 | print (f"* restore ckpt from {ckpt_path}.") 63 | return model 64 | 65 | 66 | def restore_model_state_from_jaxnerf(args, model): 67 | """ 68 | Helper for restoring checkpoint for jaxnerf. 69 | """ 70 | from flax.training import checkpoints 71 | ckpt_paths = sorted( 72 | glob.glob(os.path.join(args.train_dir, "checkpoint_*"))) 73 | 74 | if len(ckpt_paths) > 0: 75 | ckpt_dict = checkpoints.restore_checkpoint( 76 | args.train_dir, target=None)["optimizer"]["target"]["params"] 77 | state_dict = {} 78 | 79 | def _init_layer(from_name, to_name): 80 | state_dict[f"MLP_0.{to_name}.weight"] = \ 81 | ckpt_dict["MLP_0"][f"{from_name}"]["kernel"].T 82 | state_dict[f"MLP_0.{to_name}.bias"] = \ 83 | ckpt_dict["MLP_0"][f"{from_name}"]["bias"] 84 | state_dict[f"MLP_1.{to_name}.weight"] = \ 85 | ckpt_dict["MLP_1"][f"{from_name}"]["kernel"].T 86 | state_dict[f"MLP_1.{to_name}.bias"] = \ 87 | ckpt_dict["MLP_1"][f"{from_name}"]["bias"] 88 | pass 89 | 90 | # init all layers 91 | for i in range(model.net_depth): 92 | _init_layer(f"Dense_{i}", f"input_layers.{i}") 93 | i += 1 94 | _init_layer(f"Dense_{i}", f"sigma_layer") 95 | if model.use_viewdirs: 96 | i += 1 97 | _init_layer(f"Dense_{i}", f"bottleneck_layer") 98 | for j in range(model.net_depth_condition): 99 | i += 1 100 | _init_layer(f"Dense_{i}", f"condition_layers.{j}") 101 | i += 1 102 | _init_layer(f"Dense_{i}", f"rgb_layer") 103 | 104 | # support SG 105 | if model.sg_dim > 0: 106 | state_dict["sg_lambda"] = ckpt_dict["sg_lambda"] 107 | state_dict["sg_mu_spher"] = ckpt_dict["sg_mu_spher"] 108 | 109 | for key in state_dict.keys(): 110 | state_dict[key] = torch.from_numpy(state_dict[key].copy()) 111 | model.load_state_dict(state_dict) 112 | print (f"* restore ckpt from {args.train_dir}") 113 | return model 114 | 115 | 116 | class NerfModel(nn.Module): 117 | """Nerf NN Model with both coarse and fine MLPs.""" 118 | 119 | def __init__( 120 | self, 121 | num_coarse_samples: int = 64, # The number of samples for the coarse nerf. 122 | num_fine_samples: int = 128, # The number of samples for the fine nerf. 123 | use_viewdirs: bool = True, # If True, use viewdirs as an input. 124 | sh_deg: int = -1, # If != -1, use spherical harmonics output of given order 125 | sg_dim: int = -1, # If != -1, use spherical gaussians output of given dimension 126 | near: float = 2.0, # The distance to the near plane 127 | far: float = 6.0, # The distance to the far plane 128 | noise_std: float = 0.0, # The std dev of noise added to raw sigma. 129 | net_depth: int = 8, # The depth of the first part of MLP. 130 | net_width: int = 256, # The width of the first part of MLP. 131 | net_depth_condition: int = 1, # The depth of the second part of MLP. 132 | net_width_condition: int = 128, # The width of the second part of MLP. 133 | net_activation: Callable[Ellipsis, Any] = nn.ReLU(), # MLP activation 134 | skip_layer: int = 4, # How often to add skip connections. 135 | num_rgb_channels: int = 3, # The number of RGB channels. 136 | num_sigma_channels: int = 1, # The number of density channels. 137 | white_bkgd: bool = True, # If True, use a white background. 138 | min_deg_point: int = 0, # The minimum degree of positional encoding for positions. 139 | max_deg_point: int = 10, # The maximum degree of positional encoding for positions. 140 | deg_view: int = 4, # The degree of positional encoding for viewdirs. 141 | lindisp: bool = False, # If True, sample linearly in disparity rather than in depth. 142 | rgb_activation: Callable[Ellipsis, Any] = nn.Sigmoid(), # Output RGB activation. 143 | sigma_activation: Callable[Ellipsis, Any] = nn.ReLU(), # Output sigma activation. 144 | legacy_posenc_order: bool = False, # Keep the same ordering as the original tf code. 145 | ): 146 | super(NerfModel, self).__init__() 147 | self.num_coarse_samples = num_coarse_samples 148 | self.num_fine_samples = num_fine_samples 149 | self.use_viewdirs = use_viewdirs 150 | self.sh_deg = sh_deg 151 | self.sg_dim = sg_dim 152 | self.near = near 153 | self.far = far 154 | self.noise_std = noise_std 155 | self.net_depth = net_depth 156 | self.net_width = net_width 157 | self.net_depth_condition = net_depth_condition 158 | self.net_width_condition = net_width_condition 159 | self.net_activation = net_activation 160 | self.skip_layer = skip_layer 161 | self.num_rgb_channels = num_rgb_channels 162 | self.num_sigma_channels = num_sigma_channels 163 | self.white_bkgd = white_bkgd 164 | self.min_deg_point = min_deg_point 165 | self.max_deg_point = max_deg_point 166 | self.deg_view = deg_view 167 | self.lindisp = lindisp 168 | self.rgb_activation = rgb_activation 169 | self.sigma_activation = sigma_activation 170 | self.legacy_posenc_order = legacy_posenc_order 171 | # Construct the "coarse" MLP. Weird name is for 172 | # compatibility with 'compact' version 173 | self.MLP_0 = model_utils.MLP( 174 | net_depth = self.net_depth, 175 | net_width = self.net_width, 176 | net_depth_condition = self.net_depth_condition, 177 | net_width_condition = self.net_width_condition, 178 | net_activation = self.net_activation, 179 | skip_layer = self.skip_layer, 180 | num_rgb_channels = self.num_rgb_channels, 181 | num_sigma_channels = self.num_sigma_channels, 182 | input_dim=3 * (1 + 2 * (self.max_deg_point - self.min_deg_point)), 183 | condition_dim=3 * (1 + 2 * self.deg_view) if self.use_viewdirs else 0) 184 | # Construct the "fine" MLP. 185 | self.MLP_1 = model_utils.MLP( 186 | net_depth = self.net_depth, 187 | net_width = self.net_width, 188 | net_depth_condition = self.net_depth_condition, 189 | net_width_condition = self.net_width_condition, 190 | net_activation = self.net_activation, 191 | skip_layer = self.skip_layer, 192 | num_rgb_channels = self.num_rgb_channels, 193 | num_sigma_channels = self.num_sigma_channels, 194 | input_dim=3 * (1 + 2 * (self.max_deg_point - self.min_deg_point)), 195 | condition_dim=3 * (1 + 2 * self.deg_view) if self.use_viewdirs else 0) 196 | 197 | # Construct learnable variables for spherical gaussians. 198 | if self.sg_dim > 0: 199 | self.register_parameter( 200 | "sg_lambda", 201 | nn.Parameter(torch.ones([self.sg_dim])) 202 | ) 203 | self.register_parameter( 204 | "sg_mu_spher", 205 | nn.Parameter(torch.stack([ 206 | torch.rand([self.sg_dim]) * math.pi, # theta 207 | torch.rand([self.sg_dim]) * math.pi * 2 # phi 208 | ], dim=-1)) 209 | ) 210 | 211 | def eval_points_raw(self, points, viewdirs=None, coarse=False, cross_broadcast=False): 212 | """ 213 | Evaluate at points, returing rgb and sigma. 214 | If sh_deg >= 0 then this will return spherical harmonic 215 | coeffs for RGB. Please see eval_points for alternate 216 | version which always returns RGB. 217 | 218 | Args: 219 | points: torch.tensor [B, 3] 220 | viewdirs: torch.tensor [B, 3]. if cross_broadcast = True, it can be [M, 3]. 221 | coarse: if true, uses coarse MLP. 222 | cross_broadcast: if true, cross broadcast between points and viewdirs. 223 | 224 | Returns: 225 | raw_rgb: torch.tensor [B, 3 * (sh_deg + 1)**2 or 3]. if cross_broadcast = True, it 226 | returns [B, M, 3 * (sh_deg + 1)**2 or 3] 227 | raw_sigma: torch.tensor [B, 1] 228 | """ 229 | points = points[None] 230 | points_enc = model_utils.posenc( 231 | points, 232 | self.min_deg_point, 233 | self.max_deg_point, 234 | self.legacy_posenc_order, 235 | ) 236 | if self.num_fine_samples > 0 and not coarse: 237 | mlp = self.MLP_1 238 | else: 239 | mlp = self.MLP_0 240 | if self.use_viewdirs: 241 | assert viewdirs is not None 242 | viewdirs = viewdirs[None] 243 | viewdirs_enc = model_utils.posenc( 244 | viewdirs, 245 | 0, 246 | self.deg_view, 247 | self.legacy_posenc_order, 248 | ) 249 | raw_rgb, raw_sigma = mlp(points_enc, viewdirs_enc, cross_broadcast=cross_broadcast) 250 | else: 251 | raw_rgb, raw_sigma = mlp(points_enc) 252 | return raw_rgb[0], raw_sigma[0] 253 | 254 | 255 | def construct_nerf(args): 256 | """Construct a Neural Radiance Field. 257 | 258 | Args: 259 | args: FLAGS class. Hyperparameters of nerf. 260 | 261 | Returns: 262 | model: nn.Model. Nerf model with parameters. 263 | state: flax.Module.state. Nerf model state for stateful parameters. 264 | """ 265 | net_activation = getattr(nn, str(args.net_activation)) 266 | if inspect.isclass(net_activation): 267 | net_activation = net_activation() 268 | rgb_activation = getattr(nn, str(args.rgb_activation)) 269 | if inspect.isclass(rgb_activation): 270 | rgb_activation = rgb_activation() 271 | sigma_activation = getattr(nn, str(args.sigma_activation)) 272 | if inspect.isclass(sigma_activation): 273 | sigma_activation = sigma_activation() 274 | 275 | # Assert that rgb_activation always produces outputs in [0, 1], and 276 | # sigma_activation always produce non-negative outputs. 277 | x = torch.exp(torch.linspace(-90, 90, 1024)) 278 | x = torch.cat([-x, x], dim=0) 279 | 280 | rgb = rgb_activation(x) 281 | if torch.any(rgb < 0) or torch.any(rgb > 1): 282 | raise NotImplementedError( 283 | "Choice of rgb_activation `{}` produces colors outside of [0, 1]".format( 284 | args.rgb_activation 285 | ) 286 | ) 287 | 288 | sigma = sigma_activation(x) 289 | if torch.any(sigma < 0): 290 | raise NotImplementedError( 291 | "Choice of sigma_activation `{}` produces negative densities".format( 292 | args.sigma_activation 293 | ) 294 | ) 295 | 296 | num_rgb_channels = args.num_rgb_channels 297 | if not args.use_viewdirs: 298 | if args.sh_deg >= 0: 299 | assert args.sg_dim == -1, ( 300 | "You can only use up to one of: SH or SG.") 301 | num_rgb_channels *= (args.sh_deg + 1) ** 2 302 | elif args.sg_dim > 0: 303 | assert args.sh_deg == -1, ( 304 | "You can only use up to one of: SH or SG.") 305 | num_rgb_channels *= args.sg_dim 306 | 307 | model = NerfModel( 308 | min_deg_point=args.min_deg_point, 309 | max_deg_point=args.max_deg_point, 310 | deg_view=args.deg_view, 311 | num_coarse_samples=args.num_coarse_samples, 312 | num_fine_samples=args.num_fine_samples, 313 | use_viewdirs=args.use_viewdirs, 314 | sh_deg=args.sh_deg, 315 | sg_dim=args.sg_dim, 316 | near=args.near, 317 | far=args.far, 318 | noise_std=args.noise_std, 319 | white_bkgd=args.white_bkgd, 320 | net_depth=args.net_depth, 321 | net_width=args.net_width, 322 | net_depth_condition=args.net_depth_condition, 323 | net_width_condition=args.net_width_condition, 324 | skip_layer=args.skip_layer, 325 | num_rgb_channels=num_rgb_channels, 326 | num_sigma_channels=args.num_sigma_channels, 327 | lindisp=args.lindisp, 328 | net_activation=net_activation, 329 | rgb_activation=rgb_activation, 330 | sigma_activation=sigma_activation, 331 | legacy_posenc_order=args.legacy_posenc_order, 332 | ) 333 | return model 334 | -------------------------------------------------------------------------------- /octree/nerf/sh_proj.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright 2021 The PlenOctree Authors. 2 | # Original Copyright 2015 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Sperical harmonics projection related functions 16 | 17 | Some codes are borrowed from: 18 | https://github.com/google/spherical-harmonics/blob/master/sh/spherical_harmonics.cc 19 | """ 20 | from typing import Callable 21 | import math 22 | import torch 23 | 24 | kHardCodedOrderLimit = 4 25 | 26 | 27 | def spher2cart(theta, phi): 28 | """Convert spherical coordinates into Cartesian coordinates (radius 1).""" 29 | r = torch.sin(theta) 30 | x = r * torch.cos(phi) 31 | y = r * torch.sin(phi) 32 | z = torch.cos(theta) 33 | return torch.stack([x, y, z], dim=-1) 34 | 35 | 36 | # Get the total number of coefficients for a function represented by 37 | # all spherical harmonic basis of degree <= @order (it is a point of 38 | # confusion that the order of an SH refers to its degree and not the order). 39 | def GetCoefficientCount(order: int): 40 | return (order + 1) ** 2 41 | 42 | 43 | # Get the one dimensional index associated with a particular degree @l 44 | # and order @m. This is the index that can be used to access the Coeffs 45 | # returned by SHSolver. 46 | def GetIndex(l: int, m: int): 47 | return l * (l + 1) + m 48 | 49 | 50 | # Hardcoded spherical harmonic functions for low orders (l is first number 51 | # and m is second number (sign encoded as preceeding 'p' or 'n')). 52 | # 53 | # As polynomials they are evaluated more efficiently in cartesian coordinates, 54 | # assuming that @{dx, dy, dz} is unit. This is not verified for efficiency. 55 | 56 | def HardcodedSH00(dx, dy, dz): 57 | # 0.5 * sqrt(1/pi) 58 | return 0.28209479177387814 + (dx * 0) # keep the shape 59 | 60 | def HardcodedSH1n1(dx, dy, dz): 61 | # -sqrt(3/(4pi)) * y 62 | return -0.4886025119029199 * dy 63 | 64 | def HardcodedSH10(dx, dy, dz): 65 | # sqrt(3/(4pi)) * z 66 | return 0.4886025119029199 * dz 67 | 68 | def HardcodedSH1p1(dx, dy, dz): 69 | # -sqrt(3/(4pi)) * x 70 | return -0.4886025119029199 * dx 71 | 72 | def HardcodedSH2n2(dx, dy, dz): 73 | # 0.5 * sqrt(15/pi) * x * y 74 | return 1.0925484305920792 * dx * dy 75 | 76 | def HardcodedSH2n1(dx, dy, dz): 77 | # -0.5 * sqrt(15/pi) * y * z 78 | return -1.0925484305920792 * dy * dz 79 | 80 | def HardcodedSH20(dx, dy, dz): 81 | # 0.25 * sqrt(5/pi) * (-x^2-y^2+2z^2) 82 | return 0.31539156525252005 * (-dx * dx - dy * dy + 2.0 * dz * dz) 83 | 84 | def HardcodedSH2p1(dx, dy, dz): 85 | # -0.5 * sqrt(15/pi) * x * z 86 | return -1.0925484305920792 * dx * dz 87 | 88 | def HardcodedSH2p2(dx, dy, dz): 89 | # 0.25 * sqrt(15/pi) * (x^2 - y^2) 90 | return 0.5462742152960396 * (dx * dx - dy * dy) 91 | 92 | def HardcodedSH3n3(dx, dy, dz): 93 | # -0.25 * sqrt(35/(2pi)) * y * (3x^2 - y^2) 94 | return -0.5900435899266435 * dy * (3.0 * dx * dx - dy * dy) 95 | 96 | def HardcodedSH3n2(dx, dy, dz): 97 | # 0.5 * sqrt(105/pi) * x * y * z 98 | return 2.890611442640554 * dx * dy * dz 99 | 100 | def HardcodedSH3n1(dx, dy, dz): 101 | # -0.25 * sqrt(21/(2pi)) * y * (4z^2-x^2-y^2) 102 | return -0.4570457994644658 * dy * (4.0 * dz * dz - dx * dx - dy * dy) 103 | 104 | def HardcodedSH30(dx, dy, dz): 105 | # 0.25 * sqrt(7/pi) * z * (2z^2 - 3x^2 - 3y^2) 106 | return 0.3731763325901154 * dz * (2.0 * dz * dz - 3.0 * dx * dx - 3.0 * dy * dy) 107 | 108 | def HardcodedSH3p1(dx, dy, dz): 109 | # -0.25 * sqrt(21/(2pi)) * x * (4z^2-x^2-y^2) 110 | return -0.4570457994644658 * dx * (4.0 * dz * dz - dx * dx - dy * dy) 111 | 112 | def HardcodedSH3p2(dx, dy, dz): 113 | # 0.25 * sqrt(105/pi) * z * (x^2 - y^2) 114 | return 1.445305721320277 * dz * (dx * dx - dy * dy) 115 | 116 | def HardcodedSH3p3(dx, dy, dz): 117 | # -0.25 * sqrt(35/(2pi)) * x * (x^2-3y^2) 118 | return -0.5900435899266435 * dx * (dx * dx - 3.0 * dy * dy) 119 | 120 | def HardcodedSH4n4(dx, dy, dz): 121 | # 0.75 * sqrt(35/pi) * x * y * (x^2-y^2) 122 | return 2.5033429417967046 * dx * dy * (dx * dx - dy * dy) 123 | 124 | def HardcodedSH4n3(dx, dy, dz): 125 | # -0.75 * sqrt(35/(2pi)) * y * z * (3x^2-y^2) 126 | return -1.7701307697799304 * dy * dz * (3.0 * dx * dx - dy * dy) 127 | 128 | def HardcodedSH4n2(dx, dy, dz): 129 | # 0.75 * sqrt(5/pi) * x * y * (7z^2-1) 130 | return 0.9461746957575601 * dx * dy * (7.0 * dz * dz - 1.0) 131 | 132 | def HardcodedSH4n1(dx, dy, dz): 133 | # -0.75 * sqrt(5/(2pi)) * y * z * (7z^2-3) 134 | return -0.6690465435572892 * dy * dz * (7.0 * dz * dz - 3.0) 135 | 136 | def HardcodedSH40(dx, dy, dz): 137 | # 3/16 * sqrt(1/pi) * (35z^4-30z^2+3) 138 | z2 = dz * dz 139 | return 0.10578554691520431 * (35.0 * z2 * z2 - 30.0 * z2 + 3.0) 140 | 141 | def HardcodedSH4p1(dx, dy, dz): 142 | # -0.75 * sqrt(5/(2pi)) * x * z * (7z^2-3) 143 | return -0.6690465435572892 * dx * dz * (7.0 * dz * dz - 3.0) 144 | 145 | def HardcodedSH4p2(dx, dy, dz): 146 | # 3/8 * sqrt(5/pi) * (x^2 - y^2) * (7z^2 - 1) 147 | return 0.47308734787878004 * (dx * dx - dy * dy) * (7.0 * dz * dz - 1.0) 148 | 149 | def HardcodedSH4p3(dx, dy, dz): 150 | # -0.75 * sqrt(35/(2pi)) * x * z * (x^2 - 3y^2) 151 | return -1.7701307697799304 * dx * dz * (dx * dx - 3.0 * dy * dy) 152 | 153 | def HardcodedSH4p4(dx, dy, dz): 154 | # 3/16*sqrt(35/pi) * (x^2 * (x^2 - 3y^2) - y^2 * (3x^2 - y^2)) 155 | x2 = dx * dx 156 | y2 = dy * dy 157 | return 0.6258357354491761 * (x2 * (x2 - 3.0 * y2) - y2 * (3.0 * x2 - y2)) 158 | 159 | 160 | def EvalSH(l: int, m: int, dirs): 161 | """ 162 | Args: 163 | dirs: array [..., 3]. works with torch/jnp/np 164 | Return: 165 | array [...] 166 | """ 167 | if l <= kHardCodedOrderLimit: 168 | # Validate l and m here (don't do it generally since EvalSHSlow also 169 | # checks it if we delegate to that function). 170 | assert l >= 0, "l must be at least 0." 171 | assert -l <= m and m <= l, "m must be between -l and l." 172 | dx = dirs[..., 0] 173 | dy = dirs[..., 1] 174 | dz = dirs[..., 2] 175 | 176 | if l == 0: 177 | return HardcodedSH00(dx, dy, dz) 178 | elif l == 1: 179 | if m == -1: 180 | return HardcodedSH1n1(dx, dy, dz) 181 | elif m == 0: 182 | return HardcodedSH10(dx, dy, dz) 183 | elif m == 1: 184 | return HardcodedSH1p1(dx, dy, dz) 185 | elif l == 2: 186 | if m == -2: 187 | return HardcodedSH2n2(dx, dy, dz) 188 | elif m == -1: 189 | return HardcodedSH2n1(dx, dy, dz) 190 | elif m == 0: 191 | return HardcodedSH20(dx, dy, dz) 192 | elif m == 1: 193 | return HardcodedSH2p1(dx, dy, dz) 194 | elif m == 2: 195 | return HardcodedSH2p2(dx, dy, dz) 196 | elif l == 3: 197 | if m == -3: 198 | return HardcodedSH3n3(dx, dy, dz) 199 | elif m == -2: 200 | return HardcodedSH3n2(dx, dy, dz) 201 | elif m == -1: 202 | return HardcodedSH3n1(dx, dy, dz) 203 | elif m == 0: 204 | return HardcodedSH30(dx, dy, dz) 205 | elif m == 1: 206 | return HardcodedSH3p1(dx, dy, dz) 207 | elif m == 2: 208 | return HardcodedSH3p2(dx, dy, dz) 209 | elif m == 3: 210 | return HardcodedSH3p3(dx, dy, dz) 211 | elif l == 4: 212 | if m == -4: 213 | return HardcodedSH4n4(dx, dy, dz) 214 | elif m == -3: 215 | return HardcodedSH4n3(dx, dy, dz) 216 | elif m == -2: 217 | return HardcodedSH4n2(dx, dy, dz) 218 | elif m == -1: 219 | return HardcodedSH4n1(dx, dy, dz) 220 | elif m == 0: 221 | return HardcodedSH40(dx, dy, dz) 222 | elif m == 1: 223 | return HardcodedSH4p1(dx, dy, dz) 224 | elif m == 2: 225 | return HardcodedSH4p2(dx, dy, dz) 226 | elif m == 3: 227 | return HardcodedSH4p3(dx, dy, dz) 228 | elif m == 4: 229 | return HardcodedSH4p4(dx, dy, dz) 230 | 231 | # This is unreachable given the CHECK's above but the compiler can't tell. 232 | return None 233 | 234 | else: 235 | # Not hard-coded so use the recurrence relation (which will convert this 236 | # to spherical coordinates). 237 | # return EvalSHSlow(l, m, dx, dy, dz) 238 | raise NotImplementedError 239 | 240 | 241 | def spherical_uniform_sampling(sample_count, device="cpu"): 242 | # See: https://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php 243 | theta = torch.acos(2.0 * torch.rand([sample_count]) - 1.0) 244 | phi = 2.0 * math.pi * torch.rand([sample_count]) 245 | return theta.to(device), phi.to(device) 246 | 247 | 248 | def ProjectFunction(order: int, sperical_func: Callable, sample_count: int, device="cpu"): 249 | assert order >= 0, "Order must be at least zero." 250 | assert sample_count > 0, "Sample count must be at least one." 251 | 252 | # This is the approach demonstrated in [1] and is useful for arbitrary 253 | # functions on the sphere that are represented analytically. 254 | coeffs = torch.zeros([GetCoefficientCount(order)], dtype=torch.float32).to(device) 255 | 256 | # generate sample_count uniformly and stratified samples over the sphere 257 | # See http://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php 258 | theta, phi = spherical_uniform_sampling(sample_count, device=device) 259 | dirs = spher2cart(theta, phi) 260 | 261 | # evaluate the analytic function for the current spherical coords 262 | func_value = sperical_func(dirs) 263 | 264 | # evaluate the SH basis functions up to band O, scale them by the 265 | # function's value and accumulate them over all generated samples 266 | for l in range(order + 1): # end inclusive 267 | for m in range(-l, l + 1): # end inclusive 268 | coeffs[GetIndex(l, m)] = sum(func_value * EvalSH(l, m, dirs)) 269 | 270 | # scale by the probability of a particular sample, which is 271 | # 4pi/sample_count. 4pi for the surface area of a unit sphere, and 272 | # 1/sample_count for the number of samples drawn uniformly. 273 | weight = 4.0 * math.pi / sample_count 274 | coeffs *= weight 275 | return coeffs 276 | 277 | 278 | def ProjectFunctionNeRF(order: int, sperical_func: Callable, batch_size: int, sample_count: int, device="cpu"): 279 | assert order >= 0, "Order must be at least zero." 280 | assert sample_count > 0, "Sample count must be at least one." 281 | C = 3 # rgb channels 282 | 283 | # This is the approach demonstrated in [1] and is useful for arbitrary 284 | # functions on the sphere that are represented analytically. 285 | coeffs = torch.zeros([batch_size, C, GetCoefficientCount(order)], dtype=torch.float32).to(device) 286 | 287 | # generate sample_count uniformly and stratified samples over the sphere 288 | # See http://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php 289 | theta, phi = spherical_uniform_sampling(sample_count, device=device) 290 | dirs = spher2cart(theta, phi) 291 | 292 | # evaluate the analytic function for the current spherical coords 293 | func_value, others = sperical_func(dirs) # [batch_size, sample_count, C] 294 | 295 | # evaluate the SH basis functions up to band O, scale them by the 296 | # function's value and accumulate them over all generated samples 297 | for l in range(order + 1): # end inclusive 298 | for m in range(-l, l + 1): # end inclusive 299 | coeffs[:, :, GetIndex(l, m)] = torch.einsum("bsc,s->bc", func_value, EvalSH(l, m, dirs)) 300 | 301 | # scale by the probability of a particular sample, which is 302 | # 4pi/sample_count. 4pi for the surface area of a unit sphere, and 303 | # 1/sample_count for the number of samples drawn uniformly. 304 | weight = 4.0 * math.pi / sample_count 305 | coeffs *= weight 306 | return coeffs, others 307 | 308 | def ProjectFunctionNeRFSparse( 309 | order: int, 310 | spherical_func: Callable, 311 | sample_count: int, 312 | device="cpu", 313 | ): 314 | assert order >= 0, "Order must be at least zero." 315 | assert sample_count > 0, "Sample count must be at least one." 316 | C = 3 # rgb channels 317 | 318 | # generate sample_count uniformly and stratified samples over the sphere 319 | # See http://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php 320 | theta, phi = spherical_uniform_sampling(sample_count, device=device) 321 | dirs = spher2cart(theta, phi) # [sample_count, 3] 322 | 323 | # evaluate the analytic function for the current spherical coords 324 | func_value, others = spherical_func(dirs) # func_value [batch_size, sample_count, C] 325 | 326 | batch_size = func_value.shape[0] 327 | 328 | coeff_count = GetCoefficientCount(order) 329 | basis_vals = torch.empty( 330 | [sample_count, coeff_count], dtype=torch.float32 331 | ).to(device) 332 | 333 | # evaluate the SH basis functions up to band O, scale them by the 334 | # function's value and accumulate them over all generated samples 335 | for l in range(order + 1): # end inclusive 336 | for m in range(-l, l + 1): # end inclusive 337 | basis_vals[:, GetIndex(l, m)] = EvalSH(l, m, dirs) 338 | 339 | basis_vals = basis_vals.view( 340 | sample_count, coeff_count) # [sample_count, coeff_count] 341 | func_value = func_value.transpose(0, 1).reshape( 342 | sample_count, batch_size * C) # [sample_count, batch_size * C] 343 | soln = torch.lstsq(func_value, basis_vals).solution[:basis_vals.size(1)] 344 | soln = soln.T.reshape(batch_size, C, -1) 345 | return soln, others 346 | 347 | -------------------------------------------------------------------------------- /octree/nerf/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modifications Copyright 2021 The PlenOctree Authors. 3 | # Original Copyright 2021 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Lint as: python3 18 | """Utility functions.""" 19 | import collections 20 | import os 21 | from os import path 22 | from absl import flags 23 | import math 24 | 25 | import torch 26 | import torch.nn.functional as F 27 | 28 | import numpy as np 29 | from PIL import Image 30 | import yaml 31 | from tqdm import tqdm 32 | from octree.nerf import datasets 33 | 34 | INTERNAL = False 35 | 36 | Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs")) 37 | 38 | 39 | def namedtuple_map(fn, tup): 40 | """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple.""" 41 | return type(tup)(*map(fn, tup)) 42 | 43 | 44 | def define_flags(): 45 | """Define flags for both training and evaluation modes.""" 46 | flags.DEFINE_string("train_dir", None, "where to store ckpts and logs") 47 | flags.DEFINE_string("data_dir", None, "input data directory.") 48 | flags.DEFINE_string("config", None, "using config files to set hyperparameters.") 49 | 50 | # Dataset Flags 51 | # TODO(pratuls): rename to dataset_loader and consider cleaning up 52 | flags.DEFINE_enum( 53 | "dataset", 54 | "blender", 55 | list(k for k in datasets.dataset_dict.keys()), 56 | "The type of dataset feed to nerf.", 57 | ) 58 | flags.DEFINE_bool( 59 | "image_batching", False, "sample rays in a batch from different images." 60 | ) 61 | flags.DEFINE_bool( 62 | "white_bkgd", 63 | True, 64 | "using white color as default background." "(used in the blender dataset only)", 65 | ) 66 | flags.DEFINE_integer( 67 | "batch_size", 1024, "the number of rays in a mini-batch (for training)." 68 | ) 69 | flags.DEFINE_integer( 70 | "factor", 4, "the downsample factor of images, 0 for no downsample." 71 | ) 72 | flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.") 73 | flags.DEFINE_bool( 74 | "render_path", 75 | False, 76 | "render generated path if set true." "(used in the llff dataset only)", 77 | ) 78 | flags.DEFINE_integer( 79 | "llffhold", 80 | 8, 81 | "will take every 1/N images as LLFF test set." 82 | "(used in the llff dataset only)", 83 | ) 84 | 85 | # Model Flags 86 | flags.DEFINE_string("model", "nerf", "name of model to use.") 87 | flags.DEFINE_float("near", 2.0, "near clip of volumetric rendering.") 88 | flags.DEFINE_float("far", 6.0, "far clip of volumentric rendering.") 89 | flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.") 90 | flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.") 91 | flags.DEFINE_integer("net_depth_condition", 1, "depth of the second part of MLP.") 92 | flags.DEFINE_integer("net_width_condition", 128, "width of the second part of MLP.") 93 | flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay") 94 | flags.DEFINE_integer( 95 | "skip_layer", 96 | 4, 97 | "add a skip connection to the output vector of every" "skip_layer layers.", 98 | ) 99 | flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.") 100 | flags.DEFINE_integer("num_sigma_channels", 1, "the number of density channels.") 101 | flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.") 102 | flags.DEFINE_integer( 103 | "min_deg_point", 0, "Minimum degree of positional encoding for points." 104 | ) 105 | flags.DEFINE_integer( 106 | "max_deg_point", 10, "Maximum degree of positional encoding for points." 107 | ) 108 | flags.DEFINE_integer("deg_view", 4, "Degree of positional encoding for viewdirs.") 109 | flags.DEFINE_integer( 110 | "num_coarse_samples", 111 | 64, 112 | "the number of samples on each ray for the coarse model.", 113 | ) 114 | flags.DEFINE_integer( 115 | "num_fine_samples", 128, "the number of samples on each ray for the fine model." 116 | ) 117 | flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.") 118 | flags.DEFINE_integer("sh_deg", -1, "set to use spherical harmonics output of given order.") 119 | flags.DEFINE_integer("sg_dim", -1, "set to use spherical gaussians (SG).") 120 | flags.DEFINE_float( 121 | "noise_std", 122 | None, 123 | "std dev of noise added to regularize sigma output." 124 | "(used in the llff dataset only)", 125 | ) 126 | flags.DEFINE_bool( 127 | "lindisp", False, "sampling linearly in disparity rather than depth." 128 | ) 129 | flags.DEFINE_string( 130 | "net_activation", "ReLU", "activation function used within the MLP." 131 | ) 132 | flags.DEFINE_string( 133 | "rgb_activation", "Sigmoid", "activation function used to produce RGB." 134 | ) 135 | flags.DEFINE_string( 136 | "sigma_activation", "ReLU", "activation function used to produce density." 137 | ) 138 | flags.DEFINE_bool( 139 | "legacy_posenc_order", 140 | False, 141 | "If True, revert the positional encoding feature order to an older version of this codebase.", 142 | ) 143 | 144 | # Train Flags 145 | flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.") 146 | flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.") 147 | flags.DEFINE_integer( 148 | "lr_delay_steps", 149 | 0, 150 | "The number of steps at the beginning of " 151 | "training to reduce the learning rate by lr_delay_mult", 152 | ) 153 | flags.DEFINE_float( 154 | "lr_delay_mult", 155 | 1.0, 156 | "A multiplier on the learning rate when the step " "is < lr_delay_steps", 157 | ) 158 | flags.DEFINE_integer("max_steps", 1000000, "the number of optimization steps.") 159 | flags.DEFINE_integer( 160 | "save_every", 5000, "the number of steps to save a checkpoint." 161 | ) 162 | flags.DEFINE_integer( 163 | "print_every", 500, "the number of steps between reports to tensorboard." 164 | ) 165 | flags.DEFINE_integer( 166 | "render_every", 167 | 10000, 168 | "the number of steps to render a test image," 169 | "better to be x00 for accurate step time record.", 170 | ) 171 | flags.DEFINE_integer( 172 | "gc_every", 10000, "the number of steps to run python garbage collection." 173 | ) 174 | flags.DEFINE_float( 175 | "sparsity_weight", 176 | 1e-3, 177 | "Sparsity loss weight", 178 | ) 179 | flags.DEFINE_float( 180 | "sparsity_length", 181 | 0.05, 182 | "Sparsity loss 'length' for alpha calculation", 183 | ) 184 | flags.DEFINE_float( 185 | "sparsity_radius", 186 | 1.5, 187 | "Sparsity loss point sampling box 1/2 side length", 188 | ) 189 | flags.DEFINE_integer( 190 | "sparsity_npoints", 191 | 10000, 192 | "Number of samples for sparsity loss", 193 | ) 194 | 195 | # Eval Flags 196 | flags.DEFINE_bool( 197 | "eval_once", 198 | True, 199 | "evaluate the model only once if true, otherwise keeping evaluating new" 200 | "checkpoints if there's any.", 201 | ) 202 | flags.DEFINE_bool("save_output", True, "save predicted images to disk if True.") 203 | flags.DEFINE_integer( 204 | "chunk", 205 | 81920, 206 | "the size of chunks for evaluation inferences, set to the value that" 207 | "fits your GPU/TPU memory.", 208 | ) 209 | 210 | # Octree flags 211 | flags.DEFINE_float( 212 | 'renderer_step_size', 213 | 1e-4, 214 | 'step size epsilon in volume render.' 215 | '1e-3 = fast setting, 1e-4 = usual setting, 1e-5 = high setting, lower is better') 216 | flags.DEFINE_bool( 217 | 'no_early_stop', 218 | False, 219 | 'If set, does not use early stopping; slows down rendering slightly') 220 | 221 | 222 | def update_flags(args): 223 | """Update the flags in `args` with the contents of the config YAML file.""" 224 | if args.config is None: 225 | return 226 | pth = path.join(args.config + ".yaml") 227 | with open_file(pth, "r") as fin: 228 | configs = yaml.load(fin, Loader=yaml.FullLoader) 229 | # Only allow args to be updated if they already exist. 230 | invalid_args = list(set(configs.keys()) - set(dir(args))) 231 | if invalid_args: 232 | raise ValueError(f"Invalid args {invalid_args} in {pth}.") 233 | args.__dict__.update(configs) 234 | 235 | 236 | def check_flags(args, require_data=True, require_batch_size_div=False): 237 | if args.train_dir is None: 238 | raise ValueError("train_dir must be set. None set now.") 239 | if require_data and args.data_dir is None: 240 | raise ValueError("data_dir must be set. None set now.") 241 | if require_batch_size_div and args.batch_size % torch.cuda.device_count() != 0: 242 | raise ValueError("Batch size must be divisible by the number of devices.") 243 | 244 | 245 | def set_random_seed(seed): 246 | torch.manual_seed(seed) 247 | np.random.seed(seed) 248 | 249 | 250 | def open_file(pth, mode="r"): 251 | if not INTERNAL: 252 | pth = path.expanduser(pth) 253 | return open(pth, mode=mode) 254 | 255 | 256 | def file_exists(pth): 257 | if not INTERNAL: 258 | return path.exists(pth) 259 | 260 | 261 | def listdir(pth): 262 | if not INTERNAL: 263 | return os.listdir(pth) 264 | 265 | 266 | def isdir(pth): 267 | if not INTERNAL: 268 | return path.isdir(pth) 269 | 270 | 271 | def makedirs(pth): 272 | if not INTERNAL: 273 | os.makedirs(pth, exist_ok=True) 274 | 275 | 276 | @torch.no_grad() 277 | def eval_points(fn, points, chunk=720720, to_cpu=True): 278 | """Evaluate at given points (in test mode). 279 | Currently not supporting viewdirs. 280 | 281 | Args: 282 | fn: function 283 | points: torch.tensor [..., 3] 284 | chunk: int, the size of chunks to render sequentially. 285 | 286 | Returns: 287 | rgb: torch.tensor or np.array. 288 | sigmas: torch.tensor or np.array. 289 | """ 290 | num_points = points.shape[0] 291 | rgbs, sigmas = [], [] 292 | 293 | for i in tqdm(range(0, num_points, chunk)): 294 | chunk_points = points[i : i + chunk] 295 | rgb, sigma = fn(chunk_points, None) 296 | if to_cpu: 297 | rgb = rgb.detach().cpu().numpy() 298 | sigma = sigma.detach().cpu().numpy() 299 | rgbs.append(rgb) 300 | sigmas.append(sigma) 301 | if to_cpu: 302 | rgbs = np.concatenate(rgbs, axis=0) 303 | sigmas = np.concatenate(sigmas, axis=0) 304 | else: 305 | rgbs = torch.cat(rgbs, dim=0) 306 | sigmas = torch.cat(sigmas, dim=0) 307 | return rgbs, sigmas 308 | 309 | 310 | def compute_psnr(mse): 311 | """Compute psnr value given mse (we assume the maximum pixel value is 1). 312 | 313 | Args: 314 | mse: float, mean square error of pixels. 315 | 316 | Returns: 317 | psnr: float, the psnr value. 318 | """ 319 | return -10.0 * torch.log(mse) / np.log(10.0) 320 | 321 | 322 | def compute_ssim( 323 | img0, 324 | img1, 325 | max_val, 326 | filter_size=11, 327 | filter_sigma=1.5, 328 | k1=0.01, 329 | k2=0.03, 330 | return_map=False, 331 | ): 332 | """Computes SSIM from two images. 333 | 334 | This function was modeled after tf.image.ssim, and should produce comparable 335 | output. 336 | 337 | Args: 338 | img0: torch.tensor. An image of size [..., width, height, num_channels]. 339 | img1: torch.tensor. An image of size [..., width, height, num_channels]. 340 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 341 | filter_size: int >= 1. Window size. 342 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 343 | k1: float > 0. One of the SSIM dampening parameters. 344 | k2: float > 0. One of the SSIM dampening parameters. 345 | return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned 346 | 347 | Returns: 348 | Each image's mean SSIM, or a tensor of individual values if `return_map`. 349 | """ 350 | device = img0.device 351 | ori_shape = img0.size() 352 | width, height, num_channels = ori_shape[-3:] 353 | img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 354 | img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 355 | batch_size = img0.shape[0] 356 | 357 | # Construct a 1D Gaussian blur filter. 358 | hw = filter_size // 2 359 | shift = (2 * hw - filter_size + 1) / 2 360 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 361 | filt = torch.exp(-0.5 * f_i) 362 | filt /= torch.sum(filt) 363 | 364 | # Blur in x and y (faster than the 2D convolution). 365 | # z is a tensor of size [B, H, W, C] 366 | filt_fn1 = lambda z: F.conv2d( 367 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 368 | padding=[hw, 0], groups=num_channels) 369 | filt_fn2 = lambda z: F.conv2d( 370 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 371 | padding=[0, hw], groups=num_channels) 372 | 373 | # Vmap the blurs to the tensor size, and then compose them. 374 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 375 | mu0 = filt_fn(img0) 376 | mu1 = filt_fn(img1) 377 | mu00 = mu0 * mu0 378 | mu11 = mu1 * mu1 379 | mu01 = mu0 * mu1 380 | sigma00 = filt_fn(img0 ** 2) - mu00 381 | sigma11 = filt_fn(img1 ** 2) - mu11 382 | sigma01 = filt_fn(img0 * img1) - mu01 383 | 384 | # Clip the variances and covariances to valid values. 385 | # Variance must be non-negative: 386 | sigma00 = torch.clamp(sigma00, min=0.0) 387 | sigma11 = torch.clamp(sigma11, min=0.0) 388 | sigma01 = torch.sign(sigma01) * torch.min( 389 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 390 | ) 391 | 392 | c1 = (k1 * max_val) ** 2 393 | c2 = (k2 * max_val) ** 2 394 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 395 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 396 | ssim_map = numer / denom 397 | ssim = torch.mean(ssim_map.reshape([-1, num_channels*width*height]), dim=-1) 398 | return ssim_map if return_map else ssim 399 | 400 | 401 | def generate_rays(w, h, focal, camtoworlds, equirect=False): 402 | """ 403 | Generate perspective camera rays. Principal point is at center. 404 | Args: 405 | w: int image width 406 | h: int image heigth 407 | focal: float real focal length 408 | camtoworlds: jnp.ndarray [B, 4, 4] c2w homogeneous poses 409 | equirect: if true, generates spherical rays instead of pinhole 410 | Returns: 411 | rays: Rays a namedtuple(origins [B, 3], directions [B, 3], viewdirs [B, 3]) 412 | """ 413 | x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking 414 | np.arange(w, dtype=np.float32), # X-Axis (columns) 415 | np.arange(h, dtype=np.float32), # Y-Axis (rows) 416 | indexing="xy", 417 | ) 418 | 419 | if equirect: 420 | uv = np.stack([x * (2.0 / w) - 1.0, y * (2.0 / h) - 1.0], axis=-1) 421 | camera_dirs = equirect2xyz(uv) 422 | else: 423 | camera_dirs = np.stack( 424 | [ 425 | (x - w * 0.5) / focal, 426 | -(y - h * 0.5) / focal, 427 | -np.ones_like(x), 428 | ], 429 | axis=-1, 430 | ) 431 | 432 | # camera_dirs = camera_dirs / np.linalg.norm(camera_dirs, axis=-1, keepdims=True) 433 | 434 | c2w = camtoworlds[:, None, None, :3, :3] 435 | camera_dirs = camera_dirs[None, Ellipsis, None] 436 | directions = np.matmul(c2w, camera_dirs)[Ellipsis, 0] 437 | origins = np.broadcast_to( 438 | camtoworlds[:, None, None, :3, -1], directions.shape 439 | ) 440 | norms = np.linalg.norm(directions, axis=-1, keepdims=True) 441 | viewdirs = directions / norms 442 | rays = Rays( 443 | origins=origins, directions=directions, viewdirs=viewdirs 444 | ) 445 | return rays 446 | 447 | 448 | def eval_octree(t, dataset, args, want_lpips=True, want_frames=False): 449 | import svox 450 | w, h, focal = dataset.w, dataset.h, dataset.focal 451 | if 'llff' in args.config and (not args.spherify): 452 | ndc_config = svox.NDCConfig(width=w, height=h, focal=focal) 453 | else: 454 | ndc_config = None 455 | 456 | r = svox.VolumeRenderer( 457 | t, step_size=args.renderer_step_size, ndc=ndc_config) 458 | 459 | print('Evaluating octree') 460 | device = t.data.device 461 | if want_lpips: 462 | import lpips 463 | lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device) 464 | 465 | avg_psnr = 0.0 466 | avg_ssim = 0.0 467 | avg_lpips = 0.0 468 | out_frames = [] 469 | for idx in tqdm(range(dataset.size)): 470 | c2w = torch.from_numpy(dataset.camtoworlds[idx]).float().to(device) 471 | im_gt_ten = torch.from_numpy(dataset.images[idx]).float().to(device) 472 | 473 | im = r.render_persp( 474 | c2w, width=w, height=h, fx=focal, fast=not args.no_early_stop) 475 | im.clamp_(0.0, 1.0) 476 | 477 | mse = ((im - im_gt_ten) ** 2).mean() 478 | psnr = compute_psnr(mse).mean() 479 | ssim = compute_ssim(im, im_gt_ten, max_val=1.0).mean() 480 | 481 | avg_psnr += psnr.item() 482 | avg_ssim += ssim.item() 483 | if want_lpips: 484 | lpips_i = lpips_vgg(im_gt_ten.permute([2, 0, 1]).contiguous(), 485 | im.permute([2, 0, 1]).contiguous(), normalize=True) 486 | avg_lpips += lpips_i.item() 487 | 488 | if want_frames: 489 | im = im.cpu() 490 | # vis = np.hstack((im_gt_ten.cpu().numpy(), im.cpu().numpy())) 491 | vis = im.cpu().numpy() # for lpips calculation 492 | vis = (vis * 255).astype(np.uint8) 493 | out_frames.append(vis) 494 | 495 | avg_psnr /= dataset.size 496 | avg_ssim /= dataset.size 497 | avg_lpips /= dataset.size 498 | return avg_psnr, avg_ssim, avg_lpips, out_frames 499 | 500 | 501 | def memlog(device='cuda'): 502 | # Memory debugging 503 | print(torch.cuda.memory_summary(device)) 504 | import gc 505 | for obj in gc.get_objects(): 506 | try: 507 | if torch.is_tensor(obj) or ( 508 | hasattr(obj, 'data') and torch.is_tensor(obj.data)): 509 | if str(obj.device) != 'cpu': 510 | print(obj.device, '{: 10}'.format(obj.numel()), 511 | obj.dtype, 512 | obj.size(), type(obj)) 513 | except: 514 | pass 515 | -------------------------------------------------------------------------------- /octree/optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """Optimize a plenoctree through finetuning on train set. 24 | 25 | Usage: 26 | 27 | export DATA_ROOT=./data/NeRF/nerf_synthetic/ 28 | export CKPT_ROOT=./data/PlenOctree/checkpoints/syn_sh16 29 | export SCENE=chair 30 | export CONFIG_FILE=nerf_sh/config/blender 31 | 32 | python -m octree.optimization \ 33 | --input $CKPT_ROOT/$SCENE/tree.npz \ 34 | --config $CONFIG_FILE \ 35 | --data_dir $DATA_ROOT/$SCENE/ \ 36 | --output $CKPT_ROOT/$SCENE/octrees/tree_opt.npz 37 | """ 38 | import svox 39 | import torch 40 | import torch.cuda 41 | import numpy as np 42 | import json 43 | import imageio 44 | import os.path as osp 45 | import os 46 | from argparse import ArgumentParser 47 | from tqdm import tqdm 48 | from torch.optim import SGD, Adam 49 | from warnings import warn 50 | 51 | from absl import app 52 | from absl import flags 53 | 54 | from octree.nerf import datasets 55 | from octree.nerf import utils 56 | 57 | FLAGS = flags.FLAGS 58 | 59 | utils.define_flags() 60 | 61 | flags.DEFINE_string( 62 | "input", 63 | "./tree.npz", 64 | "Input octree npz from extraction.py", 65 | ) 66 | flags.DEFINE_string( 67 | "output", 68 | "./tree_opt.npz", 69 | "Output octree npz", 70 | ) 71 | flags.DEFINE_integer( 72 | 'render_interval', 73 | 0, 74 | 'render interval') 75 | flags.DEFINE_integer( 76 | 'val_interval', 77 | 2, 78 | 'validation interval') 79 | flags.DEFINE_integer( 80 | 'num_epochs', 81 | 80, 82 | 'epochs to train for') 83 | flags.DEFINE_bool( 84 | 'sgd', 85 | True, 86 | 'use SGD optimizer instead of Adam') 87 | flags.DEFINE_float( 88 | 'lr', 89 | 1e7, 90 | 'optimizer step size') 91 | flags.DEFINE_float( 92 | 'sgd_momentum', 93 | 0.0, 94 | 'sgd momentum') 95 | flags.DEFINE_bool( 96 | 'sgd_nesterov', 97 | False, 98 | 'sgd nesterov momentum?') 99 | flags.DEFINE_string( 100 | "write_vid", 101 | None, 102 | "If specified, writes rendered video to given path (*.mp4)", 103 | ) 104 | 105 | # Manual 'val' set 106 | flags.DEFINE_bool( 107 | "split_train", 108 | None, 109 | "If specified, splits train set instead of loading val set", 110 | ) 111 | flags.DEFINE_float( 112 | "split_holdout_prop", 113 | 0.2, 114 | "Proportion of images to hold out if split_train is set", 115 | ) 116 | 117 | # Do not save since it is slow 118 | flags.DEFINE_bool( 119 | "nosave", 120 | False, 121 | "If set, does not save (for speed)", 122 | ) 123 | 124 | flags.DEFINE_bool( 125 | "continue_on_decrease", 126 | False, 127 | "If set, continues training even if validation PSNR decreases", 128 | ) 129 | 130 | device = "cuda" if torch.cuda.is_available() else "cpu" 131 | torch.autograd.set_detect_anomaly(True) 132 | 133 | 134 | def main(unused_argv): 135 | utils.set_random_seed(20200823) 136 | utils.update_flags(FLAGS) 137 | 138 | def get_data(stage): 139 | assert stage in ["train", "val", "test"] 140 | dataset = datasets.get_dataset(stage, FLAGS) 141 | focal = dataset.focal 142 | all_c2w = dataset.camtoworlds 143 | all_gt = dataset.images.reshape(-1, dataset.h, dataset.w, 3) 144 | all_c2w = torch.from_numpy(all_c2w).float().to(device) 145 | all_gt = torch.from_numpy(all_gt).float() 146 | return focal, all_c2w, all_gt 147 | 148 | focal, train_c2w, train_gt = get_data("train") 149 | if FLAGS.split_train: 150 | test_sz = int(train_c2w.size(0) * FLAGS.split_holdout_prop) 151 | print('Splitting train to train/val manually, holdout', test_sz) 152 | perm = torch.randperm(train_c2w.size(0)) 153 | test_c2w = train_c2w[perm[:test_sz]] 154 | test_gt = train_gt[perm[:test_sz]] 155 | train_c2w = train_c2w[perm[test_sz:]] 156 | train_gt = train_gt[perm[test_sz:]] 157 | else: 158 | print('Using given val set') 159 | test_focal, test_c2w, test_gt = get_data("val") 160 | assert focal == test_focal 161 | H, W = train_gt[0].shape[:2] 162 | 163 | vis_dir = osp.splitext(FLAGS.input)[0] + '_render' 164 | os.makedirs(vis_dir, exist_ok=True) 165 | 166 | print('N3Tree load') 167 | t = svox.N3Tree.load(FLAGS.input, map_location=device) 168 | # t.nan_to_num_() 169 | 170 | if 'llff' in FLAGS.config: 171 | ndc_config = svox.NDCConfig(width=W, height=H, focal=focal) 172 | else: 173 | ndc_config = None 174 | r = svox.VolumeRenderer(t, step_size=FLAGS.renderer_step_size, ndc=ndc_config) 175 | 176 | if FLAGS.sgd: 177 | print('Using SGD, lr', FLAGS.lr) 178 | if FLAGS.lr < 1.0: 179 | warn('For SGD please adjust LR to about 1e7') 180 | optimizer = SGD(t.parameters(), lr=FLAGS.lr, momentum=FLAGS.sgd_momentum, 181 | nesterov=FLAGS.sgd_nesterov) 182 | else: 183 | adam_eps = 1e-4 if t.data.dtype is torch.float16 else 1e-8 184 | print('Using Adam, eps', adam_eps, 'lr', FLAGS.lr) 185 | optimizer = Adam(t.parameters(), lr=FLAGS.lr, eps=adam_eps) 186 | 187 | n_train_imgs = len(train_c2w) 188 | n_test_imgs = len(test_c2w) 189 | 190 | def run_test_step(i): 191 | print('Evaluating') 192 | with torch.no_grad(): 193 | tpsnr = 0.0 194 | for j, (c2w, im_gt) in enumerate(zip(test_c2w, test_gt)): 195 | im = r.render_persp(c2w, height=H, width=W, fx=focal, fast=False) 196 | im = im.cpu().clamp_(0.0, 1.0) 197 | 198 | mse = ((im - im_gt) ** 2).mean() 199 | psnr = -10.0 * np.log(mse) / np.log(10.0) 200 | tpsnr += psnr.item() 201 | 202 | if FLAGS.render_interval > 0 and j % FLAGS.render_interval == 0: 203 | vis = torch.cat((im_gt, im), dim=1) 204 | vis = (vis * 255).numpy().astype(np.uint8) 205 | imageio.imwrite(f"{vis_dir}/{i:04}_{j:04}.png", vis) 206 | tpsnr /= n_test_imgs 207 | return tpsnr 208 | 209 | best_validation_psnr = run_test_step(0) 210 | print('** initial val psnr ', best_validation_psnr) 211 | best_t = None 212 | for i in range(FLAGS.num_epochs): 213 | print('epoch', i) 214 | tpsnr = 0.0 215 | for j, (c2w, im_gt) in tqdm(enumerate(zip(train_c2w, train_gt)), total=n_train_imgs): 216 | im = r.render_persp(c2w, height=H, width=W, fx=focal, cuda=True) 217 | im_gt_ten = im_gt.to(device=device) 218 | im = torch.clamp(im, 0.0, 1.0) 219 | mse = ((im - im_gt_ten) ** 2).mean() 220 | im_gt_ten = None 221 | 222 | optimizer.zero_grad() 223 | t.data.grad = None # This helps save memory weirdly enough 224 | mse.backward() 225 | # print('mse', mse, t.data.grad.min(), t.data.grad.max()) 226 | optimizer.step() 227 | # t.data.data -= eta * t.data.grad 228 | psnr = -10.0 * np.log(mse.detach().cpu()) / np.log(10.0) 229 | tpsnr += psnr.item() 230 | tpsnr /= n_train_imgs 231 | print('** train_psnr', tpsnr) 232 | 233 | if i % FLAGS.val_interval == FLAGS.val_interval - 1 or i == FLAGS.num_epochs - 1: 234 | validation_psnr = run_test_step(i + 1) 235 | print('** val psnr ', validation_psnr, 'best', best_validation_psnr) 236 | if validation_psnr > best_validation_psnr: 237 | best_validation_psnr = validation_psnr 238 | best_t = t.clone(device='cpu') # SVOX 0.2.22 239 | print('') 240 | elif not FLAGS.continue_on_decrease: 241 | print('Stop since overfitting') 242 | break 243 | if not FLAGS.nosave: 244 | if best_t is not None: 245 | print('Saving best model to', FLAGS.output) 246 | best_t.save(FLAGS.output, compress=False) 247 | else: 248 | print('Did not improve upon initial model') 249 | 250 | if __name__ == "__main__": 251 | app.run(main) 252 | -------------------------------------------------------------------------------- /octree/task_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | """ 24 | Multi GPU parallel octree conversion pipeline for running hyper search. 25 | Make a file tasks.json describing tasks e.g. 26 | { 27 | "data_root": "/home/sxyu/data", 28 | "train_root": "/home/sxyu/proj/jaxnerf/jaxnerf/train/SH16", 29 | "tasks": [{ 30 | "octree_name": "oct_chair_bb1_2", 31 | "train_dir": "chair", 32 | "data_dir": "nerf_synthetic/chair", 33 | "config": "sh", 34 | "extr_flags": ["--bbox_from_data", "--bbox_scale", "1.2"], 35 | "opt_flags": [], 36 | "eval_flags": [] 37 | }, 38 | ...] 39 | } 40 | 41 | Then, 42 | python dispatch.py tasks.json --gpus='space delimited list of gpus to use' 43 | 44 | For each task, final octree is saved to 45 | //octrees//tree.npz 46 | If you specify --keep_raw, the above is raw tree and the optimized tree is saved to 47 | //octrees//tree_opt.npz 48 | 49 | Capacity, raw eval PSNR/SSIM/LPIPS, optimized eval PSNR/SSIM/LPIPS are saved to 50 | //octrees//results.txt 51 | """ 52 | import argparse 53 | import sys 54 | import os 55 | import os.path as osp 56 | import subprocess 57 | import concurrent.futures 58 | import json 59 | from multiprocessing import Process, Queue 60 | 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("task_json", type=str) 63 | parser.add_argument("--gpus", type=str, required=True, 64 | help="space delimited GPU id list (pre CUDA_VISIBLE_DEVICES)") 65 | parser.add_argument("--keep_raw", action='store_true', 66 | help="do not overwrite raw octree (takes extra disk space)") 67 | args = parser.parse_args() 68 | 69 | def convert_one(env, train_dir, data_dir, config, octree_name, 70 | extr_flags, opt_flags=[], eval_flags=[]): 71 | octree_store_dir = osp.join(train_dir, 'octrees', octree_name) 72 | octree_file = osp.join(octree_store_dir, "tree.npz") 73 | octree_opt_file = osp.join(octree_store_dir, 74 | "tree_opt.npz") if args.keep_raw else octree_file 75 | config_name = f"{config}" 76 | os.makedirs(octree_store_dir, exist_ok=True) 77 | extr_base_cmd = [ 78 | "python", "-u", "-m", "octree.extraction", 79 | "--train_dir", train_dir, 80 | "--config", config_name, "--is_jaxnerf_ckpt", 81 | "--output ", octree_file, 82 | "--data_dir", data_dir 83 | ] 84 | opt_base_cmd = [ 85 | "python", "-u", "-m", "octree.optimization", 86 | "--config", config_name, "--input", octree_file, 87 | "--output", octree_opt_file, 88 | "--data_dir", data_dir 89 | ] 90 | eval_base_cmd = [ 91 | "python", "-u", "-m", "octree.evaluation", 92 | "--config", config_name, "--input ", octree_opt_file, 93 | "--data_dir", data_dir 94 | ] 95 | out_file_path = osp.join(octree_store_dir, 'results.txt') 96 | 97 | with open(out_file_path, 'w') as out_file: 98 | print('********************************************') 99 | print('! Extract', train_dir, octree_name) 100 | extr_cmd = ' '.join(extr_base_cmd + extr_flags) 101 | print(extr_cmd) 102 | extr_ret = subprocess.check_output(extr_cmd, shell=True, env=env).decode( 103 | sys.stdout.encoding) 104 | with open('pextract.txt', 'w') as f: 105 | f.write(extr_ret) 106 | 107 | extr_ret = extr_ret.split('\n') 108 | svox_str = extr_ret[-9] 109 | capacity = int(svox_str.split()[3].split(':')[1].split('/')[0]) 110 | 111 | parse_metrics = lambda x: map(float, x.split()[2::2]) 112 | psnr, ssim, lpips = parse_metrics(extr_ret[-2]) 113 | print(': ', octree_name, 'RAW capacity', 114 | capacity, 'PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips) 115 | out_file.write(f'{capacity}\n{psnr:.10f} {ssim:.10f} {lpips:.10f}\n') 116 | 117 | print('! Optimize', train_dir, octree_name) 118 | opt_cmd = ' '.join(opt_base_cmd + opt_flags) 119 | print(opt_cmd) 120 | subprocess.call(opt_cmd, shell=True, env=env) 121 | 122 | if osp.exists(octree_opt_file): 123 | print('! Eval', train_dir, octree_name) 124 | eval_cmd = ' '.join(eval_base_cmd + eval_flags) 125 | print(eval_cmd) 126 | eval_ret = subprocess.check_output(eval_cmd, shell=True, env=env).decode( 127 | sys.stdout.encoding) 128 | eval_ret = eval_ret.split('\n') 129 | 130 | epsnr, essim, elpips = parse_metrics(eval_ret[-2]) 131 | print(':', octree_name, 'OPT capacity', 132 | capacity, 'PSNR', epsnr, 'SSIM', essim, 'LPIPS', elpips) 133 | out_file.write(f'{epsnr:.10f} {essim:.10f} {elpips:.10f}\n') 134 | else: 135 | print('! Eval skipped') 136 | out_file.write(f'{psnr:.10f} {ssim:.10f} {lpips:.10f}\n') 137 | 138 | 139 | 140 | def process_main(device, queue): 141 | # Set CUDA_VISIBLE_DEVICES programmatically 142 | env = os.environ.copy() 143 | env["CUDA_VISIBLE_DEVICES"] = str(device) 144 | while True: 145 | task = queue.get() 146 | if len(task) == 0: 147 | break 148 | convert_one(env, **task) 149 | 150 | if __name__=='__main__': 151 | with open(args.task_json, 'r') as f: 152 | tasks_file = json.load(f) 153 | all_tasks = tasks_file.get('tasks', []) 154 | data_root = tasks_file['data_root'] 155 | train_root = tasks_file['train_root'] 156 | pqueue = Queue() 157 | # Scene_tasks generated per scene (use {%} to mean scene name) 158 | if 'scene_tasks' in tasks_file: 159 | symb = '{%}' 160 | scenes = tasks_file['scenes'] 161 | for scene_task in tasks_file['scene_tasks']: 162 | for scene in scenes: 163 | task = scene_task.copy() 164 | task['data_dir'] = scene_task['data_dir'].replace(symb, scene) 165 | task['train_dir'] = scene_task['train_dir'].replace(symb, scene) 166 | task['octree_name'] = scene_task['octree_name'].replace(symb, scene) 167 | all_tasks.append(task) 168 | 169 | print(len(all_tasks), 'total tasks') 170 | 171 | for task in all_tasks: 172 | task['train_dir'] = osp.join(train_root, task['train_dir']) 173 | task['data_dir'] = osp.join(data_root, task['data_dir']) 174 | octrees_dir = osp.join(task['data_dir'], 'octrees') 175 | os.makedirs(octrees_dir, exist_ok=True) 176 | # santity check 177 | assert os.path.exists(task['train_dir']), task['train_dir'] 178 | assert os.path.exists(task['data_dir']), task['data_dir'] 179 | 180 | for task in all_tasks: 181 | pqueue.put(task) 182 | pqueue.put({}) 183 | 184 | args.gpus = list(map(int, args.gpus.split())) 185 | print('GPUS:', args.gpus) 186 | 187 | all_procs = [] 188 | for i, gpu in enumerate(args.gpus): 189 | process = Process(target=process_main, args=(gpu, pqueue)) 190 | process.daemon = True 191 | process.start() 192 | all_procs.append(process) 193 | 194 | for i, gpu in enumerate(args.gpus): 195 | all_procs[i].join() 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | imageio 3 | imageio-ffmpeg 4 | ipdb 5 | lpips 6 | jax 7 | jaxlib 8 | flax 9 | opencv-python 10 | Pillow 11 | pyyaml 12 | tensorflow==2.3.1 13 | pymcubes 14 | svox>=0.2.26 15 | --------------------------------------------------------------------------------