├── .gitignore
├── LICENSE.txt
├── README.md
├── environment.yml
├── nerf_sh
├── __init__.py
├── config
│ ├── blender.yaml
│ ├── misc
│ │ ├── og_nerf.yaml
│ │ ├── proj.yaml
│ │ └── sg.yaml
│ └── tt.yaml
├── eval.py
├── gen_mesh.py
├── gen_video.py
├── nerf
│ ├── __init__.py
│ ├── datasets.py
│ ├── model_utils.py
│ ├── models.py
│ ├── sg.py
│ ├── sh.py
│ └── utils.py
├── parse_timing.py
└── train.py
├── octree
├── compression.py
├── config
│ ├── syn_sg25.json
│ ├── syn_sh16.json
│ └── tt_sh25.json
├── evaluation.py
├── extraction.py
├── nerf
│ ├── __init__.py
│ ├── datasets.py
│ ├── model_utils.py
│ ├── models.py
│ ├── sh_proj.py
│ └── utils.py
├── optimization.py
└── task_manager.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | data
3 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright The PlenOctree Authors 2021
2 |
3 | Redistribution and use in source and binary forms, with or without
4 | modification, are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice,
7 | this list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation
11 | and/or other materials provided with the distribution.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
23 | POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PlenOctrees Official Repo: NeRF-SH training and conversion
2 |
3 | This repository contains code to train NeRF-SH and
4 | to extract the PlenOctree, constituting part of the code release for:
5 |
6 | PlenOctrees for Real Time Rendering of Neural Radiance Fields
7 | Alex Yu, Ruilong Li, Matthew Tancik, Hao Li, Ren Ng, Angjoo Kanazawa
8 |
9 | https://alexyu.net/plenoctrees
10 |
11 | ```
12 | @inproceedings{yu2021plenoctrees,
13 | title={{PlenOctrees} for Real-time Rendering of Neural Radiance Fields},
14 | author={Alex Yu and Ruilong Li and Matthew Tancik and Hao Li and Ren Ng and Angjoo Kanazawa},
15 | year={2021},
16 | booktitle={ICCV},
17 | }
18 | ```
19 |
20 | Please see the following repository for our C++ PlenOctrees volume renderer:
21 |
22 |
23 | ## Setup
24 |
25 | Please use conda for a replicable environment.
26 | ```
27 | conda env create -f environment.yml
28 | conda activate plenoctree
29 | pip install --upgrade pip
30 | ```
31 |
32 | Or you can install the dependencies manually by:
33 | ```
34 | conda install pytorch torchvision cudatoolkit=11.0 -c pytorch
35 | conda install tqdm
36 | pip install -r requirements.txt
37 | ```
38 |
39 | [Optional] Install GPU and TPU support for Jax. This is useful for NeRF-SH training.
40 | Remember to **change cuda110 to your CUDA version**, e.g. cuda102 for CUDA 10.2.
41 | ```
42 | pip install --upgrade jax jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
43 | ```
44 |
45 | ## NeRF-SH Training
46 |
47 | We release our trained NeRF-SH models as well as converted plenoctrees at
48 | [Google Drive](https://drive.google.com/drive/folders/1J0lRiDn_wOiLVpCraf6jM7vvCwDr9Dmx?usp=sharing).
49 | You can also use the following commands to reproduce the NeRF-SH models.
50 |
51 | Training and evaluation on the **NeRF-Synthetic dataset** ([Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)):
52 | ```
53 | export DATA_ROOT=./data/NeRF/nerf_synthetic/
54 | export CKPT_ROOT=./data/Plenoctree/checkpoints/syn_sh16/
55 | export SCENE=chair
56 | export CONFIG_FILE=nerf_sh/config/blender
57 |
58 | python -m nerf_sh.train \
59 | --train_dir $CKPT_ROOT/$SCENE/ \
60 | --config $CONFIG_FILE \
61 | --data_dir $DATA_ROOT/$SCENE/
62 |
63 | python -m nerf_sh.eval \
64 | --chunk 4096 \
65 | --train_dir $CKPT_ROOT/$SCENE/ \
66 | --config $CONFIG_FILE \
67 | --data_dir $DATA_ROOT/$SCENE/
68 | ```
69 | Note for `SCENE=mic`, we adopt a warmup learning rate schedule (`--lr_delay_steps 50000 --lr_delay_mult 0.01`) to avoid unstable initialization.
70 |
71 |
72 | Training and evaluation on **TanksAndTemple dataset**
73 | ([Download Link](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)) from the [NSVF](https://github.com/facebookresearch/NSVF) paper:
74 | ```
75 | export DATA_ROOT=./data/TanksAndTemple/
76 | export CKPT_ROOT=./data/Plenoctree/checkpoints/tt_sh25/
77 | export SCENE=Barn
78 | export CONFIG_FILE=nerf_sh/config/tt
79 |
80 | python -m nerf_sh.train \
81 | --train_dir $CKPT_ROOT/$SCENE/ \
82 | --config $CONFIG_FILE \
83 | --data_dir $DATA_ROOT/$SCENE/
84 |
85 | python -m nerf_sh.eval \
86 | --chunk 4096 \
87 | --train_dir $CKPT_ROOT/$SCENE/ \
88 | --config $CONFIG_FILE \
89 | --data_dir $DATA_ROOT/$SCENE/
90 | ```
91 |
92 | ## PlenOctrees Conversion and Optimization
93 |
94 | Before converting the NeRF-SH models into plenoctrees, you should already have the
95 | NeRF-SH models trained/downloaded and placed at `./data/Plenoctree/checkpoints/{syn_sh16, tt_sh25}/`.
96 | Also make sure you have the training data placed at
97 | `./data/NeRF/nerf_synthetic` and/or `./data/TanksAndTemple`.
98 |
99 | To reproduce our results in the paper, you can simplly run:
100 | ```
101 | # NeRF-Synthetic dataset
102 | python -m octree.task_manager octree/config/syn_sh16.json --gpus="0 1 2 3"
103 |
104 | # TanksAndTemple dataset
105 | python -m octree.task_manager octree/config/tt_sh25.json --gpus="0 1 2 3"
106 | ```
107 | The above command will parallel all scenes in the dataset across the gpus you set. The json files
108 | contain dedicated hyper-parameters towards better performance (PSNR, SSIM, LPIPS). So in this setting, a 24GB GPU is
109 | needed for each scene and in averange the process takes about 15 minutes to finish. The converted plenoctree
110 | will be saved to `./data/Plenoctree/checkpoints/{syn_sh16, tt_sh25}/$SCENE/octrees/`.
111 |
112 |
113 | Below is a more straight-forward script for demonstration purpose:
114 | ```
115 | export DATA_ROOT=./data/NeRF/nerf_synthetic/
116 | export CKPT_ROOT=./data/Plenoctree/checkpoints/syn_sh16
117 | export SCENE=chair
118 | export CONFIG_FILE=nerf_sh/config/blender
119 |
120 | python -m octree.extraction \
121 | --train_dir $CKPT_ROOT/$SCENE/ --is_jaxnerf_ckpt \
122 | --config $CONFIG_FILE \
123 | --data_dir $DATA_ROOT/$SCENE/ \
124 | --output $CKPT_ROOT/$SCENE/octrees/tree.npz
125 |
126 | python -m octree.optimization \
127 | --input $CKPT_ROOT/$SCENE/tree.npz \
128 | --config $CONFIG_FILE \
129 | --data_dir $DATA_ROOT/$SCENE/ \
130 | --output $CKPT_ROOT/$SCENE/octrees/tree_opt.npz
131 |
132 | python -m octree.evaluation \
133 | --input $CKPT_ROOT/$SCENE/octrees/tree_opt.npz \
134 | --config $CONFIG_FILE \
135 | --data_dir $DATA_ROOT/$SCENE/
136 |
137 | # [Optional] Only used for in-browser viewing.
138 | python -m octree.compression \
139 | $CKPT_ROOT/$SCENE/octrees/tree_opt.npz \
140 | --out_dir $CKPT_ROOT/$SCENE/ \
141 | --overwrite
142 | ```
143 |
144 | ## MISC
145 |
146 | ### Project Vanilla NeRF to PlenOctree
147 |
148 | A vanilla trained NeRF can also be converted to a plenoctree for fast inference. To mimic the
149 | view-independency propertity as in a NeRF-SH model, we project the vanilla NeRF model to SH basis functions
150 | by sampling view directions for every points in the space. Though this makes converting vanilla NeRF to
151 | a plenoctree possible, the projection process inevitability loses the quality of the model, even with a large amount
152 | of sampling view directions (which takes hours to finish). So we recommend to just directly train a NeRF-SH model end-to-end.
153 |
154 | Below is a example of projecting a trained vanilla NeRF model from
155 | [JaxNeRF repo](https://github.com/google-research/google-research/tree/master/jaxnerf)
156 | ([Download Link](http://storage.googleapis.com/gresearch/jaxnerf/jaxnerf_pretrained_models.zip)) to a plenoctree.
157 | After extraction, you can optimize & evaluate & compress the plenoctree just like usual:
158 | ```
159 | export DATA_ROOT=./data/NeRF/nerf_synthetic/
160 | export CKPT_ROOT=./data/JaxNeRF/jaxnerf_models/blender/
161 | export SCENE=drums
162 | export CONFIG_FILE=nerf_sh/config/misc/proj
163 |
164 | python -m octree.extraction \
165 | --train_dir $CKPT_ROOT/$SCENE/ --is_jaxnerf_ckpt \
166 | --config $CONFIG_FILE \
167 | --data_dir $DATA_ROOT/$SCENE/ \
168 | --output $CKPT_ROOT/$SCENE/octrees/tree.npz \
169 | --projection_samples 100 \
170 | --radius 1.3
171 | ```
172 | Note `--projection_samples` controls how many sampling view directions are used. More sampling view directions give better
173 | projection quality but takes longer time to finish. For example, for the `drums` scene
174 | in the NeRF-Synthetic dataset, `100 / 10000` sampling view directions takes about `2 mins / 2 hours` to finish the plenoctree extraction.
175 | It produce *raw* plenoctrees with `PSNR=22.49 / 23.84` (before optimization). Note that extraction from a NeRF-SH model produce
176 | a *raw* plenoctree with `PSNR=25.01`.
177 |
178 | ### List of possible improvements
179 |
180 | In the interst reproducibility, the parameters used in the paper are also used here.
181 | For future work we recommend trying the changes in mip-NeRF
182 | for improved stability and quality:
183 |
184 | - Centered pixels (+ 0.5 on x, y) when generating rays
185 | - Use shifted SoftPlus instead of ReLU for density (including for octree optimization)
186 | - Pad the RGB sigmoid output (avoid low gradient region near 0/1 color)
187 | - Multi-scale training from mip-NeRF
188 |
189 |
190 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # run: conda env create -f environment.yml
2 | name: plenoctree
3 | channels:
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - python=3.8.8
8 | - numpy>=1.16.4,<1.19.0
9 | - pip
10 | - pip:
11 | - dotmap
12 | - imageio
13 | - imageio-ffmpeg
14 | - ipdb
15 | - pretrainedmodels
16 | - lpips
17 | - jax==0.2.9
18 | - jaxlib>=0.1.57
19 | - flax>=0.3.1
20 | - opencv-python>=4.4.0
21 | - Pillow>=7.2.0
22 | - pyyaml>=5.3.1
23 | - tensorboard>=2.4.0
24 | - tensorflow>=2.3.1
25 | - imageio
26 | - imageio-ffmpeg
27 | - pymcubes
28 | - svox>=0.2.28
29 | - scipy>=1.6.0
30 | - torch>=1.7.0,<=1.7.1
31 | - tqdm
32 |
33 |
--------------------------------------------------------------------------------
/nerf_sh/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/nerf_sh/config/blender.yaml:
--------------------------------------------------------------------------------
1 | dataset: blender
2 | image_batching: false
3 | factor: 0
4 | num_coarse_samples: 64
5 | num_fine_samples: 128
6 | use_viewdirs: false
7 | white_bkgd: true
8 | batch_size: 1024
9 | sh_deg: 3
10 | randomized: true
11 | max_steps: 2000000
12 |
--------------------------------------------------------------------------------
/nerf_sh/config/misc/og_nerf.yaml:
--------------------------------------------------------------------------------
1 | dataset: blender
2 | image_batching: false
3 | factor: 0
4 | num_coarse_samples: 64
5 | num_fine_samples: 128
6 | use_viewdirs: true
7 | white_bkgd: true
8 | batch_size: 1024
9 | randomized: true
10 | sparsity_weight: 0.0
11 | max_steps: 2000000
12 |
--------------------------------------------------------------------------------
/nerf_sh/config/misc/proj.yaml:
--------------------------------------------------------------------------------
1 | dataset: blender
2 | image_batching: false
3 | factor: 0
4 | num_coarse_samples: 64
5 | num_fine_samples: 128
6 | use_viewdirs: true
7 | white_bkgd: true
8 | batch_size: 1024
9 | sh_deg: 4
10 | randomized: true
11 | max_steps: 2000000
12 |
--------------------------------------------------------------------------------
/nerf_sh/config/misc/sg.yaml:
--------------------------------------------------------------------------------
1 | dataset: blender
2 | image_batching: false
3 | factor: 0
4 | num_coarse_samples: 64
5 | num_fine_samples: 128
6 | use_viewdirs: false
7 | white_bkgd: true
8 | batch_size: 1024
9 | sg_dim: 25
10 | randomized: true
11 | max_steps: 2000000
12 |
--------------------------------------------------------------------------------
/nerf_sh/config/tt.yaml:
--------------------------------------------------------------------------------
1 | # Tanks and Temples dataset
2 | dataset: nsvf
3 | image_batching: false
4 | factor: 0
5 | num_coarse_samples: 64
6 | num_fine_samples: 128
7 | use_viewdirs: false
8 | white_bkgd: true # No alpha channel
9 | batch_size: 1024
10 | randomized: true
11 | sh_deg: 4
12 | max_steps: 2000000
13 | near: 0.0
14 | far: 4.0
15 | sparsity_radius: 5.0
16 | sparsity_length: 0.2
17 |
--------------------------------------------------------------------------------
/nerf_sh/eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Evaluation script for Nerf."""
19 |
20 | import os
21 | # Get rid of ugly TF logs
22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23 |
24 | import functools
25 | from os import path
26 |
27 | from absl import app
28 | from absl import flags
29 | import flax
30 | from flax.metrics import tensorboard
31 | from flax.training import checkpoints
32 | import jax
33 | from jax import random
34 | import numpy as np
35 |
36 | from nerf_sh.nerf import datasets
37 | from nerf_sh.nerf import models
38 | from nerf_sh.nerf import utils
39 |
40 | FLAGS = flags.FLAGS
41 |
42 | utils.define_flags()
43 |
44 |
45 | def main(unused_argv):
46 | rng = random.PRNGKey(20200823)
47 | rng, key = random.split(rng)
48 |
49 | utils.update_flags(FLAGS)
50 | utils.check_flags(FLAGS)
51 |
52 | dataset = datasets.get_dataset("test", FLAGS)
53 | model, state = models.get_model_state(key, FLAGS, restore=False)
54 |
55 | # Rendering is forced to be deterministic even if training was randomized, as
56 | # this eliminates "speckle" artifacts.
57 | render_pfn = utils.get_render_pfn(model, randomized=False)
58 |
59 | # Compiling to the CPU because it's faster and more accurate.
60 | ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.0), backend="cpu")
61 |
62 | last_step = 0
63 | out_dir = path.join(
64 | FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds"
65 | )
66 | if not FLAGS.eval_once:
67 | summary_writer = tensorboard.SummaryWriter(path.join(FLAGS.train_dir, "eval"))
68 | while True:
69 | print('Loading model')
70 | state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
71 | step = int(state.optimizer.state.step)
72 | if step <= last_step:
73 | continue
74 | if FLAGS.save_output and (not utils.isdir(out_dir)):
75 | utils.makedirs(out_dir)
76 | psnrs = []
77 | ssims = []
78 | if not FLAGS.eval_once:
79 | showcase_index = np.random.randint(0, dataset.size)
80 | for idx in range(dataset.size):
81 | print(f"Evaluating {idx+1}/{dataset.size}")
82 | batch = next(dataset)
83 | if idx % FLAGS.approx_eval_skip != 0:
84 | continue
85 | pred_color, pred_disp, pred_acc = utils.render_image(
86 | functools.partial(render_pfn, state.optimizer.target),
87 | batch["rays"],
88 | rng,
89 | FLAGS.dataset == "llff",
90 | chunk=FLAGS.chunk,
91 | )
92 | if jax.host_id() != 0: # Only record via host 0.
93 | continue
94 | if not FLAGS.eval_once and idx == showcase_index:
95 | showcase_color = pred_color
96 | showcase_disp = pred_disp
97 | showcase_acc = pred_acc
98 | if not FLAGS.render_path:
99 | showcase_gt = batch["pixels"]
100 | # if not FLAGS.render_path:
101 | # psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean())
102 | # ssim = ssim_fn(pred_color, batch["pixels"])
103 | # print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
104 | # psnrs.append(float(psnr))
105 | # ssims.append(float(ssim))
106 | if FLAGS.save_output:
107 | utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
108 | utils.save_img(
109 | pred_disp[Ellipsis, 0],
110 | path.join(out_dir, "disp_{:03d}.png".format(idx)),
111 | )
112 | if (not FLAGS.eval_once) and (jax.host_id() == 0):
113 | summary_writer.image("pred_color", showcase_color, step)
114 | summary_writer.image("pred_disp", showcase_disp, step)
115 | summary_writer.image("pred_acc", showcase_acc, step)
116 | # if not FLAGS.render_path:
117 | # summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step)
118 | # summary_writer.scalar("ssim", np.mean(np.array(ssims)), step)
119 | # summary_writer.image("target", showcase_gt, step)
120 | # if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
121 | # with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
122 | # f.write("{}".format(np.mean(np.array(psnrs))))
123 | # with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
124 | # f.write("{}".format(np.mean(np.array(ssims))))
125 | # with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
126 | # f.write(" ".join([str(v) for v in psnrs]))
127 | # with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
128 | # f.write(" ".join([str(v) for v in ssims]))
129 | if FLAGS.eval_once:
130 | break
131 | if int(step) >= FLAGS.max_steps:
132 | break
133 | last_step = step
134 |
135 |
136 | if __name__ == "__main__":
137 | app.run(main)
138 |
--------------------------------------------------------------------------------
/nerf_sh/gen_mesh.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | import os
24 | # Get rid of ugly TF logs
25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
26 |
27 | from absl import app
28 | from absl import flags
29 | import mcubes
30 |
31 | import jax
32 | from jax import config
33 | from jax import random
34 | import jax.numpy as jnp
35 | import numpy as np
36 | import flax
37 | from flax.training import checkpoints
38 |
39 | import functools
40 |
41 | from nerf_sh.nerf import models
42 | from nerf_sh.nerf import utils
43 | from nerf_sh.nerf.utils import host0_print as h0print
44 |
45 | FLAGS = flags.FLAGS
46 |
47 | utils.define_flags()
48 |
49 | flags.DEFINE_string(
50 | "reso",
51 | "300 300 300",
52 | "Marching cube resolution in each dimension: x y z",
53 | )
54 | flags.DEFINE_string(
55 | "c1",
56 | "-2 -2 -2",
57 | "Marching cubes bounds lower corner 1 in x y z OR single number",
58 | )
59 | flags.DEFINE_string(
60 | "c2",
61 | "2 2 2",
62 | "Marching cubes bounds upper corner in x y z OR single number",
63 | )
64 | flags.DEFINE_float(
65 | "iso", 6.0, "Marching cubes isosurface"
66 | )
67 | flags.DEFINE_bool(
68 | "coarse",
69 | False,
70 | "Force use corase network (else depends on renderer n_fine in conf)",
71 | )
72 | flags.DEFINE_integer(
73 | "point_chunk",
74 | 720720,
75 | "Chunk (batch) size of points for evaluation. NOTE: --chunk will be ignored",
76 | )
77 | # TODO: implement color
78 | # flags.DEFINE_bool(
79 | # "color",
80 | # False,
81 | # "Generate colored mesh."
82 | # )
83 |
84 |
85 | config.parse_flags_with_absl()
86 |
87 |
88 | def marching_cubes(
89 | fn,
90 | c1,
91 | c2,
92 | reso,
93 | isosurface,
94 | chunk,
95 | ):
96 | """
97 | Run marching cubes on network. Uses PyMCubes.
98 | Args:
99 | fn main NeRF type network
100 | c1: list corner 1 of marching cube bounds x,y,z
101 | c2: list corner 2 of marching cube bounds x,y,z (all > c1)
102 | reso: list resolutions of marching cubes x,y,z
103 | isosurface: float sigma-isosurface of marching cubes
104 | """
105 | grid = np.vstack(
106 | np.meshgrid(
107 | *(np.linspace(lo, hi, sz, dtype=np.float32)
108 | for lo, hi, sz in zip(c1, c2, reso)),
109 | indexing="ij"
110 | )
111 | ).reshape(3, -1).T
112 |
113 | h0print("* Evaluating sigma @", grid.shape[0], "points")
114 | rgbs, sigmas = utils.eval_points(
115 | fn,
116 | grid,
117 | chunk,
118 | )
119 | sigmas = sigmas.reshape(*reso)
120 | del rgbs
121 |
122 | if jax.host_id() == 0:
123 | print("* Running marching cubes")
124 | vertices, triangles = mcubes.marching_cubes(sigmas, isosurface)
125 | # Scale
126 | c1, c2 = np.array(c1), np.array(c2)
127 | vertices *= (c2 - c1) / np.array(reso)
128 |
129 | return vertices + c1, triangles
130 | return None, None
131 |
132 |
133 | def save_obj(vertices, triangles, path, vert_rgb=None):
134 | """
135 | Save OBJ file, optionally with vertex colors.
136 | This version is faster than PyMCubes and supports color.
137 | Taken from PIFu.
138 | :param vertices (N, 3)
139 | :param triangles (N, 3)
140 | :param vert_rgb (N, 3) rgb
141 | """
142 | file = open(path, "w")
143 | if vert_rgb is None:
144 | # No color
145 | for v in vertices:
146 | file.write("v %.4f %.4f %.4f\n" % (v[0], v[1], v[2]))
147 | else:
148 | # Color
149 | for idx, v in enumerate(vertices):
150 | c = vert_rgb[idx]
151 | file.write(
152 | "v %.4f %.4f %.4f %.4f %.4f %.4f\n"
153 | % (v[0], v[1], v[2], c[0], c[1], c[2])
154 | )
155 | for f in triangles:
156 | f_plus = f + 1
157 | file.write("f %d %d %d\n" % (f_plus[0], f_plus[1], f_plus[2]))
158 | file.close()
159 |
160 |
161 | def main(unused_argv):
162 | rng = random.PRNGKey(20200823)
163 |
164 | utils.update_flags(FLAGS)
165 | utils.check_flags(FLAGS, require_data=False)
166 |
167 | reso = list(map(int, FLAGS.reso.split()))
168 | if len(reso) == 1:
169 | reso *= 3
170 | c1 = list(map(float, FLAGS.c1.split()))
171 | if len(c1) == 1:
172 | c1 *= 3
173 | c2 = list(map(float, FLAGS.c2.split()))
174 | if len(c2) == 1:
175 | c2 *= 3
176 |
177 | rng, key = random.split(rng)
178 |
179 | h0print('* Creating model')
180 | model, state = models.get_model_state(key, FLAGS)
181 | h0print('* Eval reso', FLAGS.reso, 'coarse?', FLAGS.coarse)
182 |
183 | eval_points_pfn = utils.get_eval_points_pfn(model, raw_rgb=True,
184 | coarse=FLAGS.coarse)
185 |
186 | verts, faces = marching_cubes(
187 | functools.partial(eval_points_pfn, state.optimizer.target),
188 | c1=c1, c2=c2, reso=reso, isosurface=FLAGS.iso, chunk=FLAGS.point_chunk
189 | )
190 |
191 | if jax.host_id() == 0:
192 | mesh_path = os.path.join(FLAGS.train_dir, 'mesh.obj')
193 | print(' Saving to', mesh_path)
194 | save_obj(verts, faces, mesh_path)
195 |
196 |
197 | if __name__ == "__main__":
198 | app.run(main)
199 |
--------------------------------------------------------------------------------
/nerf_sh/gen_video.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """
24 | Simple video generation script (for 360 Blender scene only)
25 | """
26 | import os
27 | # Get rid of ugly TF logs
28 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
29 |
30 | from absl import app
31 | from absl import flags
32 |
33 | import jax
34 | from jax import config
35 | from jax import random
36 | import jax.numpy as jnp
37 | import numpy as np
38 | import flax
39 | from flax.training import checkpoints
40 |
41 | import functools
42 |
43 | from nerf_sh.nerf import models
44 | from nerf_sh.nerf import utils
45 | from nerf_sh.nerf.utils import host0_print as h0print
46 |
47 | import imageio
48 |
49 | FLAGS = flags.FLAGS
50 |
51 | utils.define_flags()
52 |
53 | flags.DEFINE_float(
54 | "elevation",
55 | -30.0,
56 | "Elevation angle (negative is above)",
57 | )
58 | flags.DEFINE_integer(
59 | "num_views",
60 | 40,
61 | "The number of views to generate.",
62 | )
63 | flags.DEFINE_integer(
64 | "height",
65 | 800,
66 | "The size of images to generate.",
67 | )
68 | flags.DEFINE_integer(
69 | "width",
70 | 800,
71 | "The size of images to generate.",
72 | )
73 | flags.DEFINE_float(
74 | "camera_angle_x",
75 | 0.7,
76 | "The camera angle in rad in x direction (used to get focal length).",
77 | short_name='A',
78 | )
79 | flags.DEFINE_string(
80 | "intrin",
81 | None,
82 | "Intrinsics file. If set, overrides camera_angle_x",
83 | )
84 | flags.DEFINE_float(
85 | "radius",
86 | 4.0,
87 | "Radius to origin of camera path.",
88 | )
89 | flags.DEFINE_integer(
90 | "fps",
91 | 20,
92 | "FPS of generated video",
93 | )
94 | flags.DEFINE_integer(
95 | "up_axis",
96 | 1,
97 | "up axis for camera views; 1-6: Z up/Z down/Y up/Y down/X up/X down; " +
98 | "same effect as pressing number keys in volrend",
99 | )
100 | flags.DEFINE_string(
101 | "write_poses",
102 | None,
103 | "Specify to write poses to given file (4N x 4), does not write poses else",
104 | )
105 |
106 | config.parse_flags_with_absl()
107 |
108 | def main(unused_argv):
109 | rng = random.PRNGKey(20200823)
110 |
111 | utils.update_flags(FLAGS)
112 | utils.check_flags(FLAGS, require_data=False)
113 |
114 | rng, key = random.split(rng)
115 |
116 | h0print('* Generating poses')
117 | render_poses = np.stack(
118 | [
119 | utils.pose_spherical(angle, FLAGS.elevation, FLAGS.radius, FLAGS.up_axis - 1)
120 | for angle in np.linspace(-180, 180, FLAGS.num_views + 1)[:-1]
121 | ],
122 | 0,
123 | ) # (NV, 4, 4)
124 |
125 | if FLAGS.write_poses:
126 | np.savetxt(FLAGS.write_poses, render_poses.reshape(-1, 4))
127 | print('Saved poses to', FLAGS.write_poses)
128 |
129 | h0print('* Generating rays')
130 | focal = 0.5 * FLAGS.width / np.tan(0.5 * FLAGS.camera_angle_x)
131 |
132 | if FLAGS.intrin is not None:
133 | print('Load focal length from intrin file')
134 | K : np.ndarray = np.loadtxt(FLAGS.intrin)
135 | focal = (K[0, 0] + K[1, 1]) * 0.5
136 |
137 | rays = utils.generate_rays(FLAGS.width, FLAGS.height, focal, render_poses)
138 |
139 | h0print('* Creating model')
140 | model, state = models.get_model_state(key, FLAGS)
141 | render_pfn = utils.get_render_pfn(model, randomized=False)
142 |
143 | h0print('* Rendering')
144 |
145 | vid_name = "e{:03}".format(int(-FLAGS.elevation * 10))
146 | video_dir = os.path.join(FLAGS.train_dir, 'video', vid_name)
147 | frames_dir = os.path.join(video_dir, 'frames')
148 | h0print(' Saving to', video_dir)
149 | utils.makedirs(frames_dir)
150 |
151 | frames = []
152 | for i in range(FLAGS.num_views):
153 | h0print(f'** View {i+1}/{FLAGS.num_views} = {i / FLAGS.num_views * 100}%')
154 | pred_color, pred_disp, pred_acc = utils.render_image(
155 | functools.partial(render_pfn, state.optimizer.target),
156 | utils.to_device(utils.namedtuple_map(lambda x: x[i], rays)),
157 | rng,
158 | FLAGS.dataset == "llff",
159 | chunk=FLAGS.chunk,
160 | )
161 | if jax.host_id() == 0:
162 | utils.save_img(pred_color, os.path.join(frames_dir, f'{i:04}.png'))
163 | frames.append(np.array(pred_color))
164 |
165 | if jax.host_id() == 0:
166 | frames = np.stack(frames)
167 | vid_path = os.path.join(video_dir, "video.mp4")
168 | print('* Writing video', vid_path)
169 | imageio.mimwrite(
170 | vid_path, (np.clip(frames, 0.0, 1.0) * 255).astype(np.uint8),
171 | fps=FLAGS.fps, quality=8
172 | )
173 | print('* Done')
174 |
175 | if __name__ == "__main__":
176 | app.run(main)
177 |
--------------------------------------------------------------------------------
/nerf_sh/nerf/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/nerf_sh/nerf/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Helper functions/classes for model definition."""
19 |
20 | import functools
21 | from typing import Any, Callable
22 |
23 | from flax import linen as nn
24 | import jax
25 | from jax import lax
26 | from jax import random
27 | import jax.numpy as jnp
28 |
29 |
30 | class MLP(nn.Module):
31 | """A simple MLP."""
32 |
33 | net_depth: int = 8 # The depth of the first part of MLP.
34 | net_width: int = 256 # The width of the first part of MLP.
35 | net_depth_condition: int = 1 # The depth of the second part of MLP.
36 | net_width_condition: int = 128 # The width of the second part of MLP.
37 | net_activation: Callable[Ellipsis, Any] = nn.relu # The activation function.
38 | skip_layer: int = 4 # The layer to add skip layers to.
39 | num_rgb_channels: int = 3 # The number of RGB channels.
40 | num_sigma_channels: int = 1 # The number of sigma channels.
41 |
42 | @nn.compact
43 | def __call__(self, x, condition=None):
44 | """Evaluate the MLP.
45 |
46 | Args:
47 | x: jnp.ndarray(float32), [batch, num_samples, feature], points.
48 | condition: jnp.ndarray(float32), [batch, feature], if not None, this
49 | variable will be part of the input to the second part of the MLP
50 | concatenated with the output vector of the first part of the MLP. If
51 | None, only the first part of the MLP will be used with input x. In the
52 | original paper, this variable is the view direction.
53 |
54 | Returns:
55 | raw_rgb: jnp.ndarray(float32), with a shape of
56 | [batch, num_samples, num_rgb_channels].
57 | raw_sigma: jnp.ndarray(float32), with a shape of
58 | [batch, num_samples, num_sigma_channels].
59 | """
60 | feature_dim = x.shape[-1]
61 | num_samples = x.shape[1]
62 | x = x.reshape([-1, feature_dim])
63 | dense_layer = functools.partial(
64 | nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform()
65 | )
66 | inputs = x
67 | for i in range(self.net_depth):
68 | x = dense_layer(self.net_width)(x)
69 | x = self.net_activation(x)
70 | if i % self.skip_layer == 0 and i > 0:
71 | x = jnp.concatenate([x, inputs], axis=-1)
72 | raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape(
73 | [-1, num_samples, self.num_sigma_channels]
74 | )
75 |
76 | if condition is not None:
77 | # Output of the first part of MLP.
78 | bottleneck = dense_layer(self.net_width)(x)
79 | # Broadcast condition from [batch, feature] to
80 | # [batch, num_samples, feature] since all the samples along the same ray
81 | # have the same viewdir.
82 | condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))
83 | # Collapse the [batch, num_samples, feature] tensor to
84 | # [batch * num_samples, feature] so that it can be fed into nn.Dense.
85 | condition = condition.reshape([-1, condition.shape[-1]])
86 | x = jnp.concatenate([bottleneck, condition], axis=-1)
87 | # Here use 1 extra layer to align with the original nerf model.
88 | for i in range(self.net_depth_condition):
89 | x = dense_layer(self.net_width_condition)(x)
90 | x = self.net_activation(x)
91 | raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(
92 | [-1, num_samples, self.num_rgb_channels]
93 | )
94 | return raw_rgb, raw_sigma
95 |
96 |
97 | def cast_rays(z_vals, origins, directions):
98 | return (
99 | origins[Ellipsis, None, :]
100 | + z_vals[Ellipsis, None] * directions[Ellipsis, None, :]
101 | )
102 |
103 |
104 | def sample_along_rays(
105 | key, origins, directions, num_samples, near, far, randomized, lindisp
106 | ):
107 | """Stratified sampling along the rays.
108 |
109 | Args:
110 | key: jnp.ndarray, random generator key.
111 | origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
112 | directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
113 | num_samples: int.
114 | near: float, near clip.
115 | far: float, far clip.
116 | randomized: bool, use randomized stratified sampling.
117 | lindisp: bool, sampling linearly in disparity rather than depth.
118 |
119 | Returns:
120 | z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
121 | points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
122 | """
123 | batch_size = origins.shape[0]
124 |
125 | t_vals = jnp.linspace(0.0, 1.0, num_samples)
126 | if lindisp:
127 | z_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)
128 | else:
129 | z_vals = near * (1.0 - t_vals) + far * t_vals
130 |
131 | if randomized:
132 | mids = 0.5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
133 | upper = jnp.concatenate([mids, z_vals[Ellipsis, -1:]], -1)
134 | lower = jnp.concatenate([z_vals[Ellipsis, :1], mids], -1)
135 | t_rand = random.uniform(key, [batch_size, num_samples])
136 | z_vals = lower + (upper - lower) * t_rand
137 | else:
138 | # Broadcast z_vals to make the returned shape consistent.
139 | z_vals = jnp.broadcast_to(z_vals[None, Ellipsis], [batch_size, num_samples])
140 |
141 | coords = cast_rays(z_vals, origins, directions)
142 | return z_vals, coords
143 |
144 |
145 | def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
146 | """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
147 |
148 | Instead of computing [sin(x), cos(x)], we use the trig identity
149 | cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
150 |
151 | Args:
152 | x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
153 | min_deg: int, the minimum (inclusive) degree of the encoding.
154 | max_deg: int, the maximum (exclusive) degree of the encoding.
155 | legacy_posenc_order: bool, keep the same ordering as the original tf code.
156 |
157 | Returns:
158 | encoded: jnp.ndarray, encoded variables.
159 | """
160 | if min_deg == max_deg:
161 | return x
162 | scales = jnp.array([2 ** i for i in range(min_deg, max_deg)])
163 | if legacy_posenc_order:
164 | xb = x[Ellipsis, None, :] * scales[:, None]
165 | four_feat = jnp.reshape(
166 | jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)), list(x.shape[:-1]) + [-1]
167 | )
168 | else:
169 | xb = jnp.reshape(
170 | (x[Ellipsis, None, :] * scales[:, None]), list(x.shape[:-1]) + [-1]
171 | )
172 | four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
173 | return jnp.concatenate([x] + [four_feat], axis=-1)
174 |
175 |
176 | def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
177 | """Volumetric Rendering Function.
178 |
179 | Args:
180 | rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
181 | sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
182 | z_vals: jnp.ndarray(float32), [batch_size, num_samples].
183 | dirs: jnp.ndarray(float32), [batch_size, 3].
184 | white_bkgd: bool.
185 |
186 | Returns:
187 | comp_rgb: jnp.ndarray(float32), [batch_size, 3].
188 | disp: jnp.ndarray(float32), [batch_size].
189 | acc: jnp.ndarray(float32), [batch_size].
190 | weights: jnp.ndarray(float32), [batch_size, num_samples]
191 | """
192 | eps = 1e-10
193 | dists = jnp.concatenate(
194 | [
195 | z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1],
196 | jnp.broadcast_to([1e10], z_vals[Ellipsis, :1].shape),
197 | ],
198 | -1,
199 | )
200 | dists = dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1)
201 | # Note that we're quietly turning sigma from [..., 0] to [...].
202 | alpha = 1.0 - jnp.exp(-sigma[Ellipsis, 0] * dists)
203 | accum_prod = jnp.concatenate(
204 | [
205 | jnp.ones_like(alpha[Ellipsis, :1], alpha.dtype),
206 | jnp.cumprod(1.0 - alpha[Ellipsis, :-1] + eps, axis=-1),
207 | ],
208 | axis=-1,
209 | )
210 | weights = alpha * accum_prod
211 |
212 | comp_rgb = (weights[Ellipsis, None] * rgb).sum(axis=-2)
213 | depth = (weights * z_vals).sum(axis=-1)
214 | acc = weights.sum(axis=-1) # Alpha
215 | # Equivalent to (but slightly more efficient and stable than):
216 | # disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
217 | inv_eps = 1 / eps
218 | disp = acc / depth
219 | disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)
220 | if white_bkgd:
221 | comp_rgb = comp_rgb + (1.0 - acc[Ellipsis, None])
222 | return comp_rgb, disp, acc, weights
223 |
224 |
225 | def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
226 | """Piecewise-Constant PDF sampling.
227 |
228 | Args:
229 | key: jnp.ndarray(float32), [2,], random number generator.
230 | bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
231 | weights: jnp.ndarray(float32), [batch_size, num_bins].
232 | num_samples: int, the number of samples.
233 | randomized: bool, use randomized samples.
234 |
235 | Returns:
236 | z_samples: jnp.ndarray(float32), [batch_size, num_samples].
237 | """
238 | # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
239 | # avoids NaNs when the input is zeros or small, but has no effect otherwise.
240 | eps = 1e-5
241 | weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
242 | padding = jnp.maximum(0, eps - weight_sum)
243 | weights += padding / weights.shape[-1]
244 | weight_sum += padding
245 |
246 | # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
247 | # starts with exactly 0 and ends with exactly 1.
248 | pdf = weights / weight_sum
249 | cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))
250 | cdf = jnp.concatenate(
251 | [
252 | jnp.zeros(list(cdf.shape[:-1]) + [1]),
253 | cdf,
254 | jnp.ones(list(cdf.shape[:-1]) + [1]),
255 | ],
256 | axis=-1,
257 | )
258 |
259 | # Draw uniform samples.
260 | if randomized:
261 | # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
262 | u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
263 | else:
264 | # Match the behavior of random.uniform() by spanning [0, 1-eps].
265 | u = jnp.linspace(0.0, 1.0 - jnp.finfo("float32").eps, num_samples)
266 | u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
267 |
268 | # Identify the location in `cdf` that corresponds to a random sample.
269 | # The final `True` index in `mask` will be the start of the sampled interval.
270 | mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]
271 |
272 | def find_interval(x):
273 | # Grab the value where `mask` switches from True to False, and vice versa.
274 | # This approach takes advantage of the fact that `x` is sorted.
275 | x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2)
276 | x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
277 | return x0, x1
278 |
279 | bins_g0, bins_g1 = find_interval(bins)
280 | cdf_g0, cdf_g1 = find_interval(cdf)
281 |
282 | t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
283 | samples = bins_g0 + t * (bins_g1 - bins_g0)
284 |
285 | # Prevent gradient from backprop-ing through `samples`.
286 | return lax.stop_gradient(samples)
287 |
288 |
289 | def sample_pdf(
290 | key, bins, weights, origins, directions, z_vals, num_samples, randomized
291 | ):
292 | """Hierarchical sampling.
293 |
294 | Args:
295 | key: jnp.ndarray(float32), [2,], random number generator.
296 | bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
297 | weights: jnp.ndarray(float32), [batch_size, num_bins].
298 | origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
299 | directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
300 | z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
301 | num_samples: int, the number of samples.
302 | randomized: bool, use randomized samples.
303 |
304 | Returns:
305 | z_vals: jnp.ndarray(float32),
306 | [batch_size, num_coarse_samples + num_fine_samples].
307 | points: jnp.ndarray(float32),
308 | [batch_size, num_coarse_samples + num_fine_samples, 3].
309 | """
310 | z_samples = piecewise_constant_pdf(key, bins, weights, num_samples, randomized)
311 | # Compute united z_vals and sample points
312 | z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
313 | coords = cast_rays(z_vals, origins, directions)
314 | return z_vals, coords
315 |
316 |
317 | def add_gaussian_noise(key, raw, noise_std, randomized):
318 | """Adds gaussian noise to `raw`, which can used to regularize it.
319 |
320 | Args:
321 | key: jnp.ndarray(float32), [2,], random number generator.
322 | raw: jnp.ndarray(float32), arbitrary shape.
323 | noise_std: float, The standard deviation of the noise to be added.
324 | randomized: bool, add noise if randomized is True.
325 |
326 | Returns:
327 | raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
328 | """
329 | if (noise_std is not None) and randomized:
330 | return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std
331 | else:
332 | return raw
333 |
--------------------------------------------------------------------------------
/nerf_sh/nerf/models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Different model implementation plus a general port for all the models."""
19 | from typing import Any, Callable
20 | import flax
21 | from flax import linen as nn
22 | from jax import random
23 | import jax.numpy as jnp
24 |
25 | from nerf_sh.nerf import model_utils
26 | from nerf_sh.nerf import utils
27 | from nerf_sh.nerf import sh
28 | from nerf_sh.nerf import sg
29 |
30 |
31 | def get_model(key, args):
32 | """A helper function that wraps around a 'model zoo'."""
33 | model_dict = {
34 | "nerf": construct_nerf,
35 | }
36 | return model_dict[args.model](key, args)
37 |
38 | def get_model_state(key, args, restore=True):
39 | """
40 | Helper for loading model with get_model & creating optimizer &
41 | optionally restoring checkpoint to reduce boilerplate
42 | """
43 | model, variables = get_model(key, args)
44 | optimizer = flax.optim.Adam(args.lr_init).create(variables)
45 | state = utils.TrainState(optimizer=optimizer)
46 | if restore:
47 | from flax.training import checkpoints
48 | state = checkpoints.restore_checkpoint(args.train_dir, state)
49 | return model, state
50 |
51 |
52 | class NerfModel(nn.Module):
53 | """Nerf NN Model with both coarse and fine MLPs."""
54 |
55 | num_coarse_samples: int # The number of samples for the coarse nerf.
56 | num_fine_samples: int # The number of samples for the fine nerf.
57 | use_viewdirs: bool # If True, use viewdirs as an input.
58 | sh_deg: int # If != -1, use spherical harmonics output up to given degree
59 | sg_dim: int # If != -1, use spherical gaussians output of given dimension
60 | near: float # The distance to the near plane
61 | far: float # The distance to the far plane
62 | noise_std: float # The std dev of noise added to raw sigma.
63 | net_depth: int # The depth of the first part of MLP.
64 | net_width: int # The width of the first part of MLP.
65 | net_depth_condition: int # The depth of the second part of MLP.
66 | net_width_condition: int # The width of the second part of MLP.
67 | net_activation: Callable[Ellipsis, Any] # MLP activation
68 | skip_layer: int # How often to add skip connections.
69 | num_rgb_channels: int # The number of RGB channels.
70 | num_sigma_channels: int # The number of density channels.
71 | white_bkgd: bool # If True, use a white background.
72 | min_deg_point: int # The minimum degree of positional encoding for positions.
73 | max_deg_point: int # The maximum degree of positional encoding for positions.
74 | deg_view: int # The degree of positional encoding for viewdirs.
75 | lindisp: bool # If True, sample linearly in disparity rather than in depth.
76 | rgb_activation: Callable[Ellipsis, Any] # Output RGB activation.
77 | sigma_activation: Callable[Ellipsis, Any] # Output sigma activation.
78 | legacy_posenc_order: bool # Keep the same ordering as the original tf code.
79 |
80 | def setup(self):
81 | # Construct the "coarse" MLP. Weird name is for
82 | # compatibility with 'compact' version
83 | self.MLP_0 = model_utils.MLP(
84 | net_depth=self.net_depth,
85 | net_width=self.net_width,
86 | net_depth_condition=self.net_depth_condition,
87 | net_width_condition=self.net_width_condition,
88 | net_activation=self.net_activation,
89 | skip_layer=self.skip_layer,
90 | num_rgb_channels=self.num_rgb_channels,
91 | num_sigma_channels=self.num_sigma_channels,
92 | )
93 |
94 | # Construct the "fine" MLP.
95 | self.MLP_1 = model_utils.MLP(
96 | net_depth=self.net_depth,
97 | net_width=self.net_width,
98 | net_depth_condition=self.net_depth_condition,
99 | net_width_condition=self.net_width_condition,
100 | net_activation=self.net_activation,
101 | skip_layer=self.skip_layer,
102 | num_rgb_channels=self.num_rgb_channels,
103 | num_sigma_channels=self.num_sigma_channels,
104 | )
105 |
106 | # Construct global learnable variables for spherical gaussians.
107 | if self.sg_dim > 0:
108 | key1, key2 = random.split(random.PRNGKey(0), 2)
109 | self.sg_lambda = self.variable(
110 | "params", "sg_lambda",
111 | lambda x: jnp.ones([x], jnp.float32), self.sg_dim)
112 | self.sg_mu_spher = self.variable(
113 | "params", "sg_mu_spher",
114 | lambda x: jnp.concatenate([
115 | random.uniform(key1, [x, 1]) * jnp.pi, # theta
116 | random.uniform(key2, [x, 1]) * jnp.pi * 2, # phi
117 | ], axis=-1), self.sg_dim)
118 |
119 | def _quick_init(self):
120 | points = jnp.zeros((1, 1, 3), dtype=jnp.float32)
121 | points_enc = model_utils.posenc(
122 | points,
123 | self.min_deg_point,
124 | self.max_deg_point,
125 | self.legacy_posenc_order,
126 | )
127 | if self.use_viewdirs:
128 | viewdirs = jnp.zeros((1, 1, 3), dtype=jnp.float32)
129 | viewdirs_enc = model_utils.posenc(
130 | viewdirs,
131 | 0,
132 | self.deg_view,
133 | self.legacy_posenc_order,
134 | )
135 | self.MLP_0(points_enc, viewdirs_enc)
136 | if self.num_fine_samples > 0:
137 | self.MLP_1(points_enc, viewdirs_enc)
138 | else:
139 | self.MLP_0(points_enc)
140 | if self.num_fine_samples > 0:
141 | self.MLP_1(points_enc)
142 |
143 | def eval_points_raw(self, points, viewdirs=None, coarse=False):
144 | """
145 | Evaluate at points, returing rgb and sigma.
146 | If sh_deg >= 0 / sg_dim > 0 then this will return
147 | spherical harmonic / spherical gaussians / anisotropic spherical gaussians
148 | coeffs for RGB. Please see eval_points for alternate
149 | version which always returns RGB.
150 | Args:
151 | points: jnp.ndarray [B, 3]
152 | viewdirs: jnp.ndarray [B, 3]
153 | coarse: if true, uses coarse MLP
154 | Returns:
155 | raw_rgb: jnp.ndarray [B, 3 * (sh_deg + 1)**2 or 3 or 3 * sg_dim]
156 | raw_sigma: jnp.ndarray [B, 1]
157 | """
158 | points = points[None]
159 | points_enc = model_utils.posenc(
160 | points,
161 | self.min_deg_point,
162 | self.max_deg_point,
163 | self.legacy_posenc_order,
164 | )
165 | if self.num_fine_samples > 0 and not coarse:
166 | mlp = self.MLP_1
167 | else:
168 | mlp = self.MLP_0
169 | if self.use_viewdirs:
170 | assert viewdirs is not None
171 | viewdirs = viewdirs[None]
172 | viewdirs_enc = model_utils.posenc(
173 | viewdirs,
174 | 0,
175 | self.deg_view,
176 | self.legacy_posenc_order,
177 | )
178 | raw_rgb, raw_sigma = mlp(points_enc, viewdirs_enc)
179 | else:
180 | raw_rgb, raw_sigma = mlp(points_enc)
181 | return raw_rgb[0], raw_sigma[0]
182 |
183 | def eval_points(self, points, viewdirs=None, coarse=False):
184 | """
185 | Evaluate at points, converting spherical harmonics rgb to
186 | rgb via viewdirs if applicable. Exists since jax does not allow
187 | size to depend on input.
188 | Args:
189 | points: jnp.ndarray [B, 3]
190 | viewdirs: jnp.ndarray [B, 3]
191 | coarse: if true, uses coarse MLP
192 | Returns:
193 | rgb: jnp.ndarray [B, 3]
194 | sigma: jnp.ndarray [B, 1]
195 | """
196 | raw_rgb, raw_sigma = self.eval_points_raw(points, viewdirs, coarse)
197 | if self.sh_deg >= 0:
198 | assert viewdirs is not None
199 | # (256, 64, 48) (256, 3)
200 | raw_rgb = sh.eval_sh(self.sh_deg, raw_rgb.reshape(
201 | *raw_rgb.shape[:-1],
202 | -1,
203 | (self.sh_deg + 1) ** 2), viewdirs[:, None])
204 | elif self.sg_dim > 0:
205 | assert viewdirs is not None
206 | sg_lambda = self.sg_lambda.value
207 | sg_mu_spher = self.sg_mu_spher.value
208 | sg_coeffs = raw_rgb.reshape(*raw_rgb.shape[:-1], -1, self.sg_dim)
209 | raw_rgb = sg.eval_sg(
210 | sg_lambda, sg_mu_spher, sg_coeffs, viewdirs[:, None])
211 |
212 | rgb = self.rgb_activation(raw_rgb)
213 | sigma = self.sigma_activation(raw_sigma)
214 | return rgb, sigma
215 |
216 | def __call__(self, rng_0, rng_1, rays, randomized):
217 | """Nerf Model.
218 |
219 | Args:
220 | rng_0: jnp.ndarray, random number generator for coarse model sampling.
221 | rng_1: jnp.ndarray, random number generator for fine model sampling.
222 | rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
223 | randomized: bool, use randomized stratified sampling.
224 |
225 | Returns:
226 | ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
227 | """
228 | # Stratified sampling along rays
229 | key, rng_0 = random.split(rng_0)
230 | z_vals, samples = model_utils.sample_along_rays(
231 | key,
232 | rays.origins,
233 | rays.directions,
234 | self.num_coarse_samples,
235 | self.near,
236 | self.far,
237 | randomized,
238 | self.lindisp,
239 | )
240 | samples_enc = model_utils.posenc(
241 | samples,
242 | self.min_deg_point,
243 | self.max_deg_point,
244 | self.legacy_posenc_order,
245 | )
246 |
247 | # Point attribute predictions
248 | if self.use_viewdirs:
249 | viewdirs_enc = model_utils.posenc(
250 | rays.viewdirs,
251 | 0,
252 | self.deg_view,
253 | self.legacy_posenc_order,
254 | )
255 | raw_rgb, raw_sigma = self.MLP_0(samples_enc, viewdirs_enc)
256 | else:
257 | raw_rgb, raw_sigma = self.MLP_0(samples_enc)
258 | # Add noises to regularize the density predictions if needed
259 | key, rng_0 = random.split(rng_0)
260 | raw_sigma = model_utils.add_gaussian_noise(
261 | key,
262 | raw_sigma,
263 | self.noise_std,
264 | randomized,
265 | )
266 |
267 | if self.sh_deg >= 0:
268 | # (256, 64, 48) (256, 3)
269 | raw_rgb = sh.eval_sh(self.sh_deg, raw_rgb.reshape(
270 | *raw_rgb.shape[:-1],
271 | -1,
272 | (self.sh_deg + 1) ** 2), rays.viewdirs[:, None])
273 | elif self.sg_dim > 0:
274 | sg_lambda = self.sg_lambda.value
275 | sg_mu_spher = self.sg_mu_spher.value
276 | sg_coeffs = raw_rgb.reshape(*raw_rgb.shape[:-1], -1, self.sg_dim)
277 | raw_rgb = sg.eval_sg(
278 | sg_lambda, sg_mu_spher, sg_coeffs, rays.viewdirs[:, None])
279 |
280 | rgb = self.rgb_activation(raw_rgb)
281 | sigma = self.sigma_activation(raw_sigma)
282 |
283 | # Volumetric rendering.
284 | comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
285 | rgb,
286 | sigma,
287 | z_vals,
288 | rays.directions,
289 | white_bkgd=self.white_bkgd,
290 | )
291 | ret = [
292 | (comp_rgb, disp, acc),
293 | ]
294 | # Hierarchical sampling based on coarse predictions
295 | if self.num_fine_samples > 0:
296 | z_vals_mid = 0.5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
297 | key, rng_1 = random.split(rng_1)
298 | z_vals, samples = model_utils.sample_pdf(
299 | key,
300 | z_vals_mid,
301 | weights[Ellipsis, 1:-1],
302 | rays.origins,
303 | rays.directions,
304 | z_vals,
305 | self.num_fine_samples,
306 | randomized,
307 | )
308 | samples_enc = model_utils.posenc(
309 | samples,
310 | self.min_deg_point,
311 | self.max_deg_point,
312 | self.legacy_posenc_order,
313 | )
314 |
315 | if self.use_viewdirs:
316 | raw_rgb, raw_sigma = self.MLP_1(samples_enc, viewdirs_enc)
317 | else:
318 | raw_rgb, raw_sigma = self.MLP_1(samples_enc)
319 | key, rng_1 = random.split(rng_1)
320 | raw_sigma = model_utils.add_gaussian_noise(
321 | key,
322 | raw_sigma,
323 | self.noise_std,
324 | randomized,
325 | )
326 | if self.sh_deg >= 0:
327 | raw_rgb = sh.eval_sh(self.sh_deg, raw_rgb.reshape(
328 | *raw_rgb.shape[:-1],
329 | -1,
330 | (self.sh_deg + 1) ** 2), rays.viewdirs[:, None])
331 | elif self.sg_dim > 0:
332 | sg_lambda = self.sg_lambda.value
333 | sg_mu_spher = self.sg_mu_spher.value
334 | sg_coeffs = raw_rgb.reshape(*raw_rgb.shape[:-1], -1, self.sg_dim)
335 | raw_rgb = sg.eval_sg(
336 | sg_lambda, sg_mu_spher, sg_coeffs, rays.viewdirs[:, None])
337 |
338 | rgb = self.rgb_activation(raw_rgb)
339 | sigma = self.sigma_activation(raw_sigma)
340 | comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
341 | rgb,
342 | sigma,
343 | z_vals,
344 | rays.directions,
345 | white_bkgd=self.white_bkgd,
346 | )
347 | ret.append((comp_rgb, disp, acc))
348 | return ret
349 |
350 |
351 | def construct_nerf(key, args):
352 | """Construct a Neural Radiance Field.
353 |
354 | Args:
355 | key: jnp.ndarray. Random number generator.
356 | args: FLAGS class. Hyperparameters of nerf.
357 |
358 | Returns:
359 | model: nn.Model. Nerf model with parameters.
360 | state: flax.Module.state. Nerf model state for stateful parameters.
361 | """
362 | net_activation = getattr(nn, str(args.net_activation))
363 | rgb_activation = getattr(nn, str(args.rgb_activation))
364 | sigma_activation = getattr(nn, str(args.sigma_activation))
365 |
366 | # Assert that rgb_activation always produces outputs in [0, 1], and
367 | # sigma_activation always produce non-negative outputs.
368 | x = jnp.exp(jnp.linspace(-90, 90, 1024))
369 | x = jnp.concatenate([-x[::-1], x], 0)
370 |
371 | rgb = rgb_activation(x)
372 | if jnp.any(rgb < 0) or jnp.any(rgb > 1):
373 | raise NotImplementedError(
374 | "Choice of rgb_activation `{}` produces colors outside of [0, 1]".format(
375 | args.rgb_activation
376 | )
377 | )
378 |
379 | sigma = sigma_activation(x)
380 | if jnp.any(sigma < 0):
381 | raise NotImplementedError(
382 | "Choice of sigma_activation `{}` produces negative densities".format(
383 | args.sigma_activation
384 | )
385 | )
386 | num_rgb_channels = args.num_rgb_channels
387 | # TODO cleanup assert
388 | if args.sh_deg >= 0:
389 | assert not args.use_viewdirs and args.sg_dim == -1, (
390 | "You can only use up to one of: SH, SG or use_viewdirs.")
391 | num_rgb_channels *= (args.sh_deg + 1) ** 2
392 | elif args.sg_dim > 0:
393 | assert not args.use_viewdirs and args.sh_deg == -1, (
394 | "You can only use up to one of: SH, SG or use_viewdirs.")
395 | num_rgb_channels *= args.sg_dim
396 |
397 | model = NerfModel(
398 | min_deg_point=args.min_deg_point,
399 | max_deg_point=args.max_deg_point,
400 | deg_view=args.deg_view,
401 | num_coarse_samples=args.num_coarse_samples,
402 | num_fine_samples=args.num_fine_samples,
403 | use_viewdirs=args.use_viewdirs,
404 | sh_deg=args.sh_deg,
405 | sg_dim=args.sg_dim,
406 | near=args.near,
407 | far=args.far,
408 | noise_std=args.noise_std,
409 | white_bkgd=args.white_bkgd,
410 | net_depth=args.net_depth,
411 | net_width=args.net_width,
412 | net_depth_condition=args.net_depth_condition,
413 | net_width_condition=args.net_width_condition,
414 | skip_layer=args.skip_layer,
415 | num_rgb_channels=num_rgb_channels,
416 | num_sigma_channels=args.num_sigma_channels,
417 | lindisp=args.lindisp,
418 | net_activation=net_activation,
419 | rgb_activation=rgb_activation,
420 | sigma_activation=sigma_activation,
421 | legacy_posenc_order=args.legacy_posenc_order,
422 | )
423 | key1, key = random.split(key)
424 | init_variables = model.init(
425 | key1,
426 | method=model._quick_init,
427 | )
428 | return model, init_variables
429 |
--------------------------------------------------------------------------------
/nerf_sh/nerf/sg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | import jax
24 | import jax.numpy as jnp
25 |
26 |
27 | def spher2cart(r, theta, phi):
28 | """Convert spherical coordinates into Cartesian coordinates."""
29 | x = r * jnp.sin(theta) * jnp.cos(phi)
30 | y = r * jnp.sin(theta) * jnp.sin(phi)
31 | z = r * jnp.cos(theta)
32 | return jnp.stack([x, y, z], axis=-1)
33 |
34 |
35 | def eval_sg(sg_lambda, sg_mu, sg_coeffs, dirs):
36 | """
37 | Evaluate spherical gaussians at unit directions
38 | using learnable SG basis.
39 | Works with jnp.
40 | ... Can be 0 or more batch dimensions.
41 | N is the number of SG basis we use.
42 |
43 | Output = \sigma_{i}{coeffs_i * \exp ^ {lambda_i * (\dot(mu_i, dirs) - 1)}}
44 |
45 | Args:
46 | sg_lambda: The sharpness of the SG lobes. [N] or [..., N]
47 | sg_mu: The directions of the SG lobes. [N, 3 or 2] or [..., N, 3 or 2]
48 | sg_coeffs: The coefficients of the SG (lob amplitude). [..., C, N]
49 | dirs: unit directions [..., 3]
50 |
51 | Returns:
52 | [..., C]
53 | """
54 | sg_lambda = jax.nn.softplus(sg_lambda) # force lambda > 0
55 | # spherical coordinates -> Cartesian coordinates
56 | if sg_mu.shape[-1] == 2:
57 | theta, phi = sg_mu[..., 0], sg_mu[..., 1]
58 | sg_mu = spher2cart(1.0, theta, phi) # [..., N, 3]
59 | product = jnp.einsum(
60 | "...ij,...j->...i", sg_mu, dirs) # [..., N]
61 | basis = jnp.exp(jnp.einsum(
62 | "...i,...i->...i", sg_lambda, product - 1)) # [..., N]
63 | output = jnp.einsum(
64 | "...ki,...i->...k", sg_coeffs, basis) # [..., C]
65 | output /= sg_lambda.shape[-1]
66 | return output
67 |
68 |
69 | def euler2mat(angle):
70 | """Convert euler angles to rotation matrix.
71 |
72 | Args:
73 | angle: rotation angle along 3 axis (in radians). [..., 3]
74 | Returns:
75 | Rotation matrix corresponding to the euler angles. [..., 3, 3]
76 | """
77 | x, y, z = angle[..., 0], angle[..., 1], angle[..., 2]
78 | cosz = jnp.cos(z)
79 | sinz = jnp.sin(z)
80 | cosy = jnp.cos(y)
81 | siny = jnp.sin(y)
82 | cosx = jnp.cos(x)
83 | sinx = jnp.sin(x)
84 | zeros = jnp.zeros_like(z)
85 | ones = jnp.ones_like(z)
86 | zmat = jnp.stack([jnp.stack([cosz, -sinz, zeros], axis=-1),
87 | jnp.stack([sinz, cosz, zeros], axis=-1),
88 | jnp.stack([zeros, zeros, ones], axis=-1)], axis=-1)
89 | ymat = jnp.stack([jnp.stack([ cosy, zeros, siny], axis=-1),
90 | jnp.stack([zeros, ones, zeros], axis=-1),
91 | jnp.stack([-siny, zeros, cosy], axis=-1)], axis=-1)
92 | xmat = jnp.stack([jnp.stack([ ones, zeros, zeros], axis=-1),
93 | jnp.stack([zeros, cosx, -sinx], axis=-1),
94 | jnp.stack([zeros, sinx, cosx], axis=-1)], axis=-1)
95 | rotMat = jnp.einsum("...ij,...jk,...kq->...iq", xmat, ymat, zmat)
96 | return rotMat
97 |
--------------------------------------------------------------------------------
/nerf_sh/nerf/sh.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | C0 = 0.28209479177387814
25 | C1 = 0.4886025119029199
26 | C2 = [
27 | 1.0925484305920792,
28 | -1.0925484305920792,
29 | 0.31539156525252005,
30 | -1.0925484305920792,
31 | 0.5462742152960396
32 | ]
33 | C3 = [
34 | -0.5900435899266435,
35 | 2.890611442640554,
36 | -0.4570457994644658,
37 | 0.3731763325901154,
38 | -0.4570457994644658,
39 | 1.445305721320277,
40 | -0.5900435899266435
41 | ]
42 | C4 = [
43 | 2.5033429417967046,
44 | -1.7701307697799304,
45 | 0.9461746957575601,
46 | -0.6690465435572892,
47 | 0.10578554691520431,
48 | -0.6690465435572892,
49 | 0.47308734787878004,
50 | -1.7701307697799304,
51 | 0.6258357354491761,
52 | ]
53 |
54 | def eval_sh(deg, sh, dirs):
55 | """
56 | Evaluate spherical harmonics at unit directions
57 | using hardcoded SH polynomials.
58 | Works with torch/np/jnp.
59 | ... Can be 0 or more batch dimensions.
60 |
61 | Args:
62 | deg: int SH deg. Currently, 0-3 supported
63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
64 | dirs: jnp.ndarray unit directions [..., 3]
65 |
66 | Returns:
67 | [..., C]
68 | """
69 | assert deg <= 4 and deg >= 0
70 | assert (deg + 1) ** 2 == sh.shape[-1]
71 | C = sh.shape[-2]
72 |
73 | result = C0 * sh[..., 0]
74 | if deg > 0:
75 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
76 | result = (result -
77 | C1 * y * sh[..., 1] +
78 | C1 * z * sh[..., 2] -
79 | C1 * x * sh[..., 3])
80 | if deg > 1:
81 | xx, yy, zz = x * x, y * y, z * z
82 | xy, yz, xz = x * y, y * z, x * z
83 | result = (result +
84 | C2[0] * xy * sh[..., 4] +
85 | C2[1] * yz * sh[..., 5] +
86 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
87 | C2[3] * xz * sh[..., 7] +
88 | C2[4] * (xx - yy) * sh[..., 8])
89 |
90 | if deg > 2:
91 | result = (result +
92 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
93 | C3[1] * xy * z * sh[..., 10] +
94 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
95 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
96 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
97 | C3[5] * z * (xx - yy) * sh[..., 14] +
98 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
99 | if deg > 3:
100 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
101 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
102 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
103 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
104 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
105 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
106 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
107 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
108 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
109 | return result
110 |
--------------------------------------------------------------------------------
/nerf_sh/parse_timing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """
24 | A utility to parse timings files, which are saved automatically during training
25 | (checkpoint_dir/timings.txt).
26 | If you do not stop & restart training, this allows for measuring the training time.
27 | """
28 | import argparse
29 | from datetime import datetime
30 |
31 | parser = argparse.ArgumentParser();
32 | parser.add_argument("file", type=str);
33 | parser.add_argument("--times", "-t", type=int, default=[], nargs='+');
34 | args = parser.parse_args();
35 |
36 | f = open(args.file, 'r')
37 | lines = f.readlines()
38 | lines = [line.strip() for line in lines]
39 | lines = [line.split() for line in lines if len(line)]
40 | lines = {int(line[0]) : datetime.fromisoformat(line[1]) for line in lines}
41 |
42 | if not args.times:
43 | print(list(lines.keys()))
44 |
45 | for t in args.times:
46 | print((lines[t] - lines[0]).total_seconds() / 3600)
47 |
--------------------------------------------------------------------------------
/nerf_sh/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Training script for Nerf."""
19 |
20 | import os
21 | # Get rid of ugly TF logs
22 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23 |
24 | import functools
25 | import gc
26 | import time
27 | from absl import app
28 | from absl import flags
29 | import flax
30 | import flax.linen as nn
31 | from flax.metrics import tensorboard
32 | from flax.training import checkpoints
33 | import jax
34 | from jax import config
35 | from jax import random
36 | import jax.numpy as jnp
37 | import numpy as np
38 |
39 |
40 | from nerf_sh.nerf import datasets
41 | from nerf_sh.nerf import models
42 | from nerf_sh.nerf import utils
43 | from nerf_sh.nerf.utils import host0_print as h0print
44 |
45 | FLAGS = flags.FLAGS
46 |
47 | utils.define_flags()
48 | config.parse_flags_with_absl()
49 |
50 |
51 | def train_step(model, rng, state, batch, lr):
52 | """One optimization step.
53 |
54 | Args:
55 | model: The linen model.
56 | rng: jnp.ndarray, random number generator.
57 | state: utils.TrainState, state of the model/optimizer.
58 | batch: dict, a mini-batch of data for training.
59 | lr: float, real-time learning rate.
60 |
61 | Returns:
62 | new_state: utils.TrainState, new training state.
63 | stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
64 | rng: jnp.ndarray, updated random number generator.
65 | """
66 | rng, key_0, key_1, key_2 = random.split(rng, 4)
67 |
68 | def loss_fn(variables):
69 | rays = batch["rays"]
70 | ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
71 | if len(ret) not in (1, 2):
72 | raise ValueError(
73 | "ret should contain either 1 set of output (coarse only), or 2 sets"
74 | "of output (coarse as ret[0] and fine as ret[1])."
75 | )
76 |
77 | if FLAGS.sparsity_weight > 0.0:
78 | rng, key = random.split(key_2)
79 | sp_points = random.uniform(key, (FLAGS.sparsity_npoints, 3), minval=-FLAGS.sparsity_radius, maxval=FLAGS.sparsity_radius)
80 | sp_rgb, sp_sigma = model.apply(variables, sp_points, method=model.eval_points_raw)
81 | del sp_rgb
82 | sp_sigma = nn.relu(sp_sigma)
83 | loss_sp = FLAGS.sparsity_weight * (1.0 - jnp.exp(- FLAGS.sparsity_length * sp_sigma).mean())
84 | else:
85 | loss_sp = 0.0
86 |
87 | # The main prediction is always at the end of the ret list.
88 | rgb, unused_disp, unused_acc = ret[-1]
89 | loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean()
90 | psnr = utils.compute_psnr(loss)
91 | if len(ret) > 1:
92 | # If there are both coarse and fine predictions, we compute the loss for
93 | # the coarse prediction (ret[0]) as well.
94 | rgb_c, unused_disp_c, unused_acc_c = ret[0]
95 | loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean()
96 | psnr_c = utils.compute_psnr(loss_c)
97 | else:
98 | loss_c = 0.0
99 | psnr_c = 0.0
100 |
101 | def tree_sum_fn(fn):
102 | return jax.tree_util.tree_reduce(
103 | lambda x, y: x + fn(y), variables, initializer=0
104 | )
105 |
106 | weight_l2 = tree_sum_fn(lambda z: jnp.sum(z ** 2)) / tree_sum_fn(
107 | lambda z: jnp.prod(jnp.array(z.shape))
108 | )
109 |
110 | stats = utils.Stats(
111 | loss=loss, psnr=psnr, loss_c=loss_c, loss_sp=loss_sp,
112 | psnr_c=psnr_c, weight_l2=weight_l2
113 | )
114 | return loss + loss_c + loss_sp + FLAGS.weight_decay_mult * weight_l2, stats
115 |
116 | (_, stats), grad = jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target)
117 | grad = jax.lax.pmean(grad, axis_name="batch")
118 | stats = jax.lax.pmean(stats, axis_name="batch")
119 | new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
120 | new_state = state.replace(optimizer=new_optimizer)
121 | return new_state, stats, rng
122 |
123 |
124 | def main(unused_argv):
125 | rng = random.PRNGKey(20200823)
126 | # Shift the numpy random seed by host_id() to shuffle data loaded by different
127 | # hosts.
128 | np.random.seed(20201473 + jax.host_id())
129 | rng, key = random.split(rng)
130 |
131 | utils.update_flags(FLAGS)
132 | utils.check_flags(FLAGS, require_batch_size_div=True)
133 |
134 | utils.makedirs(FLAGS.train_dir)
135 | render_dir = os.path.join(FLAGS.train_dir, 'render')
136 | utils.makedirs(render_dir)
137 |
138 | # TEMP
139 | timings_file = open(os.path.join(FLAGS.train_dir, 'timings.txt'), 'a')
140 | from datetime import datetime
141 | def write_ts_now(step):
142 | timings_file.write(f"{step} {datetime.now().isoformat()}\n")
143 | timings_file.flush()
144 | write_ts_now(0)
145 |
146 | h0print('* Load train data')
147 | dataset = datasets.get_dataset("train", FLAGS)
148 | h0print('* Load test data')
149 | test_dataset = datasets.get_dataset("test", FLAGS)
150 |
151 | h0print('* Load model')
152 | model, state = models.get_model_state(key, FLAGS)
153 |
154 | learning_rate_fn = functools.partial(
155 | utils.learning_rate_decay,
156 | lr_init=FLAGS.lr_init,
157 | lr_final=FLAGS.lr_final,
158 | max_steps=FLAGS.max_steps,
159 | lr_delay_steps=FLAGS.lr_delay_steps,
160 | lr_delay_mult=FLAGS.lr_delay_mult,
161 | )
162 |
163 | train_pstep = jax.pmap(
164 | functools.partial(train_step, model),
165 | axis_name="batch",
166 | in_axes=(0, 0, 0, None),
167 | donate_argnums=(2,),
168 | )
169 |
170 | render_pfn = utils.get_render_pfn(model, randomized=FLAGS.randomized)
171 |
172 | # Compiling to the CPU because it's faster and more accurate.
173 | ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.0), backend="cpu")
174 |
175 | # Resume training a the step of the last checkpoint.
176 | init_step = state.optimizer.state.step + 1
177 | state = flax.jax_utils.replicate(state)
178 |
179 | if jax.host_id() == 0:
180 | summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
181 |
182 | h0print('* Prefetch')
183 | # Prefetch_buffer_size = 3 x batch_size
184 | pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
185 | n_local_deices = jax.local_device_count()
186 | rng = rng + jax.host_id() # Make random seed separate across hosts.
187 | keys = random.split(rng, n_local_deices) # For pmapping RNG keys.
188 | gc.disable() # Disable automatic garbage collection for efficiency.
189 | stats_trace = []
190 |
191 |
192 | reset_timer = True
193 | for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
194 | if reset_timer:
195 | t_loop_start = time.time()
196 | reset_timer = False
197 | lr = learning_rate_fn(step)
198 | state, stats, keys = train_pstep(keys, state, batch, lr)
199 | if jax.host_id() == 0:
200 | stats_trace.append(stats)
201 | if step % FLAGS.gc_every == 0:
202 | gc.collect()
203 |
204 | # Log training summaries. This is put behind a host_id check because in
205 | # multi-host evaluation, all hosts need to run inference even though we
206 | # only use host 0 to record results.
207 | if jax.host_id() == 0:
208 | if step % FLAGS.print_every == 0:
209 | summary_writer.scalar("train_loss", stats.loss[0], step)
210 | summary_writer.scalar("train_psnr", stats.psnr[0], step)
211 | summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step)
212 | summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step)
213 | if FLAGS.sparsity_weight > 0.0:
214 | summary_writer.scalar("train_sparse_loss", stats.loss_sp[0], step)
215 | summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
216 | avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
217 | avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
218 | stats_trace = []
219 | summary_writer.scalar("train_avg_loss", avg_loss, step)
220 | summary_writer.scalar("train_avg_psnr", avg_psnr, step)
221 | summary_writer.scalar("learning_rate", lr, step)
222 | steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
223 | reset_timer = True
224 | rays_per_sec = FLAGS.batch_size * steps_per_sec
225 | summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
226 | summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
227 | precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
228 | print(
229 | ("{:" + "{:d}".format(precision) + "d}").format(step)
230 | + f"/{FLAGS.max_steps:d}: "
231 | + f"i_loss={stats.loss[0]:0.4f}, "
232 | + f"avg_loss={avg_loss:0.4f}, "
233 | + f"weight_l2={stats.weight_l2[0]:0.2e}, "
234 | + f"lr={lr:0.2e}, "
235 | + f"{rays_per_sec:0.0f} rays/sec"
236 | )
237 | if step % FLAGS.save_every == 0:
238 | print('* Saving')
239 | state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
240 | checkpoints.save_checkpoint(
241 | FLAGS.train_dir, state_to_save, int(step), keep=200
242 | )
243 |
244 | # Test-set evaluation.
245 | if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
246 | # We reuse the same random number generator from the optimization step
247 | # here on purpose so that the visualization matches what happened in
248 | # training.
249 | h0print('\n* Rendering')
250 | t_eval_start = time.time()
251 | eval_variables = jax.device_get(
252 | jax.tree_map(lambda x: x[0], state)
253 | ).optimizer.target
254 | test_case = next(test_dataset)
255 | pred_color, pred_disp, pred_acc = utils.render_image(
256 | functools.partial(render_pfn, eval_variables),
257 | test_case["rays"],
258 | keys[0],
259 | FLAGS.dataset == "llff",
260 | chunk=FLAGS.chunk,
261 | )
262 |
263 | # Log eval summaries on host 0.
264 | if jax.host_id() == 0:
265 | write_ts_now(step)
266 | psnr = utils.compute_psnr(
267 | ((pred_color - test_case["pixels"]) ** 2).mean()
268 | )
269 | ssim = ssim_fn(pred_color, test_case["pixels"])
270 | eval_time = time.time() - t_eval_start
271 | num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
272 | rays_per_sec = num_rays / eval_time
273 | summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
274 | print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
275 | summary_writer.scalar("test_psnr", psnr, step)
276 | summary_writer.scalar("test_ssim", ssim, step)
277 | # print(pred_color.shape, pred_disp.shape, pred_acc.shape,
278 | # test_case["pixels"].shape)
279 | # print(pred_color.dtype, pred_disp.dtype, pred_acc.dtype,
280 | # test_case["pixels"].dtype)
281 | # print(pred_color.min(), pred_color.max(),
282 | # pred_disp.min(), pred_disp.max(),
283 | # pred_acc.min(), pred_acc.max(),
284 | # test_case['pixels'].min(), test_case['pixels'].max())
285 | # 0 1. 0.0 1.0 0.90906805 1.0000007 0.0 1.0
286 |
287 | # (800, 800, 3) (800, 800, 1) (800, 800, 1) (800, 800, 3)
288 | # float32 float32 float32 float32
289 |
290 | vis_list= [test_case["pixels"],
291 | pred_color,
292 | np.repeat(pred_disp, 3, axis=-1),
293 | np.repeat(pred_acc, 3, axis=-1)]
294 | out_path = os.path.join(render_dir, '{:010}.png'.format(step))
295 | utils.save_img(np.hstack(vis_list), out_path)
296 | print(' Rendering saved to ', out_path)
297 |
298 | # I am saving rendering to disk instead of Tensorboard
299 | # Since Tensorboard begins to load very slowly when it has many images
300 |
301 | # summary_writer.image("test_pred_color", pred_color, step)
302 | # summary_writer.image("test_pred_disp", pred_disp, step)
303 | # summary_writer.image("test_pred_acc", pred_acc, step)
304 | # summary_writer.image("test_target", test_case["pixels"], step)
305 |
306 | if FLAGS.max_steps % FLAGS.save_every != 0:
307 | state = jax.device_get(jax.tree_map(lambda x: x[0], state))
308 | checkpoints.save_checkpoint(
309 | FLAGS.train_dir, state, int(FLAGS.max_steps), keep=200
310 | )
311 |
312 |
313 | if __name__ == "__main__":
314 | app.run(main)
315 |
--------------------------------------------------------------------------------
/octree/compression.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """Compress a plenoctree.
24 |
25 | Including quantization using median cut algorithm.
26 |
27 | Usage:
28 | python compression.py x.npz [y.npz ...]
29 | """
30 | import sys
31 | import numpy as np
32 | import os.path as osp
33 | import torch
34 | from svox.helpers import _get_c_extension
35 | from tqdm import tqdm
36 | import os
37 | import argparse
38 |
39 | @torch.no_grad()
40 | def main():
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('input', type=str, nargs='+', default=None, help='Input npz(s)')
43 | parser.add_argument('--noquant', action='store_true',
44 | help='Disable quantization')
45 | parser.add_argument('--bits', type=int, default=16,
46 | help='Quantization bits (order)')
47 | parser.add_argument('--out_dir', type=str, default='min_alt',
48 | help='Where to write compressed npz')
49 | parser.add_argument('--overwrite', action='store_true',
50 | help='Overwrite existing compressed npz')
51 | parser.add_argument('--weighted', action='store_true',
52 | help='Use weighted median cut (seems quite useless)')
53 | parser.add_argument('--sigma_thresh', type=float, default=2.0,
54 | help='Kill voxels under this sigma')
55 | parser.add_argument('--retain', type=int, default=0,
56 | help='Do not compress first x SH coeffs, needed for some scenes to keep ok quality')
57 |
58 | args = parser.parse_args()
59 |
60 | _C = _get_c_extension()
61 | os.makedirs(args.out_dir, exist_ok=True)
62 |
63 | if args.noquant:
64 | print('Quantization disabled, only applying deflate')
65 | else:
66 | print('Quantization enabled')
67 |
68 | for fname in args.input:
69 | fname_c = osp.join(args.out_dir, osp.basename(fname))
70 | print('Compressing', fname, 'to', fname_c)
71 | if not args.overwrite and osp.exists(fname_c):
72 | print(' > skip')
73 | continue
74 |
75 | z = np.load(fname)
76 |
77 | if not args.noquant:
78 | if 'quant_colors' in z.files:
79 | print(' > skip since source already compressed')
80 | continue
81 | z = dict(z)
82 | del z['parent_depth']
83 | del z['geom_resize_fact']
84 | del z['n_free']
85 | del z['n_internal']
86 | del z['depth_limit']
87 |
88 | if not args.noquant:
89 | data = torch.from_numpy(z['data'])
90 | sigma = data[..., -1].reshape(-1)
91 | snz = sigma > args.sigma_thresh
92 | sigma[~snz] = 0.0
93 |
94 | data = data[..., :-1]
95 | N = data.size(1)
96 | basis_dim = data.size(-1) // 3
97 |
98 | data = data.reshape(-1, 3, basis_dim).float()[snz].unbind(-1)
99 | if args.retain:
100 | retained = data[:args.retain]
101 | data = data[args.retain:]
102 | else:
103 | retained = None
104 |
105 | all_quant_colors = []
106 | all_quant_maps = []
107 |
108 | if args.weighted:
109 | weights = 1.0 - np.exp(-0.01 * sigma.float(float32))
110 | else:
111 | weights = torch.empty((0,))
112 |
113 | for i, d in tqdm(enumerate(data), total=len(data)):
114 | colors, color_id_map = _C.quantize_median_cut(d.contiguous(),
115 | weights,
116 | args.bits)
117 | color_id_map_full = np.zeros((snz.shape[0],), dtype=np.uint16)
118 | color_id_map_full[snz] = color_id_map
119 |
120 | all_quant_colors.append(colors.numpy().astype(np.float16))
121 | all_quant_maps.append(color_id_map_full.reshape(-1, N, N, N).astype(np.uint16))
122 | quant_map = np.stack(all_quant_maps, axis=0)
123 | quant_colors = np.stack(all_quant_colors, axis=0)
124 | del all_quant_maps
125 | del all_quant_colors
126 | z['quant_colors'] = quant_colors
127 | z['quant_map'] = quant_map
128 | z['sigma'] = sigma.reshape(-1, N, N, N)
129 | if args.retain:
130 | all_retained = []
131 | for i in range(args.retain):
132 | retained_wz = np.zeros((snz.shape[0], 3), dtype=np.float16)
133 | retained_wz[snz] = retained[i]
134 | all_retained.append(retained_wz.reshape(-1, N, N, N, 3))
135 | all_retained = np.stack(all_retained, axis=0)
136 | del retained
137 | z['data_retained'] = all_retained
138 | del z['data']
139 | np.savez_compressed(fname_c, **z)
140 | print(' > Size', osp.getsize(fname) // (1024 * 1024), 'MB ->',
141 | osp.getsize(fname_c) // (1024 * 1024), 'MB')
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
--------------------------------------------------------------------------------
/octree/config/syn_sg25.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_root": "./data/NeRF/nerf_synthetic/",
3 | "train_root": "./data/Plenoctree/checkpoints/syn_sg25/",
4 | "scenes": ["chair", "drums", "ficus", "hotdog", "lego", "ship"],
5 | "scene_tasks": [{
6 | "octree_name": "",
7 | "train_dir": "{%}",
8 | "data_dir": "{%}",
9 | "config": "nerf_sh/config/misc/sg",
10 | "extr_flags": [
11 | "--autoscale",
12 | "--scale_alpha_thresh", "0.1",
13 | "--radius", "1.4",
14 | "--samples_per_cell", "256",
15 | "--no_early_stop",
16 | "--renderer_step_size", "1e-5"],
17 | "opt_flags": [
18 | "--num_epochs", "80",
19 | "--sgd",
20 | "--lr", "1e9",
21 | "--no_early_stop",
22 | "--renderer_step_size", "1e-5"],
23 | "eval_flags": [
24 | "--renderer_step_size", "1e-5"]
25 | }],
26 | "tasks": [{
27 | "octree_name": "",
28 | "train_dir": "materials",
29 | "data_dir": "materials",
30 | "config": "nerf_sh/config/misc/sg",
31 | "extr_flags": [
32 | "--autoscale",
33 | "--bbox_scale", "1.2",
34 | "--scale_alpha_thresh", "0.1",
35 | "--radius", "1.4",
36 | "--samples_per_cell", "256",
37 | "--no_early_stop",
38 | "--renderer_step_size", "1e-5"],
39 | "opt_flags": [
40 | "--num_epochs", "80",
41 | "--sgd",
42 | "--lr", "1e9",
43 | "--no_early_stop",
44 | "--renderer_step_size", "1e-5"],
45 | "eval_flags": [
46 | "--renderer_step_size", "1e-5"]
47 | },{
48 | "octree_name": "",
49 | "train_dir": "mic",
50 | "data_dir": "mic",
51 | "config": "nerf_sh/config/misc/sg",
52 | "extr_flags": [
53 | "--autoscale",
54 | "--scale_alpha_thresh", "0.1",
55 | "--radius", "1.6",
56 | "--samples_per_cell", "256",
57 | "--no_early_stop",
58 | "--renderer_step_size", "1e-5"],
59 | "opt_flags": [
60 | "--num_epochs", "80",
61 | "--sgd",
62 | "--lr", "1e9",
63 | "--no_early_stop",
64 | "--renderer_step_size", "1e-5"],
65 | "eval_flags": [
66 | "--renderer_step_size", "1e-5"]
67 | }]
68 | }
69 |
--------------------------------------------------------------------------------
/octree/config/syn_sh16.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_root": "./data/NeRF/nerf_synthetic/",
3 | "train_root": "./data/Plenoctree/checkpoints/syn_sh16/",
4 | "scenes": ["chair", "drums", "ficus", "hotdog", "lego", "ship"],
5 | "scene_tasks": [{
6 | "octree_name": "",
7 | "train_dir": "{%}",
8 | "data_dir": "{%}",
9 | "config": "nerf_sh/config/blender",
10 | "extr_flags": [
11 | "--autoscale",
12 | "--scale_alpha_thresh", "0.1",
13 | "--radius", "1.4",
14 | "--samples_per_cell", "256",
15 | "--no_early_stop",
16 | "--renderer_step_size", "1e-5"],
17 | "opt_flags": [
18 | "--num_epochs", "80",
19 | "--sgd",
20 | "--lr", "1e7",
21 | "--no_early_stop",
22 | "--renderer_step_size", "1e-5"],
23 | "eval_flags": [
24 | "--renderer_step_size", "1e-5"]
25 | }],
26 | "tasks": [{
27 | "octree_name": "",
28 | "train_dir": "materials",
29 | "data_dir": "materials",
30 | "config": "nerf_sh/config/blender",
31 | "extr_flags": [
32 | "--autoscale",
33 | "--bbox_scale", "1.1",
34 | "--scale_alpha_thresh", "0.1",
35 | "--radius", "1.4",
36 | "--samples_per_cell", "256",
37 | "--no_early_stop",
38 | "--renderer_step_size", "1e-5"],
39 | "opt_flags": [
40 | "--num_epochs", "80",
41 | "--sgd",
42 | "--lr", "1e7",
43 | "--no_early_stop",
44 | "--renderer_step_size", "1e-5"],
45 | "eval_flags": [
46 | "--renderer_step_size", "1e-5"]
47 | },{
48 | "octree_name": "",
49 | "train_dir": "mic",
50 | "data_dir": "mic",
51 | "config": "nerf_sh/config/blender",
52 | "extr_flags": [
53 | "--autoscale",
54 | "--scale_alpha_thresh", "0.1",
55 | "--radius", "1.6",
56 | "--samples_per_cell", "256",
57 | "--no_early_stop",
58 | "--renderer_step_size", "1e-5"],
59 | "opt_flags": [
60 | "--num_epochs", "80",
61 | "--sgd",
62 | "--lr", "1e7",
63 | "--no_early_stop",
64 | "--renderer_step_size", "1e-5"],
65 | "eval_flags": [
66 | "--renderer_step_size", "1e-5"]
67 | }]
68 | }
69 |
--------------------------------------------------------------------------------
/octree/config/tt_sh25.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_root": "./data/TanksAndTemple",
3 | "train_root": "./data/Plenoctree/checkpoints/tt_sh25/",
4 | "scenes": ["Barn", "Caterpillar", "Family", "Truck"],
5 | "scene_tasks": [{
6 | "octree_name": "",
7 | "train_dir": "{%}",
8 | "data_dir": "{%}",
9 | "config": "nerf_sh/config/tt",
10 | "extr_flags": [
11 | "--autoscale",
12 | "--scale_alpha_thresh", "0.1",
13 | "--bbox_from_data",
14 | "--data_bbox_scale", "1.2",
15 | "--bbox_scale", "1.0",
16 | "--samples_per_cell", "256",
17 | "--chunk", "8192",
18 | "--no_early_stop",
19 | "--renderer_step_size", "1e-5"
20 | ],
21 | "opt_flags": [
22 | "--num_epochs", "40",
23 | "--sgd",
24 | "--lr", "1.5e6",
25 | "--renderer_step_size", "1e-5",
26 | "--split_train",
27 | "--split_holdout_prop", "0.1"
28 | ],
29 | "eval_flags": [
30 | "--renderer_step_size", "1e-5"
31 | ]
32 | }],
33 | "tasks": [{
34 | "octree_name": "",
35 | "train_dir": "Ignatius",
36 | "data_dir": "Ignatius",
37 | "config": "nerf_sh/config/tt",
38 | "extr_flags": [
39 | "--autoscale",
40 | "--scale_alpha_thresh", "0.1",
41 | "--bbox_from_data",
42 | "--data_bbox_scale", "1.2",
43 | "--bbox_scale", "1.25",
44 | "--samples_per_cell", "256",
45 | "--chunk", "8192",
46 | "--no_early_stop",
47 | "--renderer_step_size", "1e-5"
48 | ],
49 | "opt_flags": [
50 | "--num_epochs", "40",
51 | "--sgd",
52 | "--lr", "1.5e6",
53 | "--renderer_step_size", "1e-5",
54 | "--split_train",
55 | "--split_holdout_prop", "0.1"
56 | ],
57 | "eval_flags": [
58 | "--renderer_step_size", "1e-5"
59 | ]
60 | }]
61 | }
62 |
--------------------------------------------------------------------------------
/octree/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """Evluate a plenoctree on test set.
24 |
25 | Usage:
26 |
27 | export DATA_ROOT=./data/NeRF/nerf_synthetic/
28 | export CKPT_ROOT=./data/PlenOctree/checkpoints/syn_sh16
29 | export SCENE=chair
30 | export CONFIG_FILE=nerf_sh/config/blender
31 |
32 | python -m octree.evaluation \
33 | --input $CKPT_ROOT/$SCENE/octrees/tree_opt.npz \
34 | --config $CONFIG_FILE \
35 | --data_dir $DATA_ROOT/$SCENE/
36 | """
37 | import torch
38 | import numpy as np
39 | import os
40 | from absl import app
41 | from absl import flags
42 | from tqdm import tqdm
43 | import imageio
44 |
45 | from octree.nerf import models
46 | from octree.nerf import utils
47 | from octree.nerf import datasets
48 |
49 | import svox
50 |
51 | FLAGS = flags.FLAGS
52 |
53 | utils.define_flags()
54 |
55 | flags.DEFINE_string(
56 | "input",
57 | "./tree_opt.npz",
58 | "Input octree npz from optimization.py",
59 | )
60 | flags.DEFINE_string(
61 | "write_vid",
62 | None,
63 | "If specified, writes rendered video to given path (*.mp4)",
64 | )
65 | flags.DEFINE_string(
66 | "write_images",
67 | None,
68 | "If specified, writes images to given path (*.png)",
69 | )
70 |
71 | device = "cuda" if torch.cuda.is_available() else "cpu"
72 |
73 | @torch.no_grad()
74 | def main(unused_argv):
75 | utils.set_random_seed(20200823)
76 | utils.update_flags(FLAGS)
77 |
78 | dataset = datasets.get_dataset("test", FLAGS)
79 |
80 | print('N3Tree load', FLAGS.input)
81 | t = svox.N3Tree.load(FLAGS.input, map_location=device)
82 |
83 | avg_psnr, avg_ssim, avg_lpips, out_frames = utils.eval_octree(t, dataset, FLAGS,
84 | want_lpips=True,
85 | want_frames=FLAGS.write_vid is not None or FLAGS.write_images is not None)
86 | print('Average PSNR', avg_psnr, 'SSIM', avg_ssim, 'LPIPS', avg_lpips)
87 |
88 | if FLAGS.write_vid is not None and len(out_frames):
89 | print('Writing to', FLAGS.write_vid)
90 | imageio.mimwrite(FLAGS.write_vid, out_frames)
91 |
92 | if FLAGS.write_images is not None and len(out_frames):
93 | print('Writing to', FLAGS.write_images)
94 | os.makedirs(FLAGS.write_images, exist_ok=True)
95 | for idx, frame in tqdm(enumerate(out_frames)):
96 | imageio.imwrite(os.path.join(FLAGS.write_images, f"{idx:03d}.png"), frame)
97 |
98 | if __name__ == "__main__":
99 | app.run(main)
100 |
--------------------------------------------------------------------------------
/octree/extraction.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """Extract a plenoctree from a trained NeRF-SH model.
24 |
25 | Usage:
26 |
27 | export DATA_ROOT=./data/NeRF/nerf_synthetic/
28 | export CKPT_ROOT=./data/PlenOctree/checkpoints/syn_sh16
29 | export SCENE=chair
30 | export CONFIG_FILE=nerf_sh/config/blender
31 |
32 | python -m octree.extraction \
33 | --train_dir $CKPT_ROOT/$SCENE/ --is_jaxnerf_ckpt \
34 | --config $CONFIG_FILE \
35 | --data_dir $DATA_ROOT/$SCENE/ \
36 | --output $CKPT_ROOT/$SCENE/octrees/tree.npz
37 | """
38 | import os
39 | # Get rid of ugly TF logs
40 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
41 |
42 | import torch
43 | import torch.nn.functional as F
44 | import numpy as np
45 | import os.path as osp
46 |
47 | from absl import app
48 | from absl import flags
49 |
50 | from octree.nerf import models
51 | from octree.nerf import utils
52 | from octree.nerf import datasets
53 | from octree.nerf import sh_proj
54 |
55 | from svox import N3Tree
56 | from svox import NDCConfig, VolumeRenderer
57 | from svox.helpers import _get_c_extension
58 | from tqdm import tqdm
59 |
60 | _C = _get_c_extension()
61 |
62 | FLAGS = flags.FLAGS
63 |
64 | utils.define_flags()
65 |
66 | flags.DEFINE_string(
67 | "output",
68 | "./tree.npz",
69 | "Output file",
70 | )
71 | flags.DEFINE_string(
72 | "center",
73 | "0 0 0",
74 | "Center of volume in x y z OR single number",
75 | )
76 | flags.DEFINE_string(
77 | "radius",
78 | "1.5",
79 | "1/2 side length of volume",
80 | )
81 | flags.DEFINE_float(
82 | "alpha_thresh",
83 | 0.01,
84 | "Alpha threshold to keep a voxel in initial sigma thresholding",
85 | )
86 | flags.DEFINE_float(
87 | "max_refine_prop",
88 | 0.5,
89 | "Max proportion of cells to refine",
90 | )
91 | flags.DEFINE_float(
92 | "z_min",
93 | None,
94 | "Discard z axis points below this value, for NDC use",
95 | )
96 | flags.DEFINE_float(
97 | "z_max",
98 | None,
99 | "Discard z axis points above this value, for NDC use",
100 | )
101 | flags.DEFINE_integer(
102 | "tree_branch_n",
103 | 2,
104 | "Tree branch factor (2=octree)",
105 | )
106 | flags.DEFINE_integer(
107 | "init_grid_depth",
108 | 8,
109 | "Initial evaluation grid (2^{x+1} voxel grid)",
110 | )
111 | flags.DEFINE_integer(
112 | "samples_per_cell",
113 | 8,
114 | "Samples per cell in step 2 (3D antialiasing)",
115 | short_name='S',
116 | )
117 | flags.DEFINE_bool(
118 | "is_jaxnerf_ckpt",
119 | False,
120 | "Whether the ckpt is from jaxnerf or not.",
121 | )
122 | flags.DEFINE_enum(
123 | "masking_mode",
124 | "weight",
125 | ["sigma", "weight"],
126 | "How to calculate mask when building the octree",
127 | )
128 | flags.DEFINE_float(
129 | "weight_thresh",
130 | 0.001,
131 | "Weight threshold to keep a voxel",
132 | )
133 | flags.DEFINE_integer(
134 | "projection_samples",
135 | 10000,
136 | "Number of rays to sample for SH projection.",
137 | )
138 |
139 | # Load bbox from dataset
140 | flags.DEFINE_bool(
141 | "bbox_from_data",
142 | False,
143 | "Use bounding box from dataset if possible",
144 | )
145 | flags.DEFINE_float(
146 | "data_bbox_scale",
147 | 1.0,
148 | "Scaling factor to apply to the bounding box from dataset (before autoscale), " +
149 | "if bbox_from_data is used",
150 | )
151 | flags.DEFINE_bool(
152 | "autoscale",
153 | False,
154 | "Automatic scaling, after bbox_from_data",
155 | )
156 | flags.DEFINE_bool(
157 | "bbox_cube",
158 | False,
159 | "Force bbox to be a cube",
160 | )
161 | flags.DEFINE_float(
162 | "bbox_scale",
163 | 1.0,
164 | "Scaling factor to apply to the bounding box at the end (after load, autoscale)",
165 | )
166 | flags.DEFINE_float(
167 | "scale_alpha_thresh",
168 | 0.01,
169 | "Alpha threshold to keep a voxel in initial sigma thresholding for autoscale",
170 | )
171 | # For integrated eval (to avoid slow load)
172 | flags.DEFINE_bool(
173 | "eval",
174 | True,
175 | "Evaluate after building the octree",
176 | )
177 |
178 | device = "cuda" if torch.cuda.is_available() else "cpu"
179 |
180 |
181 | def calculate_grid_weights(dataset, sigmas, reso, invradius, offset):
182 | w, h, focal = dataset.w, dataset.h, dataset.focal
183 |
184 | opts = _C.RenderOptions()
185 | opts.step_size = FLAGS.renderer_step_size
186 | opts.sigma_thresh = 0.0
187 | if 'llff' in FLAGS.config and (not FLAGS.spherify):
188 | ndc_config = NDCConfig(width=w, height=h, focal=focal)
189 | opts.ndc_width = ndc_config.width
190 | opts.ndc_height = ndc_config.height
191 | opts.ndc_focal = ndc_config.focal
192 | else:
193 | opts.ndc_width = -1
194 |
195 | cam = _C.CameraSpec()
196 | cam.fx = focal
197 | cam.fy = focal
198 | cam.width = w
199 | cam.height = h
200 |
201 | grid_data = sigmas.reshape((reso, reso, reso))
202 | maximum_weight = torch.zeros_like(grid_data)
203 | for idx in tqdm(range(dataset.size)):
204 | cam.c2w = torch.from_numpy(dataset.camtoworlds[idx]).float().to(sigmas.device)
205 | grid_weight, grid_hit = _C.grid_weight_render(
206 | grid_data,
207 | cam,
208 | opts,
209 | offset,
210 | invradius,
211 | )
212 | maximum_weight = torch.max(maximum_weight, grid_weight)
213 |
214 | return maximum_weight
215 |
216 |
217 | def project_nerf_to_sh(nerf, sh_deg, points):
218 | """
219 | Args:
220 | points: [N, 3]
221 | Returns:
222 | coeffs for rgb. [N, C * (sh_deg + 1)**2]
223 | """
224 | nerf.use_viewdirs = True
225 |
226 | def _sperical_func(viewdirs):
227 | # points: [num_points, 3]
228 | # viewdirs: [num_rays, 3]
229 | # raw_rgb: [num_points, num_rays, 3]
230 | # sigma: [num_points]
231 | raw_rgb, sigma = nerf.eval_points_raw(points, viewdirs, cross_broadcast=True)
232 | return raw_rgb, sigma
233 |
234 | coeffs, sigma = sh_proj.ProjectFunctionNeRF(
235 | order=sh_deg,
236 | sperical_func=_sperical_func,
237 | batch_size=points.shape[0],
238 | sample_count=FLAGS.projection_samples,
239 | device=points.device)
240 |
241 | return coeffs.reshape([points.shape[0], -1]), sigma
242 |
243 |
244 | def auto_scale(args, center, radius, nerf):
245 | print('* Step 0: Auto scale')
246 | reso = 2 ** args.init_grid_depth
247 |
248 | radius = torch.tensor(radius, dtype=torch.float32)
249 | center = torch.tensor(center, dtype=torch.float32)
250 | scale = 0.5 / radius
251 | offset = 0.5 * (1.0 - center / radius)
252 |
253 | arr = (torch.arange(0, reso, dtype=torch.float32) + 0.5) / reso
254 | xx = (arr - offset[0]) / scale[0]
255 | yy = (arr - offset[1]) / scale[1]
256 | zz = (arr - offset[2]) / scale[2]
257 | if args.z_min is not None:
258 | zz = zz[zz >= args.z_min]
259 | if args.z_max is not None:
260 | zz = zz[zz <= args.z_max]
261 |
262 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T
263 |
264 | out_chunks = []
265 | for i in tqdm(range(0, grid.shape[0], args.chunk)):
266 | grid_chunk = grid[i:i+args.chunk].cuda()
267 | if nerf.use_viewdirs:
268 | fake_viewdirs = torch.zeros([grid_chunk.shape[0], 3], device=grid_chunk.device)
269 | else:
270 | fake_viewdirs = None
271 | rgb, sigma = nerf.eval_points_raw(grid_chunk, fake_viewdirs)
272 | del grid_chunk
273 | out_chunks.append(sigma.squeeze(-1))
274 | sigmas = torch.cat(out_chunks, 0)
275 | del out_chunks
276 |
277 | approx_delta = 2.0 / reso
278 | sigma_thresh = -np.log(1.0 - args.scale_alpha_thresh) / approx_delta
279 | mask = sigmas >= sigma_thresh
280 |
281 | grid = grid[mask]
282 | del mask
283 |
284 | lc = grid.min(dim=0)[0] - 0.5 / reso
285 | uc = grid.max(dim=0)[0] + 0.5 / reso
286 | return ((lc + uc) * 0.5).tolist(), ((uc - lc) * 0.5).tolist()
287 |
288 | def step1(args, tree, nerf, dataset):
289 | print('* Step 1: Grid eval')
290 | reso = 2 ** (args.init_grid_depth + 1)
291 | offset = tree.offset.cpu()
292 | scale = tree.invradius.cpu()
293 |
294 | arr = (torch.arange(0, reso, dtype=torch.float32) + 0.5) / reso
295 | xx = (arr - offset[0]) / scale[0]
296 | yy = (arr - offset[1]) / scale[1]
297 | zz = (arr - offset[2]) / scale[2]
298 | if args.z_min is not None:
299 | zz = zz[zz >= args.z_min]
300 | if args.z_max is not None:
301 | zz = zz[zz <= args.z_max]
302 |
303 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T
304 | print('init grid', grid.shape)
305 |
306 | approx_delta = 2.0 / reso
307 | sigma_thresh = -np.log(1.0 - args.alpha_thresh) / approx_delta
308 |
309 | out_chunks = []
310 | for i in tqdm(range(0, grid.shape[0], args.chunk)):
311 | grid_chunk = grid[i:i+args.chunk].cuda()
312 | if nerf.use_viewdirs:
313 | fake_viewdirs = torch.zeros([grid_chunk.shape[0], 3], device=grid_chunk.device)
314 | else:
315 | fake_viewdirs = None
316 | rgb, sigma = nerf.eval_points_raw(grid_chunk, fake_viewdirs)
317 | del grid_chunk
318 | out_chunks.append(sigma.squeeze(-1))
319 | sigmas = torch.cat(out_chunks, 0)
320 | del out_chunks
321 |
322 | if FLAGS.masking_mode == "sigma":
323 | mask = sigmas >= sigma_thresh
324 | elif FLAGS.masking_mode == "weight":
325 | print ("* Calculating grid weights")
326 | grid_weights = calculate_grid_weights(dataset,
327 | sigmas, reso, tree.invradius, tree.offset)
328 | mask = grid_weights.reshape(-1) >= FLAGS.weight_thresh
329 | del grid_weights
330 | else:
331 | raise ValueError
332 | del sigmas
333 |
334 | grid = grid[mask]
335 | del mask
336 | print(grid.shape, grid.min(), grid.max())
337 | grid = grid.cuda()
338 |
339 | torch.cuda.empty_cache()
340 | print(' Building octree')
341 | for i in range(args.init_grid_depth - 1):
342 | tree[grid].refine()
343 | refine_chunk = 2000000
344 | if grid.shape[0] <= refine_chunk:
345 | tree[grid].refine()
346 | else:
347 | # Do last layer separately
348 | grid = grid.cpu()
349 | for j in tqdm(range(0, grid.shape[0], refine_chunk)):
350 | tree[grid[j:j+refine_chunk].cuda()].refine()
351 | print(tree)
352 |
353 | assert tree.max_depth == args.init_grid_depth
354 |
355 | def step2(args, tree, nerf):
356 | print('* Step 2: AA', args.samples_per_cell)
357 |
358 | leaf_mask = tree.depths.cpu() == tree.max_depth
359 | leaf_ind = torch.where(leaf_mask)[0]
360 | del leaf_mask
361 |
362 | if args.use_viewdirs:
363 | chunk_size = args.chunk // (args.samples_per_cell * args.projection_samples // 10)
364 | else:
365 | chunk_size = args.chunk // (args.samples_per_cell)
366 |
367 | for i in tqdm(range(0, leaf_ind.size(0), chunk_size)):
368 | chunk_inds = leaf_ind[i:i+chunk_size]
369 | points = tree[chunk_inds].sample(args.samples_per_cell) # (n_cells, n_samples, 3)
370 | points = points.view(-1, 3)
371 |
372 | if not args.use_viewdirs: # trained NeRF-SH/SG model returns rgb as coeffs
373 | rgb, sigma = nerf.eval_points_raw(points)
374 | else: # vanilla NeRF model returns rgb, so we project them into coeffs (only SH supported)
375 | rgb, sigma = project_nerf_to_sh(nerf, args.sh_deg, points)
376 |
377 | if tree.data_format.format == tree.data_format.RGBA:
378 | rgb = rgb.reshape(-1, args.samples_per_cell, tree.data_dim - 1);
379 | sigma = sigma.reshape(-1, args.samples_per_cell, 1);
380 | sigma_avg = sigma.mean(dim=1)
381 |
382 | reso = 2 ** (args.init_grid_depth + 1)
383 | approx_delta = 2.0 / reso
384 | alpha = 1.0 - torch.exp(-approx_delta * sigma)
385 | msum = alpha.sum(dim=1)
386 | rgb_avg = (rgb * alpha).sum(dim=1) / msum
387 | rgb_avg[msum[..., 0] < 1e-3] = 0
388 | rgba = torch.cat([rgb_avg, sigma_avg], dim=-1)
389 | del rgb, sigma
390 | else:
391 | rgba = torch.cat([rgb, sigma], dim=-1)
392 | del rgb, sigma
393 | rgba = rgba.reshape(-1, args.samples_per_cell, tree.data_dim).mean(dim=1)
394 | tree[chunk_inds] = rgba
395 |
396 | def euler2mat(angle):
397 | """Convert euler angles to rotation matrix.
398 |
399 | Args:
400 | angle: rotation angle along 3 axis (in radians). [..., 3]
401 | Returns:
402 | Rotation matrix corresponding to the euler angles. [..., 3, 3]
403 | """
404 | x, y, z = angle[..., 0], angle[..., 1], angle[..., 2]
405 | cosz = torch.cos(z)
406 | sinz = torch.sin(z)
407 | cosy = torch.cos(y)
408 | siny = torch.sin(y)
409 | cosx = torch.cos(x)
410 | sinx = torch.sin(x)
411 | zeros = torch.zeros_like(z)
412 | ones = torch.ones_like(z)
413 | zmat = torch.stack([torch.stack([cosz, -sinz, zeros], dim=-1),
414 | torch.stack([sinz, cosz, zeros], dim=-1),
415 | torch.stack([zeros, zeros, ones], dim=-1)], dim=-1)
416 | ymat = torch.stack([torch.stack([ cosy, zeros, siny], dim=-1),
417 | torch.stack([zeros, ones, zeros], dim=-1),
418 | torch.stack([-siny, zeros, cosy], dim=-1)], dim=-1)
419 | xmat = torch.stack([torch.stack([ ones, zeros, zeros], dim=-1),
420 | torch.stack([zeros, cosx, -sinx], dim=-1),
421 | torch.stack([zeros, sinx, cosx], dim=-1)], dim=-1)
422 | rotMat = torch.einsum("...ij,...jk,...kq->...iq", xmat, ymat, zmat)
423 | return rotMat
424 |
425 | @torch.no_grad()
426 | def main(unused_argv):
427 | utils.set_random_seed(20200823)
428 | utils.update_flags(FLAGS)
429 |
430 | print('* Loading NeRF')
431 | nerf = models.get_model_state(FLAGS, device=device, restore=True)
432 | nerf.eval()
433 |
434 | data_format = None
435 | extra_data = None
436 | if FLAGS.sg_dim > 0:
437 | data_format = f'SG{FLAGS.sg_dim}'
438 | assert FLAGS.sg_global
439 | extra_data = torch.cat((
440 | F.softplus(nerf.sg_lambda[:, None]),
441 | sh_proj.spher2cart(nerf.sg_mu_spher[:, 0], nerf.sg_mu_spher[:, 1])
442 | ), dim=-1)
443 | elif FLAGS.sh_deg > 0:
444 | data_format = f'SH{(FLAGS.sh_deg + 1) ** 2}'
445 | if data_format is not None:
446 | print('Detected format:', data_format)
447 |
448 | base_dir = osp.dirname(FLAGS.output)
449 | if base_dir:
450 | os.makedirs(base_dir, exist_ok=True)
451 |
452 | assert FLAGS.data_dir # Dataset is required now
453 | dataset = datasets.get_dataset("train", FLAGS)
454 |
455 | if FLAGS.bbox_from_data:
456 | assert dataset.bbox is not None # Dataset must be NSVF
457 | center = (dataset.bbox[:3] + dataset.bbox[3:6]) * 0.5
458 | radius = (dataset.bbox[3:6] - dataset.bbox[:3]) * 0.5 * FLAGS.data_bbox_scale
459 | print('Bounding box from data: c', center, 'r', radius)
460 | else:
461 | center = list(map(float, FLAGS.center.split()))
462 | if len(center) == 1:
463 | center *= 3
464 | radius = list(map(float, FLAGS.radius.split()))
465 | if len(radius) == 1:
466 | radius *= 3
467 |
468 | if FLAGS.autoscale:
469 | center, radius = auto_scale(FLAGS, center, radius, nerf)
470 | print('Autoscale result center', center, 'radius', radius)
471 |
472 | radius = [r * FLAGS.bbox_scale for r in radius]
473 | if FLAGS.bbox_cube:
474 | radius = [max(radius)] * 3
475 |
476 | num_rgb_channels = FLAGS.num_rgb_channels
477 | if FLAGS.sh_deg >= 0:
478 | assert FLAGS.sg_dim == -1, (
479 | "You can only use up to one of: SH or SG")
480 | num_rgb_channels *= (FLAGS.sh_deg + 1) ** 2
481 | elif FLAGS.sg_dim > 0:
482 | assert FLAGS.sh_deg == -1, (
483 | "You can only use up to one of: SH or SG")
484 | num_rgb_channels *= FLAGS.sg_dim
485 | data_dim = 1 + num_rgb_channels # alpha + rgb
486 | print('data dim is', data_dim)
487 |
488 | print('* Creating model')
489 | tree = N3Tree(N=FLAGS.tree_branch_n,
490 | data_dim=data_dim,
491 | init_refine=0,
492 | init_reserve=500000,
493 | geom_resize_fact=1.0,
494 | depth_limit=FLAGS.init_grid_depth,
495 | radius=radius,
496 | center=center,
497 | data_format=data_format,
498 | extra_data=extra_data,
499 | map_location=device)
500 |
501 | step1(FLAGS, tree, nerf, dataset)
502 | step2(FLAGS, tree, nerf)
503 | tree[:, -1:].relu_()
504 | tree.shrink_to_fit()
505 | print(tree)
506 |
507 | del dataset.images
508 | print('* Saving', FLAGS.output)
509 | tree.save(FLAGS.output, compress=False) # Faster saving
510 |
511 | if FLAGS.eval:
512 | dataset = datasets.get_dataset("test", FLAGS)
513 | print('* Evaluation (before fine tune)')
514 | avg_psnr, avg_ssim, avg_lpips, out_frames = utils.eval_octree(tree,
515 | dataset, FLAGS, want_lpips=True)
516 | print('Average PSNR', avg_psnr, 'SSIM', avg_ssim, 'LPIPS', avg_lpips)
517 |
518 |
519 | if __name__ == "__main__":
520 | app.run(main)
521 |
--------------------------------------------------------------------------------
/octree/nerf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sxyu/plenoctree/f0b82631199a1aa7dc9ce263c08980a3a7504014/octree/nerf/__init__.py
--------------------------------------------------------------------------------
/octree/nerf/datasets.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Different datasets implementation plus a general port for all the datasets."""
19 | INTERNAL = False # pylint: disable=g-statement-before-imports
20 | import json
21 | import os
22 | from os import path
23 |
24 | if not INTERNAL:
25 | import cv2 # pylint: disable=g-import-not-at-top
26 | import numpy as np
27 | from PIL import Image
28 | from tqdm import tqdm
29 |
30 | from octree.nerf import utils
31 |
32 |
33 | def get_dataset(split, args):
34 | return dataset_dict[args.dataset](split, args)
35 |
36 |
37 | def convert_to_ndc(origins, directions, focal, w, h, near=1.0):
38 | """Convert a set of rays to NDC coordinates."""
39 | # Shift ray origins to near plane
40 | t = -(near + origins[Ellipsis, 2]) / directions[Ellipsis, 2]
41 | origins = origins + t[Ellipsis, None] * directions
42 |
43 | dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
44 | ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
45 |
46 | # Projection
47 | o0 = -((2 * focal) / w) * (ox / oz)
48 | o1 = -((2 * focal) / h) * (oy / oz)
49 | o2 = 1 + 2 * near / oz
50 |
51 | d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
52 | d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
53 | d2 = -2 * near / oz
54 |
55 | origins = np.stack([o0, o1, o2], -1)
56 | directions = np.stack([d0, d1, d2], -1)
57 | return origins, directions
58 |
59 |
60 | class Dataset():
61 | """Dataset Base Class."""
62 |
63 | def __init__(self, split, args, prefetch=True):
64 | super(Dataset, self).__init__()
65 | self.split = split
66 | self._general_init(args)
67 |
68 | @property
69 | def size(self):
70 | return self.n_examples
71 |
72 | def _general_init(self, args):
73 | bbox_path = path.join(args.data_dir, 'bbox.txt')
74 | if os.path.isfile(bbox_path):
75 | self.bbox = np.loadtxt(bbox_path)[:-1]
76 | else:
77 | self.bbox = None
78 | self._load_renderings(args)
79 |
80 |
81 | class Blender(Dataset):
82 | """Blender Dataset."""
83 |
84 | def _load_renderings(self, args):
85 | """Load images from disk."""
86 | if args.render_path:
87 | raise ValueError("render_path cannot be used for the blender dataset.")
88 | with utils.open_file(
89 | path.join(args.data_dir, "transforms_{}.json".format(self.split)), "r"
90 | ) as fp:
91 | meta = json.load(fp)
92 | images = []
93 | cams = []
94 | print(' Load Blender', args.data_dir, 'split', self.split)
95 | for i in tqdm(range(len(meta["frames"]))):
96 | frame = meta["frames"][i]
97 | fname = os.path.join(args.data_dir, frame["file_path"] + ".png")
98 | with utils.open_file(fname, "rb") as imgin:
99 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0
100 | if args.factor == 2:
101 | [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]]
102 | image = cv2.resize(
103 | image, (halfres_w, halfres_h), interpolation=cv2.INTER_AREA
104 | )
105 | elif args.factor > 0:
106 | raise ValueError(
107 | "Blender dataset only supports factor=0 or 2, {} "
108 | "set.".format(args.factor)
109 | )
110 | cams.append(frame["transform_matrix"])
111 | if args.white_bkgd:
112 | mask = image[..., -1:]
113 | image = image[..., :3] * mask + (1.0 - mask)
114 | else:
115 | image = image[..., :3]
116 | images.append(image)
117 | self.images = np.stack(images, axis=0)
118 | self.h, self.w = self.images.shape[1:3]
119 | self.resolution = self.h * self.w
120 | self.camtoworlds = np.stack(cams, axis=0).astype(np.float32)
121 | camera_angle_x = float(meta["camera_angle_x"])
122 | self.focal = 0.5 * self.w / np.tan(0.5 * camera_angle_x)
123 | self.n_examples = self.images.shape[0]
124 |
125 |
126 | class LLFF(Dataset):
127 | """LLFF Dataset."""
128 |
129 | def _load_renderings(self, args):
130 | """Load images from disk."""
131 | args.data_dir = path.expanduser(args.data_dir)
132 | print(' Load LLFF', args.data_dir, 'split', self.split)
133 | # Load images.
134 | imgdir_suffix = ""
135 | if args.factor > 0:
136 | imgdir_suffix = "_{}".format(args.factor)
137 | factor = args.factor
138 | else:
139 | factor = 1
140 | imgdir = path.join(args.data_dir, "images" + imgdir_suffix)
141 | if not utils.file_exists(imgdir):
142 | raise ValueError("Image folder {} doesn't exist.".format(imgdir))
143 | imgfiles = [
144 | path.join(imgdir, f)
145 | for f in sorted(utils.listdir(imgdir))
146 | if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
147 | ]
148 | images = []
149 | for imgfile in imgfiles:
150 | with utils.open_file(imgfile, "rb") as imgin:
151 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0
152 | images.append(image)
153 | images = np.stack(images, axis=-1)
154 |
155 | # Load poses and bds.
156 | with utils.open_file(path.join(args.data_dir, "poses_bounds.npy"), "rb") as fp:
157 | poses_arr = np.load(fp)
158 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
159 | bds = poses_arr[:, -2:].transpose([1, 0])
160 | if poses.shape[-1] != images.shape[-1]:
161 | raise RuntimeError(
162 | "Mismatch between imgs {} and poses {}".format(
163 | images.shape[-1], poses.shape[-1]
164 | )
165 | )
166 |
167 | # Update poses according to downsampling.
168 | poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
169 | poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor
170 |
171 | # Correct rotation matrix ordering and move variable dim to axis 0.
172 | poses = np.concatenate(
173 | [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1
174 | )
175 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
176 | images = np.moveaxis(images, -1, 0)
177 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
178 |
179 | # Rescale according to a default bd factor.
180 | scale = 1.0 / (bds.min() * 0.75)
181 | poses[:, :3, 3] *= scale
182 | bds *= scale
183 |
184 | # Recenter poses.
185 | poses = self._recenter_poses(poses)
186 |
187 | # Generate a spiral/spherical ray path for rendering videos.
188 | if args.spherify:
189 | poses = self._generate_spherical_poses(poses, bds)
190 | self.spherify = True
191 | else:
192 | self.spherify = False
193 | if not args.spherify and self.split == "test":
194 | self._generate_spiral_poses(poses, bds)
195 |
196 | # Select the split.
197 | i_test = np.arange(images.shape[0])[:: args.llffhold]
198 | i_train = np.array(
199 | [i for i in np.arange(int(images.shape[0])) if i not in i_test]
200 | )
201 | if self.split == "train":
202 | indices = i_train
203 | else:
204 | indices = i_test
205 | images = images[indices]
206 | poses = poses[indices]
207 |
208 | self.images = images
209 | self.camtoworlds = poses[:, :3, :4]
210 | self.focal = poses[0, -1, -1]
211 | self.h, self.w = images.shape[1:3]
212 | self.resolution = self.h * self.w
213 | if args.render_path:
214 | self.n_examples = self.render_poses.shape[0]
215 | else:
216 | self.n_examples = images.shape[0]
217 |
218 | def _generate_rays(self):
219 | """Generate normalized device coordinate rays for llff."""
220 | if self.split == "test":
221 | n_render_poses = self.render_poses.shape[0]
222 | self.camtoworlds = np.concatenate(
223 | [self.render_poses, self.camtoworlds], axis=0
224 | )
225 |
226 | super()._generate_rays()
227 |
228 | if not self.spherify:
229 | ndc_origins, ndc_directions = convert_to_ndc(
230 | self.rays.origins, self.rays.directions, self.focal, self.w, self.h
231 | )
232 | self.rays = utils.Rays(
233 | origins=ndc_origins,
234 | directions=ndc_directions,
235 | viewdirs=self.rays.viewdirs,
236 | )
237 |
238 | # Split poses from the dataset and generated poses
239 | if self.split == "test":
240 | self.camtoworlds = self.camtoworlds[n_render_poses:]
241 | split = [np.split(r, [n_render_poses], 0) for r in self.rays]
242 | split0, split1 = zip(*split)
243 | self.render_rays = utils.Rays(*split0)
244 | self.rays = utils.Rays(*split1)
245 |
246 | def _recenter_poses(self, poses):
247 | """Recenter poses according to the original NeRF code."""
248 | poses_ = poses.copy()
249 | bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
250 | c2w = self._poses_avg(poses)
251 | c2w = np.concatenate([c2w[:3, :4], bottom], -2)
252 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
253 | poses = np.concatenate([poses[:, :3, :4], bottom], -2)
254 | poses = np.linalg.inv(c2w) @ poses
255 | poses_[:, :3, :4] = poses[:, :3, :4]
256 | poses = poses_
257 | return poses
258 |
259 | def _poses_avg(self, poses):
260 | """Average poses according to the original NeRF code."""
261 | hwf = poses[0, :3, -1:]
262 | center = poses[:, :3, 3].mean(0)
263 | vec2 = self._normalize(poses[:, :3, 2].sum(0))
264 | up = poses[:, :3, 1].sum(0)
265 | c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1)
266 | return c2w
267 |
268 | def _viewmatrix(self, z, up, pos):
269 | """Construct lookat view matrix."""
270 | vec2 = self._normalize(z)
271 | vec1_avg = up
272 | vec0 = self._normalize(np.cross(vec1_avg, vec2))
273 | vec1 = self._normalize(np.cross(vec2, vec0))
274 | m = np.stack([vec0, vec1, vec2, pos], 1)
275 | return m
276 |
277 | def _normalize(self, x):
278 | """Normalization helper function."""
279 | return x / np.linalg.norm(x)
280 |
281 | def _generate_spiral_poses(self, poses, bds):
282 | """Generate a spiral path for rendering."""
283 | c2w = self._poses_avg(poses)
284 | # Get average pose.
285 | up = self._normalize(poses[:, :3, 1].sum(0))
286 | # Find a reasonable "focus depth" for this dataset.
287 | close_depth, inf_depth = bds.min() * 0.9, bds.max() * 5.0
288 | dt = 0.75
289 | mean_dz = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
290 | focal = mean_dz
291 | # Get radii for spiral path.
292 | tt = poses[:, :3, 3]
293 | rads = np.percentile(np.abs(tt), 90, 0)
294 | c2w_path = c2w
295 | n_views = 120
296 | n_rots = 2
297 | # Generate poses for spiral path.
298 | render_poses = []
299 | rads = np.array(list(rads) + [1.0])
300 | hwf = c2w_path[:, 4:5]
301 | zrate = 0.5
302 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_views + 1)[:-1]:
303 | c = np.dot(
304 | c2w[:3, :4],
305 | (
306 | np.array(
307 | [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
308 | )
309 | * rads
310 | ),
311 | )
312 | z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
313 | render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1))
314 | self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4]
315 |
316 | def _generate_spherical_poses(self, poses, bds):
317 | """Generate a 360 degree spherical path for rendering."""
318 | # pylint: disable=g-long-lambda
319 | p34_to_44 = lambda p: np.concatenate(
320 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
321 | )
322 | rays_d = poses[:, :3, 2:3]
323 | rays_o = poses[:, :3, 3:4]
324 |
325 | def min_line_dist(rays_o, rays_d):
326 | a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
327 | b_i = -a_i @ rays_o
328 | pt_mindist = np.squeeze(
329 | -np.linalg.inv((np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0))
330 | @ (b_i).mean(0)
331 | )
332 | return pt_mindist
333 |
334 | pt_mindist = min_line_dist(rays_o, rays_d)
335 | center = pt_mindist
336 | up = (poses[:, :3, 3] - center).mean(0)
337 | vec0 = self._normalize(up)
338 | vec1 = self._normalize(np.cross([0.1, 0.2, 0.3], vec0))
339 | vec2 = self._normalize(np.cross(vec0, vec1))
340 | pos = center
341 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
342 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
343 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
344 | sc = 1.0 / rad
345 | poses_reset[:, :3, 3] *= sc
346 | bds *= sc
347 | rad *= sc
348 | centroid = np.mean(poses_reset[:, :3, 3], 0)
349 | zh = centroid[2]
350 | radcircle = np.sqrt(rad ** 2 - zh ** 2)
351 | new_poses = []
352 |
353 | for th in np.linspace(0.0, 2.0 * np.pi, 120):
354 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
355 | up = np.array([0, 0, -1.0])
356 | vec2 = self._normalize(camorigin)
357 | vec0 = self._normalize(np.cross(vec2, up))
358 | vec1 = self._normalize(np.cross(vec2, vec0))
359 | pos = camorigin
360 | p = np.stack([vec0, vec1, vec2, pos], 1)
361 | new_poses.append(p)
362 |
363 | new_poses = np.stack(new_poses, 0)
364 | new_poses = np.concatenate(
365 | [
366 | new_poses,
367 | np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape),
368 | ],
369 | -1,
370 | )
371 | poses_reset = np.concatenate(
372 | [
373 | poses_reset[:, :3, :4],
374 | np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape),
375 | ],
376 | -1,
377 | )
378 | if self.split == "test":
379 | self.render_poses = new_poses[:, :3, :4]
380 | return poses_reset
381 |
382 |
383 | class NSVF(Dataset):
384 | """NSVF Generic Dataset."""
385 |
386 | def _load_renderings(self, args):
387 | """Load images from disk."""
388 | if args.render_path:
389 | raise ValueError("render_path cannot be used for the NSVF dataset.")
390 | args.data_dir = path.expanduser(args.data_dir)
391 | K : np.ndarray = np.loadtxt(path.join(args.data_dir, "intrinsics.txt"))
392 | pose_files = sorted(os.listdir(path.join(args.data_dir, 'pose')))
393 | img_files = sorted(os.listdir(path.join(args.data_dir, 'rgb')))
394 |
395 | if self.split == 'train':
396 | pose_files = [x for x in pose_files if x.startswith('0_')]
397 | img_files = [x for x in img_files if x.startswith('0_')]
398 | elif self.split == 'val':
399 | pose_files = [x for x in pose_files if x.startswith('1_')]
400 | img_files = [x for x in img_files if x.startswith('1_')]
401 | elif self.split == 'test':
402 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
403 | test_img_files = [x for x in img_files if x.startswith('2_')]
404 | if len(test_pose_files) == 0:
405 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
406 | test_img_files = [x for x in img_files if x.startswith('1_')]
407 | pose_files = test_pose_files
408 | img_files = test_img_files
409 |
410 | images = []
411 | cams = []
412 |
413 | cam_trans = np.diag(np.array([1, -1, -1, 1], dtype=np.float32))
414 |
415 | assert len(img_files) == len(pose_files)
416 | print(' Load NSVF', args.data_dir, 'split', self.split, 'num_images', len(img_files))
417 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), total=len(img_files)):
418 | img_fname = path.join(args.data_dir, 'rgb', img_fname)
419 | with utils.open_file(img_fname, "rb") as imgin:
420 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0
421 | cam_mtx = np.loadtxt(path.join(args.data_dir, 'pose', pose_fname)) @ cam_trans
422 | cams.append(cam_mtx) # C2W
423 | if image.shape[-1] == 4:
424 | # Alpha channel available
425 | if args.white_bkgd:
426 | mask = image[..., -1:]
427 | image = image[..., :3] * mask + (1.0 - mask)
428 | else:
429 | image = image[..., :3]
430 | if args.factor > 1:
431 | [rsz_h, rsz_w] = [hw // args.factor for hw in image.shape[:2]]
432 | image = cv2.resize(
433 | image, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA
434 | )
435 |
436 | images.append(image)
437 | self.images = np.stack(images, axis=0)
438 | self.n_examples, self.h, self.w = self.images.shape[:3]
439 | self.resolution = self.h * self.w
440 | self.camtoworlds = np.stack(cams, axis=0).astype(np.float32)
441 | # We assume fx and fy are same
442 | self.focal = (K[0, 0] + K[1, 1]) * 0.5
443 | if args.factor > 1:
444 | self.focal /= args.factor
445 |
446 |
447 | dataset_dict = {
448 | "blender": Blender,
449 | "llff": LLFF,
450 | "nsvf": NSVF,
451 | }
452 |
--------------------------------------------------------------------------------
/octree/nerf/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Helper functions/classes for model definition."""
19 |
20 | import functools
21 | from typing import Any, Callable
22 | import math
23 |
24 | import torch
25 | import torch.nn as nn
26 |
27 |
28 | def dense_layer(in_features, out_features):
29 | layer = nn.Linear(in_features, out_features)
30 | # The initialization matters!
31 | nn.init.xavier_uniform_(layer.weight)
32 | nn.init.zeros_(layer.bias)
33 | return layer
34 |
35 |
36 | class MLP(nn.Module):
37 | """A simple MLP."""
38 |
39 | def __init__(
40 | self,
41 | net_depth: int = 8, # The depth of the first part of MLP.
42 | net_width: int = 256, # The width of the first part of MLP.
43 | net_depth_condition: int = 1, # The depth of the second part of MLP.
44 | net_width_condition: int = 128, # The width of the second part of MLP.
45 | net_activation: Callable[Ellipsis, Any] = nn.ReLU(), # The activation function.
46 | skip_layer: int = 4, # The layer to add skip layers to.
47 | num_rgb_channels: int = 3, # The number of RGB channels.
48 | num_sigma_channels: int = 1, # The number of sigma channels.
49 | input_dim: int = 63, # The number of input tensor channels.
50 | condition_dim: int = 27, # The number of conditional tensor channels.
51 | ):
52 | super(MLP, self).__init__()
53 | self.net_depth = net_depth
54 | self.net_width = net_width
55 | self.net_depth_condition = net_depth_condition
56 | self.net_width_condition = net_width_condition
57 | self.net_activation = net_activation
58 | self.skip_layer = skip_layer
59 | self.num_rgb_channels = num_rgb_channels
60 | self.num_sigma_channels = num_sigma_channels
61 | self.input_dim = input_dim
62 | self.condition_dim = condition_dim
63 |
64 | self.input_layers = nn.ModuleList()
65 | in_features = self.input_dim
66 | for i in range(self.net_depth):
67 | self.input_layers.append(
68 | dense_layer(in_features, self.net_width)
69 | )
70 | if i % self.skip_layer == 0 and i > 0:
71 | in_features = self.net_width + self.input_dim
72 | else:
73 | in_features = self.net_width
74 | self.sigma_layer = dense_layer(in_features, self.num_sigma_channels)
75 |
76 | if self.condition_dim > 0:
77 | self.bottleneck_layer = dense_layer(in_features, self.net_width)
78 | self.condition_layers = nn.ModuleList()
79 | in_features = self.net_width + self.condition_dim
80 | for i in range(self.net_depth_condition):
81 | self.condition_layers.append(
82 | dense_layer(in_features, self.net_width_condition)
83 | )
84 | in_features = self.net_width_condition
85 | self.rgb_layer = dense_layer(in_features, self.num_rgb_channels)
86 |
87 | def forward(self, x, condition=None, cross_broadcast=False):
88 | """Evaluate the MLP.
89 |
90 | Args:
91 | x: torch.tensor(float32), [batch, num_samples, feature], points.
92 | condition: torch.tensor(float32),
93 | [batch, feature] or [batch, num_samples, feature] or [batch, num_rays, feature],
94 | if not None, this variable will be part of the input to the second part of the MLP
95 | concatenated with the output vector of the first part of the MLP. If
96 | None, only the first part of the MLP will be used with input x. In the
97 | original paper, this variable is the view direction. Note when the shape of this
98 | tensor is [batch, num_rays, feature], where `num_rays` != `num_samples`, this
99 | function will cross broadcast all rays with all samples. And the `cross_broadcast`
100 | option must be set to `True`.
101 | cross_broadcast: if true, cross broadcast the x tensor and the condition
102 | tensor.
103 |
104 | Returns:
105 | raw_rgb: torch.tensor(float32), with a shape of
106 | [batch, num_samples, num_rgb_channels]. If `cross_broadcast` is true, the return
107 | shape would be [batch, num_samples, num_rays, num_rgb_channels].
108 | raw_sigma: torch.tensor(float32), with a shape of
109 | [batch, num_samples, num_sigma_channels]. If `cross_broadcast` is true, the return
110 | shape woudl be [batch, num_samples, num_rays, num_sigma_channels].
111 | """
112 | batch_size = x.shape[0]
113 | feature_dim = x.shape[-1]
114 | num_samples = x.shape[1]
115 | x = x.view([-1, feature_dim])
116 | inputs = x
117 | for i in range(self.net_depth):
118 | x = self.input_layers[i](x)
119 | x = self.net_activation(x)
120 | if i % self.skip_layer == 0 and i > 0:
121 | x = torch.cat([x, inputs], dim=-1)
122 | raw_sigma = self.sigma_layer(x).view(
123 | [-1, num_samples, self.num_sigma_channels]
124 | )
125 |
126 | if condition is not None:
127 | # Output of the first part of MLP.
128 | bottleneck = self.bottleneck_layer(x)
129 | # Broadcast condition from [batch, feature] to
130 | # [batch, num_samples, feature] since all the samples along the same ray
131 | # have the same viewdir.
132 | if len(condition.shape) == 2 and (not cross_broadcast):
133 | condition = condition[:, None, :].repeat(1, num_samples, 1)
134 | # Broadcast samples from [batch, num_samples, feature]
135 | # and condition from [batch, num_rays, feature] to
136 | # [batch, num_samples, num_rays, feature] since in this case each point
137 | # is passed by all the rays. This option is used for projecting an
138 | # trained vanilla NeRF to NeRF-SH.
139 | if cross_broadcast:
140 | condition = condition.view([batch_size, -1, condition.shape[-1]])
141 | num_rays = condition.shape[1]
142 | condition = condition[:, None, :, :].repeat(1, num_samples, 1, 1)
143 | bottleneck = bottleneck.view([batch_size, -1, bottleneck.shape[-1]])
144 | bottleneck = bottleneck[:, :, None, :].repeat(1, 1, num_rays, 1)
145 | # Collapse the [batch, num_samples, (num_rays,) feature] tensor to
146 | # [batch * num_samples (* num_rays), feature] so that it can be fed into nn.Dense.
147 | x = torch.cat([
148 | bottleneck.view([-1, bottleneck.shape[-1]]),
149 | condition.view([-1, condition.shape[-1]])], dim=-1)
150 | # Here use 1 extra layer to align with the original nerf model.
151 | for i in range(self.net_depth_condition):
152 | x = self.condition_layers[i](x)
153 | x = self.net_activation(x)
154 | raw_rgb = self.rgb_layer(x).view(
155 | [batch_size, num_samples, self.num_rgb_channels] if not cross_broadcast else \
156 | [batch_size, num_samples, num_rays, self.num_rgb_channels]
157 | )
158 | return raw_rgb, raw_sigma
159 |
160 |
161 | def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
162 | """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
163 |
164 | Instead of computing [sin(x), cos(x)], we use the trig identity
165 | cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
166 |
167 | Args:
168 | x: torch.tensor, variables to be encoded. Note that x should be in [-pi, pi].
169 | min_deg: int, the minimum (inclusive) degree of the encoding.
170 | max_deg: int, the maximum (exclusive) degree of the encoding.
171 | legacy_posenc_order: bool, keep the same ordering as the original tf code.
172 |
173 | Returns:
174 | encoded: torch.tensor, encoded variables.
175 | """
176 | if min_deg == max_deg:
177 | return x
178 | scales = torch.tensor([2 ** i for i in range(min_deg, max_deg)],
179 | dtype=x.dtype, device=x.device)
180 | if legacy_posenc_order:
181 | xb = x[Ellipsis, None, :] * scales[:, None]
182 | four_feat = torch.reshape(
183 | torch.sin(torch.stack([xb, xb + 0.5 * math.pi], -2)), list(x.shape[:-1]) + [-1]
184 | )
185 | else:
186 | xb = torch.reshape(
187 | (x[Ellipsis, None, :] * scales[:, None]), list(x.shape[:-1]) + [-1]
188 | )
189 | four_feat = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
190 | return torch.cat([x] + [four_feat], dim=-1)
191 |
192 |
--------------------------------------------------------------------------------
/octree/nerf/models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Different model implementation plus a general port for all the models."""
19 | import os, glob
20 | import inspect
21 | from typing import Any, Callable
22 | import math
23 |
24 | import torch
25 | import torch.nn as nn
26 |
27 | from octree.nerf import model_utils
28 |
29 |
30 | def get_model(args):
31 | """A helper function that wraps around a 'model zoo'."""
32 | model_dict = {
33 | "nerf": construct_nerf,
34 | }
35 | return model_dict[args.model](args)
36 |
37 |
38 | def get_model_state(args, device="cpu", restore=True):
39 | """
40 | Helper for loading model with get_model & creating optimizer &
41 | optionally restoring checkpoint to reduce boilerplate
42 | """
43 | model = get_model(args).to(device)
44 | if restore:
45 | if args.is_jaxnerf_ckpt:
46 | model = restore_model_state_from_jaxnerf(args, model)
47 | else:
48 | model = restore_model_state(args, model)
49 | return model
50 |
51 |
52 | def restore_model_state(args, model):
53 | """
54 | Helper for restoring checkpoint.
55 | """
56 | ckpt_paths = sorted(
57 | glob.glob(os.path.join(args.train_dir, "*.ckpt")))
58 | if len(ckpt_paths) > 0:
59 | ckpt_path = ckpt_paths[-1]
60 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
61 | model.load_state_dict(ckpt["model"])
62 | print (f"* restore ckpt from {ckpt_path}.")
63 | return model
64 |
65 |
66 | def restore_model_state_from_jaxnerf(args, model):
67 | """
68 | Helper for restoring checkpoint for jaxnerf.
69 | """
70 | from flax.training import checkpoints
71 | ckpt_paths = sorted(
72 | glob.glob(os.path.join(args.train_dir, "checkpoint_*")))
73 |
74 | if len(ckpt_paths) > 0:
75 | ckpt_dict = checkpoints.restore_checkpoint(
76 | args.train_dir, target=None)["optimizer"]["target"]["params"]
77 | state_dict = {}
78 |
79 | def _init_layer(from_name, to_name):
80 | state_dict[f"MLP_0.{to_name}.weight"] = \
81 | ckpt_dict["MLP_0"][f"{from_name}"]["kernel"].T
82 | state_dict[f"MLP_0.{to_name}.bias"] = \
83 | ckpt_dict["MLP_0"][f"{from_name}"]["bias"]
84 | state_dict[f"MLP_1.{to_name}.weight"] = \
85 | ckpt_dict["MLP_1"][f"{from_name}"]["kernel"].T
86 | state_dict[f"MLP_1.{to_name}.bias"] = \
87 | ckpt_dict["MLP_1"][f"{from_name}"]["bias"]
88 | pass
89 |
90 | # init all layers
91 | for i in range(model.net_depth):
92 | _init_layer(f"Dense_{i}", f"input_layers.{i}")
93 | i += 1
94 | _init_layer(f"Dense_{i}", f"sigma_layer")
95 | if model.use_viewdirs:
96 | i += 1
97 | _init_layer(f"Dense_{i}", f"bottleneck_layer")
98 | for j in range(model.net_depth_condition):
99 | i += 1
100 | _init_layer(f"Dense_{i}", f"condition_layers.{j}")
101 | i += 1
102 | _init_layer(f"Dense_{i}", f"rgb_layer")
103 |
104 | # support SG
105 | if model.sg_dim > 0:
106 | state_dict["sg_lambda"] = ckpt_dict["sg_lambda"]
107 | state_dict["sg_mu_spher"] = ckpt_dict["sg_mu_spher"]
108 |
109 | for key in state_dict.keys():
110 | state_dict[key] = torch.from_numpy(state_dict[key].copy())
111 | model.load_state_dict(state_dict)
112 | print (f"* restore ckpt from {args.train_dir}")
113 | return model
114 |
115 |
116 | class NerfModel(nn.Module):
117 | """Nerf NN Model with both coarse and fine MLPs."""
118 |
119 | def __init__(
120 | self,
121 | num_coarse_samples: int = 64, # The number of samples for the coarse nerf.
122 | num_fine_samples: int = 128, # The number of samples for the fine nerf.
123 | use_viewdirs: bool = True, # If True, use viewdirs as an input.
124 | sh_deg: int = -1, # If != -1, use spherical harmonics output of given order
125 | sg_dim: int = -1, # If != -1, use spherical gaussians output of given dimension
126 | near: float = 2.0, # The distance to the near plane
127 | far: float = 6.0, # The distance to the far plane
128 | noise_std: float = 0.0, # The std dev of noise added to raw sigma.
129 | net_depth: int = 8, # The depth of the first part of MLP.
130 | net_width: int = 256, # The width of the first part of MLP.
131 | net_depth_condition: int = 1, # The depth of the second part of MLP.
132 | net_width_condition: int = 128, # The width of the second part of MLP.
133 | net_activation: Callable[Ellipsis, Any] = nn.ReLU(), # MLP activation
134 | skip_layer: int = 4, # How often to add skip connections.
135 | num_rgb_channels: int = 3, # The number of RGB channels.
136 | num_sigma_channels: int = 1, # The number of density channels.
137 | white_bkgd: bool = True, # If True, use a white background.
138 | min_deg_point: int = 0, # The minimum degree of positional encoding for positions.
139 | max_deg_point: int = 10, # The maximum degree of positional encoding for positions.
140 | deg_view: int = 4, # The degree of positional encoding for viewdirs.
141 | lindisp: bool = False, # If True, sample linearly in disparity rather than in depth.
142 | rgb_activation: Callable[Ellipsis, Any] = nn.Sigmoid(), # Output RGB activation.
143 | sigma_activation: Callable[Ellipsis, Any] = nn.ReLU(), # Output sigma activation.
144 | legacy_posenc_order: bool = False, # Keep the same ordering as the original tf code.
145 | ):
146 | super(NerfModel, self).__init__()
147 | self.num_coarse_samples = num_coarse_samples
148 | self.num_fine_samples = num_fine_samples
149 | self.use_viewdirs = use_viewdirs
150 | self.sh_deg = sh_deg
151 | self.sg_dim = sg_dim
152 | self.near = near
153 | self.far = far
154 | self.noise_std = noise_std
155 | self.net_depth = net_depth
156 | self.net_width = net_width
157 | self.net_depth_condition = net_depth_condition
158 | self.net_width_condition = net_width_condition
159 | self.net_activation = net_activation
160 | self.skip_layer = skip_layer
161 | self.num_rgb_channels = num_rgb_channels
162 | self.num_sigma_channels = num_sigma_channels
163 | self.white_bkgd = white_bkgd
164 | self.min_deg_point = min_deg_point
165 | self.max_deg_point = max_deg_point
166 | self.deg_view = deg_view
167 | self.lindisp = lindisp
168 | self.rgb_activation = rgb_activation
169 | self.sigma_activation = sigma_activation
170 | self.legacy_posenc_order = legacy_posenc_order
171 | # Construct the "coarse" MLP. Weird name is for
172 | # compatibility with 'compact' version
173 | self.MLP_0 = model_utils.MLP(
174 | net_depth = self.net_depth,
175 | net_width = self.net_width,
176 | net_depth_condition = self.net_depth_condition,
177 | net_width_condition = self.net_width_condition,
178 | net_activation = self.net_activation,
179 | skip_layer = self.skip_layer,
180 | num_rgb_channels = self.num_rgb_channels,
181 | num_sigma_channels = self.num_sigma_channels,
182 | input_dim=3 * (1 + 2 * (self.max_deg_point - self.min_deg_point)),
183 | condition_dim=3 * (1 + 2 * self.deg_view) if self.use_viewdirs else 0)
184 | # Construct the "fine" MLP.
185 | self.MLP_1 = model_utils.MLP(
186 | net_depth = self.net_depth,
187 | net_width = self.net_width,
188 | net_depth_condition = self.net_depth_condition,
189 | net_width_condition = self.net_width_condition,
190 | net_activation = self.net_activation,
191 | skip_layer = self.skip_layer,
192 | num_rgb_channels = self.num_rgb_channels,
193 | num_sigma_channels = self.num_sigma_channels,
194 | input_dim=3 * (1 + 2 * (self.max_deg_point - self.min_deg_point)),
195 | condition_dim=3 * (1 + 2 * self.deg_view) if self.use_viewdirs else 0)
196 |
197 | # Construct learnable variables for spherical gaussians.
198 | if self.sg_dim > 0:
199 | self.register_parameter(
200 | "sg_lambda",
201 | nn.Parameter(torch.ones([self.sg_dim]))
202 | )
203 | self.register_parameter(
204 | "sg_mu_spher",
205 | nn.Parameter(torch.stack([
206 | torch.rand([self.sg_dim]) * math.pi, # theta
207 | torch.rand([self.sg_dim]) * math.pi * 2 # phi
208 | ], dim=-1))
209 | )
210 |
211 | def eval_points_raw(self, points, viewdirs=None, coarse=False, cross_broadcast=False):
212 | """
213 | Evaluate at points, returing rgb and sigma.
214 | If sh_deg >= 0 then this will return spherical harmonic
215 | coeffs for RGB. Please see eval_points for alternate
216 | version which always returns RGB.
217 |
218 | Args:
219 | points: torch.tensor [B, 3]
220 | viewdirs: torch.tensor [B, 3]. if cross_broadcast = True, it can be [M, 3].
221 | coarse: if true, uses coarse MLP.
222 | cross_broadcast: if true, cross broadcast between points and viewdirs.
223 |
224 | Returns:
225 | raw_rgb: torch.tensor [B, 3 * (sh_deg + 1)**2 or 3]. if cross_broadcast = True, it
226 | returns [B, M, 3 * (sh_deg + 1)**2 or 3]
227 | raw_sigma: torch.tensor [B, 1]
228 | """
229 | points = points[None]
230 | points_enc = model_utils.posenc(
231 | points,
232 | self.min_deg_point,
233 | self.max_deg_point,
234 | self.legacy_posenc_order,
235 | )
236 | if self.num_fine_samples > 0 and not coarse:
237 | mlp = self.MLP_1
238 | else:
239 | mlp = self.MLP_0
240 | if self.use_viewdirs:
241 | assert viewdirs is not None
242 | viewdirs = viewdirs[None]
243 | viewdirs_enc = model_utils.posenc(
244 | viewdirs,
245 | 0,
246 | self.deg_view,
247 | self.legacy_posenc_order,
248 | )
249 | raw_rgb, raw_sigma = mlp(points_enc, viewdirs_enc, cross_broadcast=cross_broadcast)
250 | else:
251 | raw_rgb, raw_sigma = mlp(points_enc)
252 | return raw_rgb[0], raw_sigma[0]
253 |
254 |
255 | def construct_nerf(args):
256 | """Construct a Neural Radiance Field.
257 |
258 | Args:
259 | args: FLAGS class. Hyperparameters of nerf.
260 |
261 | Returns:
262 | model: nn.Model. Nerf model with parameters.
263 | state: flax.Module.state. Nerf model state for stateful parameters.
264 | """
265 | net_activation = getattr(nn, str(args.net_activation))
266 | if inspect.isclass(net_activation):
267 | net_activation = net_activation()
268 | rgb_activation = getattr(nn, str(args.rgb_activation))
269 | if inspect.isclass(rgb_activation):
270 | rgb_activation = rgb_activation()
271 | sigma_activation = getattr(nn, str(args.sigma_activation))
272 | if inspect.isclass(sigma_activation):
273 | sigma_activation = sigma_activation()
274 |
275 | # Assert that rgb_activation always produces outputs in [0, 1], and
276 | # sigma_activation always produce non-negative outputs.
277 | x = torch.exp(torch.linspace(-90, 90, 1024))
278 | x = torch.cat([-x, x], dim=0)
279 |
280 | rgb = rgb_activation(x)
281 | if torch.any(rgb < 0) or torch.any(rgb > 1):
282 | raise NotImplementedError(
283 | "Choice of rgb_activation `{}` produces colors outside of [0, 1]".format(
284 | args.rgb_activation
285 | )
286 | )
287 |
288 | sigma = sigma_activation(x)
289 | if torch.any(sigma < 0):
290 | raise NotImplementedError(
291 | "Choice of sigma_activation `{}` produces negative densities".format(
292 | args.sigma_activation
293 | )
294 | )
295 |
296 | num_rgb_channels = args.num_rgb_channels
297 | if not args.use_viewdirs:
298 | if args.sh_deg >= 0:
299 | assert args.sg_dim == -1, (
300 | "You can only use up to one of: SH or SG.")
301 | num_rgb_channels *= (args.sh_deg + 1) ** 2
302 | elif args.sg_dim > 0:
303 | assert args.sh_deg == -1, (
304 | "You can only use up to one of: SH or SG.")
305 | num_rgb_channels *= args.sg_dim
306 |
307 | model = NerfModel(
308 | min_deg_point=args.min_deg_point,
309 | max_deg_point=args.max_deg_point,
310 | deg_view=args.deg_view,
311 | num_coarse_samples=args.num_coarse_samples,
312 | num_fine_samples=args.num_fine_samples,
313 | use_viewdirs=args.use_viewdirs,
314 | sh_deg=args.sh_deg,
315 | sg_dim=args.sg_dim,
316 | near=args.near,
317 | far=args.far,
318 | noise_std=args.noise_std,
319 | white_bkgd=args.white_bkgd,
320 | net_depth=args.net_depth,
321 | net_width=args.net_width,
322 | net_depth_condition=args.net_depth_condition,
323 | net_width_condition=args.net_width_condition,
324 | skip_layer=args.skip_layer,
325 | num_rgb_channels=num_rgb_channels,
326 | num_sigma_channels=args.num_sigma_channels,
327 | lindisp=args.lindisp,
328 | net_activation=net_activation,
329 | rgb_activation=rgb_activation,
330 | sigma_activation=sigma_activation,
331 | legacy_posenc_order=args.legacy_posenc_order,
332 | )
333 | return model
334 |
--------------------------------------------------------------------------------
/octree/nerf/sh_proj.py:
--------------------------------------------------------------------------------
1 | # Modifications Copyright 2021 The PlenOctree Authors.
2 | # Original Copyright 2015 Google Inc. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Sperical harmonics projection related functions
16 |
17 | Some codes are borrowed from:
18 | https://github.com/google/spherical-harmonics/blob/master/sh/spherical_harmonics.cc
19 | """
20 | from typing import Callable
21 | import math
22 | import torch
23 |
24 | kHardCodedOrderLimit = 4
25 |
26 |
27 | def spher2cart(theta, phi):
28 | """Convert spherical coordinates into Cartesian coordinates (radius 1)."""
29 | r = torch.sin(theta)
30 | x = r * torch.cos(phi)
31 | y = r * torch.sin(phi)
32 | z = torch.cos(theta)
33 | return torch.stack([x, y, z], dim=-1)
34 |
35 |
36 | # Get the total number of coefficients for a function represented by
37 | # all spherical harmonic basis of degree <= @order (it is a point of
38 | # confusion that the order of an SH refers to its degree and not the order).
39 | def GetCoefficientCount(order: int):
40 | return (order + 1) ** 2
41 |
42 |
43 | # Get the one dimensional index associated with a particular degree @l
44 | # and order @m. This is the index that can be used to access the Coeffs
45 | # returned by SHSolver.
46 | def GetIndex(l: int, m: int):
47 | return l * (l + 1) + m
48 |
49 |
50 | # Hardcoded spherical harmonic functions for low orders (l is first number
51 | # and m is second number (sign encoded as preceeding 'p' or 'n')).
52 | #
53 | # As polynomials they are evaluated more efficiently in cartesian coordinates,
54 | # assuming that @{dx, dy, dz} is unit. This is not verified for efficiency.
55 |
56 | def HardcodedSH00(dx, dy, dz):
57 | # 0.5 * sqrt(1/pi)
58 | return 0.28209479177387814 + (dx * 0) # keep the shape
59 |
60 | def HardcodedSH1n1(dx, dy, dz):
61 | # -sqrt(3/(4pi)) * y
62 | return -0.4886025119029199 * dy
63 |
64 | def HardcodedSH10(dx, dy, dz):
65 | # sqrt(3/(4pi)) * z
66 | return 0.4886025119029199 * dz
67 |
68 | def HardcodedSH1p1(dx, dy, dz):
69 | # -sqrt(3/(4pi)) * x
70 | return -0.4886025119029199 * dx
71 |
72 | def HardcodedSH2n2(dx, dy, dz):
73 | # 0.5 * sqrt(15/pi) * x * y
74 | return 1.0925484305920792 * dx * dy
75 |
76 | def HardcodedSH2n1(dx, dy, dz):
77 | # -0.5 * sqrt(15/pi) * y * z
78 | return -1.0925484305920792 * dy * dz
79 |
80 | def HardcodedSH20(dx, dy, dz):
81 | # 0.25 * sqrt(5/pi) * (-x^2-y^2+2z^2)
82 | return 0.31539156525252005 * (-dx * dx - dy * dy + 2.0 * dz * dz)
83 |
84 | def HardcodedSH2p1(dx, dy, dz):
85 | # -0.5 * sqrt(15/pi) * x * z
86 | return -1.0925484305920792 * dx * dz
87 |
88 | def HardcodedSH2p2(dx, dy, dz):
89 | # 0.25 * sqrt(15/pi) * (x^2 - y^2)
90 | return 0.5462742152960396 * (dx * dx - dy * dy)
91 |
92 | def HardcodedSH3n3(dx, dy, dz):
93 | # -0.25 * sqrt(35/(2pi)) * y * (3x^2 - y^2)
94 | return -0.5900435899266435 * dy * (3.0 * dx * dx - dy * dy)
95 |
96 | def HardcodedSH3n2(dx, dy, dz):
97 | # 0.5 * sqrt(105/pi) * x * y * z
98 | return 2.890611442640554 * dx * dy * dz
99 |
100 | def HardcodedSH3n1(dx, dy, dz):
101 | # -0.25 * sqrt(21/(2pi)) * y * (4z^2-x^2-y^2)
102 | return -0.4570457994644658 * dy * (4.0 * dz * dz - dx * dx - dy * dy)
103 |
104 | def HardcodedSH30(dx, dy, dz):
105 | # 0.25 * sqrt(7/pi) * z * (2z^2 - 3x^2 - 3y^2)
106 | return 0.3731763325901154 * dz * (2.0 * dz * dz - 3.0 * dx * dx - 3.0 * dy * dy)
107 |
108 | def HardcodedSH3p1(dx, dy, dz):
109 | # -0.25 * sqrt(21/(2pi)) * x * (4z^2-x^2-y^2)
110 | return -0.4570457994644658 * dx * (4.0 * dz * dz - dx * dx - dy * dy)
111 |
112 | def HardcodedSH3p2(dx, dy, dz):
113 | # 0.25 * sqrt(105/pi) * z * (x^2 - y^2)
114 | return 1.445305721320277 * dz * (dx * dx - dy * dy)
115 |
116 | def HardcodedSH3p3(dx, dy, dz):
117 | # -0.25 * sqrt(35/(2pi)) * x * (x^2-3y^2)
118 | return -0.5900435899266435 * dx * (dx * dx - 3.0 * dy * dy)
119 |
120 | def HardcodedSH4n4(dx, dy, dz):
121 | # 0.75 * sqrt(35/pi) * x * y * (x^2-y^2)
122 | return 2.5033429417967046 * dx * dy * (dx * dx - dy * dy)
123 |
124 | def HardcodedSH4n3(dx, dy, dz):
125 | # -0.75 * sqrt(35/(2pi)) * y * z * (3x^2-y^2)
126 | return -1.7701307697799304 * dy * dz * (3.0 * dx * dx - dy * dy)
127 |
128 | def HardcodedSH4n2(dx, dy, dz):
129 | # 0.75 * sqrt(5/pi) * x * y * (7z^2-1)
130 | return 0.9461746957575601 * dx * dy * (7.0 * dz * dz - 1.0)
131 |
132 | def HardcodedSH4n1(dx, dy, dz):
133 | # -0.75 * sqrt(5/(2pi)) * y * z * (7z^2-3)
134 | return -0.6690465435572892 * dy * dz * (7.0 * dz * dz - 3.0)
135 |
136 | def HardcodedSH40(dx, dy, dz):
137 | # 3/16 * sqrt(1/pi) * (35z^4-30z^2+3)
138 | z2 = dz * dz
139 | return 0.10578554691520431 * (35.0 * z2 * z2 - 30.0 * z2 + 3.0)
140 |
141 | def HardcodedSH4p1(dx, dy, dz):
142 | # -0.75 * sqrt(5/(2pi)) * x * z * (7z^2-3)
143 | return -0.6690465435572892 * dx * dz * (7.0 * dz * dz - 3.0)
144 |
145 | def HardcodedSH4p2(dx, dy, dz):
146 | # 3/8 * sqrt(5/pi) * (x^2 - y^2) * (7z^2 - 1)
147 | return 0.47308734787878004 * (dx * dx - dy * dy) * (7.0 * dz * dz - 1.0)
148 |
149 | def HardcodedSH4p3(dx, dy, dz):
150 | # -0.75 * sqrt(35/(2pi)) * x * z * (x^2 - 3y^2)
151 | return -1.7701307697799304 * dx * dz * (dx * dx - 3.0 * dy * dy)
152 |
153 | def HardcodedSH4p4(dx, dy, dz):
154 | # 3/16*sqrt(35/pi) * (x^2 * (x^2 - 3y^2) - y^2 * (3x^2 - y^2))
155 | x2 = dx * dx
156 | y2 = dy * dy
157 | return 0.6258357354491761 * (x2 * (x2 - 3.0 * y2) - y2 * (3.0 * x2 - y2))
158 |
159 |
160 | def EvalSH(l: int, m: int, dirs):
161 | """
162 | Args:
163 | dirs: array [..., 3]. works with torch/jnp/np
164 | Return:
165 | array [...]
166 | """
167 | if l <= kHardCodedOrderLimit:
168 | # Validate l and m here (don't do it generally since EvalSHSlow also
169 | # checks it if we delegate to that function).
170 | assert l >= 0, "l must be at least 0."
171 | assert -l <= m and m <= l, "m must be between -l and l."
172 | dx = dirs[..., 0]
173 | dy = dirs[..., 1]
174 | dz = dirs[..., 2]
175 |
176 | if l == 0:
177 | return HardcodedSH00(dx, dy, dz)
178 | elif l == 1:
179 | if m == -1:
180 | return HardcodedSH1n1(dx, dy, dz)
181 | elif m == 0:
182 | return HardcodedSH10(dx, dy, dz)
183 | elif m == 1:
184 | return HardcodedSH1p1(dx, dy, dz)
185 | elif l == 2:
186 | if m == -2:
187 | return HardcodedSH2n2(dx, dy, dz)
188 | elif m == -1:
189 | return HardcodedSH2n1(dx, dy, dz)
190 | elif m == 0:
191 | return HardcodedSH20(dx, dy, dz)
192 | elif m == 1:
193 | return HardcodedSH2p1(dx, dy, dz)
194 | elif m == 2:
195 | return HardcodedSH2p2(dx, dy, dz)
196 | elif l == 3:
197 | if m == -3:
198 | return HardcodedSH3n3(dx, dy, dz)
199 | elif m == -2:
200 | return HardcodedSH3n2(dx, dy, dz)
201 | elif m == -1:
202 | return HardcodedSH3n1(dx, dy, dz)
203 | elif m == 0:
204 | return HardcodedSH30(dx, dy, dz)
205 | elif m == 1:
206 | return HardcodedSH3p1(dx, dy, dz)
207 | elif m == 2:
208 | return HardcodedSH3p2(dx, dy, dz)
209 | elif m == 3:
210 | return HardcodedSH3p3(dx, dy, dz)
211 | elif l == 4:
212 | if m == -4:
213 | return HardcodedSH4n4(dx, dy, dz)
214 | elif m == -3:
215 | return HardcodedSH4n3(dx, dy, dz)
216 | elif m == -2:
217 | return HardcodedSH4n2(dx, dy, dz)
218 | elif m == -1:
219 | return HardcodedSH4n1(dx, dy, dz)
220 | elif m == 0:
221 | return HardcodedSH40(dx, dy, dz)
222 | elif m == 1:
223 | return HardcodedSH4p1(dx, dy, dz)
224 | elif m == 2:
225 | return HardcodedSH4p2(dx, dy, dz)
226 | elif m == 3:
227 | return HardcodedSH4p3(dx, dy, dz)
228 | elif m == 4:
229 | return HardcodedSH4p4(dx, dy, dz)
230 |
231 | # This is unreachable given the CHECK's above but the compiler can't tell.
232 | return None
233 |
234 | else:
235 | # Not hard-coded so use the recurrence relation (which will convert this
236 | # to spherical coordinates).
237 | # return EvalSHSlow(l, m, dx, dy, dz)
238 | raise NotImplementedError
239 |
240 |
241 | def spherical_uniform_sampling(sample_count, device="cpu"):
242 | # See: https://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php
243 | theta = torch.acos(2.0 * torch.rand([sample_count]) - 1.0)
244 | phi = 2.0 * math.pi * torch.rand([sample_count])
245 | return theta.to(device), phi.to(device)
246 |
247 |
248 | def ProjectFunction(order: int, sperical_func: Callable, sample_count: int, device="cpu"):
249 | assert order >= 0, "Order must be at least zero."
250 | assert sample_count > 0, "Sample count must be at least one."
251 |
252 | # This is the approach demonstrated in [1] and is useful for arbitrary
253 | # functions on the sphere that are represented analytically.
254 | coeffs = torch.zeros([GetCoefficientCount(order)], dtype=torch.float32).to(device)
255 |
256 | # generate sample_count uniformly and stratified samples over the sphere
257 | # See http://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php
258 | theta, phi = spherical_uniform_sampling(sample_count, device=device)
259 | dirs = spher2cart(theta, phi)
260 |
261 | # evaluate the analytic function for the current spherical coords
262 | func_value = sperical_func(dirs)
263 |
264 | # evaluate the SH basis functions up to band O, scale them by the
265 | # function's value and accumulate them over all generated samples
266 | for l in range(order + 1): # end inclusive
267 | for m in range(-l, l + 1): # end inclusive
268 | coeffs[GetIndex(l, m)] = sum(func_value * EvalSH(l, m, dirs))
269 |
270 | # scale by the probability of a particular sample, which is
271 | # 4pi/sample_count. 4pi for the surface area of a unit sphere, and
272 | # 1/sample_count for the number of samples drawn uniformly.
273 | weight = 4.0 * math.pi / sample_count
274 | coeffs *= weight
275 | return coeffs
276 |
277 |
278 | def ProjectFunctionNeRF(order: int, sperical_func: Callable, batch_size: int, sample_count: int, device="cpu"):
279 | assert order >= 0, "Order must be at least zero."
280 | assert sample_count > 0, "Sample count must be at least one."
281 | C = 3 # rgb channels
282 |
283 | # This is the approach demonstrated in [1] and is useful for arbitrary
284 | # functions on the sphere that are represented analytically.
285 | coeffs = torch.zeros([batch_size, C, GetCoefficientCount(order)], dtype=torch.float32).to(device)
286 |
287 | # generate sample_count uniformly and stratified samples over the sphere
288 | # See http://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php
289 | theta, phi = spherical_uniform_sampling(sample_count, device=device)
290 | dirs = spher2cart(theta, phi)
291 |
292 | # evaluate the analytic function for the current spherical coords
293 | func_value, others = sperical_func(dirs) # [batch_size, sample_count, C]
294 |
295 | # evaluate the SH basis functions up to band O, scale them by the
296 | # function's value and accumulate them over all generated samples
297 | for l in range(order + 1): # end inclusive
298 | for m in range(-l, l + 1): # end inclusive
299 | coeffs[:, :, GetIndex(l, m)] = torch.einsum("bsc,s->bc", func_value, EvalSH(l, m, dirs))
300 |
301 | # scale by the probability of a particular sample, which is
302 | # 4pi/sample_count. 4pi for the surface area of a unit sphere, and
303 | # 1/sample_count for the number of samples drawn uniformly.
304 | weight = 4.0 * math.pi / sample_count
305 | coeffs *= weight
306 | return coeffs, others
307 |
308 | def ProjectFunctionNeRFSparse(
309 | order: int,
310 | spherical_func: Callable,
311 | sample_count: int,
312 | device="cpu",
313 | ):
314 | assert order >= 0, "Order must be at least zero."
315 | assert sample_count > 0, "Sample count must be at least one."
316 | C = 3 # rgb channels
317 |
318 | # generate sample_count uniformly and stratified samples over the sphere
319 | # See http://www.bogotobogo.com/Algorithms/uniform_distribution_sphere.php
320 | theta, phi = spherical_uniform_sampling(sample_count, device=device)
321 | dirs = spher2cart(theta, phi) # [sample_count, 3]
322 |
323 | # evaluate the analytic function for the current spherical coords
324 | func_value, others = spherical_func(dirs) # func_value [batch_size, sample_count, C]
325 |
326 | batch_size = func_value.shape[0]
327 |
328 | coeff_count = GetCoefficientCount(order)
329 | basis_vals = torch.empty(
330 | [sample_count, coeff_count], dtype=torch.float32
331 | ).to(device)
332 |
333 | # evaluate the SH basis functions up to band O, scale them by the
334 | # function's value and accumulate them over all generated samples
335 | for l in range(order + 1): # end inclusive
336 | for m in range(-l, l + 1): # end inclusive
337 | basis_vals[:, GetIndex(l, m)] = EvalSH(l, m, dirs)
338 |
339 | basis_vals = basis_vals.view(
340 | sample_count, coeff_count) # [sample_count, coeff_count]
341 | func_value = func_value.transpose(0, 1).reshape(
342 | sample_count, batch_size * C) # [sample_count, batch_size * C]
343 | soln = torch.lstsq(func_value, basis_vals).solution[:basis_vals.size(1)]
344 | soln = soln.T.reshape(batch_size, C, -1)
345 | return soln, others
346 |
347 |
--------------------------------------------------------------------------------
/octree/nerf/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Modifications Copyright 2021 The PlenOctree Authors.
3 | # Original Copyright 2021 The Google Research Authors.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Lint as: python3
18 | """Utility functions."""
19 | import collections
20 | import os
21 | from os import path
22 | from absl import flags
23 | import math
24 |
25 | import torch
26 | import torch.nn.functional as F
27 |
28 | import numpy as np
29 | from PIL import Image
30 | import yaml
31 | from tqdm import tqdm
32 | from octree.nerf import datasets
33 |
34 | INTERNAL = False
35 |
36 | Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
37 |
38 |
39 | def namedtuple_map(fn, tup):
40 | """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
41 | return type(tup)(*map(fn, tup))
42 |
43 |
44 | def define_flags():
45 | """Define flags for both training and evaluation modes."""
46 | flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
47 | flags.DEFINE_string("data_dir", None, "input data directory.")
48 | flags.DEFINE_string("config", None, "using config files to set hyperparameters.")
49 |
50 | # Dataset Flags
51 | # TODO(pratuls): rename to dataset_loader and consider cleaning up
52 | flags.DEFINE_enum(
53 | "dataset",
54 | "blender",
55 | list(k for k in datasets.dataset_dict.keys()),
56 | "The type of dataset feed to nerf.",
57 | )
58 | flags.DEFINE_bool(
59 | "image_batching", False, "sample rays in a batch from different images."
60 | )
61 | flags.DEFINE_bool(
62 | "white_bkgd",
63 | True,
64 | "using white color as default background." "(used in the blender dataset only)",
65 | )
66 | flags.DEFINE_integer(
67 | "batch_size", 1024, "the number of rays in a mini-batch (for training)."
68 | )
69 | flags.DEFINE_integer(
70 | "factor", 4, "the downsample factor of images, 0 for no downsample."
71 | )
72 | flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
73 | flags.DEFINE_bool(
74 | "render_path",
75 | False,
76 | "render generated path if set true." "(used in the llff dataset only)",
77 | )
78 | flags.DEFINE_integer(
79 | "llffhold",
80 | 8,
81 | "will take every 1/N images as LLFF test set."
82 | "(used in the llff dataset only)",
83 | )
84 |
85 | # Model Flags
86 | flags.DEFINE_string("model", "nerf", "name of model to use.")
87 | flags.DEFINE_float("near", 2.0, "near clip of volumetric rendering.")
88 | flags.DEFINE_float("far", 6.0, "far clip of volumentric rendering.")
89 | flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
90 | flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
91 | flags.DEFINE_integer("net_depth_condition", 1, "depth of the second part of MLP.")
92 | flags.DEFINE_integer("net_width_condition", 128, "width of the second part of MLP.")
93 | flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
94 | flags.DEFINE_integer(
95 | "skip_layer",
96 | 4,
97 | "add a skip connection to the output vector of every" "skip_layer layers.",
98 | )
99 | flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
100 | flags.DEFINE_integer("num_sigma_channels", 1, "the number of density channels.")
101 | flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
102 | flags.DEFINE_integer(
103 | "min_deg_point", 0, "Minimum degree of positional encoding for points."
104 | )
105 | flags.DEFINE_integer(
106 | "max_deg_point", 10, "Maximum degree of positional encoding for points."
107 | )
108 | flags.DEFINE_integer("deg_view", 4, "Degree of positional encoding for viewdirs.")
109 | flags.DEFINE_integer(
110 | "num_coarse_samples",
111 | 64,
112 | "the number of samples on each ray for the coarse model.",
113 | )
114 | flags.DEFINE_integer(
115 | "num_fine_samples", 128, "the number of samples on each ray for the fine model."
116 | )
117 | flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
118 | flags.DEFINE_integer("sh_deg", -1, "set to use spherical harmonics output of given order.")
119 | flags.DEFINE_integer("sg_dim", -1, "set to use spherical gaussians (SG).")
120 | flags.DEFINE_float(
121 | "noise_std",
122 | None,
123 | "std dev of noise added to regularize sigma output."
124 | "(used in the llff dataset only)",
125 | )
126 | flags.DEFINE_bool(
127 | "lindisp", False, "sampling linearly in disparity rather than depth."
128 | )
129 | flags.DEFINE_string(
130 | "net_activation", "ReLU", "activation function used within the MLP."
131 | )
132 | flags.DEFINE_string(
133 | "rgb_activation", "Sigmoid", "activation function used to produce RGB."
134 | )
135 | flags.DEFINE_string(
136 | "sigma_activation", "ReLU", "activation function used to produce density."
137 | )
138 | flags.DEFINE_bool(
139 | "legacy_posenc_order",
140 | False,
141 | "If True, revert the positional encoding feature order to an older version of this codebase.",
142 | )
143 |
144 | # Train Flags
145 | flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
146 | flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
147 | flags.DEFINE_integer(
148 | "lr_delay_steps",
149 | 0,
150 | "The number of steps at the beginning of "
151 | "training to reduce the learning rate by lr_delay_mult",
152 | )
153 | flags.DEFINE_float(
154 | "lr_delay_mult",
155 | 1.0,
156 | "A multiplier on the learning rate when the step " "is < lr_delay_steps",
157 | )
158 | flags.DEFINE_integer("max_steps", 1000000, "the number of optimization steps.")
159 | flags.DEFINE_integer(
160 | "save_every", 5000, "the number of steps to save a checkpoint."
161 | )
162 | flags.DEFINE_integer(
163 | "print_every", 500, "the number of steps between reports to tensorboard."
164 | )
165 | flags.DEFINE_integer(
166 | "render_every",
167 | 10000,
168 | "the number of steps to render a test image,"
169 | "better to be x00 for accurate step time record.",
170 | )
171 | flags.DEFINE_integer(
172 | "gc_every", 10000, "the number of steps to run python garbage collection."
173 | )
174 | flags.DEFINE_float(
175 | "sparsity_weight",
176 | 1e-3,
177 | "Sparsity loss weight",
178 | )
179 | flags.DEFINE_float(
180 | "sparsity_length",
181 | 0.05,
182 | "Sparsity loss 'length' for alpha calculation",
183 | )
184 | flags.DEFINE_float(
185 | "sparsity_radius",
186 | 1.5,
187 | "Sparsity loss point sampling box 1/2 side length",
188 | )
189 | flags.DEFINE_integer(
190 | "sparsity_npoints",
191 | 10000,
192 | "Number of samples for sparsity loss",
193 | )
194 |
195 | # Eval Flags
196 | flags.DEFINE_bool(
197 | "eval_once",
198 | True,
199 | "evaluate the model only once if true, otherwise keeping evaluating new"
200 | "checkpoints if there's any.",
201 | )
202 | flags.DEFINE_bool("save_output", True, "save predicted images to disk if True.")
203 | flags.DEFINE_integer(
204 | "chunk",
205 | 81920,
206 | "the size of chunks for evaluation inferences, set to the value that"
207 | "fits your GPU/TPU memory.",
208 | )
209 |
210 | # Octree flags
211 | flags.DEFINE_float(
212 | 'renderer_step_size',
213 | 1e-4,
214 | 'step size epsilon in volume render.'
215 | '1e-3 = fast setting, 1e-4 = usual setting, 1e-5 = high setting, lower is better')
216 | flags.DEFINE_bool(
217 | 'no_early_stop',
218 | False,
219 | 'If set, does not use early stopping; slows down rendering slightly')
220 |
221 |
222 | def update_flags(args):
223 | """Update the flags in `args` with the contents of the config YAML file."""
224 | if args.config is None:
225 | return
226 | pth = path.join(args.config + ".yaml")
227 | with open_file(pth, "r") as fin:
228 | configs = yaml.load(fin, Loader=yaml.FullLoader)
229 | # Only allow args to be updated if they already exist.
230 | invalid_args = list(set(configs.keys()) - set(dir(args)))
231 | if invalid_args:
232 | raise ValueError(f"Invalid args {invalid_args} in {pth}.")
233 | args.__dict__.update(configs)
234 |
235 |
236 | def check_flags(args, require_data=True, require_batch_size_div=False):
237 | if args.train_dir is None:
238 | raise ValueError("train_dir must be set. None set now.")
239 | if require_data and args.data_dir is None:
240 | raise ValueError("data_dir must be set. None set now.")
241 | if require_batch_size_div and args.batch_size % torch.cuda.device_count() != 0:
242 | raise ValueError("Batch size must be divisible by the number of devices.")
243 |
244 |
245 | def set_random_seed(seed):
246 | torch.manual_seed(seed)
247 | np.random.seed(seed)
248 |
249 |
250 | def open_file(pth, mode="r"):
251 | if not INTERNAL:
252 | pth = path.expanduser(pth)
253 | return open(pth, mode=mode)
254 |
255 |
256 | def file_exists(pth):
257 | if not INTERNAL:
258 | return path.exists(pth)
259 |
260 |
261 | def listdir(pth):
262 | if not INTERNAL:
263 | return os.listdir(pth)
264 |
265 |
266 | def isdir(pth):
267 | if not INTERNAL:
268 | return path.isdir(pth)
269 |
270 |
271 | def makedirs(pth):
272 | if not INTERNAL:
273 | os.makedirs(pth, exist_ok=True)
274 |
275 |
276 | @torch.no_grad()
277 | def eval_points(fn, points, chunk=720720, to_cpu=True):
278 | """Evaluate at given points (in test mode).
279 | Currently not supporting viewdirs.
280 |
281 | Args:
282 | fn: function
283 | points: torch.tensor [..., 3]
284 | chunk: int, the size of chunks to render sequentially.
285 |
286 | Returns:
287 | rgb: torch.tensor or np.array.
288 | sigmas: torch.tensor or np.array.
289 | """
290 | num_points = points.shape[0]
291 | rgbs, sigmas = [], []
292 |
293 | for i in tqdm(range(0, num_points, chunk)):
294 | chunk_points = points[i : i + chunk]
295 | rgb, sigma = fn(chunk_points, None)
296 | if to_cpu:
297 | rgb = rgb.detach().cpu().numpy()
298 | sigma = sigma.detach().cpu().numpy()
299 | rgbs.append(rgb)
300 | sigmas.append(sigma)
301 | if to_cpu:
302 | rgbs = np.concatenate(rgbs, axis=0)
303 | sigmas = np.concatenate(sigmas, axis=0)
304 | else:
305 | rgbs = torch.cat(rgbs, dim=0)
306 | sigmas = torch.cat(sigmas, dim=0)
307 | return rgbs, sigmas
308 |
309 |
310 | def compute_psnr(mse):
311 | """Compute psnr value given mse (we assume the maximum pixel value is 1).
312 |
313 | Args:
314 | mse: float, mean square error of pixels.
315 |
316 | Returns:
317 | psnr: float, the psnr value.
318 | """
319 | return -10.0 * torch.log(mse) / np.log(10.0)
320 |
321 |
322 | def compute_ssim(
323 | img0,
324 | img1,
325 | max_val,
326 | filter_size=11,
327 | filter_sigma=1.5,
328 | k1=0.01,
329 | k2=0.03,
330 | return_map=False,
331 | ):
332 | """Computes SSIM from two images.
333 |
334 | This function was modeled after tf.image.ssim, and should produce comparable
335 | output.
336 |
337 | Args:
338 | img0: torch.tensor. An image of size [..., width, height, num_channels].
339 | img1: torch.tensor. An image of size [..., width, height, num_channels].
340 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
341 | filter_size: int >= 1. Window size.
342 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
343 | k1: float > 0. One of the SSIM dampening parameters.
344 | k2: float > 0. One of the SSIM dampening parameters.
345 | return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
346 |
347 | Returns:
348 | Each image's mean SSIM, or a tensor of individual values if `return_map`.
349 | """
350 | device = img0.device
351 | ori_shape = img0.size()
352 | width, height, num_channels = ori_shape[-3:]
353 | img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2)
354 | img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2)
355 | batch_size = img0.shape[0]
356 |
357 | # Construct a 1D Gaussian blur filter.
358 | hw = filter_size // 2
359 | shift = (2 * hw - filter_size + 1) / 2
360 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2
361 | filt = torch.exp(-0.5 * f_i)
362 | filt /= torch.sum(filt)
363 |
364 | # Blur in x and y (faster than the 2D convolution).
365 | # z is a tensor of size [B, H, W, C]
366 | filt_fn1 = lambda z: F.conv2d(
367 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1),
368 | padding=[hw, 0], groups=num_channels)
369 | filt_fn2 = lambda z: F.conv2d(
370 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1),
371 | padding=[0, hw], groups=num_channels)
372 |
373 | # Vmap the blurs to the tensor size, and then compose them.
374 | filt_fn = lambda z: filt_fn1(filt_fn2(z))
375 | mu0 = filt_fn(img0)
376 | mu1 = filt_fn(img1)
377 | mu00 = mu0 * mu0
378 | mu11 = mu1 * mu1
379 | mu01 = mu0 * mu1
380 | sigma00 = filt_fn(img0 ** 2) - mu00
381 | sigma11 = filt_fn(img1 ** 2) - mu11
382 | sigma01 = filt_fn(img0 * img1) - mu01
383 |
384 | # Clip the variances and covariances to valid values.
385 | # Variance must be non-negative:
386 | sigma00 = torch.clamp(sigma00, min=0.0)
387 | sigma11 = torch.clamp(sigma11, min=0.0)
388 | sigma01 = torch.sign(sigma01) * torch.min(
389 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)
390 | )
391 |
392 | c1 = (k1 * max_val) ** 2
393 | c2 = (k2 * max_val) ** 2
394 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
395 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
396 | ssim_map = numer / denom
397 | ssim = torch.mean(ssim_map.reshape([-1, num_channels*width*height]), dim=-1)
398 | return ssim_map if return_map else ssim
399 |
400 |
401 | def generate_rays(w, h, focal, camtoworlds, equirect=False):
402 | """
403 | Generate perspective camera rays. Principal point is at center.
404 | Args:
405 | w: int image width
406 | h: int image heigth
407 | focal: float real focal length
408 | camtoworlds: jnp.ndarray [B, 4, 4] c2w homogeneous poses
409 | equirect: if true, generates spherical rays instead of pinhole
410 | Returns:
411 | rays: Rays a namedtuple(origins [B, 3], directions [B, 3], viewdirs [B, 3])
412 | """
413 | x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
414 | np.arange(w, dtype=np.float32), # X-Axis (columns)
415 | np.arange(h, dtype=np.float32), # Y-Axis (rows)
416 | indexing="xy",
417 | )
418 |
419 | if equirect:
420 | uv = np.stack([x * (2.0 / w) - 1.0, y * (2.0 / h) - 1.0], axis=-1)
421 | camera_dirs = equirect2xyz(uv)
422 | else:
423 | camera_dirs = np.stack(
424 | [
425 | (x - w * 0.5) / focal,
426 | -(y - h * 0.5) / focal,
427 | -np.ones_like(x),
428 | ],
429 | axis=-1,
430 | )
431 |
432 | # camera_dirs = camera_dirs / np.linalg.norm(camera_dirs, axis=-1, keepdims=True)
433 |
434 | c2w = camtoworlds[:, None, None, :3, :3]
435 | camera_dirs = camera_dirs[None, Ellipsis, None]
436 | directions = np.matmul(c2w, camera_dirs)[Ellipsis, 0]
437 | origins = np.broadcast_to(
438 | camtoworlds[:, None, None, :3, -1], directions.shape
439 | )
440 | norms = np.linalg.norm(directions, axis=-1, keepdims=True)
441 | viewdirs = directions / norms
442 | rays = Rays(
443 | origins=origins, directions=directions, viewdirs=viewdirs
444 | )
445 | return rays
446 |
447 |
448 | def eval_octree(t, dataset, args, want_lpips=True, want_frames=False):
449 | import svox
450 | w, h, focal = dataset.w, dataset.h, dataset.focal
451 | if 'llff' in args.config and (not args.spherify):
452 | ndc_config = svox.NDCConfig(width=w, height=h, focal=focal)
453 | else:
454 | ndc_config = None
455 |
456 | r = svox.VolumeRenderer(
457 | t, step_size=args.renderer_step_size, ndc=ndc_config)
458 |
459 | print('Evaluating octree')
460 | device = t.data.device
461 | if want_lpips:
462 | import lpips
463 | lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device)
464 |
465 | avg_psnr = 0.0
466 | avg_ssim = 0.0
467 | avg_lpips = 0.0
468 | out_frames = []
469 | for idx in tqdm(range(dataset.size)):
470 | c2w = torch.from_numpy(dataset.camtoworlds[idx]).float().to(device)
471 | im_gt_ten = torch.from_numpy(dataset.images[idx]).float().to(device)
472 |
473 | im = r.render_persp(
474 | c2w, width=w, height=h, fx=focal, fast=not args.no_early_stop)
475 | im.clamp_(0.0, 1.0)
476 |
477 | mse = ((im - im_gt_ten) ** 2).mean()
478 | psnr = compute_psnr(mse).mean()
479 | ssim = compute_ssim(im, im_gt_ten, max_val=1.0).mean()
480 |
481 | avg_psnr += psnr.item()
482 | avg_ssim += ssim.item()
483 | if want_lpips:
484 | lpips_i = lpips_vgg(im_gt_ten.permute([2, 0, 1]).contiguous(),
485 | im.permute([2, 0, 1]).contiguous(), normalize=True)
486 | avg_lpips += lpips_i.item()
487 |
488 | if want_frames:
489 | im = im.cpu()
490 | # vis = np.hstack((im_gt_ten.cpu().numpy(), im.cpu().numpy()))
491 | vis = im.cpu().numpy() # for lpips calculation
492 | vis = (vis * 255).astype(np.uint8)
493 | out_frames.append(vis)
494 |
495 | avg_psnr /= dataset.size
496 | avg_ssim /= dataset.size
497 | avg_lpips /= dataset.size
498 | return avg_psnr, avg_ssim, avg_lpips, out_frames
499 |
500 |
501 | def memlog(device='cuda'):
502 | # Memory debugging
503 | print(torch.cuda.memory_summary(device))
504 | import gc
505 | for obj in gc.get_objects():
506 | try:
507 | if torch.is_tensor(obj) or (
508 | hasattr(obj, 'data') and torch.is_tensor(obj.data)):
509 | if str(obj.device) != 'cpu':
510 | print(obj.device, '{: 10}'.format(obj.numel()),
511 | obj.dtype,
512 | obj.size(), type(obj))
513 | except:
514 | pass
515 |
--------------------------------------------------------------------------------
/octree/optimization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """Optimize a plenoctree through finetuning on train set.
24 |
25 | Usage:
26 |
27 | export DATA_ROOT=./data/NeRF/nerf_synthetic/
28 | export CKPT_ROOT=./data/PlenOctree/checkpoints/syn_sh16
29 | export SCENE=chair
30 | export CONFIG_FILE=nerf_sh/config/blender
31 |
32 | python -m octree.optimization \
33 | --input $CKPT_ROOT/$SCENE/tree.npz \
34 | --config $CONFIG_FILE \
35 | --data_dir $DATA_ROOT/$SCENE/ \
36 | --output $CKPT_ROOT/$SCENE/octrees/tree_opt.npz
37 | """
38 | import svox
39 | import torch
40 | import torch.cuda
41 | import numpy as np
42 | import json
43 | import imageio
44 | import os.path as osp
45 | import os
46 | from argparse import ArgumentParser
47 | from tqdm import tqdm
48 | from torch.optim import SGD, Adam
49 | from warnings import warn
50 |
51 | from absl import app
52 | from absl import flags
53 |
54 | from octree.nerf import datasets
55 | from octree.nerf import utils
56 |
57 | FLAGS = flags.FLAGS
58 |
59 | utils.define_flags()
60 |
61 | flags.DEFINE_string(
62 | "input",
63 | "./tree.npz",
64 | "Input octree npz from extraction.py",
65 | )
66 | flags.DEFINE_string(
67 | "output",
68 | "./tree_opt.npz",
69 | "Output octree npz",
70 | )
71 | flags.DEFINE_integer(
72 | 'render_interval',
73 | 0,
74 | 'render interval')
75 | flags.DEFINE_integer(
76 | 'val_interval',
77 | 2,
78 | 'validation interval')
79 | flags.DEFINE_integer(
80 | 'num_epochs',
81 | 80,
82 | 'epochs to train for')
83 | flags.DEFINE_bool(
84 | 'sgd',
85 | True,
86 | 'use SGD optimizer instead of Adam')
87 | flags.DEFINE_float(
88 | 'lr',
89 | 1e7,
90 | 'optimizer step size')
91 | flags.DEFINE_float(
92 | 'sgd_momentum',
93 | 0.0,
94 | 'sgd momentum')
95 | flags.DEFINE_bool(
96 | 'sgd_nesterov',
97 | False,
98 | 'sgd nesterov momentum?')
99 | flags.DEFINE_string(
100 | "write_vid",
101 | None,
102 | "If specified, writes rendered video to given path (*.mp4)",
103 | )
104 |
105 | # Manual 'val' set
106 | flags.DEFINE_bool(
107 | "split_train",
108 | None,
109 | "If specified, splits train set instead of loading val set",
110 | )
111 | flags.DEFINE_float(
112 | "split_holdout_prop",
113 | 0.2,
114 | "Proportion of images to hold out if split_train is set",
115 | )
116 |
117 | # Do not save since it is slow
118 | flags.DEFINE_bool(
119 | "nosave",
120 | False,
121 | "If set, does not save (for speed)",
122 | )
123 |
124 | flags.DEFINE_bool(
125 | "continue_on_decrease",
126 | False,
127 | "If set, continues training even if validation PSNR decreases",
128 | )
129 |
130 | device = "cuda" if torch.cuda.is_available() else "cpu"
131 | torch.autograd.set_detect_anomaly(True)
132 |
133 |
134 | def main(unused_argv):
135 | utils.set_random_seed(20200823)
136 | utils.update_flags(FLAGS)
137 |
138 | def get_data(stage):
139 | assert stage in ["train", "val", "test"]
140 | dataset = datasets.get_dataset(stage, FLAGS)
141 | focal = dataset.focal
142 | all_c2w = dataset.camtoworlds
143 | all_gt = dataset.images.reshape(-1, dataset.h, dataset.w, 3)
144 | all_c2w = torch.from_numpy(all_c2w).float().to(device)
145 | all_gt = torch.from_numpy(all_gt).float()
146 | return focal, all_c2w, all_gt
147 |
148 | focal, train_c2w, train_gt = get_data("train")
149 | if FLAGS.split_train:
150 | test_sz = int(train_c2w.size(0) * FLAGS.split_holdout_prop)
151 | print('Splitting train to train/val manually, holdout', test_sz)
152 | perm = torch.randperm(train_c2w.size(0))
153 | test_c2w = train_c2w[perm[:test_sz]]
154 | test_gt = train_gt[perm[:test_sz]]
155 | train_c2w = train_c2w[perm[test_sz:]]
156 | train_gt = train_gt[perm[test_sz:]]
157 | else:
158 | print('Using given val set')
159 | test_focal, test_c2w, test_gt = get_data("val")
160 | assert focal == test_focal
161 | H, W = train_gt[0].shape[:2]
162 |
163 | vis_dir = osp.splitext(FLAGS.input)[0] + '_render'
164 | os.makedirs(vis_dir, exist_ok=True)
165 |
166 | print('N3Tree load')
167 | t = svox.N3Tree.load(FLAGS.input, map_location=device)
168 | # t.nan_to_num_()
169 |
170 | if 'llff' in FLAGS.config:
171 | ndc_config = svox.NDCConfig(width=W, height=H, focal=focal)
172 | else:
173 | ndc_config = None
174 | r = svox.VolumeRenderer(t, step_size=FLAGS.renderer_step_size, ndc=ndc_config)
175 |
176 | if FLAGS.sgd:
177 | print('Using SGD, lr', FLAGS.lr)
178 | if FLAGS.lr < 1.0:
179 | warn('For SGD please adjust LR to about 1e7')
180 | optimizer = SGD(t.parameters(), lr=FLAGS.lr, momentum=FLAGS.sgd_momentum,
181 | nesterov=FLAGS.sgd_nesterov)
182 | else:
183 | adam_eps = 1e-4 if t.data.dtype is torch.float16 else 1e-8
184 | print('Using Adam, eps', adam_eps, 'lr', FLAGS.lr)
185 | optimizer = Adam(t.parameters(), lr=FLAGS.lr, eps=adam_eps)
186 |
187 | n_train_imgs = len(train_c2w)
188 | n_test_imgs = len(test_c2w)
189 |
190 | def run_test_step(i):
191 | print('Evaluating')
192 | with torch.no_grad():
193 | tpsnr = 0.0
194 | for j, (c2w, im_gt) in enumerate(zip(test_c2w, test_gt)):
195 | im = r.render_persp(c2w, height=H, width=W, fx=focal, fast=False)
196 | im = im.cpu().clamp_(0.0, 1.0)
197 |
198 | mse = ((im - im_gt) ** 2).mean()
199 | psnr = -10.0 * np.log(mse) / np.log(10.0)
200 | tpsnr += psnr.item()
201 |
202 | if FLAGS.render_interval > 0 and j % FLAGS.render_interval == 0:
203 | vis = torch.cat((im_gt, im), dim=1)
204 | vis = (vis * 255).numpy().astype(np.uint8)
205 | imageio.imwrite(f"{vis_dir}/{i:04}_{j:04}.png", vis)
206 | tpsnr /= n_test_imgs
207 | return tpsnr
208 |
209 | best_validation_psnr = run_test_step(0)
210 | print('** initial val psnr ', best_validation_psnr)
211 | best_t = None
212 | for i in range(FLAGS.num_epochs):
213 | print('epoch', i)
214 | tpsnr = 0.0
215 | for j, (c2w, im_gt) in tqdm(enumerate(zip(train_c2w, train_gt)), total=n_train_imgs):
216 | im = r.render_persp(c2w, height=H, width=W, fx=focal, cuda=True)
217 | im_gt_ten = im_gt.to(device=device)
218 | im = torch.clamp(im, 0.0, 1.0)
219 | mse = ((im - im_gt_ten) ** 2).mean()
220 | im_gt_ten = None
221 |
222 | optimizer.zero_grad()
223 | t.data.grad = None # This helps save memory weirdly enough
224 | mse.backward()
225 | # print('mse', mse, t.data.grad.min(), t.data.grad.max())
226 | optimizer.step()
227 | # t.data.data -= eta * t.data.grad
228 | psnr = -10.0 * np.log(mse.detach().cpu()) / np.log(10.0)
229 | tpsnr += psnr.item()
230 | tpsnr /= n_train_imgs
231 | print('** train_psnr', tpsnr)
232 |
233 | if i % FLAGS.val_interval == FLAGS.val_interval - 1 or i == FLAGS.num_epochs - 1:
234 | validation_psnr = run_test_step(i + 1)
235 | print('** val psnr ', validation_psnr, 'best', best_validation_psnr)
236 | if validation_psnr > best_validation_psnr:
237 | best_validation_psnr = validation_psnr
238 | best_t = t.clone(device='cpu') # SVOX 0.2.22
239 | print('')
240 | elif not FLAGS.continue_on_decrease:
241 | print('Stop since overfitting')
242 | break
243 | if not FLAGS.nosave:
244 | if best_t is not None:
245 | print('Saving best model to', FLAGS.output)
246 | best_t.save(FLAGS.output, compress=False)
247 | else:
248 | print('Did not improve upon initial model')
249 |
250 | if __name__ == "__main__":
251 | app.run(main)
252 |
--------------------------------------------------------------------------------
/octree/task_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | """
24 | Multi GPU parallel octree conversion pipeline for running hyper search.
25 | Make a file tasks.json describing tasks e.g.
26 | {
27 | "data_root": "/home/sxyu/data",
28 | "train_root": "/home/sxyu/proj/jaxnerf/jaxnerf/train/SH16",
29 | "tasks": [{
30 | "octree_name": "oct_chair_bb1_2",
31 | "train_dir": "chair",
32 | "data_dir": "nerf_synthetic/chair",
33 | "config": "sh",
34 | "extr_flags": ["--bbox_from_data", "--bbox_scale", "1.2"],
35 | "opt_flags": [],
36 | "eval_flags": []
37 | },
38 | ...]
39 | }
40 |
41 | Then,
42 | python dispatch.py tasks.json --gpus='space delimited list of gpus to use'
43 |
44 | For each task, final octree is saved to
45 | //octrees//tree.npz
46 | If you specify --keep_raw, the above is raw tree and the optimized tree is saved to
47 | //octrees//tree_opt.npz
48 |
49 | Capacity, raw eval PSNR/SSIM/LPIPS, optimized eval PSNR/SSIM/LPIPS are saved to
50 | //octrees//results.txt
51 | """
52 | import argparse
53 | import sys
54 | import os
55 | import os.path as osp
56 | import subprocess
57 | import concurrent.futures
58 | import json
59 | from multiprocessing import Process, Queue
60 |
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument("task_json", type=str)
63 | parser.add_argument("--gpus", type=str, required=True,
64 | help="space delimited GPU id list (pre CUDA_VISIBLE_DEVICES)")
65 | parser.add_argument("--keep_raw", action='store_true',
66 | help="do not overwrite raw octree (takes extra disk space)")
67 | args = parser.parse_args()
68 |
69 | def convert_one(env, train_dir, data_dir, config, octree_name,
70 | extr_flags, opt_flags=[], eval_flags=[]):
71 | octree_store_dir = osp.join(train_dir, 'octrees', octree_name)
72 | octree_file = osp.join(octree_store_dir, "tree.npz")
73 | octree_opt_file = osp.join(octree_store_dir,
74 | "tree_opt.npz") if args.keep_raw else octree_file
75 | config_name = f"{config}"
76 | os.makedirs(octree_store_dir, exist_ok=True)
77 | extr_base_cmd = [
78 | "python", "-u", "-m", "octree.extraction",
79 | "--train_dir", train_dir,
80 | "--config", config_name, "--is_jaxnerf_ckpt",
81 | "--output ", octree_file,
82 | "--data_dir", data_dir
83 | ]
84 | opt_base_cmd = [
85 | "python", "-u", "-m", "octree.optimization",
86 | "--config", config_name, "--input", octree_file,
87 | "--output", octree_opt_file,
88 | "--data_dir", data_dir
89 | ]
90 | eval_base_cmd = [
91 | "python", "-u", "-m", "octree.evaluation",
92 | "--config", config_name, "--input ", octree_opt_file,
93 | "--data_dir", data_dir
94 | ]
95 | out_file_path = osp.join(octree_store_dir, 'results.txt')
96 |
97 | with open(out_file_path, 'w') as out_file:
98 | print('********************************************')
99 | print('! Extract', train_dir, octree_name)
100 | extr_cmd = ' '.join(extr_base_cmd + extr_flags)
101 | print(extr_cmd)
102 | extr_ret = subprocess.check_output(extr_cmd, shell=True, env=env).decode(
103 | sys.stdout.encoding)
104 | with open('pextract.txt', 'w') as f:
105 | f.write(extr_ret)
106 |
107 | extr_ret = extr_ret.split('\n')
108 | svox_str = extr_ret[-9]
109 | capacity = int(svox_str.split()[3].split(':')[1].split('/')[0])
110 |
111 | parse_metrics = lambda x: map(float, x.split()[2::2])
112 | psnr, ssim, lpips = parse_metrics(extr_ret[-2])
113 | print(': ', octree_name, 'RAW capacity',
114 | capacity, 'PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips)
115 | out_file.write(f'{capacity}\n{psnr:.10f} {ssim:.10f} {lpips:.10f}\n')
116 |
117 | print('! Optimize', train_dir, octree_name)
118 | opt_cmd = ' '.join(opt_base_cmd + opt_flags)
119 | print(opt_cmd)
120 | subprocess.call(opt_cmd, shell=True, env=env)
121 |
122 | if osp.exists(octree_opt_file):
123 | print('! Eval', train_dir, octree_name)
124 | eval_cmd = ' '.join(eval_base_cmd + eval_flags)
125 | print(eval_cmd)
126 | eval_ret = subprocess.check_output(eval_cmd, shell=True, env=env).decode(
127 | sys.stdout.encoding)
128 | eval_ret = eval_ret.split('\n')
129 |
130 | epsnr, essim, elpips = parse_metrics(eval_ret[-2])
131 | print(':', octree_name, 'OPT capacity',
132 | capacity, 'PSNR', epsnr, 'SSIM', essim, 'LPIPS', elpips)
133 | out_file.write(f'{epsnr:.10f} {essim:.10f} {elpips:.10f}\n')
134 | else:
135 | print('! Eval skipped')
136 | out_file.write(f'{psnr:.10f} {ssim:.10f} {lpips:.10f}\n')
137 |
138 |
139 |
140 | def process_main(device, queue):
141 | # Set CUDA_VISIBLE_DEVICES programmatically
142 | env = os.environ.copy()
143 | env["CUDA_VISIBLE_DEVICES"] = str(device)
144 | while True:
145 | task = queue.get()
146 | if len(task) == 0:
147 | break
148 | convert_one(env, **task)
149 |
150 | if __name__=='__main__':
151 | with open(args.task_json, 'r') as f:
152 | tasks_file = json.load(f)
153 | all_tasks = tasks_file.get('tasks', [])
154 | data_root = tasks_file['data_root']
155 | train_root = tasks_file['train_root']
156 | pqueue = Queue()
157 | # Scene_tasks generated per scene (use {%} to mean scene name)
158 | if 'scene_tasks' in tasks_file:
159 | symb = '{%}'
160 | scenes = tasks_file['scenes']
161 | for scene_task in tasks_file['scene_tasks']:
162 | for scene in scenes:
163 | task = scene_task.copy()
164 | task['data_dir'] = scene_task['data_dir'].replace(symb, scene)
165 | task['train_dir'] = scene_task['train_dir'].replace(symb, scene)
166 | task['octree_name'] = scene_task['octree_name'].replace(symb, scene)
167 | all_tasks.append(task)
168 |
169 | print(len(all_tasks), 'total tasks')
170 |
171 | for task in all_tasks:
172 | task['train_dir'] = osp.join(train_root, task['train_dir'])
173 | task['data_dir'] = osp.join(data_root, task['data_dir'])
174 | octrees_dir = osp.join(task['data_dir'], 'octrees')
175 | os.makedirs(octrees_dir, exist_ok=True)
176 | # santity check
177 | assert os.path.exists(task['train_dir']), task['train_dir']
178 | assert os.path.exists(task['data_dir']), task['data_dir']
179 |
180 | for task in all_tasks:
181 | pqueue.put(task)
182 | pqueue.put({})
183 |
184 | args.gpus = list(map(int, args.gpus.split()))
185 | print('GPUS:', args.gpus)
186 |
187 | all_procs = []
188 | for i, gpu in enumerate(args.gpus):
189 | process = Process(target=process_main, args=(gpu, pqueue))
190 | process.daemon = True
191 | process.start()
192 | all_procs.append(process)
193 |
194 | for i, gpu in enumerate(args.gpus):
195 | all_procs[i].join()
196 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorboard
2 | imageio
3 | imageio-ffmpeg
4 | ipdb
5 | lpips
6 | jax
7 | jaxlib
8 | flax
9 | opencv-python
10 | Pillow
11 | pyyaml
12 | tensorflow==2.3.1
13 | pymcubes
14 | svox>=0.2.26
15 |
--------------------------------------------------------------------------------