├── configs ├── render_config.gin ├── tat.gin ├── llff_raw_test.gin ├── 360.gin ├── 360_glo4.gin ├── debug.gin ├── llff_512.gin ├── blender_512.gin ├── llff_256.gin ├── blender_256.gin ├── blender_refnerf.gin └── llff_raw.gin ├── .gitignore ├── requirements.txt ├── internal ├── pycolmap │ ├── pycolmap │ │ ├── __init__.py │ │ ├── image.py │ │ ├── camera.py │ │ ├── database.py │ │ └── rotation.py │ ├── README.md │ ├── setup.py │ └── LICENSE ├── math.py ├── geopoly.py ├── utils.py ├── image.py ├── coord.py ├── ref_utils.py ├── render.py ├── configs.py └── vis.py ├── tests ├── utils_test.py ├── camera_utils_test.py ├── ref_utils_test.py ├── datasets_test.py ├── image_test.py ├── math_test.py ├── geopoly_test.py └── coord_test.py ├── scripts ├── run_all_unit_tests.sh ├── eval_360.sh ├── eval_llff.sh ├── eval_raw.sh ├── eval_blender.sh ├── eval_shinyblender.sh ├── train_llff.sh ├── train_raw.sh ├── train_blender.sh ├── train_shinyblender.sh ├── train_360.sh ├── render_360.sh ├── render_llff.sh ├── render_raw.sh └── local_colmap_and_resize.sh ├── CONTRIBUTING.md ├── render.py ├── eval.py ├── LICENSE └── README.md /configs/render_config.gin: -------------------------------------------------------------------------------- 1 | Config.render_path = True 2 | Config.render_path_frames = 480 3 | Config.render_video_fps = 60 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | internal/pycolmap 2 | __pycache__/ 3 | interal/__pycache__/ 4 | tests/__pycache__/ 5 | .DS_Store 6 | .vscode/ 7 | .idea/ 8 | __MACOSX/ 9 | -------------------------------------------------------------------------------- /configs/tat.gin: -------------------------------------------------------------------------------- 1 | # This config is meant to be run while overriding a 360*.gin config. 2 | 3 | Config.dataset_loader = 'tat_nerfpp' 4 | Config.near = 0.1 5 | Config.far = 1e6 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | jax 3 | jaxlib 4 | flax 5 | opencv-python 6 | Pillow 7 | tensorboard 8 | tensorflow 9 | gin-config 10 | dm_pix 11 | rawpy 12 | mediapy 13 | -------------------------------------------------------------------------------- /configs/llff_raw_test.gin: -------------------------------------------------------------------------------- 1 | include 'experimental/users/barron/mipnerf360/configs/llff_raw.gin' 2 | 3 | Config.factor = 0 4 | Config.eval_raw_affine_cc = True 5 | Config.eval_crop_borders = 16 6 | Config.vis_decimate = 4 7 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/__init__.py: -------------------------------------------------------------------------------- 1 | from .camera import Camera 2 | from .database import COLMAPDatabase 3 | from .image import Image 4 | from .scene_manager import SceneManager 5 | from .rotation import Quaternion, DualQuaternion 6 | -------------------------------------------------------------------------------- /internal/pycolmap/README.md: -------------------------------------------------------------------------------- 1 | # pycolmap 2 | Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions. 3 | 4 | This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output. 5 | -------------------------------------------------------------------------------- /configs/360.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0.2 3 | Config.far = 1e6 4 | Config.factor = 4 5 | Config.batch_size = 1024 6 | 7 | Model.raydist_fn = @jnp.reciprocal 8 | Model.opaque_background = True 9 | 10 | PropMLP.warp_fn = @coord.contract 11 | PropMLP.net_depth = 4 12 | PropMLP.net_width = 256 13 | PropMLP.disable_density_normals = True 14 | PropMLP.disable_rgb = True 15 | 16 | NerfMLP.warp_fn = @coord.contract 17 | NerfMLP.net_depth = 8 18 | NerfMLP.net_width = 1024 19 | NerfMLP.disable_density_normals = True 20 | 21 | -------------------------------------------------------------------------------- /configs/360_glo4.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0.2 3 | Config.far = 1e6 4 | Config.factor = 4 5 | 6 | Model.raydist_fn = @jnp.reciprocal 7 | Model.num_glo_features = 4 8 | Model.opaque_background = True 9 | 10 | PropMLP.warp_fn = @coord.contract 11 | PropMLP.net_depth = 4 12 | PropMLP.net_width = 256 13 | PropMLP.disable_density_normals = True 14 | PropMLP.disable_rgb = True 15 | 16 | NerfMLP.warp_fn = @coord.contract 17 | NerfMLP.net_depth = 8 18 | NerfMLP.net_width = 1024 19 | NerfMLP.disable_density_normals = True 20 | -------------------------------------------------------------------------------- /configs/debug.gin: -------------------------------------------------------------------------------- 1 | # A short training schedule with no "warm up", useful for debugging. 2 | Config.checkpoint_every = 1000 3 | Config.print_every = 100 4 | Config.train_render_every = 1000 5 | Config.lr_delay_mult = 0.1 6 | Config.lr_delay_steps = 500 7 | Config.batch_size = 2048 8 | Config.render_chunk_size = 2048 9 | Config.lr_init = 5e-4 10 | Config.lr_final = 5e-6 11 | Config.factor = 4 12 | Config.early_exit_steps = 3000 13 | 14 | PropMLP.net_depth = 2 15 | PropMLP.net_width = 64 16 | 17 | NerfMLP.net_depth = 4 18 | NerfMLP.net_width = 128 19 | -------------------------------------------------------------------------------- /configs/llff_512.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 4 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | 8 | Model.ray_shape = 'cylinder' 9 | Model.opaque_background = True 10 | Model.num_levels = 2 11 | Model.num_prop_samples = 128 12 | Model.num_nerf_samples = 32 13 | 14 | PropMLP.net_depth = 4 15 | PropMLP.net_width = 256 16 | PropMLP.disable_density_normals = True 17 | PropMLP.disable_rgb = True 18 | 19 | NerfMLP.net_depth = 8 20 | NerfMLP.net_width = 512 21 | NerfMLP.disable_density_normals = True 22 | 23 | NerfMLP.max_deg_point = 16 24 | PropMLP.max_deg_point = 16 25 | -------------------------------------------------------------------------------- /configs/blender_512.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | Config.near = 2 4 | Config.far = 6 5 | Config.eval_render_interval = 5 6 | Config.data_loss_type = 'mse' 7 | Config.adam_eps = 1e-8 8 | 9 | Model.num_levels = 2 10 | Model.num_prop_samples = 128 11 | Model.num_nerf_samples = 32 12 | 13 | PropMLP.net_depth = 4 14 | PropMLP.net_width = 256 15 | PropMLP.disable_density_normals = True 16 | PropMLP.disable_rgb = True 17 | 18 | NerfMLP.net_depth = 8 19 | NerfMLP.net_width = 512 20 | NerfMLP.disable_density_normals = True 21 | 22 | Config.distortion_loss_mult = 0. 23 | 24 | NerfMLP.max_deg_point = 16 25 | PropMLP.max_deg_point = 16 26 | -------------------------------------------------------------------------------- /internal/pycolmap/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="pycolmap", 8 | version="0.0.1", 9 | author="True Price", 10 | description="PyColmap", 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/google/nerfies/third_party/pycolmap", 14 | packages=setuptools.find_packages(), 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], 20 | python_requires='>=3.6', 21 | ) 22 | -------------------------------------------------------------------------------- /configs/llff_256.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'llff' 2 | Config.near = 0. 3 | Config.far = 1. 4 | Config.factor = 4 5 | Config.forward_facing = True 6 | Config.adam_eps = 1e-8 7 | 8 | Model.ray_shape = 'cylinder' 9 | Model.opaque_background = True 10 | Model.num_levels = 2 11 | Model.num_prop_samples = 128 12 | Model.num_nerf_samples = 32 13 | 14 | PropMLP.net_depth = 4 15 | PropMLP.net_width = 256 16 | PropMLP.basis_shape = 'octahedron' 17 | PropMLP.basis_subdivisions = 1 18 | PropMLP.disable_density_normals = True 19 | PropMLP.disable_rgb = True 20 | 21 | NerfMLP.net_depth = 8 22 | NerfMLP.net_width = 256 23 | NerfMLP.basis_shape = 'octahedron' 24 | NerfMLP.basis_subdivisions = 1 25 | NerfMLP.disable_density_normals = True 26 | 27 | NerfMLP.max_deg_point = 16 28 | PropMLP.max_deg_point = 16 29 | -------------------------------------------------------------------------------- /configs/blender_256.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | Config.near = 2 4 | Config.far = 6 5 | Config.eval_render_interval = 5 6 | Config.data_loss_type = 'mse' 7 | Config.adam_eps = 1e-8 8 | 9 | Model.num_levels = 2 10 | Model.num_prop_samples = 128 11 | Model.num_nerf_samples = 32 12 | 13 | PropMLP.net_depth = 4 14 | PropMLP.net_width = 256 15 | PropMLP.basis_shape = 'octahedron' 16 | PropMLP.basis_subdivisions = 1 17 | PropMLP.disable_density_normals = True 18 | PropMLP.disable_rgb = True 19 | 20 | NerfMLP.net_depth = 8 21 | NerfMLP.net_width = 256 22 | NerfMLP.basis_shape = 'octahedron' 23 | NerfMLP.basis_subdivisions = 1 24 | NerfMLP.disable_density_normals = True 25 | 26 | Config.distortion_loss_mult = 0. 27 | 28 | NerfMLP.max_deg_point = 16 29 | PropMLP.max_deg_point = 16 30 | -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | from absl.testing import absltest 18 | 19 | from internal import utils 20 | 21 | 22 | class UtilsTest(absltest.TestCase): 23 | 24 | def test_dummy_rays(self): 25 | """Ensures that the dummy Rays object is correctly initialized.""" 26 | rays = utils.dummy_rays() 27 | self.assertEqual(rays.origins.shape[-1], 3) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/image.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Image 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | 12 | class Image: 13 | 14 | def __init__(self, name_, camera_id_, q_, tvec_): 15 | self.name = name_ 16 | self.camera_id = camera_id_ 17 | self.q = q_ 18 | self.tvec = tvec_ 19 | 20 | self.points2D = np.empty((0, 2), dtype=np.float64) 21 | self.point3D_ids = np.empty((0,), dtype=np.uint64) 22 | 23 | #--------------------------------------------------------------------------- 24 | 25 | def R(self): 26 | return self.q.ToR() 27 | 28 | #--------------------------------------------------------------------------- 29 | 30 | def C(self): 31 | return -self.R().T.dot(self.tvec) 32 | 33 | #--------------------------------------------------------------------------- 34 | 35 | @property 36 | def t(self): 37 | return self.tvec 38 | -------------------------------------------------------------------------------- /scripts/run_all_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | python -m unittest tests.camera_utils_test 18 | python -m unittest tests.geopoly_test 19 | python -m unittest tests.stepfun_test 20 | python -m unittest tests.coord_test 21 | python -m unittest tests.image_test 22 | python -m unittest tests.ref_utils_test 23 | python -m unittest tests.utils_test 24 | python -m unittest tests.datasets_test 25 | python -m unittest tests.math_test 26 | python -m unittest tests.render_test 27 | -------------------------------------------------------------------------------- /scripts/eval_360.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=gardenvase 19 | EXPERIMENT=360 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_real_360 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m eval \ 24 | --gin_configs=configs/360.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --logtostderr 28 | -------------------------------------------------------------------------------- /scripts/eval_llff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=flower 19 | EXPERIMENT=llff 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_llff_data 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m eval \ 24 | --gin_configs=configs/llff_256.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --logtostderr 28 | -------------------------------------------------------------------------------- /scripts/eval_raw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=nightpiano 19 | EXPERIMENT=raw 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/rawnerf/scenes 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m eval \ 24 | --gin_configs=configs/llff_raw.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --logtostderr 28 | -------------------------------------------------------------------------------- /scripts/eval_blender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=ficus 19 | EXPERIMENT=blender 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/dors_nerf_synthetic 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m eval \ 24 | --gin_configs=configs/blender_256.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --logtostderr 28 | -------------------------------------------------------------------------------- /scripts/eval_shinyblender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=toaster 19 | EXPERIMENT=shinyblender 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/dors_nerf_synthetic 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m eval \ 24 | --gin_configs=configs/blender_refnerf.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --logtostderr 28 | -------------------------------------------------------------------------------- /scripts/train_llff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=flower 19 | EXPERIMENT=llff 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_llff_data 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | rm "$CHECKPOINT_DIR"/* 24 | python -m train \ 25 | --gin_configs=configs/llff_256.gin \ 26 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 27 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/train_raw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=nightpiano 19 | EXPERIMENT=raw 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/rawnerf/scenes 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | rm "$CHECKPOINT_DIR"/* 24 | python -m train \ 25 | --gin_configs=configs/llff_raw.gin \ 26 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 27 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/train_blender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=ficus 19 | EXPERIMENT=blender 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_synthetic 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | rm "$CHECKPOINT_DIR"/* 24 | python -m train \ 25 | --gin_configs=configs/blender_256.gin \ 26 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 27 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /internal/pycolmap/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 True Price, UNC Chapel Hill 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/train_shinyblender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=toaster 19 | EXPERIMENT=shinyblender 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/dors_nerf_synthetic 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | rm "$CHECKPOINT_DIR"/* 24 | python -m train \ 25 | --gin_configs=configs/blender_refnerf.gin \ 26 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 27 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 28 | --logtostderr 29 | -------------------------------------------------------------------------------- /scripts/train_360.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=garden 19 | EXPERIMENT=360_v2 20 | DATA_DIR=/media/hjx/DataDisk/360_v2 21 | CHECKPOINT_DIR=/home/hjx/Videos/mipnerf360/"$EXPERIMENT"/"$SCENE" 22 | 23 | # If running one of the indoor scenes, add 24 | # --gin_bindings="Config.factor = 2" 25 | 26 | rm "$CHECKPOINT_DIR"/* 27 | python -m train \ 28 | --gin_configs=configs/360.gin \ 29 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 30 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 31 | --logtostderr 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /scripts/render_360.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=gardenvase 19 | EXPERIMENT=360 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_real_360 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m render \ 24 | --gin_configs=configs/360.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --gin_bindings="Config.render_path = True" \ 28 | --gin_bindings="Config.render_path_frames = 10" \ 29 | --gin_bindings="Config.render_dir = '${CHECKPOINT_DIR}/render/'" \ 30 | --gin_bindings="Config.render_video_fps = 2" \ 31 | --logtostderr 32 | -------------------------------------------------------------------------------- /scripts/render_llff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=flower 19 | EXPERIMENT=llff 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/nerf_llff_data 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m render \ 24 | --gin_configs=configs/llff_256.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --gin_bindings="Config.render_path = True" \ 28 | --gin_bindings="Config.render_path_frames = 10" \ 29 | --gin_bindings="Config.render_dir = '${CHECKPOINT_DIR}/render/'" \ 30 | --gin_bindings="Config.render_video_fps = 2" \ 31 | --logtostderr 32 | -------------------------------------------------------------------------------- /scripts/render_raw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | SCENE=nightpiano 19 | EXPERIMENT=raw 20 | DATA_DIR=/usr/local/google/home/barron/tmp/nerf_data/rawnerf/scenes 21 | CHECKPOINT_DIR=/usr/local/google/home/barron/tmp/nerf_results/"$EXPERIMENT"/"$SCENE" 22 | 23 | python -m render \ 24 | --gin_configs=configs/llff_raw.gin \ 25 | --gin_bindings="Config.data_dir = '${DATA_DIR}/${SCENE}'" \ 26 | --gin_bindings="Config.checkpoint_dir = '${CHECKPOINT_DIR}'" \ 27 | --gin_bindings="Config.render_path = True" \ 28 | --gin_bindings="Config.render_path_frames = 10" \ 29 | --gin_bindings="Config.render_dir = '${CHECKPOINT_DIR}/render/'" \ 30 | --gin_bindings="Config.render_video_fps = 2" \ 31 | --logtostderr 32 | -------------------------------------------------------------------------------- /configs/blender_refnerf.gin: -------------------------------------------------------------------------------- 1 | Config.dataset_loader = 'blender' 2 | Config.batching = 'single_image' 3 | Config.near = 2 4 | Config.far = 6 5 | Config.eval_render_interval = 5 6 | Config.compute_normal_metrics = True 7 | Config.data_loss_type = 'mse' 8 | Config.distortion_loss_mult = 0.0 9 | Config.orientation_loss_mult = 0.1 10 | Config.orientation_loss_target = 'normals_pred' 11 | Config.predicted_normal_loss_mult = 3e-4 12 | Config.orientation_coarse_loss_mult = 0.01 13 | Config.predicted_normal_coarse_loss_mult = 3e-5 14 | Config.interlevel_loss_mult = 0.0 15 | Config.data_coarse_loss_mult = 0.1 16 | Config.adam_eps = 1e-8 17 | 18 | Model.num_levels = 2 19 | Model.single_mlp = True 20 | Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True. 21 | Model.num_nerf_samples = 128 22 | Model.anneal_slope = 0. 23 | Model.dilation_multiplier = 0. 24 | Model.dilation_bias = 0. 25 | Model.single_jitter = False 26 | Model.resample_padding = 0.01 27 | 28 | NerfMLP.net_depth = 8 29 | NerfMLP.net_width = 256 30 | NerfMLP.net_depth_viewdirs = 8 31 | NerfMLP.basis_shape = 'octahedron' 32 | NerfMLP.basis_subdivisions = 1 33 | NerfMLP.disable_density_normals = False 34 | NerfMLP.enable_pred_normals = True 35 | NerfMLP.use_directional_enc = True 36 | NerfMLP.use_reflections = True 37 | NerfMLP.deg_view = 5 38 | NerfMLP.enable_pred_roughness = True 39 | NerfMLP.use_diffuse_color = True 40 | NerfMLP.use_specular_tint = True 41 | NerfMLP.use_n_dot_v = True 42 | NerfMLP.bottleneck_width = 128 43 | NerfMLP.density_bias = 0.5 44 | NerfMLP.max_deg_point = 16 45 | -------------------------------------------------------------------------------- /configs/llff_raw.gin: -------------------------------------------------------------------------------- 1 | # General LLFF settings 2 | 3 | Config.dataset_loader = 'llff' 4 | Config.near = 0. 5 | Config.far = 1. 6 | Config.factor = 4 7 | Config.forward_facing = True 8 | 9 | Model.ray_shape = 'cylinder' 10 | 11 | PropMLP.net_depth = 4 12 | PropMLP.net_width = 256 13 | PropMLP.basis_shape = 'octahedron' 14 | PropMLP.basis_subdivisions = 1 15 | PropMLP.disable_density_normals = True # Turn this off if using orientation loss. 16 | PropMLP.disable_rgb = True 17 | 18 | NerfMLP.net_depth = 8 19 | NerfMLP.net_width = 256 20 | NerfMLP.basis_shape = 'octahedron' 21 | NerfMLP.basis_subdivisions = 1 22 | NerfMLP.disable_density_normals = True # Turn this off if using orientation loss. 23 | 24 | NerfMLP.max_deg_point = 16 25 | PropMLP.max_deg_point = 16 26 | 27 | Config.train_render_every = 5000 28 | 29 | 30 | ########################## RawNeRF specific settings ########################## 31 | 32 | Config.rawnerf_mode = True 33 | Config.data_loss_type = 'rawnerf' 34 | Config.apply_bayer_mask = True 35 | Model.learned_exposure_scaling = True 36 | 37 | Model.num_levels = 2 38 | Model.num_prop_samples = 128 # Using extra samples for now because of noise instability. 39 | Model.num_nerf_samples = 128 40 | Model.opaque_background = True 41 | 42 | # RGB activation we use for linear color outputs is exp(x - 5). 43 | NerfMLP.rgb_padding = 0. 44 | NerfMLP.rgb_activation = @math.safe_exp 45 | NerfMLP.rgb_bias = -5. 46 | PropMLP.rgb_padding = 0. 47 | PropMLP.rgb_activation = @math.safe_exp 48 | PropMLP.rgb_bias = -5. 49 | 50 | ## Experimenting with the various regularizers and losses: 51 | Config.interlevel_loss_mult = .0 # Turning off interlevel for now (default = 1.). 52 | Config.distortion_loss_mult = .01 # Distortion loss helps with floaters (default = .01). 53 | Config.orientation_loss_mult = 0. # Orientation loss also not great (try .01). 54 | Config.data_coarse_loss_mult = 0.1 # Setting this to match old MipNeRF. 55 | 56 | ## Density noise used in original NeRF: 57 | NerfMLP.density_noise = 1. 58 | PropMLP.density_noise = 1. 59 | 60 | ## Use a single MLP for all rounds of sampling: 61 | Model.single_mlp = True 62 | 63 | ## Some algorithmic settings to match the paper: 64 | Model.anneal_slope = 0. 65 | Model.dilation_multiplier = 0. 66 | Model.dilation_bias = 0. 67 | Model.single_jitter = False 68 | NerfMLP.weight_init = 'glorot_uniform' 69 | PropMLP.weight_init = 'glorot_uniform' 70 | 71 | ## Training hyperparameters used in the paper: 72 | Config.batch_size = 16384 73 | Config.render_chunk_size = 16384 74 | Config.lr_init = 1e-3 75 | Config.lr_final = 1e-5 76 | Config.max_steps = 500000 77 | Config.checkpoint_every = 25000 78 | Config.lr_delay_steps = 2500 79 | Config.lr_delay_mult = 0.01 80 | Config.grad_max_norm = 0.1 81 | Config.grad_max_val = 0.1 82 | Config.adam_eps = 1e-8 83 | -------------------------------------------------------------------------------- /tests/camera_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for camera_utils.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import camera_utils 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | class CameraUtilsTest(parameterized.TestCase): 26 | 27 | def test_convert_to_ndc(self): 28 | rng = random.PRNGKey(0) 29 | for _ in range(10): 30 | # Random pinhole camera intrinsics. 31 | key, rng = random.split(rng) 32 | focal, width, height = random.uniform(key, (3,), minval=100., maxval=200.) 33 | camtopix = camera_utils.intrinsic_matrix(focal, focal, width / 2., 34 | height / 2.) 35 | pixtocam = np.linalg.inv(camtopix) 36 | near = 1. 37 | 38 | # Random rays, pointing forward (negative z direction). 39 | num_rays = 1000 40 | key, rng = random.split(rng) 41 | origins = jnp.array([0., 0., 1.]) 42 | origins += random.uniform(key, (num_rays, 3), minval=-1., maxval=1.) 43 | directions = jnp.array([0., 0., -1.]) 44 | directions += random.uniform(key, (num_rays, 3), minval=-.5, maxval=.5) 45 | 46 | # Project world-space points along each ray into NDC space. 47 | t = jnp.linspace(0., 1., 10) 48 | pts_world = origins + t[:, None, None] * directions 49 | pts_ndc = jnp.stack([ 50 | -focal / (.5 * width) * pts_world[..., 0] / pts_world[..., 2], 51 | -focal / (.5 * height) * pts_world[..., 1] / pts_world[..., 2], 52 | 1. + 2. * near / pts_world[..., 2], 53 | ], 54 | axis=-1) 55 | 56 | # Get NDC space rays. 57 | origins_ndc, directions_ndc = camera_utils.convert_to_ndc( 58 | origins, directions, pixtocam, near) 59 | 60 | # Ensure that the NDC space points lie on the calculated rays. 61 | directions_ndc_norm = jnp.linalg.norm( 62 | directions_ndc, axis=-1, keepdims=True) 63 | directions_ndc_unit = directions_ndc / directions_ndc_norm 64 | projection = ((pts_ndc - origins_ndc) * directions_ndc_unit).sum(axis=-1) 65 | pts_ndc_proj = origins_ndc + directions_ndc_unit * projection[..., None] 66 | 67 | # pts_ndc should be close to their projections pts_ndc_proj onto the rays. 68 | np.testing.assert_allclose(pts_ndc, pts_ndc_proj, atol=1e-5, rtol=1e-5) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /scripts/local_colmap_and_resize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Set to 0 if you do not have a GPU. 18 | USE_GPU=1 19 | # Path to a directory `base/` with images in `base/images/`. 20 | DATASET_PATH=$1 21 | # Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye. 22 | CAMERA=${2:-OPENCV} 23 | 24 | 25 | # Run COLMAP. 26 | 27 | ### Feature extraction 28 | 29 | colmap feature_extractor \ 30 | --database_path "$DATASET_PATH"/database.db \ 31 | --image_path "$DATASET_PATH"/images \ 32 | --ImageReader.single_camera 1 \ 33 | --ImageReader.camera_model "$CAMERA" \ 34 | --SiftExtraction.use_gpu "$USE_GPU" 35 | 36 | 37 | ### Feature matching 38 | 39 | colmap exhaustive_matcher \ 40 | --database_path "$DATASET_PATH"/database.db \ 41 | --SiftMatching.use_gpu "$USE_GPU" 42 | 43 | ## Use if your scene has > 500 images 44 | ## Replace this path with your own local copy of the file. 45 | ## Download from: https://demuc.de/colmap/#download 46 | # VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin 47 | # colmap vocab_tree_matcher \ 48 | # --database_path "$DATASET_PATH"/database.db \ 49 | # --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \ 50 | # --SiftMatching.use_gpu "$USE_GPU" 51 | 52 | 53 | ### Bundle adjustment 54 | 55 | # The default Mapper tolerance is unnecessarily large, 56 | # decreasing it speeds up bundle adjustment steps. 57 | mkdir -p "$DATASET_PATH"/sparse 58 | colmap mapper \ 59 | --database_path "$DATASET_PATH"/database.db \ 60 | --image_path "$DATASET_PATH"/images \ 61 | --output_path "$DATASET_PATH"/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001 63 | 64 | 65 | ### Image undistortion 66 | 67 | ## Use this if you want to undistort your images into ideal pinhole intrinsics. 68 | # mkdir -p "$DATASET_PATH"/dense 69 | # colmap image_undistorter \ 70 | # --image_path "$DATASET_PATH"/images \ 71 | # --input_path "$DATASET_PATH"/sparse/0 \ 72 | # --output_path "$DATASET_PATH"/dense \ 73 | # --output_type COLMAP 74 | 75 | # Resize images. 76 | 77 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_2 78 | 79 | pushd "$DATASET_PATH"/images_2 80 | ls | xargs -P 8 -I {} mogrify -resize 50% {} 81 | popd 82 | 83 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_4 84 | 85 | pushd "$DATASET_PATH"/images_4 86 | ls | xargs -P 8 -I {} mogrify -resize 25% {} 87 | popd 88 | 89 | cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_8 90 | 91 | pushd "$DATASET_PATH"/images_8 92 | ls | xargs -P 8 -I {} mogrify -resize 12.5% {} 93 | popd 94 | -------------------------------------------------------------------------------- /tests/ref_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ref_utils.""" 16 | 17 | from absl.testing import absltest 18 | from internal import ref_utils 19 | from jax import random 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import scipy 23 | 24 | 25 | def generate_dir_enc_fn_scipy(deg_view): 26 | """Return spherical harmonics using scipy.special.sph_harm.""" 27 | ml_array = ref_utils.get_ml_array(deg_view) 28 | 29 | def dir_enc_fn(theta, phi): 30 | de = [scipy.special.sph_harm(m, l, phi, theta) for m, l in ml_array.T] 31 | de = np.stack(de, axis=-1) 32 | # Split into real and imaginary parts. 33 | return np.concatenate([np.real(de), np.imag(de)], axis=-1) 34 | 35 | return dir_enc_fn 36 | 37 | 38 | class RefUtilsTest(absltest.TestCase): 39 | 40 | def test_reflection(self): 41 | """Make sure reflected vectors have the same angle from normals as input.""" 42 | rng = random.PRNGKey(0) 43 | for shape in [(45, 3), (4, 7, 3)]: 44 | key, rng = random.split(rng) 45 | normals = random.normal(key, shape) 46 | key, rng = random.split(rng) 47 | directions = random.normal(key, shape) 48 | 49 | # Normalize normal vectors. 50 | normals = normals / ( 51 | jnp.linalg.norm(normals, axis=-1, keepdims=True) + 1e-10) 52 | 53 | reflected_directions = ref_utils.reflect(directions, normals) 54 | 55 | cos_angle_original = jnp.sum(directions * normals, axis=-1) 56 | cos_angle_reflected = jnp.sum(reflected_directions * normals, axis=-1) 57 | 58 | np.testing.assert_allclose( 59 | cos_angle_original, cos_angle_reflected, atol=1E-5, rtol=1E-5) 60 | 61 | def test_spherical_harmonics(self): 62 | """Make sure the fast spherical harmonics are accurate.""" 63 | shape = (12, 11, 13) 64 | 65 | # Generate random points on sphere. 66 | rng = random.PRNGKey(0) 67 | key1, key2 = random.split(rng) 68 | theta = random.uniform(key1, shape, minval=0.0, maxval=jnp.pi) 69 | phi = random.uniform(key2, shape, minval=0.0, maxval=2.0*jnp.pi) 70 | 71 | # Convert to Cartesian coordinates. 72 | x = jnp.sin(theta) * jnp.cos(phi) 73 | y = jnp.sin(theta) * jnp.sin(phi) 74 | z = jnp.cos(theta) 75 | xyz = jnp.stack([x, y, z], axis=-1) 76 | 77 | deg_view = 5 78 | de = ref_utils.generate_dir_enc_fn(deg_view)(xyz) 79 | de_scipy = generate_dir_enc_fn_scipy(deg_view)(theta, phi) 80 | 81 | np.testing.assert_allclose( 82 | de, de_scipy, atol=0.02, rtol=1e6) # Only use atol. 83 | self.assertFalse(jnp.any(jnp.isnan(de))) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | -------------------------------------------------------------------------------- /tests/datasets_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for datasets.""" 16 | 17 | from absl.testing import absltest 18 | from internal import camera_utils 19 | from internal import configs 20 | from internal import datasets 21 | from jax import random 22 | import numpy as np 23 | 24 | 25 | class DummyDataset(datasets.Dataset): 26 | 27 | def _load_renderings(self, config): 28 | """Generates dummy image and pose data.""" 29 | self._n_examples = 2 30 | self.height = 3 31 | self.width = 4 32 | self._resolution = self.height * self.width 33 | self.focal = 5. 34 | self.pixtocams = np.linalg.inv( 35 | camera_utils.intrinsic_matrix(self.focal, self.focal, self.width * 0.5, 36 | self.height * 0.5)) 37 | 38 | rng = random.PRNGKey(0) 39 | 40 | key, rng = random.split(rng) 41 | images_shape = (self._n_examples, self.height, self.width, 3) 42 | self.images = random.uniform(key, images_shape) 43 | 44 | key, rng = random.split(rng) 45 | self.camtoworlds = np.stack([ 46 | camera_utils.viewmatrix(*random.normal(k, (3, 3))) 47 | for k in random.split(key, self._n_examples) 48 | ], 49 | axis=0) 50 | 51 | 52 | class DatasetsTest(absltest.TestCase): 53 | 54 | def test_dataset_batch_creation(self): 55 | np.random.seed(0) 56 | config = configs.Config(batch_size=8) 57 | 58 | # Check shapes are consistent across all ray attributes. 59 | for split in ['train', 'test']: 60 | dummy_dataset = DummyDataset(split, '', config) 61 | rays = dummy_dataset.peek().rays 62 | sh_gt = rays.origins.shape[:-1] 63 | for z in rays.__dict__.values(): 64 | if z is not None: 65 | self.assertEqual(z.shape[:-1], sh_gt) 66 | 67 | # Check test batch generation matches golden data. 68 | dummy_dataset = DummyDataset('test', '', config) 69 | batch = dummy_dataset.peek() 70 | 71 | rgb = batch.rgb.ravel() 72 | rgb_gt = np.array([ 73 | 0.5289556, 0.28869557, 0.24527192, 0.12083626, 0.8904066, 0.6259936, 74 | 0.57573485, 0.09355974, 0.8017353, 0.538651, 0.4998169, 0.42061496, 75 | 0.5591258, 0.00577283, 0.6804651, 0.9139203, 0.00444758, 0.96962905, 76 | 0.52956843, 0.38282406, 0.28777933, 0.6640035, 0.39736128, 0.99495006, 77 | 0.13100398, 0.7597165, 0.8532667, 0.67468107, 0.6804743, 0.26873016, 78 | 0.60699487, 0.5722265, 0.44482303, 0.6511061, 0.54807067, 0.09894073 79 | ]) 80 | np.testing.assert_allclose(rgb, rgb_gt, atol=1e-4, rtol=1e-4) 81 | 82 | ray_origins = batch.rays.origins.ravel() 83 | ray_origins_gt = np.array([ 84 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, 85 | -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469, 86 | -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224, 87 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, 88 | -0.8818224, -0.20050469, -0.6451472, -0.8818224, -0.20050469, 89 | -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224, 90 | -0.20050469, -0.6451472, -0.8818224, -0.20050469, -0.6451472, -0.8818224 91 | ]) 92 | np.testing.assert_allclose( 93 | ray_origins, ray_origins_gt, atol=1e-4, rtol=1e-4) 94 | 95 | ray_dirs = batch.rays.directions.ravel() 96 | ray_dirs_gt = np.array([ 97 | 0.24370372, 0.89296186, -0.5227117, 0.05601424, 0.8468699, -0.57417226, 98 | -0.13167524, 0.8007779, -0.62563276, -0.31936473, 0.75468594, 99 | -0.67709327, 0.17780769, 0.96766925, -0.34928587, -0.0098818, 0.9215773, 100 | -0.4007464, -0.19757128, 0.87548524, -0.4522069, -0.38526076, 101 | 0.82939327, -0.5036674, 0.11191163, 1.0423766, -0.17586003, -0.07577785, 102 | 0.9962846, -0.22732055, -0.26346734, 0.95019263, -0.2787811, 103 | -0.45115682, 0.90410066, -0.3302416 104 | ]) 105 | np.testing.assert_allclose(ray_dirs, ray_dirs_gt, atol=1e-4, rtol=1e-4) 106 | 107 | 108 | if __name__ == '__main__': 109 | absltest.main() 110 | -------------------------------------------------------------------------------- /internal/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mathy utility functions.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def matmul(a, b): 22 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 23 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 24 | 25 | 26 | def safe_trig_helper(x, fn, t=100 * jnp.pi): 27 | """Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" 28 | return fn(jnp.where(jnp.abs(x) < t, x, x % t)) 29 | 30 | 31 | def safe_cos(x): 32 | """jnp.cos() on a TPU may NaN out for large values.""" 33 | return safe_trig_helper(x, jnp.cos) 34 | 35 | 36 | def safe_sin(x): 37 | """jnp.sin() on a TPU may NaN out for large values.""" 38 | return safe_trig_helper(x, jnp.sin) 39 | 40 | 41 | @jax.custom_jvp 42 | def safe_exp(x): 43 | """jnp.exp() but with finite output and gradients for large inputs.""" 44 | return jnp.exp(jnp.minimum(x, 88.)) # jnp.exp(89) is infinity. 45 | 46 | 47 | @safe_exp.defjvp 48 | def safe_exp_jvp(primals, tangents): 49 | """Override safe_exp()'s gradient so that it's large when inputs are large.""" 50 | x, = primals 51 | x_dot, = tangents 52 | exp_x = safe_exp(x) 53 | exp_x_dot = exp_x * x_dot 54 | return exp_x, exp_x_dot 55 | 56 | 57 | def log_lerp(t, v0, v1): 58 | """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" 59 | if v0 <= 0 or v1 <= 0: 60 | raise ValueError(f'Interpolants {v0} and {v1} must be positive.') 61 | lv0 = jnp.log(v0) 62 | lv1 = jnp.log(v1) 63 | return jnp.exp(jnp.clip(t, 0, 1) * (lv1 - lv0) + lv0) 64 | 65 | 66 | def learning_rate_decay(step, 67 | lr_init, 68 | lr_final, 69 | max_steps, 70 | lr_delay_steps=0, 71 | lr_delay_mult=1): 72 | """Continuous learning rate decay function. 73 | 74 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 75 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 76 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 77 | function of lr_delay_mult, such that the initial learning rate is 78 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 79 | to the normal learning rate when steps>lr_delay_steps. 80 | 81 | Args: 82 | step: int, the current optimization step. 83 | lr_init: float, the initial learning rate. 84 | lr_final: float, the final learning rate. 85 | max_steps: int, the number of steps during optimization. 86 | lr_delay_steps: int, the number of steps to delay the full learning rate. 87 | lr_delay_mult: float, the multiplier on the rate when delaying it. 88 | 89 | Returns: 90 | lr: the learning for current step 'step'. 91 | """ 92 | if lr_delay_steps > 0: 93 | # A kind of reverse cosine decay. 94 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin( 95 | 0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1)) 96 | else: 97 | delay_rate = 1. 98 | return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) 99 | 100 | 101 | def interp(*args): 102 | """A gather-based (GPU-friendly) vectorized replacement for jnp.interp().""" 103 | args_flat = [x.reshape([-1, x.shape[-1]]) for x in args] 104 | ret = jax.vmap(jnp.interp)(*args_flat).reshape(args[0].shape) 105 | return ret 106 | 107 | 108 | def sorted_interp(x, xp, fp): 109 | """A TPU-friendly version of interp(), where xp and fp must be sorted.""" 110 | 111 | # Identify the location in `xp` that corresponds to each `x`. 112 | # The final `True` index in `mask` is the start of the matching interval. 113 | mask = x[..., None, :] >= xp[..., :, None] 114 | 115 | def find_interval(x): 116 | # Grab the value where `mask` switches from True to False, and vice versa. 117 | # This approach takes advantage of the fact that `x` is sorted. 118 | x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) 119 | x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) 120 | return x0, x1 121 | 122 | fp0, fp1 = find_interval(fp) 123 | xp0, xp1 = find_interval(xp) 124 | 125 | offset = jnp.clip(jnp.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) 126 | ret = fp0 + offset * (fp1 - fp0) 127 | return ret 128 | -------------------------------------------------------------------------------- /internal/geopoly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for constructing geodesic polyhedron, which are used as a basis.""" 16 | 17 | import itertools 18 | import numpy as np 19 | 20 | 21 | def compute_sq_dist(mat0, mat1=None): 22 | """Compute the squared Euclidean distance between all pairs of columns.""" 23 | if mat1 is None: 24 | mat1 = mat0 25 | # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y. 26 | sq_norm0 = np.sum(mat0**2, 0) 27 | sq_norm1 = np.sum(mat1**2, 0) 28 | sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1 29 | sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors. 30 | return sq_dist 31 | 32 | 33 | def compute_tesselation_weights(v): 34 | """Tesselate the vertices of a triangle by a factor of `v`.""" 35 | if v < 1: 36 | raise ValueError(f'v {v} must be >= 1') 37 | int_weights = [] 38 | for i in range(v + 1): 39 | for j in range(v + 1 - i): 40 | int_weights.append((i, j, v - (i + j))) 41 | int_weights = np.array(int_weights) 42 | weights = int_weights / v # Barycentric weights. 43 | return weights 44 | 45 | 46 | def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4): 47 | """Tesselate the vertices of a geodesic polyhedron. 48 | 49 | Args: 50 | base_verts: tensor of floats, the vertex coordinates of the geodesic. 51 | base_faces: tensor of ints, the indices of the vertices of base_verts that 52 | constitute eachface of the polyhedra. 53 | v: int, the factor of the tesselation (v==1 is a no-op). 54 | eps: float, a small value used to determine if two vertices are the same. 55 | 56 | Returns: 57 | verts: a tensor of floats, the coordinates of the tesselated vertices. 58 | """ 59 | if not isinstance(v, int): 60 | raise ValueError(f'v {v} must an integer') 61 | tri_weights = compute_tesselation_weights(v) 62 | 63 | verts = [] 64 | for base_face in base_faces: 65 | new_verts = np.matmul(tri_weights, base_verts[base_face, :]) 66 | new_verts /= np.sqrt(np.sum(new_verts**2, 1, keepdims=True)) 67 | verts.append(new_verts) 68 | verts = np.concatenate(verts, 0) 69 | 70 | sq_dist = compute_sq_dist(verts.T) 71 | assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist]) 72 | unique = np.unique(assignment) 73 | verts = verts[unique, :] 74 | 75 | return verts 76 | 77 | 78 | def generate_basis(base_shape, 79 | angular_tesselation, 80 | remove_symmetries=True, 81 | eps=1e-4): 82 | """Generates a 3D basis by tesselating a geometric polyhedron. 83 | 84 | Args: 85 | base_shape: string, the name of the starting polyhedron, must be either 86 | 'icosahedron' or 'octahedron'. 87 | angular_tesselation: int, the number of times to tesselate the polyhedron, 88 | must be >= 1 (a value of 1 is a no-op to the polyhedron). 89 | remove_symmetries: bool, if True then remove the symmetric basis columns, 90 | which is usually a good idea because otherwise projections onto the basis 91 | will have redundant negative copies of each other. 92 | eps: float, a small number used to determine symmetries. 93 | 94 | Returns: 95 | basis: a matrix with shape [3, n]. 96 | """ 97 | if base_shape == 'icosahedron': 98 | a = (np.sqrt(5) + 1) / 2 99 | verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1), 100 | (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0), 101 | (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2) 102 | faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1), 103 | (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3), 104 | (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6), 105 | (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5), 106 | (7, 2, 11)]) 107 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 108 | elif base_shape == 'octahedron': 109 | verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), 110 | (1, 0, 0)]) 111 | corners = np.array(list(itertools.product([-1, 1], repeat=3))) 112 | pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2) 113 | faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1) 114 | verts = tesselate_geodesic(verts, faces, angular_tesselation) 115 | else: 116 | raise ValueError(f'base_shape {base_shape} not supported') 117 | 118 | if remove_symmetries: 119 | # Remove elements of `verts` that are reflections of each other. 120 | match = compute_sq_dist(verts.T, -verts.T) < eps 121 | verts = verts[np.any(np.triu(match), 1), :] 122 | 123 | basis = verts[:, ::-1] 124 | return basis 125 | -------------------------------------------------------------------------------- /internal/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions.""" 16 | 17 | import enum 18 | import os 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import flax 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | from PIL import ExifTags 26 | from PIL import Image 27 | 28 | _Array = Union[np.ndarray, jnp.ndarray] 29 | 30 | 31 | @flax.struct.dataclass 32 | class Pixels: 33 | """All tensors must have the same num_dims and first n-1 dims must match.""" 34 | pix_x_int: _Array 35 | pix_y_int: _Array 36 | lossmult: _Array 37 | near: _Array 38 | far: _Array 39 | cam_idx: _Array 40 | exposure_idx: Optional[_Array] = None 41 | exposure_values: Optional[_Array] = None 42 | 43 | 44 | @flax.struct.dataclass 45 | class Rays: 46 | """All tensors must have the same num_dims and first n-1 dims must match.""" 47 | origins: _Array 48 | directions: _Array 49 | viewdirs: _Array 50 | radii: _Array 51 | imageplane: _Array 52 | lossmult: _Array 53 | near: _Array 54 | far: _Array 55 | cam_idx: _Array 56 | exposure_idx: Optional[_Array] = None 57 | exposure_values: Optional[_Array] = None 58 | 59 | 60 | # Dummy Rays object that can be used to initialize NeRF model. 61 | def dummy_rays(include_exposure_idx: bool = False, 62 | include_exposure_values: bool = False) -> Rays: 63 | data_fn = lambda n: jnp.zeros((1, n)) 64 | exposure_kwargs = {} 65 | if include_exposure_idx: 66 | exposure_kwargs['exposure_idx'] = data_fn(1).astype(jnp.int32) 67 | if include_exposure_values: 68 | exposure_kwargs['exposure_values'] = data_fn(1) 69 | return Rays( 70 | origins=data_fn(3), 71 | directions=data_fn(3), 72 | viewdirs=data_fn(3), 73 | radii=data_fn(1), 74 | imageplane=data_fn(2), 75 | lossmult=data_fn(1), 76 | near=data_fn(1), 77 | far=data_fn(1), 78 | cam_idx=data_fn(1).astype(jnp.int32), 79 | **exposure_kwargs) 80 | 81 | 82 | @flax.struct.dataclass 83 | class Batch: 84 | """Data batch for NeRF training or testing.""" 85 | rays: Union[Pixels, Rays] 86 | rgb: Optional[_Array] = None 87 | disps: Optional[_Array] = None 88 | normals: Optional[_Array] = None 89 | alphas: Optional[_Array] = None 90 | 91 | 92 | class DataSplit(enum.Enum): 93 | """Dataset split.""" 94 | TRAIN = 'train' 95 | TEST = 'test' 96 | 97 | 98 | class BatchingMethod(enum.Enum): 99 | """Draw rays randomly from a single image or all images, in each batch.""" 100 | ALL_IMAGES = 'all_images' 101 | SINGLE_IMAGE = 'single_image' 102 | 103 | 104 | def open_file(pth, mode='r'): 105 | return open(pth, mode=mode) 106 | 107 | 108 | def file_exists(pth): 109 | return os.path.exists(pth) 110 | 111 | 112 | def listdir(pth): 113 | return os.listdir(pth) 114 | 115 | 116 | def isdir(pth): 117 | return os.path.isdir(pth) 118 | 119 | 120 | def makedirs(pth): 121 | if not file_exists(pth): 122 | print(pth) 123 | os.makedirs(pth) 124 | 125 | 126 | def shard(xs): 127 | """Split data into shards for multiple devices along the first dimension.""" 128 | # shape = jax.tree_util.tree_map(lambda x: [(jax.local_device_count(), -1) + x.shape[1:], x.shape], xs) 129 | # print(shape) 130 | return jax.tree_util.tree_map( 131 | lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs) 132 | 133 | 134 | def unshard(x, padding=0): 135 | """Collect the sharded tensor to the shape before sharding.""" 136 | y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) 137 | if padding > 0: 138 | y = y[:-padding] 139 | return y 140 | 141 | 142 | def load_img(pth: str) -> np.ndarray: 143 | """Load an image and cast to float32.""" 144 | with open_file(pth, 'rb') as f: 145 | image = np.array(Image.open(f), dtype=np.float32) 146 | return image 147 | 148 | 149 | def load_exif(pth: str) -> Dict[str, Any]: 150 | """Load EXIF data for an image.""" 151 | with open_file(pth, 'rb') as f: 152 | image_pil = Image.open(f) 153 | exif_pil = image_pil._getexif() # pylint: disable=protected-access 154 | if exif_pil is not None: 155 | exif = { 156 | ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS 157 | } 158 | else: 159 | exif = {} 160 | return exif 161 | 162 | 163 | def save_img_u8(img, pth): 164 | """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" 165 | with open_file(pth, 'wb') as f: 166 | Image.fromarray( 167 | (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( 168 | f, 'PNG') 169 | 170 | 171 | def save_img_f32(depthmap, pth): 172 | """Save an image (probably a depthmap) to disk as a float32 TIFF.""" 173 | with open_file(pth, 'wb') as f: 174 | Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF') 175 | -------------------------------------------------------------------------------- /internal/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing images.""" 16 | 17 | import types 18 | from typing import Optional, Union 19 | 20 | import dm_pix 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | _Array = Union[np.ndarray, jnp.ndarray] 26 | 27 | 28 | def mse_to_psnr(mse): 29 | """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" 30 | return -10. / jnp.log(10.) * jnp.log(mse) 31 | 32 | 33 | def psnr_to_mse(psnr): 34 | """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" 35 | return jnp.exp(-0.1 * jnp.log(10.) * psnr) 36 | 37 | 38 | def ssim_to_dssim(ssim): 39 | """Compute DSSIM given an SSIM.""" 40 | return (1 - ssim) / 2 41 | 42 | 43 | def dssim_to_ssim(dssim): 44 | """Compute DSSIM given an SSIM.""" 45 | return 1 - 2 * dssim 46 | 47 | 48 | def linear_to_srgb(linear: _Array, 49 | eps: Optional[float] = None, 50 | xnp: types.ModuleType = jnp) -> _Array: 51 | """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 52 | if eps is None: 53 | eps = xnp.finfo(xnp.float32).eps 54 | srgb0 = 323 / 25 * linear 55 | srgb1 = (211 * xnp.maximum(eps, linear)**(5 / 12) - 11) / 200 56 | return xnp.where(linear <= 0.0031308, srgb0, srgb1) 57 | 58 | 59 | def srgb_to_linear(srgb: _Array, 60 | eps: Optional[float] = None, 61 | xnp: types.ModuleType = jnp) -> _Array: 62 | """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" 63 | if eps is None: 64 | eps = xnp.finfo(xnp.float32).eps 65 | linear0 = 25 / 323 * srgb 66 | linear1 = xnp.maximum(eps, ((200 * srgb + 11) / (211)))**(12 / 5) 67 | return xnp.where(srgb <= 0.04045, linear0, linear1) 68 | 69 | 70 | def downsample(img, factor): 71 | """Area downsample img (factor must evenly divide img height and width).""" 72 | sh = img.shape 73 | if not (sh[0] % factor == 0 and sh[1] % factor == 0): 74 | raise ValueError(f'Downsampling factor {factor} does not ' 75 | f'evenly divide image shape {sh[:2]}') 76 | img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:]) 77 | img = img.mean((1, 3)) 78 | return img 79 | 80 | 81 | def color_correct(img, ref, num_iters=5, eps=0.5 / 255): 82 | """Warp `img` to match the colors in `ref_img`.""" 83 | if img.shape[-1] != ref.shape[-1]: 84 | raise ValueError( 85 | f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' 86 | ) 87 | num_channels = img.shape[-1] 88 | img_mat = img.reshape([-1, num_channels]) 89 | ref_mat = ref.reshape([-1, num_channels]) 90 | is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps]. 91 | mask0 = is_unclipped(img_mat) 92 | # Because the set of saturated pixels may change after solving for a 93 | # transformation, we repeatedly solve a system `num_iters` times and update 94 | # our estimate of which pixels are saturated. 95 | for _ in range(num_iters): 96 | # Construct the left hand side of a linear system that contains a quadratic 97 | # expansion of each pixel of `img`. 98 | a_mat = [] 99 | for c in range(num_channels): 100 | a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term. 101 | a_mat.append(img_mat) # Linear term. 102 | a_mat.append(jnp.ones_like(img_mat[:, :1])) # Bias term. 103 | a_mat = jnp.concatenate(a_mat, axis=-1) 104 | warp = [] 105 | for c in range(num_channels): 106 | # Construct the right hand side of a linear system containing each color 107 | # of `ref`. 108 | b = ref_mat[:, c] 109 | # Ignore rows of the linear system that were saturated in the input or are 110 | # saturated in the current corrected color estimate. 111 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) 112 | ma_mat = jnp.where(mask[:, None], a_mat, 0) 113 | mb = jnp.where(mask, b, 0) 114 | # Solve the linear system. We're using the np.lstsq instead of jnp because 115 | # it's significantly more stable in this case, for some reason. 116 | w = np.linalg.lstsq(ma_mat, mb, rcond=-1)[0] 117 | assert jnp.all(jnp.isfinite(w)) 118 | warp.append(w) 119 | warp = jnp.stack(warp, axis=-1) 120 | # Apply the warp to update img_mat. 121 | img_mat = jnp.clip( 122 | jnp.matmul(a_mat, warp, precision=jax.lax.Precision.HIGHEST), 0, 1) 123 | corrected_img = jnp.reshape(img_mat, img.shape) 124 | return corrected_img 125 | 126 | 127 | class MetricHarness: 128 | """A helper class for evaluating several error metrics.""" 129 | 130 | def __init__(self): 131 | self.ssim_fn = jax.jit(dm_pix.ssim) 132 | 133 | def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): 134 | """Evaluate the error between a predicted rgb image and the true image.""" 135 | psnr = float(mse_to_psnr(((rgb_pred - rgb_gt)**2).mean())) 136 | ssim = float(self.ssim_fn(rgb_pred, rgb_gt)) 137 | 138 | return { 139 | name_fn('psnr'): psnr, 140 | name_fn('ssim'): ssim, 141 | } 142 | -------------------------------------------------------------------------------- /tests/image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for image.""" 16 | 17 | from absl.testing import absltest 18 | from internal import image 19 | import jax 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | def matmul(a, b): 26 | """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" 27 | return jnp.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 28 | 29 | 30 | class ImageTest(absltest.TestCase): 31 | 32 | def test_color_correction(self): 33 | """Test that color correction can undo a CCM + quadratic warp + shift.""" 34 | im_shape = (128, 128, 3) 35 | rng = random.PRNGKey(0) 36 | for _ in range(10): 37 | # Construct a random image. 38 | key, rng = random.split(rng) 39 | im0 = random.uniform(key, shape=im_shape, minval=0.1, maxval=0.9) 40 | 41 | # Construct a random linear + quadratic color transformation. 42 | key, rng = random.split(rng) 43 | ccm_scale = random.normal(key) / 10 44 | key, rng = random.split(rng) 45 | shift = random.normal(key) / 10 46 | key, rng = random.split(rng) 47 | sq_mult = random.normal(key) / 10 48 | key, rng = random.split(rng) 49 | ccm = jnp.eye(3) + random.normal(key, shape=(3, 3)) * ccm_scale 50 | 51 | # Apply that random transformation to the image. 52 | im1 = jnp.clip( 53 | (matmul(jnp.reshape(im0, [-1, 3]), ccm)).reshape(im0.shape) + 54 | sq_mult * im0**2 + shift, 0, 1) 55 | 56 | # Check that color correction recovers the randomly transformed image. 57 | im0_cc = image.color_correct(im0, im1) 58 | np.testing.assert_allclose(im0_cc, im1, atol=1E-5, rtol=1E-5) 59 | 60 | def test_psnr_mse_round_trip(self): 61 | """PSNR -> MSE -> PSNR is a no-op.""" 62 | for psnr in [10., 20., 30.]: 63 | np.testing.assert_allclose( 64 | image.mse_to_psnr(image.psnr_to_mse(psnr)), 65 | psnr, 66 | atol=1E-5, 67 | rtol=1E-5) 68 | 69 | def test_ssim_dssim_round_trip(self): 70 | """SSIM -> DSSIM -> SSIM is a no-op.""" 71 | for ssim in [-0.9, 0, 0.9]: 72 | np.testing.assert_allclose( 73 | image.dssim_to_ssim(image.ssim_to_dssim(ssim)), 74 | ssim, 75 | atol=1E-5, 76 | rtol=1E-5) 77 | 78 | def test_srgb_linearize(self): 79 | x = jnp.linspace(-1, 3, 10000) # Nobody should call this <0 but it works. 80 | # Check that the round-trip transformation is a no-op. 81 | np.testing.assert_allclose( 82 | image.linear_to_srgb(image.srgb_to_linear(x)), x, atol=1E-5, rtol=1E-5) 83 | np.testing.assert_allclose( 84 | image.srgb_to_linear(image.linear_to_srgb(x)), x, atol=1E-5, rtol=1E-5) 85 | # Check that gradients are finite. 86 | self.assertTrue( 87 | jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.linear_to_srgb))(x)))) 88 | self.assertTrue( 89 | jnp.all(jnp.isfinite(jax.vmap(jax.grad(image.srgb_to_linear))(x)))) 90 | 91 | def test_srgb_to_linear_golden(self): 92 | """A lazy golden test for srgb_to_linear.""" 93 | srgb = jnp.linspace(0, 1, 64) 94 | linear = image.srgb_to_linear(srgb) 95 | linear_gt = jnp.array([ 96 | 0.00000000, 0.00122856, 0.00245712, 0.00372513, 0.00526076, 0.00711347, 97 | 0.00929964, 0.01183453, 0.01473243, 0.01800687, 0.02167065, 0.02573599, 98 | 0.03021459, 0.03511761, 0.04045585, 0.04623971, 0.05247922, 0.05918410, 99 | 0.06636375, 0.07402734, 0.08218378, 0.09084171, 0.10000957, 0.10969563, 100 | 0.11990791, 0.13065430, 0.14194246, 0.15377994, 0.16617411, 0.17913227, 101 | 0.19266140, 0.20676863, 0.22146071, 0.23674440, 0.25262633, 0.26911288, 102 | 0.28621066, 0.30392596, 0.32226467, 0.34123330, 0.36083785, 0.38108405, 103 | 0.40197787, 0.42352500, 0.44573134, 0.46860245, 0.49214387, 0.51636110, 104 | 0.54125960, 0.56684470, 0.59312177, 0.62009590, 0.64777250, 0.67615650, 105 | 0.70525320, 0.73506740, 0.76560410, 0.79686830, 0.82886493, 0.86159873, 106 | 0.89507430, 0.92929670, 0.96427040, 1.00000000 107 | ]) 108 | np.testing.assert_allclose(linear, linear_gt, atol=1E-5, rtol=1E-5) 109 | 110 | def test_mse_to_psnr_golden(self): 111 | """A lazy golden test for mse_to_psnr.""" 112 | mse = jnp.exp(jnp.linspace(-10, 0, 64)) 113 | psnr = image.mse_to_psnr(mse) 114 | psnr_gt = jnp.array([ 115 | 43.429447, 42.740090, 42.050735, 41.361378, 40.6720240, 39.982666, 116 | 39.293310, 38.603954, 37.914597, 37.225240, 36.5358850, 35.846527, 117 | 35.157170, 34.467810, 33.778458, 33.089100, 32.3997460, 31.710388, 118 | 31.021034, 30.331675, 29.642320, 28.952961, 28.2636070, 27.574250, 119 | 26.884893, 26.195538, 25.506180, 24.816826, 24.1274700, 23.438112, 120 | 22.748756, 22.059400, 21.370045, 20.680689, 19.9913310, 19.301975, 121 | 18.612620, 17.923262, 17.233906, 16.544550, 15.8551940, 15.165837, 122 | 14.4764805, 13.787125, 13.097769, 12.408413, 11.719056, 11.029700, 123 | 10.3403420, 9.6509850, 8.9616290, 8.2722720, 7.5829163, 6.8935600, 124 | 6.2042036, 5.5148473, 4.825491, 4.136135, 3.4467785, 2.7574227, 125 | 2.0680661, 1.37871, 0.68935364, 0. 126 | ]) 127 | np.testing.assert_allclose(psnr, psnr_gt, atol=1E-5, rtol=1E-5) 128 | 129 | 130 | if __name__ == '__main__': 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /internal/coord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for manipulating coordinate spaces and distances along rays.""" 15 | 16 | from internal import math 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | def contract(x): 22 | """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).""" 23 | eps = jnp.finfo(jnp.float32).eps 24 | # Clamping to eps prevents non-finite gradients when x == 0. 25 | x_mag_sq = jnp.maximum(eps, jnp.sum(x**2, axis=-1, keepdims=True)) 26 | z = jnp.where(x_mag_sq <= 1, x, ((2 * jnp.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) 27 | return z 28 | 29 | 30 | def inv_contract(z): 31 | """The inverse of contract().""" 32 | eps = jnp.finfo(jnp.float32).eps 33 | # Clamping to eps prevents non-finite gradients when z == 0. 34 | z_mag_sq = jnp.maximum(eps, jnp.sum(z**2, axis=-1, keepdims=True)) 35 | x = jnp.where(z_mag_sq <= 1, z, z / (2 * jnp.sqrt(z_mag_sq) - z_mag_sq)) 36 | return x 37 | 38 | 39 | def track_linearize(fn, mean, cov): 40 | """Apply function `fn` to a set of means and covariances, ala a Kalman filter. 41 | 42 | We can analytically transform a Gaussian parameterized by `mean` and `cov` 43 | with a function `fn` by linearizing `fn` around `mean`, and taking advantage 44 | of the fact that Covar[Ax + y] = A(Covar[x])A^T (see 45 | https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details). 46 | 47 | Args: 48 | fn: the function applied to the Gaussians parameterized by (mean, cov). 49 | mean: a tensor of means, where the last axis is the dimension. 50 | cov: a tensor of covariances, where the last two axes are the dimensions. 51 | 52 | Returns: 53 | fn_mean: the transformed means. 54 | fn_cov: the transformed covariances. 55 | """ 56 | if (len(mean.shape) + 1) != len(cov.shape): 57 | raise ValueError('cov must be non-diagonal') 58 | # print('mean ---> ', mean.shape) 59 | # mean - --> (1024, 1, 1, 64, 3) 60 | fn_mean, lin_fn = jax.linearize(fn, mean) 61 | # print('cov ---> ', cov.shape) 62 | # cov - --> (1024, 1, 1, 64, 3, 3) 63 | fn_cov = jax.vmap(lin_fn, -1, -2)(jax.vmap(lin_fn, -1, -2)(cov)) 64 | return fn_mean, fn_cov 65 | 66 | 67 | def construct_ray_warps(fn, t_near, t_far): 68 | """Construct a bijection between metric distances and normalized distances. 69 | 70 | See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a 71 | detailed explanation. 72 | 73 | Args: 74 | fn: the function to ray distances. 75 | t_near: a tensor of near-plane distances. 76 | t_far: a tensor of far-plane distances. 77 | 78 | Returns: 79 | t_to_s: a function that maps distances to normalized distances in [0, 1]. 80 | s_to_t: the inverse of t_to_s. 81 | """ 82 | if fn is None: 83 | fn_fwd = lambda x: x 84 | fn_inv = lambda x: x 85 | elif fn == 'piecewise': 86 | # Piecewise spacing combining identity and 1/x functions to allow t_near=0. 87 | fn_fwd = lambda x: jnp.where(x < 1, .5 * x, 1 - .5 / x) 88 | fn_inv = lambda x: jnp.where(x < .5, 2 * x, .5 / (1 - x)) 89 | else: 90 | inv_mapping = { 91 | 'reciprocal': jnp.reciprocal, 92 | 'log': jnp.exp, 93 | 'exp': jnp.log, 94 | 'sqrt': jnp.square, 95 | 'square': jnp.sqrt 96 | } 97 | fn_fwd = fn 98 | fn_inv = inv_mapping[fn.__name__] 99 | 100 | s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] 101 | t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) 102 | s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near) 103 | return t_to_s, s_to_t 104 | 105 | 106 | def expected_sin(mean, var): 107 | """Compute the mean of sin(x), x ~ N(mean, var).""" 108 | return jnp.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value. 109 | 110 | 111 | def integrated_pos_enc(mean, var, min_deg, max_deg): 112 | """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg). 113 | 114 | Args: 115 | mean: tensor, the mean coordinates to be encoded 116 | var: tensor, the variance of the coordinates to be encoded. 117 | min_deg: int, the min degree of the encoding. 118 | max_deg: int, the max degree of the encoding. 119 | 120 | Returns: 121 | encoded: jnp.ndarray, encoded variables. 122 | """ 123 | scales = 2**jnp.arange(min_deg, max_deg) 124 | shape = mean.shape[:-1] + (-1,) 125 | scaled_mean = jnp.reshape(mean[..., None, :] * scales[:, None], shape) 126 | scaled_var = jnp.reshape(var[..., None, :] * scales[:, None]**2, shape) 127 | 128 | return expected_sin( 129 | jnp.concatenate([scaled_mean, scaled_mean + 0.5 * jnp.pi], axis=-1), 130 | jnp.concatenate([scaled_var] * 2, axis=-1)) 131 | 132 | 133 | def lift_and_diagonalize(mean, cov, basis): 134 | """Project `mean` and `cov` onto basis and diagonalize the projected cov.""" 135 | fn_mean = math.matmul(mean, basis) 136 | fn_cov_diag = jnp.sum(basis * math.matmul(cov, basis), axis=-2) 137 | return fn_mean, fn_cov_diag 138 | 139 | 140 | def pos_enc(x, min_deg, max_deg, append_identity=True): 141 | """The positional encoding used by the original NeRF paper.""" 142 | scales = 2**jnp.arange(min_deg, max_deg) 143 | shape = x.shape[:-1] + (-1,) 144 | scaled_x = jnp.reshape((x[..., None, :] * scales[:, None]), shape) 145 | # Note that we're not using safe_sin, unlike IPE. 146 | four_feat = jnp.sin( 147 | jnp.concatenate([scaled_x, scaled_x + 0.5 * jnp.pi], axis=-1)) 148 | if append_identity: 149 | return jnp.concatenate([x] + [four_feat], axis=-1) 150 | else: 151 | return four_feat 152 | -------------------------------------------------------------------------------- /internal/ref_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for reflection directions and directional encodings.""" 16 | 17 | from internal import math 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | 22 | def reflect(viewdirs, normals): 23 | """Reflect view directions about normals. 24 | 25 | The reflection of a vector v about a unit vector n is a vector u such that 26 | dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two 27 | equations is u = 2 dot(n, v) n - v. 28 | 29 | Args: 30 | viewdirs: [..., 3] array of view directions. 31 | normals: [..., 3] array of normal directions (assumed to be unit vectors). 32 | 33 | Returns: 34 | [..., 3] array of reflection directions. 35 | """ 36 | return 2.0 * jnp.sum( 37 | normals * viewdirs, axis=-1, keepdims=True) * normals - viewdirs 38 | 39 | 40 | def l2_normalize(x, eps=jnp.finfo(jnp.float32).eps): 41 | """Normalize x to unit length along last axis.""" 42 | return x / jnp.sqrt(jnp.maximum(jnp.sum(x**2, axis=-1, keepdims=True), eps)) 43 | 44 | 45 | def compute_weighted_mae(weights, normals, normals_gt): 46 | """Compute weighted mean angular error, assuming normals are unit length.""" 47 | one_eps = 1 - jnp.finfo(jnp.float32).eps 48 | return (weights * jnp.arccos( 49 | jnp.clip((normals * normals_gt).sum(-1), -one_eps, 50 | one_eps))).sum() / weights.sum() * 180.0 / jnp.pi 51 | 52 | 53 | def generalized_binomial_coeff(a, k): 54 | """Compute generalized binomial coefficients.""" 55 | return np.prod(a - np.arange(k)) / np.math.factorial(k) 56 | 57 | 58 | def assoc_legendre_coeff(l, m, k): 59 | """Compute associated Legendre polynomial coefficients. 60 | 61 | Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the 62 | (l, m)th associated Legendre polynomial, P_l^m(cos(theta)). 63 | 64 | Args: 65 | l: associated Legendre polynomial degree. 66 | m: associated Legendre polynomial order. 67 | k: power of cos(theta). 68 | 69 | Returns: 70 | A float, the coefficient of the term corresponding to the inputs. 71 | """ 72 | return ((-1)**m * 2**l * np.math.factorial(l) / np.math.factorial(k) / 73 | np.math.factorial(l - k - m) * 74 | generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l)) 75 | 76 | 77 | def sph_harm_coeff(l, m, k): 78 | """Compute spherical harmonic coefficients.""" 79 | return (np.sqrt( 80 | (2.0 * l + 1.0) * np.math.factorial(l - m) / 81 | (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k)) 82 | 83 | 84 | def get_ml_array(deg_view): 85 | """Create a list with all pairs of (l, m) values to use in the encoding.""" 86 | ml_list = [] 87 | for i in range(deg_view): 88 | l = 2**i 89 | # Only use nonnegative m values, later splitting real and imaginary parts. 90 | for m in range(l + 1): 91 | ml_list.append((m, l)) 92 | 93 | # Convert list into a numpy array. 94 | ml_array = np.array(ml_list).T 95 | return ml_array 96 | 97 | 98 | def generate_ide_fn(deg_view): 99 | """Generate integrated directional encoding (IDE) function. 100 | 101 | This function returns a function that computes the integrated directional 102 | encoding from Equations 6-8 of arxiv.org/abs/2112.03907. 103 | 104 | Args: 105 | deg_view: number of spherical harmonics degrees to use. 106 | 107 | Returns: 108 | A function for evaluating integrated directional encoding. 109 | 110 | Raises: 111 | ValueError: if deg_view is larger than 5. 112 | """ 113 | if deg_view > 5: 114 | raise ValueError('Only deg_view of at most 5 is numerically stable.') 115 | 116 | ml_array = get_ml_array(deg_view) 117 | l_max = 2**(deg_view - 1) 118 | 119 | # Create a matrix corresponding to ml_array holding all coefficients, which, 120 | # when multiplied (from the right) by the z coordinate Vandermonde matrix, 121 | # results in the z component of the encoding. 122 | mat = np.zeros((l_max + 1, ml_array.shape[1])) 123 | for i, (m, l) in enumerate(ml_array.T): 124 | for k in range(l - m + 1): 125 | mat[k, i] = sph_harm_coeff(l, m, k) 126 | 127 | def integrated_dir_enc_fn(xyz, kappa_inv): 128 | """Function returning integrated directional encoding (IDE). 129 | 130 | Args: 131 | xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at. 132 | kappa_inv: [..., 1] reciprocal of the concentration parameter of the von 133 | Mises-Fisher distribution. 134 | 135 | Returns: 136 | An array with the resulting IDE. 137 | """ 138 | x = xyz[..., 0:1] 139 | y = xyz[..., 1:2] 140 | z = xyz[..., 2:3] 141 | 142 | # Compute z Vandermonde matrix. 143 | vmz = jnp.concatenate([z**i for i in range(mat.shape[0])], axis=-1) 144 | 145 | # Compute x+iy Vandermonde matrix. 146 | vmxy = jnp.concatenate([(x + 1j * y)**m for m in ml_array[0, :]], axis=-1) 147 | 148 | # Get spherical harmonics. 149 | sph_harms = vmxy * math.matmul(vmz, mat) 150 | 151 | # Apply attenuation function using the von Mises-Fisher distribution 152 | # concentration parameter, kappa. 153 | sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1) 154 | ide = sph_harms * jnp.exp(-sigma * kappa_inv) 155 | 156 | # Split into real and imaginary parts and return 157 | return jnp.concatenate([jnp.real(ide), jnp.imag(ide)], axis=-1) 158 | 159 | return integrated_dir_enc_fn 160 | 161 | 162 | def generate_dir_enc_fn(deg_view): 163 | """Generate directional encoding (DE) function. 164 | 165 | Args: 166 | deg_view: number of spherical harmonics degrees to use. 167 | 168 | Returns: 169 | A function for evaluating directional encoding. 170 | """ 171 | integrated_dir_enc_fn = generate_ide_fn(deg_view) 172 | 173 | def dir_enc_fn(xyz): 174 | """Function returning directional encoding (DE).""" 175 | return integrated_dir_enc_fn(xyz, jnp.zeros_like(xyz[..., :1])) 176 | 177 | return dir_enc_fn 178 | -------------------------------------------------------------------------------- /tests/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for math.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from internal import math 22 | import jax 23 | from jax import random 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | def safe_trig_harness(fn, max_exp): 29 | x = 10**np.linspace(-30, max_exp, 10000) 30 | x = np.concatenate([-x[::-1], np.array([0]), x]) 31 | y_true = getattr(np, fn)(x) 32 | y = getattr(math, 'safe_' + fn)(x) 33 | return y_true, y 34 | 35 | 36 | class MathTest(parameterized.TestCase): 37 | 38 | def test_sin(self): 39 | """In [-1e10, 1e10] safe_sin and safe_cos are accurate.""" 40 | for fn in ['sin', 'cos']: 41 | y_true, y = safe_trig_harness(fn, 10) 42 | self.assertLess(jnp.max(jnp.abs(y - y_true)), 1e-4) 43 | self.assertFalse(jnp.any(jnp.isnan(y))) 44 | # Beyond that range it's less accurate but we just don't want it to be NaN. 45 | for fn in ['sin', 'cos']: 46 | y_true, y = safe_trig_harness(fn, 60) 47 | self.assertFalse(jnp.any(jnp.isnan(y))) 48 | 49 | def test_safe_exp_correct(self): 50 | """math.safe_exp() should match np.exp() for not-huge values.""" 51 | x = jnp.linspace(-80, 80, 10001) 52 | y = math.safe_exp(x) 53 | g = jax.vmap(jax.grad(math.safe_exp))(x) 54 | yg_true = jnp.exp(x) 55 | np.testing.assert_allclose(y, yg_true) 56 | np.testing.assert_allclose(g, yg_true) 57 | 58 | def test_safe_exp_finite(self): 59 | """math.safe_exp() behaves reasonably for huge values.""" 60 | x = jnp.linspace(-100000, 100000, 10001) 61 | y = math.safe_exp(x) 62 | g = jax.vmap(jax.grad(math.safe_exp))(x) 63 | # `y` and `g` should both always be finite. 64 | self.assertTrue(jnp.all(jnp.isfinite(y))) 65 | self.assertTrue(jnp.all(jnp.isfinite(g))) 66 | # The derivative of exp() should be exp(). 67 | np.testing.assert_allclose(y, g) 68 | # safe_exp()'s output and gradient should be monotonic. 69 | self.assertTrue(jnp.all(y[1:] >= y[:-1])) 70 | self.assertTrue(jnp.all(g[1:] >= g[:-1])) 71 | 72 | def test_learning_rate_decay(self): 73 | rng = random.PRNGKey(0) 74 | for _ in range(10): 75 | key, rng = random.split(rng) 76 | lr_init = jnp.exp(random.normal(key) - 3) 77 | key, rng = random.split(rng) 78 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 79 | key, rng = random.split(rng) 80 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 81 | 82 | lr_fn = functools.partial( 83 | math.learning_rate_decay, 84 | lr_init=lr_init, 85 | lr_final=lr_final, 86 | max_steps=max_steps) 87 | 88 | # Test that the rate at the beginning is the initial rate. 89 | np.testing.assert_allclose(lr_fn(0), lr_init, atol=1E-5, rtol=1E-5) 90 | 91 | # Test that the rate at the end is the final rate. 92 | np.testing.assert_allclose( 93 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 94 | 95 | # Test that the rate at the middle is the geometric mean of the two rates. 96 | np.testing.assert_allclose( 97 | lr_fn(max_steps / 2), 98 | jnp.sqrt(lr_init * lr_final), 99 | atol=1E-5, 100 | rtol=1E-5) 101 | 102 | # Test that the rate past the end is the final rate 103 | np.testing.assert_allclose( 104 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 105 | 106 | def test_delayed_learning_rate_decay(self): 107 | rng = random.PRNGKey(0) 108 | for _ in range(10): 109 | key, rng = random.split(rng) 110 | lr_init = jnp.exp(random.normal(key) - 3) 111 | key, rng = random.split(rng) 112 | lr_final = lr_init * jnp.exp(random.normal(key) - 5) 113 | key, rng = random.split(rng) 114 | max_steps = int(jnp.ceil(100 + 100 * jnp.exp(random.normal(key)))) 115 | key, rng = random.split(rng) 116 | lr_delay_steps = int( 117 | random.uniform(key, minval=0.1, maxval=0.4) * max_steps) 118 | key, rng = random.split(rng) 119 | lr_delay_mult = jnp.exp(random.normal(key) - 3) 120 | 121 | lr_fn = functools.partial( 122 | math.learning_rate_decay, 123 | lr_init=lr_init, 124 | lr_final=lr_final, 125 | max_steps=max_steps, 126 | lr_delay_steps=lr_delay_steps, 127 | lr_delay_mult=lr_delay_mult) 128 | 129 | # Test that the rate at the beginning is the delayed initial rate. 130 | np.testing.assert_allclose( 131 | lr_fn(0), lr_delay_mult * lr_init, atol=1E-5, rtol=1E-5) 132 | 133 | # Test that the rate at the end is the final rate. 134 | np.testing.assert_allclose( 135 | lr_fn(max_steps), lr_final, atol=1E-5, rtol=1E-5) 136 | 137 | # Test that the rate at after the delay is over is the usual rate. 138 | np.testing.assert_allclose( 139 | lr_fn(lr_delay_steps), 140 | math.learning_rate_decay(lr_delay_steps, lr_init, lr_final, 141 | max_steps), 142 | atol=1E-5, 143 | rtol=1E-5) 144 | 145 | # Test that the rate at the middle is the geometric mean of the two rates. 146 | np.testing.assert_allclose( 147 | lr_fn(max_steps / 2), 148 | jnp.sqrt(lr_init * lr_final), 149 | atol=1E-5, 150 | rtol=1E-5) 151 | 152 | # Test that the rate past the end is the final rate 153 | np.testing.assert_allclose( 154 | lr_fn(max_steps + 100), lr_final, atol=1E-5, rtol=1E-5) 155 | 156 | @parameterized.named_parameters(('', False), ('sort', True)) 157 | def test_interp(self, sort): 158 | n, d0, d1 = 100, 10, 20 159 | rng = random.PRNGKey(0) 160 | 161 | key, rng = random.split(rng) 162 | x = random.normal(key, [n, d0]) 163 | 164 | key, rng = random.split(rng) 165 | xp = random.normal(key, [n, d1]) 166 | 167 | key, rng = random.split(rng) 168 | fp = random.normal(key, [n, d1]) 169 | 170 | if sort: 171 | xp = jnp.sort(xp, axis=-1) 172 | fp = jnp.sort(fp, axis=-1) 173 | z = math.sorted_interp(x, xp, fp) 174 | else: 175 | z = math.interp(x, xp, fp) 176 | 177 | z_true = jnp.stack([jnp.interp(x[i], xp[i], fp[i]) for i in range(n)]) 178 | np.testing.assert_allclose(z, z_true, atol=1e-5, rtol=1e-5) 179 | 180 | 181 | if __name__ == '__main__': 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /tests/geopoly_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for geopoly.""" 16 | import itertools 17 | 18 | from absl.testing import absltest 19 | from internal import geopoly 20 | import jax 21 | from jax import random 22 | import numpy as np 23 | 24 | 25 | def is_same_basis(x, y, tol=1e-10): 26 | """Check if `x` and `y` describe the same linear basis.""" 27 | match = np.minimum( 28 | geopoly.compute_sq_dist(x, y), geopoly.compute_sq_dist(x, -y)) <= tol 29 | return (np.all(np.array(x.shape) == np.array(y.shape)) and 30 | np.all(np.sum(match, axis=0) == 1) and 31 | np.all(np.sum(match, axis=1) == 1)) 32 | 33 | 34 | class GeopolyTest(absltest.TestCase): 35 | 36 | def test_compute_sq_dist_reference(self): 37 | """Test against a simple reimplementation of compute_sq_dist.""" 38 | num_points = 100 39 | num_dims = 10 40 | rng = random.PRNGKey(0) 41 | key, rng = random.split(rng) 42 | mat0 = jax.random.normal(key, [num_dims, num_points]) 43 | key, rng = random.split(rng) 44 | mat1 = jax.random.normal(key, [num_dims, num_points]) 45 | 46 | sq_dist = geopoly.compute_sq_dist(mat0, mat1) 47 | 48 | sq_dist_ref = np.zeros([num_points, num_points]) 49 | for i in range(num_points): 50 | for j in range(num_points): 51 | sq_dist_ref[i, j] = np.sum((mat0[:, i] - mat1[:, j])**2) 52 | 53 | np.testing.assert_allclose(sq_dist, sq_dist_ref, atol=1e-5, rtol=1e-5) 54 | 55 | def test_compute_sq_dist_single_input(self): 56 | """Test that compute_sq_dist with a single input works correctly.""" 57 | rng = random.PRNGKey(0) 58 | num_points = 100 59 | num_dims = 10 60 | key, rng = random.split(rng) 61 | mat0 = jax.random.normal(key, [num_dims, num_points]) 62 | 63 | sq_dist = geopoly.compute_sq_dist(mat0) 64 | sq_dist_ref = geopoly.compute_sq_dist(mat0, mat0) 65 | np.testing.assert_allclose(sq_dist, sq_dist_ref) 66 | 67 | def test_compute_tesselation_weights_reference(self): 68 | """A reference implementation for triangle tesselation.""" 69 | for v in range(1, 10): 70 | w = geopoly.compute_tesselation_weights(v) 71 | perm = np.array(list(itertools.product(range(v + 1), repeat=3))) 72 | w_ref = perm[np.sum(perm, axis=-1) == v, :] / v 73 | # Check that all rows of x are close to some row in x_ref. 74 | self.assertTrue(is_same_basis(w.T, w_ref.T)) 75 | 76 | def test_generate_basis_golden(self): 77 | """A mediocre golden test against two arbitrary basis choices.""" 78 | basis = geopoly.generate_basis('icosahedron', 2) 79 | basis_golden = np.array([[0.85065081, 0.00000000, 0.52573111], 80 | [0.80901699, 0.50000000, 0.30901699], 81 | [0.52573111, 0.85065081, 0.00000000], 82 | [1.00000000, 0.00000000, 0.00000000], 83 | [0.80901699, 0.50000000, -0.30901699], 84 | [0.85065081, 0.00000000, -0.52573111], 85 | [0.30901699, 0.80901699, -0.50000000], 86 | [0.00000000, 0.52573111, -0.85065081], 87 | [0.50000000, 0.30901699, -0.80901699], 88 | [0.00000000, 1.00000000, 0.00000000], 89 | [-0.52573111, 0.85065081, 0.00000000], 90 | [-0.30901699, 0.80901699, -0.50000000], 91 | [0.00000000, 0.52573111, 0.85065081], 92 | [-0.30901699, 0.80901699, 0.50000000], 93 | [0.30901699, 0.80901699, 0.50000000], 94 | [0.50000000, 0.30901699, 0.80901699], 95 | [0.50000000, -0.30901699, 0.80901699], 96 | [0.00000000, 0.00000000, 1.00000000], 97 | [-0.50000000, 0.30901699, 0.80901699], 98 | [-0.80901699, 0.50000000, 0.30901699], 99 | [-0.80901699, 0.50000000, -0.30901699]]) 100 | self.assertTrue(is_same_basis(basis.T, basis_golden.T)) 101 | 102 | basis = geopoly.generate_basis('octahedron', 4) 103 | basis_golden = np.array([[0.00000000, 0.00000000, -1.00000000], 104 | [0.00000000, -0.31622777, -0.94868330], 105 | [0.00000000, -0.70710678, -0.70710678], 106 | [0.00000000, -0.94868330, -0.31622777], 107 | [0.00000000, -1.00000000, 0.00000000], 108 | [-0.31622777, 0.00000000, -0.94868330], 109 | [-0.40824829, -0.40824829, -0.81649658], 110 | [-0.40824829, -0.81649658, -0.40824829], 111 | [-0.31622777, -0.94868330, 0.00000000], 112 | [-0.70710678, 0.00000000, -0.70710678], 113 | [-0.81649658, -0.40824829, -0.40824829], 114 | [-0.70710678, -0.70710678, 0.00000000], 115 | [-0.94868330, 0.00000000, -0.31622777], 116 | [-0.94868330, -0.31622777, 0.00000000], 117 | [-1.00000000, 0.00000000, 0.00000000], 118 | [0.00000000, -0.31622777, 0.94868330], 119 | [0.00000000, -0.70710678, 0.70710678], 120 | [0.00000000, -0.94868330, 0.31622777], 121 | [0.40824829, -0.40824829, 0.81649658], 122 | [0.40824829, -0.81649658, 0.40824829], 123 | [0.31622777, -0.94868330, 0.00000000], 124 | [0.81649658, -0.40824829, 0.40824829], 125 | [0.70710678, -0.70710678, 0.00000000], 126 | [0.94868330, -0.31622777, 0.00000000], 127 | [0.31622777, 0.00000000, -0.94868330], 128 | [0.40824829, 0.40824829, -0.81649658], 129 | [0.40824829, 0.81649658, -0.40824829], 130 | [0.70710678, 0.00000000, -0.70710678], 131 | [0.81649658, 0.40824829, -0.40824829], 132 | [0.94868330, 0.00000000, -0.31622777], 133 | [0.40824829, -0.40824829, -0.81649658], 134 | [0.40824829, -0.81649658, -0.40824829], 135 | [0.81649658, -0.40824829, -0.40824829]]) 136 | self.assertTrue(is_same_basis(basis.T, basis_golden.T)) 137 | 138 | 139 | if __name__ == '__main__': 140 | absltest.main() 141 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Render script.""" 16 | 17 | import concurrent.futures 18 | import functools 19 | import glob 20 | import os 21 | import time 22 | 23 | from absl import app 24 | from flax.training import checkpoints 25 | import gin 26 | from internal import configs 27 | from internal import datasets 28 | from internal import models 29 | from internal import train_utils 30 | from internal import utils 31 | import jax 32 | from jax import random 33 | from matplotlib import cm 34 | import mediapy as media 35 | import numpy as np 36 | 37 | configs.define_common_flags() 38 | jax.config.parse_flags_with_absl() 39 | 40 | 41 | def create_videos(config, base_dir, out_dir, out_name, num_frames): 42 | """Creates videos out of the images saved to disk.""" 43 | names = [n for n in config.checkpoint_dir.split('/') if n] 44 | # Last two parts of checkpoint path are experiment name and scene name. 45 | exp_name, scene_name = names[-2:] 46 | video_prefix = f'{scene_name}_{exp_name}_{out_name}' 47 | 48 | zpad = max(3, len(str(num_frames - 1))) 49 | idx_to_str = lambda idx: str(idx).zfill(zpad) 50 | 51 | utils.makedirs(base_dir) 52 | 53 | # Load one example frame to get image shape and depth range. 54 | depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff') 55 | depth_frame = utils.load_img(depth_file) 56 | shape = depth_frame.shape 57 | p = config.render_dist_percentile 58 | distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) 59 | lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits] 60 | print(f'Video shape is {shape[:2]}') 61 | 62 | video_kwargs = { 63 | 'shape': shape[:2], 64 | 'codec': 'h264', 65 | 'fps': config.render_video_fps, 66 | 'crf': config.render_video_crf, 67 | } 68 | 69 | for k in ['color', 'normals', 'acc', 'distance_mean', 'distance_median']: 70 | video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') 71 | input_format = 'gray' if k == 'acc' else 'rgb' 72 | file_ext = 'png' if k in ['color', 'normals'] else 'tiff' 73 | idx = 0 74 | file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') 75 | if not utils.file_exists(file0): 76 | print(f'Images missing for tag {k}') 77 | continue 78 | print(f'Making video {video_file}...') 79 | with media.VideoWriter( 80 | video_file, **video_kwargs, input_format=input_format) as writer: 81 | for idx in range(num_frames): 82 | img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') 83 | if not utils.file_exists(img_file): 84 | ValueError(f'Image file {img_file} does not exist.') 85 | img = utils.load_img(img_file) 86 | if k in ['color', 'normals']: 87 | img = img / 255. 88 | elif k.startswith('distance'): 89 | img = config.render_dist_curve_fn(img) 90 | img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1) 91 | img = cm.get_cmap('turbo')(img)[..., :3] 92 | 93 | frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) 94 | writer.add_image(frame) 95 | idx += 1 96 | 97 | 98 | def main(unused_argv): 99 | 100 | config = configs.load_config(save_config=False) 101 | 102 | dataset = datasets.load_dataset('test', config.data_dir, config) 103 | 104 | key = random.PRNGKey(20200823) 105 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key) 106 | 107 | if config.rawnerf_mode: 108 | postprocess_fn = dataset.metadata['postprocess_fn'] 109 | else: 110 | postprocess_fn = lambda z: z 111 | 112 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) 113 | step = int(state.step) 114 | print(f'Rendering checkpoint at step {step}.') 115 | 116 | out_name = 'path_renders' if config.render_path else 'test_preds' 117 | out_name = f'{out_name}_step_{step}' 118 | base_dir = config.render_dir 119 | if base_dir is None: 120 | base_dir = os.path.join(config.checkpoint_dir, 'render') 121 | out_dir = os.path.join(base_dir, out_name) 122 | if not utils.isdir(out_dir): 123 | utils.makedirs(out_dir) 124 | 125 | path_fn = lambda x: os.path.join(out_dir, x) 126 | 127 | # Ensure sufficient zero-padding of image indices in output filenames. 128 | zpad = max(3, len(str(dataset.size - 1))) 129 | idx_to_str = lambda idx: str(idx).zfill(zpad) 130 | 131 | if config.render_save_async: 132 | async_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) 133 | async_futures = [] 134 | def save_fn(fn, *args, **kwargs): 135 | async_futures.append(async_executor.submit(fn, *args, **kwargs)) 136 | else: 137 | def save_fn(fn, *args, **kwargs): 138 | fn(*args, **kwargs) 139 | 140 | for idx in range(dataset.size): 141 | if idx % config.render_num_jobs != config.render_job_id: 142 | continue 143 | # If current image and next image both already exist, skip ahead. 144 | idx_str = idx_to_str(idx) 145 | curr_file = path_fn(f'color_{idx_str}.png') 146 | next_idx_str = idx_to_str(idx + config.render_num_jobs) 147 | next_file = path_fn(f'color_{next_idx_str}.png') 148 | if utils.file_exists(curr_file) and utils.file_exists(next_file): 149 | print(f'Image {idx}/{dataset.size} already exists, skipping') 150 | continue 151 | print(f'Evaluating image {idx+1}/{dataset.size}') 152 | eval_start_time = time.time() 153 | rays = dataset.generate_ray_batch(idx).rays 154 | train_frac = 1. 155 | rendering = models.render_image( 156 | functools.partial(render_eval_pfn, state.params, train_frac), 157 | rays, None, config) 158 | print(f'Rendered in {(time.time() - eval_start_time):0.3f}s') 159 | 160 | if jax.host_id() != 0: # Only record via host 0. 161 | continue 162 | 163 | rendering['rgb'] = postprocess_fn(rendering['rgb']) 164 | 165 | save_fn( 166 | utils.save_img_u8, rendering['rgb'], path_fn(f'color_{idx_str}.png')) 167 | if 'normals' in rendering: 168 | save_fn( 169 | utils.save_img_u8, rendering['normals'] / 2. + 0.5, 170 | path_fn(f'normals_{idx_str}.png')) 171 | save_fn( 172 | utils.save_img_f32, rendering['distance_mean'], 173 | path_fn(f'distance_mean_{idx_str}.tiff')) 174 | save_fn( 175 | utils.save_img_f32, rendering['distance_median'], 176 | path_fn(f'distance_median_{idx_str}.tiff')) 177 | save_fn( 178 | utils.save_img_f32, rendering['acc'], path_fn(f'acc_{idx_str}.tiff')) 179 | 180 | if config.render_save_async: 181 | # Wait until all worker threads finish. 182 | async_executor.shutdown(wait=True) 183 | 184 | # This will ensure that exceptions in child threads are raised to the 185 | # main thread. 186 | for future in async_futures: 187 | future.result() 188 | 189 | time.sleep(1) 190 | num_files = len(glob.glob(path_fn('acc_*.tiff'))) 191 | time.sleep(10) 192 | if jax.host_id() == 0 and num_files == dataset.size: 193 | print(f'All files found, creating videos (job {config.render_job_id}).') 194 | create_videos(config, base_dir, out_dir, out_name, dataset.size) 195 | 196 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished. 197 | x = jax.numpy.ones([jax.local_device_count()]) 198 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) 199 | print(x) 200 | 201 | 202 | if __name__ == '__main__': 203 | with gin.config_scope('eval'): # Use the same scope as eval.py 204 | app.run(main) 205 | -------------------------------------------------------------------------------- /internal/render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for shooting and rendering rays.""" 16 | 17 | from internal import stepfun 18 | import jax.numpy as jnp 19 | 20 | 21 | def lift_gaussian(d, t_mean, t_var, r_var, diag): 22 | """Lift a Gaussian defined along a ray to 3D coordinates.""" 23 | mean = d[..., None, :] * t_mean[..., None] 24 | 25 | d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True)) 26 | 27 | if diag: 28 | d_outer_diag = d**2 29 | null_outer_diag = 1 - d_outer_diag / d_mag_sq 30 | t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] 31 | xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] 32 | cov_diag = t_cov_diag + xy_cov_diag 33 | return mean, cov_diag 34 | else: 35 | d_outer = d[..., :, None] * d[..., None, :] 36 | eye = jnp.eye(d.shape[-1]) 37 | null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] 38 | t_cov = t_var[..., None, None] * d_outer[..., None, :, :] 39 | xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] 40 | cov = t_cov + xy_cov 41 | return mean, cov 42 | 43 | 44 | def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): 45 | """Approximate a conical frustum as a Gaussian distribution (mean+cov). 46 | 47 | Assumes the ray is originating from the origin, and base_radius is the 48 | radius at dist=1. Doesn't assume `d` is normalized. 49 | 50 | Args: 51 | d: jnp.float32 3-vector, the axis of the cone 52 | t0: float, the starting distance of the frustum. 53 | t1: float, the ending distance of the frustum. 54 | base_radius: float, the scale of the radius as a function of distance. 55 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 56 | stable: boolean, whether or not to use the stable computation described in 57 | the paper (setting this to False will cause catastrophic failure). 58 | 59 | Returns: 60 | a Gaussian (mean and covariance). 61 | """ 62 | if stable: 63 | # Equation 7 in the paper (https://arxiv.org/abs/2103.13415). 64 | mu = (t0 + t1) / 2 # The average of the two `t` values. 65 | hw = (t1 - t0) / 2 # The half-width of the two `t` values. 66 | eps = jnp.finfo(jnp.float32).eps 67 | t_mean = mu + (2 * mu * hw**2) / jnp.maximum(eps, 3 * mu**2 + hw**2) 68 | denom = jnp.maximum(eps, 3 * mu**2 + hw**2) 69 | t_var = (hw**2) / 3 - (4 / 15) * hw**4 * (12 * mu**2 - hw**2) / denom**2 70 | r_var = (mu**2) / 4 + (5 / 12) * hw**2 - (4 / 15) * (hw**4) / denom 71 | else: 72 | # Equations 37-39 in the paper. 73 | t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3)) 74 | r_var = 3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3) 75 | t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3) 76 | t_var = t_mosq - t_mean**2 77 | r_var *= base_radius**2 78 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 79 | 80 | 81 | def cylinder_to_gaussian(d, t0, t1, radius, diag): 82 | """Approximate a cylinder as a Gaussian distribution (mean+cov). 83 | 84 | Assumes the ray is originating from the origin, and radius is the 85 | radius. Does not renormalize `d`. 86 | 87 | Args: 88 | d: jnp.float32 3-vector, the axis of the cylinder 89 | t0: float, the starting distance of the cylinder. 90 | t1: float, the ending distance of the cylinder. 91 | radius: float, the radius of the cylinder 92 | diag: boolean, whether or the Gaussian will be diagonal or full-covariance. 93 | 94 | Returns: 95 | a Gaussian (mean and covariance). 96 | """ 97 | t_mean = (t0 + t1) / 2 98 | r_var = radius**2 / 4 99 | t_var = (t1 - t0)**2 / 12 100 | return lift_gaussian(d, t_mean, t_var, r_var, diag) 101 | 102 | 103 | def cast_rays(tdist, origins, directions, radii, ray_shape, diag=True): 104 | """Cast rays (cone- or cylinder-shaped) and featurize sections of it. 105 | 106 | Args: 107 | tdist: float array, the "fencepost" distances along the ray. 108 | origins: float array, the ray origin coordinates. 109 | directions: float array, the ray direction vectors. 110 | radii: float array, the radii (base radii for cones) of the rays. 111 | ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. 112 | diag: boolean, whether or not the covariance matrices should be diagonal. 113 | 114 | Returns: 115 | a tuple of arrays of means and covariances. 116 | """ 117 | t0 = tdist[..., :-1] 118 | t1 = tdist[..., 1:] 119 | if ray_shape == 'cone': 120 | gaussian_fn = conical_frustum_to_gaussian 121 | elif ray_shape == 'cylinder': 122 | gaussian_fn = cylinder_to_gaussian 123 | else: 124 | raise ValueError('ray_shape must be \'cone\' or \'cylinder\'') 125 | means, covs = gaussian_fn(directions, t0, t1, radii, diag) 126 | means = means + origins[..., None, :] 127 | return means, covs 128 | 129 | 130 | def compute_alpha_weights(density, tdist, dirs, opaque_background=False): 131 | """Helper function for computing alpha compositing weights.""" 132 | t_delta = tdist[..., 1:] - tdist[..., :-1] 133 | delta = t_delta * jnp.linalg.norm(dirs[..., None, :], axis=-1) 134 | density_delta = density * delta 135 | 136 | if opaque_background: 137 | # Equivalent to making the final t-interval infinitely wide. 138 | density_delta = jnp.concatenate([ 139 | density_delta[..., :-1], 140 | jnp.full_like(density_delta[..., -1:], jnp.inf) 141 | ], 142 | axis=-1) 143 | 144 | alpha = 1 - jnp.exp(-density_delta) 145 | trans = jnp.exp(-jnp.concatenate([ 146 | jnp.zeros_like(density_delta[..., :1]), 147 | jnp.cumsum(density_delta[..., :-1], axis=-1) 148 | ], 149 | axis=-1)) 150 | weights = alpha * trans 151 | return weights, alpha, trans 152 | 153 | 154 | def volumetric_rendering(rgbs, 155 | weights, 156 | tdist, 157 | bg_rgbs, 158 | t_far, 159 | compute_extras, 160 | extras=None): 161 | """Volumetric Rendering Function. 162 | 163 | Args: 164 | rgbs: jnp.ndarray(float32), color, [batch_size, num_samples, 3] 165 | weights: jnp.ndarray(float32), weights, [batch_size, num_samples]. 166 | tdist: jnp.ndarray(float32), [batch_size, num_samples]. 167 | bg_rgbs: jnp.ndarray(float32), the color(s) to use for the background. 168 | t_far: jnp.ndarray(float32), [batch_size, 1], the distance of the far plane. 169 | compute_extras: bool, if True, compute extra quantities besides color. 170 | extras: dict, a set of values along rays to render by alpha compositing. 171 | 172 | Returns: 173 | rendering: a dict containing an rgb image of size [batch_size, 3], and other 174 | visualizations if compute_extras=True. 175 | """ 176 | eps = jnp.finfo(jnp.float32).eps 177 | rendering = {} 178 | 179 | acc = weights.sum(axis=-1) 180 | bg_w = jnp.maximum(0, 1 - acc[..., None]) # The weight of the background. 181 | rgb = (weights[..., None] * rgbs).sum(axis=-2) + bg_w * bg_rgbs 182 | rendering['rgb'] = rgb 183 | 184 | if compute_extras: 185 | rendering['acc'] = acc 186 | 187 | if extras is not None: 188 | for k, v in extras.items(): 189 | if v is not None: 190 | rendering[k] = (weights[..., None] * v).sum(axis=-2) 191 | 192 | expectation = lambda x: (weights * x).sum(axis=-1) / jnp.maximum(eps, acc) 193 | t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) 194 | # For numerical stability this expectation is computing using log-distance. 195 | rendering['distance_mean'] = ( 196 | jnp.clip( 197 | jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf), 198 | tdist[..., 0], tdist[..., -1])) 199 | 200 | # Add an extra fencepost with the far distance at the end of each ray, with 201 | # whatever weight is needed to make the new weight vector sum to exactly 1 202 | # (`weights` is only guaranteed to sum to <= 1, not == 1). 203 | t_aug = jnp.concatenate([tdist, t_far], axis=-1) 204 | weights_aug = jnp.concatenate([weights, bg_w], axis=-1) 205 | 206 | ps = [5, 50, 95] 207 | distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps) 208 | 209 | for i, p in enumerate(ps): 210 | s = 'median' if p == 50 else 'percentile_' + str(p) 211 | rendering['distance_' + s] = distance_percentiles[..., i] 212 | 213 | return rendering 214 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/camera.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | from scipy.optimize import root 6 | 7 | 8 | #------------------------------------------------------------------------------- 9 | # 10 | # camera distortion functions for arrays of size (..., 2) 11 | # 12 | #------------------------------------------------------------------------------- 13 | 14 | def simple_radial_distortion(camera, x): 15 | return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True)) 16 | 17 | 18 | def radial_distortion(camera, x): 19 | r_sq = np.square(x).sum(axis=-1, keepdims=True) 20 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) 21 | 22 | 23 | def opencv_distortion(camera, x): 24 | x_sq = np.square(x) 25 | xy = np.prod(x, axis=-1, keepdims=True) 26 | r_sq = x_sq.sum(axis=-1, keepdims=True) 27 | 28 | return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate( 29 | (2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq), camera.p1 * 30 | (r_sq + 2. * y_sq) + 2. * camera.p2 * xy), 31 | axis=-1) 32 | 33 | 34 | #------------------------------------------------------------------------------- 35 | # 36 | # Camera 37 | # 38 | #------------------------------------------------------------------------------- 39 | 40 | class Camera: 41 | 42 | @staticmethod 43 | def GetNumParams(type_): 44 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE': 45 | return 3 46 | if type_ == 1 or type_ == 'PINHOLE': 47 | return 4 48 | if type_ == 2 or type_ == 'SIMPLE_RADIAL': 49 | return 4 50 | if type_ == 3 or type_ == 'RADIAL': 51 | return 5 52 | if type_ == 4 or type_ == 'OPENCV': 53 | return 8 54 | #if type_ == 5 or type_ == 'OPENCV_FISHEYE': 55 | # return 8 56 | #if type_ == 6 or type_ == 'FULL_OPENCV': 57 | # return 12 58 | #if type_ == 7 or type_ == 'FOV': 59 | # return 5 60 | #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE': 61 | # return 4 62 | #if type_ == 9 or type_ == 'RADIAL_FISHEYE': 63 | # return 5 64 | #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE': 65 | # return 12 66 | 67 | # TODO: not supporting other camera types, currently 68 | raise Exception('Camera type not supported') 69 | 70 | #--------------------------------------------------------------------------- 71 | 72 | @staticmethod 73 | def GetNameFromType(type_): 74 | if type_ == 0: 75 | return 'SIMPLE_PINHOLE' 76 | if type_ == 1: 77 | return 'PINHOLE' 78 | if type_ == 2: 79 | return 'SIMPLE_RADIAL' 80 | if type_ == 3: 81 | return 'RADIAL' 82 | if type_ == 4: 83 | return 'OPENCV' 84 | #if type_ == 5: return 'OPENCV_FISHEYE' 85 | #if type_ == 6: return 'FULL_OPENCV' 86 | #if type_ == 7: return 'FOV' 87 | #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE' 88 | #if type_ == 9: return 'RADIAL_FISHEYE' 89 | #if type_ == 10: return 'THIN_PRISM_FISHEYE' 90 | 91 | raise Exception('Camera type not supported') 92 | 93 | #--------------------------------------------------------------------------- 94 | 95 | def __init__(self, type_, width_, height_, params): 96 | self.width = width_ 97 | self.height = height_ 98 | 99 | if type_ == 0 or type_ == 'SIMPLE_PINHOLE': 100 | self.fx, self.cx, self.cy = params 101 | self.fy = self.fx 102 | self.distortion_func = None 103 | self.camera_type = 0 104 | 105 | elif type_ == 1 or type_ == 'PINHOLE': 106 | self.fx, self.fy, self.cx, self.cy = params 107 | self.distortion_func = None 108 | self.camera_type = 1 109 | 110 | elif type_ == 2 or type_ == 'SIMPLE_RADIAL': 111 | self.fx, self.cx, self.cy, self.k1 = params 112 | self.fy = self.fx 113 | self.distortion_func = simple_radial_distortion 114 | self.camera_type = 2 115 | 116 | elif type_ == 3 or type_ == 'RADIAL': 117 | self.fx, self.cx, self.cy, self.k1, self.k2 = params 118 | self.fy = self.fx 119 | self.distortion_func = radial_distortion 120 | self.camera_type = 3 121 | 122 | elif type_ == 4 or type_ == 'OPENCV': 123 | self.fx, self.fy, self.cx, self.cy = params[:4] 124 | self.k1, self.k2, self.p1, self.p2 = params[4:] 125 | self.distortion_func = opencv_distortion 126 | self.camera_type = 4 127 | 128 | else: 129 | raise Exception('Camera type not supported') 130 | 131 | #--------------------------------------------------------------------------- 132 | 133 | def __str__(self): 134 | s = ( 135 | self.GetNameFromType(self.camera_type) + 136 | ' {} {} {}'.format(self.width, self.height, self.fx)) 137 | 138 | if self.camera_type in (1, 4): # PINHOLE, OPENCV 139 | s += ' {}'.format(self.fy) 140 | 141 | s += ' {} {}'.format(self.cx, self.cy) 142 | 143 | if self.camera_type == 2: # SIMPLE_RADIAL 144 | s += ' {}'.format(self.k1) 145 | 146 | elif self.camera_type == 3: # RADIAL 147 | s += ' {} {}'.format(self.k1, self.k2) 148 | 149 | elif self.camera_type == 4: # OPENCV 150 | s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2) 151 | 152 | return s 153 | 154 | #--------------------------------------------------------------------------- 155 | 156 | # return the camera parameters in the same order as the colmap output format 157 | def get_params(self): 158 | if self.camera_type == 0: 159 | return np.array((self.fx, self.cx, self.cy)) 160 | if self.camera_type == 1: 161 | return np.array((self.fx, self.fy, self.cx, self.cy)) 162 | if self.camera_type == 2: 163 | return np.array((self.fx, self.cx, self.cy, self.k1)) 164 | if self.camera_type == 3: 165 | return np.array((self.fx, self.cx, self.cy, self.k1, self.k2)) 166 | if self.camera_type == 4: 167 | return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, self.k2, 168 | self.p1, self.p2)) 169 | 170 | #--------------------------------------------------------------------------- 171 | 172 | def get_camera_matrix(self): 173 | return np.array(((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1))) 174 | 175 | def get_inverse_camera_matrix(self): 176 | return np.array(((1. / self.fx, 0, -self.cx / self.fx), 177 | (0, 1. / self.fy, -self.cy / self.fy), (0, 0, 1))) 178 | 179 | @property 180 | def K(self): 181 | return self.get_camera_matrix() 182 | 183 | @property 184 | def K_inv(self): 185 | return self.get_inverse_camera_matrix() 186 | 187 | #--------------------------------------------------------------------------- 188 | 189 | # return the inverse camera matrix 190 | def get_inv_camera_matrix(self): 191 | inv_fx, inv_fy = 1. / self.fx, 1. / self.fy 192 | return np.array(((inv_fx, 0, -inv_fx * self.cx), 193 | (0, inv_fy, -inv_fy * self.cy), (0, 0, 1))) 194 | 195 | #--------------------------------------------------------------------------- 196 | 197 | # return an (x, y) pixel coordinate grid for this camera 198 | def get_image_grid(self): 199 | xmin = (0.5 - self.cx) / self.fx 200 | xmax = (self.width - 0.5 - self.cx) / self.fx 201 | ymin = (0.5 - self.cy) / self.fy 202 | ymax = (self.height - 0.5 - self.cy) / self.fy 203 | return np.meshgrid( 204 | np.linspace(xmin, xmax, self.width), 205 | np.linspace(ymin, ymax, self.height)) 206 | 207 | #--------------------------------------------------------------------------- 208 | 209 | # x: array of shape (N,2) or (2,) 210 | # normalized: False if the input points are in pixel coordinates 211 | # denormalize: True if the points should be put back into pixel coordinates 212 | def distort_points(self, x, normalized=True, denormalize=True): 213 | x = np.atleast_2d(x) 214 | 215 | # put the points into normalized camera coordinates 216 | if not normalized: 217 | x -= np.array([[self.cx, self.cy]]) 218 | x /= np.array([[self.fx, self.fy]]) 219 | 220 | # distort, if necessary 221 | if self.distortion_func is not None: 222 | x = self.distortion_func(self, x) 223 | 224 | if denormalize: 225 | x *= np.array([[self.fx, self.fy]]) 226 | x += np.array([[self.cx, self.cy]]) 227 | 228 | return x 229 | 230 | #--------------------------------------------------------------------------- 231 | 232 | # x: array of shape (N1,N2,...,2), (N,2), or (2,) 233 | # normalized: False if the input points are in pixel coordinates 234 | # denormalize: True if the points should be put back into pixel coordinates 235 | def undistort_points(self, x, normalized=False, denormalize=True): 236 | x = np.atleast_2d(x) 237 | 238 | # put the points into normalized camera coordinates 239 | if not normalized: 240 | x = x - np.array([self.cx, self.cy]) # creates a copy 241 | x /= np.array([self.fx, self.fy]) 242 | 243 | # undistort, if necessary 244 | if self.distortion_func is not None: 245 | 246 | def objective(xu): 247 | return (x - self.distortion_func(self, xu.reshape(*x.shape))).ravel() 248 | 249 | xu = root(objective, x).x.reshape(*x.shape) 250 | else: 251 | xu = x 252 | 253 | if denormalize: 254 | xu *= np.array([[self.fx, self.fy]]) 255 | xu += np.array([[self.cx, self.cy]]) 256 | 257 | return xu 258 | -------------------------------------------------------------------------------- /internal/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for handling configurations.""" 16 | 17 | import dataclasses 18 | from typing import Any, Callable, Optional, Tuple 19 | 20 | from absl import flags 21 | from flax.core import FrozenDict 22 | import gin 23 | from internal import utils 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | gin.add_config_file_search_path('experimental/users/barron/mipnerf360/') 28 | 29 | configurables = { 30 | 'jnp': [jnp.reciprocal, jnp.log, jnp.log1p, jnp.exp, jnp.sqrt, jnp.square], 31 | 'jax.nn': [jax.nn.relu, jax.nn.softplus, jax.nn.silu], 32 | 'jax.nn.initializers.he_normal': [jax.nn.initializers.he_normal()], 33 | 'jax.nn.initializers.he_uniform': [jax.nn.initializers.he_uniform()], 34 | 'jax.nn.initializers.glorot_normal': [jax.nn.initializers.glorot_normal()], 35 | 'jax.nn.initializers.glorot_uniform': [ 36 | jax.nn.initializers.glorot_uniform() 37 | ], 38 | } 39 | 40 | for module, configurables in configurables.items(): 41 | for configurable in configurables: 42 | gin.config.external_configurable(configurable, module=module) 43 | 44 | 45 | @gin.configurable() 46 | @dataclasses.dataclass 47 | class Config: 48 | """Configuration flags for everything.""" 49 | dataset_loader: str = 'llff' # The type of dataset loader to use. 50 | batching: str = 'all_images' # Batch composition, [single_image, all_images]. 51 | batch_size: int = 16384 # The number of rays/pixels in each batch. 52 | patch_size: int = 1 # Resolution of patches sampled for training batches. 53 | factor: int = 0 # The downsample factor of images, 0 for no downsampling. 54 | load_alphabetical: bool = True # Load images in COLMAP vs alphabetical 55 | # ordering (affects heldout test set). 56 | forward_facing: bool = False # Set to True for forward-facing LLFF captures. 57 | render_path: bool = False # If True, render a path. Used only by LLFF. 58 | llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF. 59 | # If true, use all input images for training. 60 | llff_use_all_images_for_training: bool = False 61 | use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender. 62 | compute_disp_metrics: bool = False # If True, load and compute disparity MSE. 63 | compute_normal_metrics: bool = False # If True, load and compute normal MAE. 64 | gc_every: int = 10000 # The number of steps between garbage collections. 65 | disable_multiscale_loss: bool = False # If True, disable multiscale loss. 66 | randomized: bool = True # Use randomized stratified sampling. 67 | near: float = 2. # Near plane distance. 68 | far: float = 6. # Far plane distance. 69 | checkpoint_dir: Optional[str] = None # Where to log checkpoints. 70 | render_dir: Optional[str] = None # Output rendering directory. 71 | data_dir: Optional[str] = None # Input data directory. 72 | vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP. 73 | render_chunk_size: int = 16384 # Chunk size for whole-image renderings. 74 | num_showcase_images: int = 5 # The number of test-set images to showcase. 75 | deterministic_showcase: bool = True # If True, showcase the same images. 76 | vis_num_rays: int = 16 # The number of rays to visualize. 77 | # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage. 78 | vis_decimate: int = 0 79 | 80 | # Only used by train.py: 81 | max_steps: int = 250000 # The number of optimization steps. 82 | early_exit_steps: Optional[int] = None # Early stopping, for debugging. 83 | checkpoint_every: int = 25000 # The number of steps to save a checkpoint. 84 | print_every: int = 100 # The number of steps between reports to tensorboard. 85 | train_render_every: int = 5000 # Steps between test set renders when training 86 | cast_rays_in_train_step: bool = False # If True, compute rays in train step. 87 | data_loss_type: str = 'charb' # What kind of loss to use ('mse' or 'charb'). 88 | charb_padding: float = 0.001 # The padding used for Charbonnier loss. 89 | data_loss_mult: float = 1.0 # Mult for the finest data term in the loss. 90 | data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms. 91 | interlevel_loss_mult: float = 1.0 # Mult. for the loss on the proposal MLP. 92 | orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss. 93 | orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights. 94 | # What that loss is imposed on, options are 'normals' or 'normals_pred'. 95 | orientation_loss_target: str = 'normals_pred' 96 | predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss. 97 | # Mult. on the coarser predicted normal loss. 98 | predicted_normal_coarse_loss_mult: float = 0.0 99 | weight_decay_mults: FrozenDict[str, Any] = FrozenDict({}) # Weight decays. 100 | # An example that regularizes the NeRF and the first layer of the prop MLP: 101 | # weight_decay_mults = { 102 | # 'NerfMLP_0': 0.00001, 103 | # 'PropMLP_0/Dense_0': 0.001, 104 | # } 105 | # Any model parameter that isn't specified gets a mult of 0. See the 106 | # train_weight_l2_* parameters in TensorBoard to know what can be regularized. 107 | 108 | lr_init: float = 0.002 # The initial learning rate. 109 | lr_final: float = 0.00002 # The final learning rate. 110 | lr_delay_steps: int = 512 # The number of "warmup" learning steps. 111 | lr_delay_mult: float = 0.01 # How much sever the "warmup" should be. 112 | adam_beta1: float = 0.9 # Adam's beta2 hyperparameter. 113 | adam_beta2: float = 0.999 # Adam's beta2 hyperparameter. 114 | adam_eps: float = 1e-6 # Adam's epsilon hyperparameter. 115 | grad_max_norm: float = 0.001 # Gradient clipping magnitude, disabled if == 0. 116 | grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. 117 | distortion_loss_mult: float = 0.01 # Multiplier on the distortion loss. 118 | 119 | # Only used by eval.py: 120 | eval_only_once: bool = True # If True evaluate the model only once, ow loop. 121 | eval_save_output: bool = True # If True save predicted images to disk. 122 | eval_save_ray_data: bool = False # If True save individual ray traces. 123 | eval_render_interval: int = 1 # The interval between images saved to disk. 124 | eval_dataset_limit: int = jnp.iinfo(jnp.int32).max # Num test images to eval. 125 | eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images. 126 | eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]). 127 | 128 | # Only used by render.py 129 | render_video_fps: int = 60 # Framerate in frames-per-second. 130 | render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality. 131 | render_path_frames: int = 120 # Number of frames in render path. 132 | z_variation: float = 0. # How much height variation in render path. 133 | z_phase: float = 0. # Phase offset for height variation in render path. 134 | render_dist_percentile: float = 0.5 # How much to trim from near/far planes. 135 | render_dist_curve_fn: Callable[..., Any] = jnp.log # How depth is curved. 136 | render_path_file: Optional[str] = None # Numpy render pose file to load. 137 | render_job_id: int = 0 # Render job id. 138 | render_num_jobs: int = 1 # Total number of render jobs. 139 | render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as 140 | # (width, height). 141 | render_focal: Optional[float] = None # Render focal length. 142 | render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'. 143 | render_spherical: bool = False # Render spherical 360 panoramas. 144 | render_save_async: bool = True # Save to CNS using a separate thread. 145 | 146 | render_spline_keyframes: Optional[str] = None # Text file containing names of 147 | # images to be used as spline 148 | # keyframes, OR directory 149 | # containing those images. 150 | render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe. 151 | render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation. 152 | render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for 153 | # exact interpolation of keyframes. 154 | # Interpolate per-frame exposure value from spline keyframes. 155 | render_spline_interpolate_exposure: bool = False 156 | 157 | # Flags for raw datasets. 158 | rawnerf_mode: bool = False # Load raw images and train in raw color space. 159 | exposure_percentile: float = 97. # Image percentile to expose as white. 160 | num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border 161 | # around each input image. 162 | apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask. 163 | autoexpose_renders: bool = False # During rendering, autoexpose each image. 164 | # For raw test scenes, use affine raw-space color correction. 165 | eval_raw_affine_cc: bool = False 166 | 167 | 168 | def define_common_flags(): 169 | # Define the flags used by both train.py and eval.py 170 | flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') 171 | flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') 172 | flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') 173 | flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') 174 | 175 | 176 | def load_config(save_config=True): 177 | """Load the config, and optionally checkpoint it.""" 178 | gin.parse_config_files_and_bindings( 179 | flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) 180 | config = Config() 181 | if save_config and jax.host_id() == 0: 182 | utils.makedirs(config.checkpoint_dir) 183 | with utils.open_file(config.checkpoint_dir + '/config.gin', 'w') as f: 184 | f.write(gin.config_str()) 185 | return config 186 | -------------------------------------------------------------------------------- /internal/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for visualizing things.""" 16 | 17 | from internal import stepfun 18 | import jax.numpy as jnp 19 | from matplotlib import cm 20 | 21 | 22 | def weighted_percentile(x, w, ps, assume_sorted=False): 23 | """Compute the weighted percentile(s) of a single vector.""" 24 | x = x.reshape([-1]) 25 | w = w.reshape([-1]) 26 | if not assume_sorted: 27 | sortidx = jnp.argsort(x) 28 | x, w = x[sortidx], w[sortidx] 29 | acc_w = jnp.cumsum(w) 30 | return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x) 31 | 32 | 33 | def sinebow(h): 34 | """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" 35 | f = lambda x: jnp.sin(jnp.pi * x)**2 36 | return jnp.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) 37 | 38 | 39 | def matte(vis, acc, dark=0.8, light=1.0, width=8): 40 | """Set non-accumulated pixels to a Photoshop-esque checker pattern.""" 41 | bg_mask = jnp.logical_xor( 42 | (jnp.arange(acc.shape[0]) % (2 * width) // width)[:, None], 43 | (jnp.arange(acc.shape[1]) % (2 * width) // width)[None, :]) 44 | bg = jnp.where(bg_mask, light, dark) 45 | return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None] 46 | 47 | 48 | def visualize_cmap(value, 49 | weight, 50 | colormap, 51 | lo=None, 52 | hi=None, 53 | percentile=99., 54 | curve_fn=lambda x: x, 55 | modulus=None, 56 | matte_background=True): 57 | """Visualize a 1D image and a 1D weighting according to some colormap. 58 | 59 | Args: 60 | value: A 1D image. 61 | weight: A weight map, in [0, 1]. 62 | colormap: A colormap function. 63 | lo: The lower bound to use when rendering, if None then use a percentile. 64 | hi: The upper bound to use when rendering, if None then use a percentile. 65 | percentile: What percentile of the value map to crop to when automatically 66 | generating `lo` and `hi`. Depends on `weight` as well as `value'. 67 | curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` 68 | before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). 69 | modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If 70 | `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. 71 | matte_background: If True, matte the image over a checkerboard. 72 | 73 | Returns: 74 | A colormap rendering. 75 | """ 76 | # Identify the values that bound the middle of `value' according to `weight`. 77 | lo_auto, hi_auto = weighted_percentile( 78 | value, weight, [50 - percentile / 2, 50 + percentile / 2]) 79 | 80 | # If `lo` or `hi` are None, use the automatically-computed bounds above. 81 | eps = jnp.finfo(jnp.float32).eps 82 | lo = lo or (lo_auto - eps) 83 | hi = hi or (hi_auto + eps) 84 | 85 | # Curve all values. 86 | value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] 87 | 88 | # Wrap the values around if requested. 89 | if modulus: 90 | value = jnp.mod(value, modulus) / modulus 91 | else: 92 | # Otherwise, just scale to [0, 1]. 93 | value = jnp.nan_to_num( 94 | jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1)) 95 | 96 | if colormap: 97 | colorized = colormap(value)[:, :, :3] 98 | else: 99 | if len(value.shape) != 3: 100 | raise ValueError(f'value must have 3 dims but has {len(value.shape)}') 101 | if value.shape[-1] != 3: 102 | raise ValueError( 103 | f'value must have 3 channels but has {len(value.shape[-1])}') 104 | colorized = value 105 | 106 | return matte(colorized, weight) if matte_background else colorized 107 | 108 | 109 | def visualize_coord_mod(coords, acc): 110 | """Visualize the coordinate of each point within its "cell".""" 111 | return matte(((coords + 1) % 2) / 2, acc) 112 | 113 | 114 | def visualize_rays(dist, 115 | dist_range, 116 | weights, 117 | rgbs, 118 | accumulate=False, 119 | renormalize=False, 120 | resolution=2048, 121 | bg_color=0.8): 122 | """Visualize a bundle of rays.""" 123 | dist_vis = jnp.linspace(*dist_range, resolution + 1) 124 | vis_rgb, vis_alpha = [], [] 125 | for ds, ws, rs in zip(dist, weights, rgbs): 126 | vis_rs, vis_ws = [], [] 127 | for d, w, r in zip(ds, ws, rs): 128 | if accumulate: 129 | # Produce the accumulated color and weight at each point along the ray. 130 | w_csum = jnp.cumsum(w, axis=0) 131 | rw_csum = jnp.cumsum((r * w[:, None]), axis=0) 132 | eps = jnp.finfo(jnp.float32).eps 133 | r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum 134 | vis_rs.append(stepfun.resample(dist_vis, d, r.T, use_avg=True).T) 135 | vis_ws.append(stepfun.resample(dist_vis, d, w.T, use_avg=True).T) 136 | vis_rgb.append(jnp.stack(vis_rs)) 137 | vis_alpha.append(jnp.stack(vis_ws)) 138 | vis_rgb = jnp.stack(vis_rgb, axis=1) 139 | vis_alpha = jnp.stack(vis_alpha, axis=1) 140 | 141 | if renormalize: 142 | # Scale the alphas so that the largest value is 1, for visualization. 143 | vis_alpha /= jnp.maximum(jnp.finfo(jnp.float32).eps, jnp.max(vis_alpha)) 144 | 145 | if resolution > vis_rgb.shape[0]: 146 | rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) 147 | stride = rep * vis_rgb.shape[1] 148 | 149 | vis_rgb = vis_rgb.tile((1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:]) 150 | vis_alpha = vis_alpha.tile((1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:]) 151 | 152 | # Add a strip of background pixels after each set of levels of rays. 153 | vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) 154 | vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) 155 | vis_rgb = jnp.concatenate([vis_rgb, jnp.zeros_like(vis_rgb[:, :1])], 156 | axis=1).reshape((-1,) + vis_rgb.shape[2:]) 157 | vis_alpha = jnp.concatenate( 158 | [vis_alpha, jnp.zeros_like(vis_alpha[:, :1])], 159 | axis=1).reshape((-1,) + vis_alpha.shape[2:]) 160 | 161 | # Matte the RGB image over the background. 162 | vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None] 163 | 164 | # Remove the final row of background pixels. 165 | vis = vis[:-1] 166 | vis_alpha = vis_alpha[:-1] 167 | return vis, vis_alpha 168 | 169 | 170 | def visualize_suite(rendering, rays): 171 | """A wrapper around other visualizations for easy integration.""" 172 | 173 | depth_curve_fn = lambda x: -jnp.log(x + jnp.finfo(jnp.float32).eps) 174 | 175 | rgb = rendering['rgb'] 176 | acc = rendering['acc'] 177 | 178 | distance_mean = rendering['distance_mean'] 179 | distance_median = rendering['distance_median'] 180 | distance_p5 = rendering['distance_percentile_5'] 181 | distance_p95 = rendering['distance_percentile_95'] 182 | acc = jnp.where(jnp.isnan(distance_mean), jnp.zeros_like(acc), acc) 183 | 184 | # The xyz coordinates where rays terminate. 185 | coords = rays.origins + rays.directions * distance_mean[:, :, None] 186 | 187 | vis_depth_mean, vis_depth_median = [ 188 | visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn) 189 | for x in [distance_mean, distance_median] 190 | ] 191 | 192 | # Render three depth percentiles directly to RGB channels, where the spacing 193 | # determines the color. delta == big change, epsilon = small change. 194 | # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon] 195 | # Purple: A thin but even density, [x-delta, x, x+delta] 196 | # Red: A thin density, then a thick density, [x-delta, x, x+epsilon] 197 | # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta] 198 | vis_depth_triplet = visualize_cmap( 199 | jnp.stack( 200 | [2 * distance_median - distance_p5, distance_median, distance_p95], 201 | axis=-1), 202 | acc, 203 | None, 204 | curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps)) 205 | 206 | dist = rendering['ray_sdist'] 207 | dist_range = (0, 1) 208 | weights = rendering['ray_weights'] 209 | rgbs = [jnp.clip(r, 0, 1) for r in rendering['ray_rgbs']] 210 | 211 | vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs) 212 | 213 | sqrt_weights = [jnp.sqrt(w) for w in weights] 214 | sqrt_ray_weights, ray_alpha = visualize_rays( 215 | dist, 216 | dist_range, 217 | [jnp.ones_like(lw) for lw in sqrt_weights], 218 | [lw[..., None] for lw in sqrt_weights], 219 | bg_color=0, 220 | ) 221 | sqrt_ray_weights = sqrt_ray_weights[..., 0] 222 | 223 | null_color = jnp.array([1., 0., 0.]) 224 | vis_ray_weights = jnp.where( 225 | ray_alpha[:, :, None] == 0, 226 | null_color[None, None], 227 | visualize_cmap( 228 | sqrt_ray_weights, 229 | jnp.ones_like(sqrt_ray_weights), 230 | cm.get_cmap('gray'), 231 | lo=0, 232 | hi=1, 233 | matte_background=False, 234 | ), 235 | ) 236 | 237 | vis = { 238 | 'color': rgb, 239 | 'acc': acc, 240 | 'color_matte': matte(rgb, acc), 241 | 'depth_mean': vis_depth_mean, 242 | 'depth_median': vis_depth_median, 243 | 'depth_triplet': vis_depth_triplet, 244 | 'coords_mod': visualize_coord_mod(coords, acc), 245 | 'ray_colors': vis_ray_colors, 246 | 'ray_weights': vis_ray_weights, 247 | } 248 | 249 | if 'rgb_cc' in rendering: 250 | vis['color_corrected'] = rendering['rgb_cc'] 251 | 252 | # Render every item named "normals*". 253 | for key, val in rendering.items(): 254 | if key.startswith('normals'): 255 | vis[key] = matte(val / 2. + 0.5, acc) 256 | 257 | if 'roughness' in rendering: 258 | vis['roughness'] = matte(jnp.tanh(rendering['roughness']), acc) 259 | 260 | return vis 261 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluation script.""" 16 | 17 | import functools 18 | from os import path 19 | import sys 20 | import time 21 | 22 | from absl import app 23 | from flax.metrics import tensorboard 24 | from flax.training import checkpoints 25 | import gin 26 | from internal import configs 27 | from internal import datasets 28 | from internal import image 29 | from internal import models 30 | from internal import raw_utils 31 | from internal import ref_utils 32 | from internal import train_utils 33 | from internal import utils 34 | from internal import vis 35 | import jax 36 | from jax import random 37 | import jax.numpy as jnp 38 | import numpy as np 39 | 40 | configs.define_common_flags() 41 | jax.config.parse_flags_with_absl() 42 | 43 | 44 | def main(unused_argv): 45 | config = configs.load_config(save_config=False) 46 | 47 | dataset = datasets.load_dataset('test', config.data_dir, config) 48 | 49 | key = random.PRNGKey(20200823) 50 | _, state, render_eval_pfn, _, _ = train_utils.setup_model(config, key) 51 | 52 | if config.rawnerf_mode: 53 | postprocess_fn = dataset.metadata['postprocess_fn'] 54 | else: 55 | postprocess_fn = lambda z: z 56 | 57 | if config.eval_raw_affine_cc: 58 | cc_fun = raw_utils.match_images_affine 59 | else: 60 | cc_fun = image.color_correct 61 | 62 | metric_harness = image.MetricHarness() 63 | 64 | last_step = 0 65 | out_dir = path.join(config.checkpoint_dir, 66 | 'path_renders' if config.render_path else 'test_preds') 67 | path_fn = lambda x: path.join(out_dir, x) 68 | 69 | if not config.eval_only_once: 70 | summary_writer = tensorboard.SummaryWriter( 71 | path.join(config.checkpoint_dir, 'eval')) 72 | while True: 73 | state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) 74 | step = int(state.step) 75 | if step <= last_step: 76 | print(f'Checkpoint step {step} <= last step {last_step}, sleeping.') 77 | time.sleep(10) 78 | continue 79 | print(f'Evaluating checkpoint at step {step}.') 80 | if config.eval_save_output and (not utils.isdir(out_dir)): 81 | utils.makedirs(out_dir) 82 | 83 | num_eval = min(dataset.size, config.eval_dataset_limit) 84 | key = random.PRNGKey(0 if config.deterministic_showcase else step) 85 | perm = random.permutation(key, num_eval) 86 | showcase_indices = np.sort(perm[:config.num_showcase_images]) 87 | 88 | metrics = [] 89 | metrics_cc = [] 90 | showcases = [] 91 | render_times = [] 92 | for idx in range(dataset.size): 93 | eval_start_time = time.time() 94 | batch = next(dataset) 95 | if idx >= num_eval: 96 | print(f'Skipping image {idx+1}/{dataset.size}') 97 | continue 98 | print(f'Evaluating image {idx+1}/{dataset.size}') 99 | rays = batch.rays 100 | train_frac = state.step / config.max_steps 101 | rendering = models.render_image( 102 | functools.partial( 103 | render_eval_pfn, 104 | state.params, 105 | train_frac, 106 | ), 107 | rays, 108 | None, 109 | config, 110 | ) 111 | 112 | if jax.host_id() != 0: # Only record via host 0. 113 | continue 114 | 115 | render_times.append((time.time() - eval_start_time)) 116 | print(f'Rendered in {render_times[-1]:0.3f}s') 117 | 118 | # Cast to 64-bit to ensure high precision for color correction function. 119 | gt_rgb = np.array(batch.rgb, dtype=np.float64) 120 | rendering['rgb'] = np.array(rendering['rgb'], dtype=np.float64) 121 | 122 | cc_start_time = time.time() 123 | rendering['rgb_cc'] = cc_fun(rendering['rgb'], gt_rgb) 124 | print(f'Color corrected in {(time.time() - cc_start_time):0.3f}s') 125 | 126 | if not config.eval_only_once and idx in showcase_indices: 127 | showcase_idx = idx if config.deterministic_showcase else len(showcases) 128 | showcases.append((showcase_idx, rendering, batch)) 129 | if not config.render_path: 130 | rgb = postprocess_fn(rendering['rgb']) 131 | rgb_cc = postprocess_fn(rendering['rgb_cc']) 132 | rgb_gt = postprocess_fn(gt_rgb) 133 | 134 | if config.eval_quantize_metrics: 135 | # Ensures that the images written to disk reproduce the metrics. 136 | rgb = np.round(rgb * 255) / 255 137 | rgb_cc = np.round(rgb_cc * 255) / 255 138 | 139 | if config.eval_crop_borders > 0: 140 | crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c] 141 | rgb = crop_fn(rgb) 142 | rgb_cc = crop_fn(rgb_cc) 143 | rgb_gt = crop_fn(rgb_gt) 144 | 145 | metric = metric_harness(rgb, rgb_gt) 146 | metric_cc = metric_harness(rgb_cc, rgb_gt) 147 | 148 | if config.compute_disp_metrics: 149 | for tag in ['mean', 'median']: 150 | key = f'distance_{tag}' 151 | if key in rendering: 152 | disparity = 1 / (1 + rendering[key]) 153 | metric[f'disparity_{tag}_mse'] = float( 154 | ((disparity - batch.disps)**2).mean()) 155 | 156 | if config.compute_normal_metrics: 157 | weights = rendering['acc'] * batch.alphas 158 | normalized_normals_gt = ref_utils.l2_normalize(batch.normals) 159 | for key, val in rendering.items(): 160 | if key.startswith('normals') and val is not None: 161 | normalized_normals = ref_utils.l2_normalize(val) 162 | metric[key + '_mae'] = ref_utils.compute_weighted_mae( 163 | weights, normalized_normals, normalized_normals_gt) 164 | 165 | for m, v in metric.items(): 166 | print(f'{m:30s} = {v:.4f}') 167 | 168 | metrics.append(metric) 169 | metrics_cc.append(metric_cc) 170 | 171 | if config.eval_save_output and (config.eval_render_interval > 0): 172 | if (idx % config.eval_render_interval) == 0: 173 | utils.save_img_u8(postprocess_fn(rendering['rgb']), 174 | path_fn(f'color_{idx:03d}.png')) 175 | utils.save_img_u8(postprocess_fn(rendering['rgb_cc']), 176 | path_fn(f'color_cc_{idx:03d}.png')) 177 | 178 | for key in ['distance_mean', 'distance_median']: 179 | if key in rendering: 180 | utils.save_img_f32(rendering[key], 181 | path_fn(f'{key}_{idx:03d}.tiff')) 182 | 183 | for key in ['normals']: 184 | if key in rendering: 185 | utils.save_img_u8(rendering[key] / 2. + 0.5, 186 | path_fn(f'{key}_{idx:03d}.png')) 187 | 188 | utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff')) 189 | 190 | if (not config.eval_only_once) and (jax.host_id() == 0): 191 | summary_writer.scalar('eval_median_render_time', np.median(render_times), 192 | step) 193 | for name in metrics[0]: 194 | scores = [m[name] for m in metrics] 195 | summary_writer.scalar('eval_metrics/' + name, np.mean(scores), step) 196 | summary_writer.histogram('eval_metrics/' + 'perimage_' + name, scores, 197 | step) 198 | for name in metrics_cc[0]: 199 | scores = [m[name] for m in metrics_cc] 200 | summary_writer.scalar('eval_metrics_cc/' + name, np.mean(scores), step) 201 | summary_writer.histogram('eval_metrics_cc/' + 'perimage_' + name, 202 | scores, step) 203 | 204 | for i, r, b in showcases: 205 | if config.vis_decimate > 1: 206 | d = config.vis_decimate 207 | decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] 208 | else: 209 | decimate_fn = lambda x: x 210 | r = jax.tree_util.tree_map(decimate_fn, r) 211 | b = jax.tree_util.tree_map(decimate_fn, b) 212 | visualizations = vis.visualize_suite(r, b.rays) 213 | for k, v in visualizations.items(): 214 | if k == 'color': 215 | v = postprocess_fn(v) 216 | summary_writer.image(f'output_{k}_{i}', v, step) 217 | if not config.render_path: 218 | target = postprocess_fn(b.rgb) 219 | summary_writer.image(f'true_color_{i}', target, step) 220 | pred = postprocess_fn(visualizations['color']) 221 | residual = np.clip(pred - target + 0.5, 0, 1) 222 | summary_writer.image(f'true_residual_{i}', residual, step) 223 | if config.compute_normal_metrics: 224 | summary_writer.image(f'true_normals_{i}', b.normals / 2. + 0.5, 225 | step) 226 | 227 | if (config.eval_save_output and (not config.render_path) and 228 | (jax.host_id() == 0)): 229 | with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f: 230 | f.write(' '.join([str(r) for r in render_times])) 231 | for name in metrics[0]: 232 | with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f: 233 | f.write(' '.join([str(m[name]) for m in metrics])) 234 | for name in metrics_cc[0]: 235 | with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f: 236 | f.write(' '.join([str(m[name]) for m in metrics_cc])) 237 | if config.eval_save_ray_data: 238 | for i, r, b in showcases: 239 | rays = {k: v for k, v in r.items() if 'ray_' in k} 240 | np.set_printoptions(threshold=sys.maxsize) 241 | with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f: 242 | f.write(repr(rays)) 243 | 244 | # A hack that forces Jax to keep all TPUs alive until every TPU is finished. 245 | x = jnp.ones([jax.local_device_count()]) 246 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) 247 | print(x) 248 | 249 | if config.eval_only_once: 250 | break 251 | if config.early_exit_steps is not None: 252 | num_steps = config.early_exit_steps 253 | else: 254 | num_steps = config.max_steps 255 | if int(step) >= num_steps: 256 | break 257 | last_step = step 258 | 259 | 260 | if __name__ == '__main__': 261 | with gin.config_scope('eval'): 262 | app.run(main) 263 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/database.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sqlite3 4 | 5 | #------------------------------------------------------------------------------- 6 | # convert SQLite BLOBs to/from numpy arrays 7 | 8 | 9 | def array_to_blob(arr): 10 | return np.getbuffer(arr) 11 | 12 | 13 | def blob_to_array(blob, dtype, shape=(-1,)): 14 | return np.frombuffer(blob, dtype).reshape(*shape) 15 | 16 | 17 | #------------------------------------------------------------------------------- 18 | # convert to/from image pair ids 19 | 20 | MAX_IMAGE_ID = 2**31 - 1 21 | 22 | 23 | def get_pair_id(image_id1, image_id2): 24 | if image_id1 > image_id2: 25 | image_id1, image_id2 = image_id2, image_id1 26 | return image_id1 * MAX_IMAGE_ID + image_id2 27 | 28 | 29 | def get_image_ids_from_pair_id(pair_id): 30 | image_id2 = pair_id % MAX_IMAGE_ID 31 | return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2 32 | 33 | 34 | #------------------------------------------------------------------------------- 35 | # create table commands 36 | 37 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 38 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 39 | model INTEGER NOT NULL, 40 | width INTEGER NOT NULL, 41 | height INTEGER NOT NULL, 42 | params BLOB, 43 | prior_focal_length INTEGER NOT NULL)""" 44 | 45 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 46 | image_id INTEGER PRIMARY KEY NOT NULL, 47 | rows INTEGER NOT NULL, 48 | cols INTEGER NOT NULL, 49 | data BLOB, 50 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 51 | 52 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 53 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 54 | name TEXT NOT NULL UNIQUE, 55 | camera_id INTEGER NOT NULL, 56 | prior_qw REAL, 57 | prior_qx REAL, 58 | prior_qy REAL, 59 | prior_qz REAL, 60 | prior_tx REAL, 61 | prior_ty REAL, 62 | prior_tz REAL, 63 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647), 64 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))""" 65 | 66 | CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries ( 67 | pair_id INTEGER PRIMARY KEY NOT NULL, 68 | rows INTEGER NOT NULL, 69 | cols INTEGER NOT NULL, 70 | data BLOB, 71 | config INTEGER NOT NULL, 72 | F BLOB, 73 | E BLOB, 74 | H BLOB)""" 75 | 76 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 77 | image_id INTEGER PRIMARY KEY NOT NULL, 78 | rows INTEGER NOT NULL, 79 | cols INTEGER NOT NULL, 80 | data BLOB, 81 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 82 | 83 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 84 | pair_id INTEGER PRIMARY KEY NOT NULL, 85 | rows INTEGER NOT NULL, 86 | cols INTEGER NOT NULL, 87 | data BLOB)""" 88 | 89 | CREATE_NAME_INDEX = \ 90 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 91 | 92 | CREATE_ALL = "; ".join([ 93 | CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, CREATE_IMAGES_TABLE, 94 | CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, CREATE_MATCHES_TABLE, 95 | CREATE_NAME_INDEX 96 | ]) 97 | 98 | #------------------------------------------------------------------------------- 99 | # functional interface for adding objects 100 | 101 | 102 | def add_camera(db, 103 | model, 104 | width, 105 | height, 106 | params, 107 | prior_focal_length=False, 108 | camera_id=None): 109 | # TODO: Parameter count checks 110 | params = np.asarray(params, np.float64) 111 | db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 112 | (camera_id, model, width, height, array_to_blob(params), 113 | prior_focal_length)) 114 | 115 | 116 | def add_descriptors(db, image_id, descriptors): 117 | descriptors = np.ascontiguousarray(descriptors, np.uint8) 118 | db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)", 119 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) 120 | 121 | 122 | def add_image(db, 123 | name, 124 | camera_id, 125 | prior_q=np.zeros(4), 126 | prior_t=np.zeros(3), 127 | image_id=None): 128 | db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 129 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], 130 | prior_q[3], prior_t[0], prior_t[1], prior_t[2])) 131 | 132 | 133 | # config: defaults to fundamental matrix 134 | def add_inlier_matches(db, 135 | image_id1, 136 | image_id2, 137 | matches, 138 | config=2, 139 | F=None, 140 | E=None, 141 | H=None): 142 | assert (len(matches.shape) == 2) 143 | assert (matches.shape[1] == 2) 144 | 145 | if image_id1 > image_id2: 146 | matches = matches[:, ::-1] 147 | 148 | if F is not None: 149 | F = np.asarray(F, np.float64) 150 | if E is not None: 151 | E = np.asarray(E, np.float64) 152 | if H is not None: 153 | H = np.asarray(H, np.float64) 154 | 155 | pair_id = get_pair_id(image_id1, image_id2) 156 | matches = np.asarray(matches, np.uint32) 157 | db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 158 | (pair_id,) + matches.shape + 159 | (array_to_blob(matches), config, F, E, H)) 160 | 161 | 162 | def add_keypoints(db, image_id, keypoints): 163 | assert (len(keypoints.shape) == 2) 164 | assert (keypoints.shape[1] in [2, 4, 6]) 165 | 166 | keypoints = np.asarray(keypoints, np.float32) 167 | db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)", 168 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) 169 | 170 | 171 | # config: defaults to fundamental matrix 172 | def add_matches(db, image_id1, image_id2, matches): 173 | assert (len(matches.shape) == 2) 174 | assert (matches.shape[1] == 2) 175 | 176 | if image_id1 > image_id2: 177 | matches = matches[:, ::-1] 178 | 179 | pair_id = get_pair_id(image_id1, image_id2) 180 | matches = np.asarray(matches, np.uint32) 181 | db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)", 182 | (pair_id,) + matches.shape + (array_to_blob(matches),)) 183 | 184 | 185 | #------------------------------------------------------------------------------- 186 | # simple functional interface 187 | 188 | 189 | class COLMAPDatabase(sqlite3.Connection): 190 | 191 | @staticmethod 192 | def connect(database_path): 193 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 194 | 195 | def __init__(self, *args, **kwargs): 196 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 197 | 198 | self.initialize_tables = lambda: self.executescript(CREATE_ALL) 199 | 200 | self.initialize_cameras = \ 201 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 202 | self.initialize_descriptors = \ 203 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 204 | self.initialize_images = \ 205 | lambda: self.executescript(CREATE_IMAGES_TABLE) 206 | self.initialize_inlier_matches = \ 207 | lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE) 208 | self.initialize_keypoints = \ 209 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 210 | self.initialize_matches = \ 211 | lambda: self.executescript(CREATE_MATCHES_TABLE) 212 | 213 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 214 | 215 | add_camera = add_camera 216 | add_descriptors = add_descriptors 217 | add_image = add_image 218 | add_inlier_matches = add_inlier_matches 219 | add_keypoints = add_keypoints 220 | add_matches = add_matches 221 | 222 | 223 | #------------------------------------------------------------------------------- 224 | 225 | 226 | def main(args): 227 | import os 228 | 229 | if os.path.exists(args.database_path): 230 | print("Error: database path already exists -- will not modify it.") 231 | exit() 232 | 233 | db = COLMAPDatabase.connect(args.database_path) 234 | 235 | # 236 | # for convenience, try creating all the tables upfront 237 | # 238 | 239 | db.initialize_tables() 240 | 241 | # 242 | # create dummy cameras 243 | # 244 | 245 | model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.)) 246 | model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1)) 247 | 248 | db.add_camera(model1, w1, h1, params1) 249 | db.add_camera(model2, w2, h2, params2) 250 | 251 | # 252 | # create dummy images 253 | # 254 | 255 | db.add_image("image1.png", 0) 256 | db.add_image("image2.png", 0) 257 | db.add_image("image3.png", 2) 258 | db.add_image("image4.png", 2) 259 | 260 | # 261 | # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y), 262 | # 4D keypoints (x, y, theta, scale), and 6D affine keypoints 263 | # (x, y, a_11, a_12, a_21, a_22) 264 | # 265 | 266 | N = 1000 267 | kp1 = np.random.rand(N, 2) * (1024., 768.) 268 | kp2 = np.random.rand(N, 2) * (1024., 768.) 269 | kp3 = np.random.rand(N, 2) * (1024., 768.) 270 | kp4 = np.random.rand(N, 2) * (1024., 768.) 271 | 272 | db.add_keypoints(1, kp1) 273 | db.add_keypoints(2, kp2) 274 | db.add_keypoints(3, kp3) 275 | db.add_keypoints(4, kp4) 276 | 277 | # 278 | # create dummy matches 279 | # 280 | 281 | M = 50 282 | m12 = np.random.randint(N, size=(M, 2)) 283 | m23 = np.random.randint(N, size=(M, 2)) 284 | m34 = np.random.randint(N, size=(M, 2)) 285 | 286 | db.add_matches(1, 2, m12) 287 | db.add_matches(2, 3, m23) 288 | db.add_matches(3, 4, m34) 289 | 290 | # 291 | # check cameras 292 | # 293 | 294 | rows = db.execute("SELECT * FROM cameras") 295 | 296 | camera_id, model, width, height, params, prior = next(rows) 297 | params = blob_to_array(params, np.float32) 298 | assert model == model1 and width == w1 and height == h1 299 | assert np.allclose(params, params1) 300 | 301 | camera_id, model, width, height, params, prior = next(rows) 302 | params = blob_to_array(params, np.float32) 303 | assert model == model2 and width == w2 and height == h2 304 | assert np.allclose(params, params2) 305 | 306 | # 307 | # check keypoints 308 | # 309 | 310 | kps = dict( 311 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 312 | for image_id, data in db.execute("SELECT image_id, data FROM keypoints")) 313 | 314 | assert np.allclose(kps[1], kp1) 315 | assert np.allclose(kps[2], kp2) 316 | assert np.allclose(kps[3], kp3) 317 | assert np.allclose(kps[4], kp4) 318 | 319 | # 320 | # check matches 321 | # 322 | 323 | pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]] 324 | 325 | matches = dict( 326 | (get_image_ids_from_pair_id(pair_id), 327 | blob_to_array(data, np.uint32, (-1, 2))) 328 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches")) 329 | 330 | assert np.all(matches[(1, 2)] == m12) 331 | assert np.all(matches[(2, 3)] == m23) 332 | assert np.all(matches[(3, 4)] == m34) 333 | 334 | # 335 | # clean up 336 | # 337 | 338 | db.close() 339 | os.remove(args.database_path) 340 | -------------------------------------------------------------------------------- /tests/coord_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for coord.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from internal import coord 20 | from internal import math 21 | import jax 22 | from jax import random 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def sample_covariance(rng, batch_size, num_dims): 28 | """Sample a random covariance matrix.""" 29 | half_cov = jax.random.normal(rng, [batch_size] + [num_dims] * 2) 30 | cov = math.matmul(half_cov, jnp.moveaxis(half_cov, -1, -2)) 31 | return cov 32 | 33 | 34 | def stable_pos_enc(x, n): 35 | """A stable pos_enc for very high degrees, courtesy of Sameer Agarwal.""" 36 | sin_x = np.sin(x) 37 | cos_x = np.cos(x) 38 | output = [] 39 | rotmat = np.array([[cos_x, -sin_x], [sin_x, cos_x]], dtype='double') 40 | for _ in range(n): 41 | output.append(rotmat[::-1, 0, :]) 42 | rotmat = np.einsum('ijn,jkn->ikn', rotmat, rotmat) 43 | return np.reshape(np.transpose(np.stack(output, 0), [2, 1, 0]), [-1, 2 * n]) 44 | 45 | 46 | class CoordTest(parameterized.TestCase): 47 | 48 | def test_stable_pos_enc(self): 49 | """Test that the stable posenc implementation works on multiples of pi/2.""" 50 | n = 10 51 | x = np.linspace(-np.pi, np.pi, 5) 52 | z = stable_pos_enc(x, n).reshape([-1, 2, n]) 53 | z0_true = np.zeros_like(z[:, 0, :]) 54 | z1_true = np.ones_like(z[:, 1, :]) 55 | z0_true[:, 0] = [0, -1, 0, 1, 0] 56 | z1_true[:, 0] = [-1, 0, 1, 0, -1] 57 | z1_true[:, 1] = [1, -1, 1, -1, 1] 58 | z_true = np.stack([z0_true, z1_true], axis=1) 59 | np.testing.assert_allclose(z, z_true, atol=1e-10) 60 | 61 | def test_contract_matches_special_case(self): 62 | """Test the math for Figure 2 of https://arxiv.org/abs/2111.12077.""" 63 | n = 10 64 | _, s_to_t = coord.construct_ray_warps(jnp.reciprocal, 1, jnp.inf) 65 | s = jnp.linspace(0, 1 - jnp.finfo(jnp.float32).eps, n + 1) 66 | tc = coord.contract(s_to_t(s)[:, None])[:, 0] 67 | delta_tc = tc[1:] - tc[:-1] 68 | np.testing.assert_allclose( 69 | delta_tc, np.full_like(delta_tc, 1 / n), atol=1E-5, rtol=1E-5) 70 | 71 | def test_contract_is_bounded(self): 72 | n, d = 10000, 3 73 | rng = random.PRNGKey(0) 74 | key0, key1, rng = random.split(rng, 3) 75 | x = jnp.where(random.bernoulli(key0, shape=[n, d]), 1, -1) * jnp.exp( 76 | random.uniform(key1, [n, d], minval=-10, maxval=10)) 77 | y = coord.contract(x) 78 | self.assertLessEqual(jnp.max(y), 2) 79 | 80 | def test_contract_is_noop_when_norm_is_leq_one(self): 81 | n, d = 10000, 3 82 | rng = random.PRNGKey(0) 83 | key, rng = random.split(rng) 84 | x = random.normal(key, shape=[n, d]) 85 | xc = x / jnp.maximum(1, jnp.linalg.norm(x, axis=-1, keepdims=True)) 86 | 87 | # Sanity check on the test itself. 88 | assert jnp.abs(jnp.max(jnp.linalg.norm(xc, axis=-1)) - 1) < 1e-6 89 | 90 | yc = coord.contract(xc) 91 | np.testing.assert_allclose(xc, yc, atol=1E-5, rtol=1E-5) 92 | 93 | def test_contract_gradients_are_finite(self): 94 | # Construct x such that we probe x == 0, where things are unstable. 95 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 96 | grad = jax.grad(lambda x: jnp.sum(coord.contract(x)))(x) 97 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 98 | 99 | def test_inv_contract_gradients_are_finite(self): 100 | z = jnp.stack(jnp.meshgrid(*[jnp.linspace(-2, 2, 21)] * 2), axis=-1) 101 | z = z.reshape([-1, 2]) 102 | z = z[jnp.sum(z**2, axis=-1) < 2, :] 103 | grad = jax.grad(lambda z: jnp.sum(coord.inv_contract(z)))(z) 104 | self.assertTrue(jnp.all(jnp.isfinite(grad))) 105 | 106 | def test_inv_contract_inverts_contract(self): 107 | """Do a round-trip from metric space to contracted space and back.""" 108 | x = jnp.stack(jnp.meshgrid(*[jnp.linspace(-4, 4, 11)] * 2), axis=-1) 109 | x_recon = coord.inv_contract(coord.contract(x)) 110 | np.testing.assert_allclose(x, x_recon, atol=1E-5, rtol=1E-5) 111 | 112 | @parameterized.named_parameters( 113 | ('05_1e-5', 5, 1e-5), 114 | ('10_1e-4', 10, 1e-4), 115 | ('15_0.005', 15, 0.005), 116 | ('20_0.2', 20, 0.2), # At high degrees, our implementation is unstable. 117 | ('25_2', 25, 2), # 2 is the maximum possible error. 118 | ('30_2', 30, 2), 119 | ) 120 | def test_pos_enc(self, n, tol): 121 | """test pos_enc against a stable recursive implementation.""" 122 | x = np.linspace(-np.pi, np.pi, 10001) 123 | z = coord.pos_enc(x[:, None], 0, n, append_identity=False) 124 | z_stable = stable_pos_enc(x, n) 125 | max_err = np.max(np.abs(z - z_stable)) 126 | print(f'PE of degree {n} has a maximum error of {max_err}') 127 | self.assertLess(max_err, tol) 128 | 129 | def test_pos_enc_matches_integrated(self): 130 | """Integrated positional encoding with a variance of zero must be pos_enc.""" 131 | min_deg = 0 132 | max_deg = 10 133 | np.linspace(-jnp.pi, jnp.pi, 10) 134 | x = jnp.stack( 135 | jnp.meshgrid(*[np.linspace(-jnp.pi, jnp.pi, 10)] * 2), axis=-1) 136 | x = np.linspace(-jnp.pi, jnp.pi, 10000) 137 | z_ipe = coord.integrated_pos_enc(x, jnp.zeros_like(x), min_deg, max_deg) 138 | z_pe = coord.pos_enc(x, min_deg, max_deg, append_identity=False) 139 | # We're using a pretty wide tolerance because IPE uses safe_sin(). 140 | np.testing.assert_allclose(z_pe, z_ipe, atol=1e-4) 141 | 142 | def test_track_linearize(self): 143 | rng = random.PRNGKey(0) 144 | batch_size = 20 145 | for _ in range(30): 146 | # Construct some random Gaussians with dimensionalities in [1, 10]. 147 | key, rng = random.split(rng) 148 | in_dims = random.randint(key, (), 1, 10) 149 | key, rng = random.split(rng) 150 | mean = jax.random.normal(key, [batch_size, in_dims]) 151 | key, rng = random.split(rng) 152 | cov = sample_covariance(key, batch_size, in_dims) 153 | key, rng = random.split(rng) 154 | out_dims = random.randint(key, (), 1, 10) 155 | 156 | # Construct a random affine transformation. 157 | key, rng = random.split(rng) 158 | a_mat = jax.random.normal(key, [out_dims, in_dims]) 159 | key, rng = random.split(rng) 160 | b = jax.random.normal(key, [out_dims]) 161 | 162 | def fn(x): 163 | x_vec = x.reshape([-1, x.shape[-1]]) 164 | y_vec = jax.vmap(lambda z: math.matmul(a_mat, z))(x_vec) + b # pylint:disable=cell-var-from-loop 165 | y = y_vec.reshape(list(x.shape[:-1]) + [y_vec.shape[-1]]) 166 | return y 167 | 168 | # Apply the affine function to the Gaussians. 169 | fn_mean_true = fn(mean) 170 | fn_cov_true = math.matmul(math.matmul(a_mat, cov), a_mat.T) 171 | 172 | # Tracking the Gaussians through a linearized function of a linear 173 | # operator should be the same. 174 | fn_mean, fn_cov = coord.track_linearize(fn, mean, cov) 175 | np.testing.assert_allclose(fn_mean, fn_mean_true, atol=1E-5, rtol=1E-5) 176 | np.testing.assert_allclose(fn_cov, fn_cov_true, atol=1e-5, rtol=1e-5) 177 | 178 | @parameterized.named_parameters(('reciprocal', jnp.reciprocal), 179 | ('log', jnp.log), ('sqrt', jnp.sqrt)) 180 | def test_construct_ray_warps_extents(self, fn): 181 | n = 100 182 | rng = random.PRNGKey(0) 183 | key, rng = random.split(rng) 184 | t_near = jnp.exp(jax.random.normal(key, [n])) 185 | key, rng = random.split(rng) 186 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 187 | 188 | t_to_s, s_to_t = coord.construct_ray_warps(fn, t_near, t_far) 189 | 190 | np.testing.assert_allclose( 191 | t_to_s(t_near), jnp.zeros_like(t_near), atol=1E-5, rtol=1E-5) 192 | np.testing.assert_allclose( 193 | t_to_s(t_far), jnp.ones_like(t_far), atol=1E-5, rtol=1E-5) 194 | np.testing.assert_allclose( 195 | s_to_t(jnp.zeros_like(t_near)), t_near, atol=1E-5, rtol=1E-5) 196 | np.testing.assert_allclose( 197 | s_to_t(jnp.ones_like(t_near)), t_far, atol=1E-5, rtol=1E-5) 198 | 199 | def test_construct_ray_warps_special_reciprocal(self): 200 | """Test fn=1/x against its closed form.""" 201 | n = 100 202 | rng = random.PRNGKey(0) 203 | key, rng = random.split(rng) 204 | t_near = jnp.exp(jax.random.normal(key, [n])) 205 | key, rng = random.split(rng) 206 | t_far = t_near + jnp.exp(jax.random.normal(key, [n])) 207 | 208 | key, rng = random.split(rng) 209 | u = jax.random.uniform(key, [n]) 210 | t = t_near * (1 - u) + t_far * u 211 | key, rng = random.split(rng) 212 | s = jax.random.uniform(key, [n]) 213 | 214 | t_to_s, s_to_t = coord.construct_ray_warps(jnp.reciprocal, t_near, t_far) 215 | 216 | # Special cases for fn=reciprocal. 217 | s_to_t_ref = lambda s: 1 / (s / t_far + (1 - s) / t_near) 218 | t_to_s_ref = lambda t: (t_far * (t - t_near)) / (t * (t_far - t_near)) 219 | 220 | np.testing.assert_allclose(t_to_s(t), t_to_s_ref(t), atol=1E-5, rtol=1E-5) 221 | np.testing.assert_allclose(s_to_t(s), s_to_t_ref(s), atol=1E-5, rtol=1E-5) 222 | 223 | def test_expected_sin(self): 224 | normal_samples = random.normal(random.PRNGKey(0), (10000,)) 225 | for mu, var in [(0, 1), (1, 3), (-2, .2), (10, 10)]: 226 | sin_mu = coord.expected_sin(mu, var) 227 | x = jnp.sin(jnp.sqrt(var) * normal_samples + mu) 228 | np.testing.assert_allclose(sin_mu, jnp.mean(x), atol=1e-2) 229 | 230 | def test_integrated_pos_enc(self): 231 | num_dims = 2 # The number of input dimensions. 232 | min_deg = 0 # Must be 0 for this test to work. 233 | max_deg = 4 234 | num_samples = 100000 235 | rng = random.PRNGKey(0) 236 | for _ in range(5): 237 | # Generate a coordinate's mean and covariance matrix. 238 | key, rng = random.split(rng) 239 | mean = random.normal(key, (2,)) 240 | key, rng = random.split(rng) 241 | half_cov = jax.random.normal(key, [num_dims] * 2) 242 | cov = half_cov @ half_cov.T 243 | var = jnp.diag(cov) 244 | # Generate an IPE. 245 | enc = coord.integrated_pos_enc( 246 | mean, 247 | var, 248 | min_deg, 249 | max_deg, 250 | ) 251 | 252 | # Draw samples, encode them, and take their mean. 253 | key, rng = random.split(rng) 254 | samples = random.multivariate_normal(key, mean, cov, [num_samples]) 255 | assert min_deg == 0 256 | enc_samples = np.concatenate( 257 | [stable_pos_enc(x, max_deg) for x in tuple(samples.T)], axis=-1) 258 | # Correct for a different dimension ordering in stable_pos_enc. 259 | enc_gt = jnp.mean(enc_samples, 0) 260 | enc_gt = enc_gt.reshape([num_dims, max_deg * 2]).T.reshape([-1]) 261 | np.testing.assert_allclose(enc, enc_gt, rtol=1e-2, atol=1e-2) 262 | 263 | 264 | if __name__ == '__main__': 265 | absltest.main() 266 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /internal/pycolmap/pycolmap/rotation.py: -------------------------------------------------------------------------------- 1 | # Author: True Price 2 | 3 | import numpy as np 4 | 5 | #------------------------------------------------------------------------------- 6 | # 7 | # Axis-Angle Functions 8 | # 9 | #------------------------------------------------------------------------------- 10 | 11 | 12 | # returns the cross product matrix representation of a 3-vector v 13 | def cross_prod_matrix(v): 14 | return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.))) 15 | 16 | 17 | #------------------------------------------------------------------------------- 18 | 19 | 20 | # www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/ 21 | # if angle is None, assume ||axis|| == angle, in radians 22 | # if angle is not None, assume that axis is a unit vector 23 | def axis_angle_to_rotation_matrix(axis, angle=None): 24 | if angle is None: 25 | angle = np.linalg.norm(axis) 26 | if np.abs(angle) > np.finfo('float').eps: 27 | axis = axis / angle 28 | 29 | cp_axis = cross_prod_matrix(axis) 30 | return np.eye(3) + ( 31 | np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis)) 32 | 33 | 34 | #------------------------------------------------------------------------------- 35 | 36 | 37 | # after some deliberation, I've decided the easiest way to do this is to use 38 | # quaternions as an intermediary 39 | def rotation_matrix_to_axis_angle(R): 40 | return Quaternion.FromR(R).ToAxisAngle() 41 | 42 | 43 | #------------------------------------------------------------------------------- 44 | # 45 | # Quaternion 46 | # 47 | #------------------------------------------------------------------------------- 48 | 49 | 50 | class Quaternion: 51 | # create a quaternion from an existing rotation matrix 52 | # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/ 53 | @staticmethod 54 | def FromR(R): 55 | trace = np.trace(R) 56 | 57 | if trace > 0: 58 | qw = 0.5 * np.sqrt(1. + trace) 59 | qx = (R[2, 1] - R[1, 2]) * 0.25 / qw 60 | qy = (R[0, 2] - R[2, 0]) * 0.25 / qw 61 | qz = (R[1, 0] - R[0, 1]) * 0.25 / qw 62 | elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: 63 | s = 2. * np.sqrt(1. + R[0, 0] - R[1, 1] - R[2, 2]) 64 | qw = (R[2, 1] - R[1, 2]) / s 65 | qx = 0.25 * s 66 | qy = (R[0, 1] + R[1, 0]) / s 67 | qz = (R[0, 2] + R[2, 0]) / s 68 | elif R[1, 1] > R[2, 2]: 69 | s = 2. * np.sqrt(1. + R[1, 1] - R[0, 0] - R[2, 2]) 70 | qw = (R[0, 2] - R[2, 0]) / s 71 | qx = (R[0, 1] + R[1, 0]) / s 72 | qy = 0.25 * s 73 | qz = (R[1, 2] + R[2, 1]) / s 74 | else: 75 | s = 2. * np.sqrt(1. + R[2, 2] - R[0, 0] - R[1, 1]) 76 | qw = (R[1, 0] - R[0, 1]) / s 77 | qx = (R[0, 2] + R[2, 0]) / s 78 | qy = (R[1, 2] + R[2, 1]) / s 79 | qz = 0.25 * s 80 | 81 | return Quaternion(np.array((qw, qx, qy, qz))) 82 | 83 | # if angle is None, assume ||axis|| == angle, in radians 84 | # if angle is not None, assume that axis is a unit vector 85 | @staticmethod 86 | def FromAxisAngle(axis, angle=None): 87 | if angle is None: 88 | angle = np.linalg.norm(axis) 89 | if np.abs(angle) > np.finfo('float').eps: 90 | axis = axis / angle 91 | 92 | qw = np.cos(0.5 * angle) 93 | axis = axis * np.sin(0.5 * angle) 94 | 95 | return Quaternion(np.array((qw, axis[0], axis[1], axis[2]))) 96 | 97 | #--------------------------------------------------------------------------- 98 | 99 | def __init__(self, q=np.array((1., 0., 0., 0.))): 100 | if isinstance(q, Quaternion): 101 | self.q = q.q.copy() 102 | else: 103 | q = np.asarray(q) 104 | if q.size == 4: 105 | self.q = q.copy() 106 | elif q.size == 3: # convert from a 3-vector to a quaternion 107 | self.q = np.empty(4) 108 | self.q[0], self.q[1:] = 0., q.ravel() 109 | else: 110 | raise Exception('Input quaternion should be a 3- or 4-vector') 111 | 112 | def __add__(self, other): 113 | return Quaternion(self.q + other.q) 114 | 115 | def __iadd__(self, other): 116 | self.q += other.q 117 | return self 118 | 119 | # conjugation via the ~ operator 120 | def __invert__(self): 121 | return Quaternion(np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3]))) 122 | 123 | # returns: self.q * other.q if other is a Quaternion; otherwise performs 124 | # scalar multiplication 125 | def __mul__(self, other): 126 | if isinstance(other, Quaternion): # quaternion multiplication 127 | return Quaternion( 128 | np.array((self.q[0] * other.q[0] - self.q[1] * other.q[1] - 129 | self.q[2] * other.q[2] - self.q[3] * other.q[3], 130 | self.q[0] * other.q[1] + self.q[1] * other.q[0] + 131 | self.q[2] * other.q[3] - self.q[3] * other.q[2], 132 | self.q[0] * other.q[2] - self.q[1] * other.q[3] + 133 | self.q[2] * other.q[0] + self.q[3] * other.q[1], 134 | self.q[0] * other.q[3] + self.q[1] * other.q[2] - 135 | self.q[2] * other.q[1] + self.q[3] * other.q[0]))) 136 | else: # scalar multiplication (assumed) 137 | return Quaternion(other * self.q) 138 | 139 | def __rmul__(self, other): 140 | return self * other 141 | 142 | def __imul__(self, other): 143 | self.q[:] = (self * other).q 144 | return self 145 | 146 | def __irmul__(self, other): 147 | self.q[:] = (self * other).q 148 | return self 149 | 150 | def __neg__(self): 151 | return Quaternion(-self.q) 152 | 153 | def __sub__(self, other): 154 | return Quaternion(self.q - other.q) 155 | 156 | def __isub__(self, other): 157 | self.q -= other.q 158 | return self 159 | 160 | def __str__(self): 161 | return str(self.q) 162 | 163 | def copy(self): 164 | return Quaternion(self) 165 | 166 | def dot(self, other): 167 | return self.q.dot(other.q) 168 | 169 | # assume the quaternion is nonzero! 170 | def inverse(self): 171 | return Quaternion((~self).q / self.q.dot(self.q)) 172 | 173 | def norm(self): 174 | return np.linalg.norm(self.q) 175 | 176 | def normalize(self): 177 | self.q /= np.linalg.norm(self.q) 178 | return self 179 | 180 | # assume x is a Nx3 numpy array or a numpy 3-vector 181 | def rotate_points(self, x): 182 | x = np.atleast_2d(x) 183 | return x.dot(self.ToR().T) 184 | 185 | # convert to a rotation matrix 186 | def ToR(self): 187 | return np.eye(3) + 2 * np.array(( 188 | (-self.q[2] * self.q[2] - self.q[3] * self.q[3], self.q[1] * self.q[2] - 189 | self.q[3] * self.q[0], self.q[1] * self.q[3] + self.q[2] * self.q[0]), 190 | (self.q[1] * self.q[2] + self.q[3] * self.q[0], -self.q[1] * self.q[1] - 191 | self.q[3] * self.q[3], self.q[2] * self.q[3] - self.q[1] * self.q[0]), 192 | (self.q[1] * self.q[3] - self.q[2] * self.q[0], 193 | self.q[2] * self.q[3] + self.q[1] * self.q[0], 194 | -self.q[1] * self.q[1] - self.q[2] * self.q[2]))) 195 | 196 | # convert to axis-angle representation, with angle encoded by the length 197 | def ToAxisAngle(self): 198 | # recall that for axis-angle representation (a, angle), with "a" unit: 199 | # q = (cos(angle/2), a * sin(angle/2)) 200 | # below, for readability, "theta" actually means half of the angle 201 | 202 | sin_sq_theta = self.q[1:].dot(self.q[1:]) 203 | 204 | # if theta is non-zero, then we can compute a unique rotation 205 | if np.abs(sin_sq_theta) > np.finfo('float').eps: 206 | sin_theta = np.sqrt(sin_sq_theta) 207 | cos_theta = self.q[0] 208 | 209 | # atan2 is more stable, so we use it to compute theta 210 | # note that we multiply by 2 to get the actual angle 211 | angle = 2. * ( 212 | np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else np.arctan2( 213 | sin_theta, cos_theta)) 214 | 215 | return self.q[1:] * (angle / sin_theta) 216 | 217 | # otherwise, the result is singular, and we avoid dividing by 218 | # sin(angle/2) = 0 219 | return np.zeros(3) 220 | 221 | # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler 222 | # this assumes the quaternion is non-zero 223 | # returns yaw, pitch, roll, with application in that order 224 | def ToEulerAngles(self): 225 | qsq = self.q**2 226 | k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum() 227 | 228 | if (1. - k) < np.finfo('float').eps: # north pole singularity 229 | return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0. 230 | if (1. + k) < np.finfo('float').eps: # south pole singularity 231 | return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0. 232 | 233 | yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]), 234 | qsq[0] + qsq[1] - qsq[2] - qsq[3]) 235 | pitch = np.arcsin(k) 236 | roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]), 237 | qsq[0] - qsq[1] + qsq[2] - qsq[3]) 238 | 239 | return yaw, pitch, roll 240 | 241 | 242 | #------------------------------------------------------------------------------- 243 | # 244 | # DualQuaternion 245 | # 246 | #------------------------------------------------------------------------------- 247 | 248 | 249 | class DualQuaternion: 250 | # DualQuaternion from an existing rotation + translation 251 | @staticmethod 252 | def FromQT(q, t): 253 | return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q) 254 | 255 | def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)): 256 | self.q0, self.qe = Quaternion(q0), Quaternion(qe) 257 | 258 | def __add__(self, other): 259 | return DualQuaternion(self.q0 + other.q0, self.qe + other.qe) 260 | 261 | def __iadd__(self, other): 262 | self.q0 += other.q0 263 | self.qe += other.qe 264 | return self 265 | 266 | # conguation via the ~ operator 267 | def __invert__(self): 268 | return DualQuaternion(~self.q0, ~self.qe) 269 | 270 | def __mul__(self, other): 271 | if isinstance(other, DualQuaternion): 272 | return DualQuaternion(self.q0 * other.q0, 273 | self.q0 * other.qe + self.qe * other.q0) 274 | elif isinstance(other, complex): # multiplication by a dual number 275 | return DualQuaternion(self.q0 * other.real, 276 | self.q0 * other.imag + self.qe * other.real) 277 | else: # scalar multiplication (assumed) 278 | return DualQuaternion(other * self.q0, other * self.qe) 279 | 280 | def __rmul__(self, other): 281 | return self.__mul__(other) 282 | 283 | def __imul__(self, other): 284 | tmp = self * other 285 | self.q0, self.qe = tmp.q0, tmp.qe 286 | return self 287 | 288 | def __neg__(self): 289 | return DualQuaternion(-self.q0, -self.qe) 290 | 291 | def __sub__(self, other): 292 | return DualQuaternion(self.q0 - other.q0, self.qe - other.qe) 293 | 294 | def __isub__(self, other): 295 | self.q0 -= other.q0 296 | self.qe -= other.qe 297 | return self 298 | 299 | # q^-1 = q* / ||q||^2 300 | # assume that q0 is nonzero! 301 | def inverse(self): 302 | normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q)) 303 | inv_len_real = 1. / normsq.real 304 | return ~self * complex(inv_len_real, 305 | -normsq.imag * inv_len_real * inv_len_real) 306 | 307 | # returns a complex representation of the real and imaginary parts of the norm 308 | # assume that q0 is nonzero! 309 | def norm(self): 310 | q0_norm = self.q0.norm() 311 | return complex(q0_norm, self.q0.dot(self.qe) / q0_norm) 312 | 313 | # assume that q0 is nonzero! 314 | def normalize(self): 315 | # current length is ||q0|| + eps * ( / ||q0||) 316 | # writing this as a + eps * b, the inverse is 317 | # 1/||q|| = 1/a - eps * b / a^2 318 | norm = self.norm() 319 | inv_len_real = 1. / norm.real 320 | self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real) 321 | return self 322 | 323 | # return the translation vector for this dual quaternion 324 | def getT(self): 325 | return 2 * (self.qe * ~self.q0).q[1:] 326 | 327 | def ToQT(self): 328 | return self.q0, self.getT() 329 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiNeRF: A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF 2 | 3 | *This is not an officially supported Google product.* 4 | 5 | This repository contains the code release for three CVPR 2022 papers: 6 | [Mip-NeRF 360](https://jonbarron.info/mipnerf360/), 7 | [Ref-NeRF](https://dorverbin.github.io/refnerf/), and 8 | [RawNeRF](https://bmild.github.io/rawnerf/). 9 | This codebase was written by 10 | integrating our internal implementations of Ref-NeRF and RawNeRF into our 11 | mip-NeRF 360 implementation. As such, this codebase should exactly 12 | reproduce the results shown in mip-NeRF 360, but may differ slightly when 13 | reproducing Ref-NeRF or RawNeRF results. 14 | 15 | This implementation is written in [JAX](https://github.com/google/jax), and 16 | is a fork of [mip-NeRF](https://github.com/google/mipnerf). 17 | This is research code, and should be treated accordingly. 18 | 19 | ## Setup 20 | 21 | ``` 22 | # Clone the repo. 23 | git clone https://github.com/google-research/multinerf.git 24 | cd multinerf 25 | 26 | # Make a conda environment. 27 | conda create --name multinerf python=3.9 28 | conda activate multinerf 29 | 30 | # Prepare pip. 31 | conda install pip 32 | pip install --upgrade pip 33 | 34 | # Install requirements. 35 | pip install -r requirements.txt 36 | 37 | # Manually install rmbrualla's `pycolmap` (don't use pip's! It's different). 38 | git clone https://github.com/rmbrualla/pycolmap.git ./internal/pycolmap 39 | 40 | # Confirm that all the unit tests pass. 41 | ./scripts/run_all_unit_tests.sh 42 | ``` 43 | You'll probably also need to update your JAX installation to support GPUs or TPUs. 44 | 45 | ## Running 46 | 47 | Example scripts for training, evaluating, and rendering can be found in 48 | `scripts/`. You'll need to change the paths to point to wherever the datasets 49 | are located. [Gin](https://github.com/google/gin-config) configuration files 50 | for our model and some ablations can be found in `configs/`. 51 | After evaluating on the test set of each scene in one of the datasets, you can 52 | use `scripts/generate_tables.ipynb` to produce error metrics across all scenes 53 | in the same format as was used in tables in the paper. 54 | 55 | ### OOM errors 56 | 57 | You may need to reduce the batch size (`Config.batch_size`) to avoid out of memory 58 | errors. If you do this, but want to preserve quality, be sure to increase the number 59 | of training iterations and decrease the learning rate by whatever scale factor you 60 | decrease batch size by. 61 | 62 | ## Using your own data 63 | 64 | ### Running COLMAP to get camera poses 65 | 66 | In order to run MultiNeRF on your own captured images of a scene, you must first run [COLMAP](https://colmap.github.io/install.html) to calculate camera poses. You can do this using our provided script `scripts/local_colmap_and_resize.sh`. Just make a directory `my_dataset_dir/` and copy your input images into a folder `my_dataset_dir/images/`, then run: 67 | ``` 68 | bash scripts/local_colmap_and_resize.sh my_dataset_dir 69 | ``` 70 | This will run COLMAP and create 2x, 4x, and 8x downsampled versions of your images. These lower resolution images can be used in NeRF by setting, e.g., the `Config.factor = 4` gin flag. 71 | 72 | By default, `local_colmap_and_resize.sh` uses the OPENCV camera model, which is a perspective pinhole camera with k1, k2 radial and t1, t2 tangential distortion coefficients. To switch to another COLMAP camera model, for example OPENCV_FISHEYE, you can run 73 | ``` 74 | bash scripts/local_colmap_and_resize.sh my_dataset_dir OPENCV_FISHEYE 75 | ``` 76 | 77 | If you have a very large capture of more than around 500 images, we recommend switching from the exhaustive matcher to the vocabulary tree matcher in COLMAP (see the script for a commented-out example). 78 | 79 | Our script is simply a thin wrapper for COLMAP--if you have run COLMAP yourself, all you need to do to load your scene in NeRF is ensure it has the following format: 80 | ``` 81 | my_dataset_dir/images/ <--- all input images 82 | my_dataset_dir/sparse/0/ <--- COLMAP sparse reconstruction files (cameras, images, points) 83 | ``` 84 | 85 | ### Writing a custom dataloader 86 | 87 | If you already have poses for your own data, you may prefer to write your own custom dataloader. 88 | 89 | MultiNeRF includes a variety of dataloaders, all of which inherit from the 90 | base 91 | [Dataset class](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L152). 92 | 93 | The job of this class is to load all image and pose information from disk, then 94 | create batches of ray and color data for training or rendering a NeRF model. 95 | 96 | Any inherited subclass is responsible for loading images and camera poses from 97 | disk by implementing the `_load_renderings` method (which is marked as 98 | abstract by the decorator `@abc.abstractmethod`). This data is then used to 99 | generate train and test batches of ray + color data for feeding through the NeRF 100 | model. The ray parameters are calculated in `_make_ray_batch`. 101 | 102 | #### Existing data loaders 103 | 104 | To work from an example, you can see how this function is overloaded for the 105 | different dataloaders we have already implemented: 106 | 107 | - [Blender](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L470) 108 | - [DTU dataset](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L793) 109 | - [Tanks and Temples](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L680), 110 | as processed by the NeRF++ paper 111 | - [Tanks and Temples](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L728), 112 | as processed by the Free View Synthesis paper 113 | 114 | The main data loader we rely on is 115 | [LLFF](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L526) 116 | (named for historical reasons), which is the loader for a dataset that has been 117 | posed by COLMAP. 118 | 119 | #### Making your own loader by implementing `_load_renderings` 120 | 121 | To make a new dataset, make a class inheriting from `Dataset` and overload the 122 | `_load_renderings` method: 123 | 124 | ``` 125 | class MyNewDataset(Dataset): 126 | def _load_renderings(self, config): 127 | ... 128 | ``` 129 | 130 | In this function, you **must** set the following public attributes: 131 | 132 | - images 133 | - camtoworlds 134 | - pixtocams 135 | - height, width 136 | 137 | Many of our dataset loaders also set other useful attributes, but these are the 138 | critical ones for generating rays. You can see how they are used (along with a batch of pixel coordinates) to create rays in [`camera_utils.pixels_to_rays`](https://github.com/google-research/multinerf/blob/main/internal/camera_utils.py#L520). 139 | 140 | **Images** 141 | 142 | `images` = [N, height, width, 3] numpy array of RGB images. Currently we 143 | require all images to have the same resolution. 144 | 145 | **Extrinsic camera poses** 146 | 147 | `camtoworlds` = [N, 3, 4] numpy array of extrinsic pose matrices. 148 | `camtoworlds[i]` should be in **camera-to-world** format, such that we can run 149 | 150 | ``` 151 | pose = camtoworlds[i] 152 | x_world = pose[:3, :3] @ x_camera + pose[:3, 3:4] 153 | ``` 154 | 155 | to convert a 3D camera space point `x_camera` into a world space point `x_world`. 156 | 157 | These matrices must be stored in the **OpenGL** coordinate system convention for camera rotation: 158 | x-axis to the right, y-axis upward, and z-axis backward along the camera's focal 159 | axis. 160 | 161 | The most common conventions are 162 | 163 | - `[right, up, backwards]`: OpenGL, NeRF, most graphics code. 164 | - `[right, down, forwards]`: OpenCV, COLMAP, most computer vision code. 165 | 166 | Fortunately switching from OpenCV/COLMAP to NeRF is 167 | [simple](https://github.com/google-research/multinerf/blob/main/internal/datasets.py#L108): 168 | you just need to right-multiply the OpenCV pose matrices by `np.diag([1, -1, -1, 1])`, 169 | which will flip the sign of the y-axis (from down to up) and z-axis (from 170 | forwards to backwards): 171 | ``` 172 | camtoworlds_opengl = camtoworlds_opencv @ np.diag([1, -1, -1, 1]) 173 | ``` 174 | 175 | You may also want to **scale** your camera pose translations such that they all 176 | lie within the `[-1, 1]^3` cube for best performance with the default mipnerf360 177 | config files. 178 | 179 | We provide a useful helper function [`camera_utils.transform_poses_pca`](https://github.com/google-research/multinerf/blob/main/internal/camera_utils.py#L191) that computes a translation/rotation/scaling transform for the input poses that aligns the world space x-y plane with the ground (based on PCA) and scales the scene so that all input pose positions lie within `[-1, 1]^3`. (This function is applied by default when loading mip-NeRF 360 scenes with the LLFF data loader.) For a scene where this transformation has been applied, [`camera_utils.generate_ellipse_path`](https://github.com/google-research/multinerf/blob/main/internal/camera_utils.py#L230) can be used to generate a nice elliptical camera path for rendering videos. 180 | 181 | **Intrinsic camera poses** 182 | 183 | `pixtocams`= [N, 3, 4] numpy array of inverse intrinsic matrices, OR [3, 4] 184 | numpy array of a single shared inverse intrinsic matrix. These should be in 185 | **OpenCV** format, e.g. 186 | 187 | ``` 188 | camtopix = np.array([ 189 | [focal, 0, width/2], 190 | [ 0, focal, height/2], 191 | [ 0, 0, 1], 192 | ]) 193 | pixtocam = np.linalg.inv(camtopix) 194 | ``` 195 | 196 | Given a focal length and image size (and assuming a centered principal point, 197 | this matrix can be created using 198 | [`camera_utils.get_pixtocam`](https://github.com/google-research/multinerf/blob/main/internal/camera_utils.py#L411). 199 | 200 | Alternatively, it can be created by using 201 | [`camera_utils.intrinsic_matrix`](https://github.com/google-research/multinerf/blob/main/internal/camera_utils.py#L398) 202 | and inverting the resulting matrix. 203 | 204 | **Resolution** 205 | 206 | `height` = int, height of images. 207 | 208 | `width` = int, width of images. 209 | 210 | **Distortion parameters (optional)** 211 | 212 | `distortion_params` = dict, camera lens distortion model parameters. This 213 | dictionary must map from strings -> floats, and the allowed keys are `['k1', 214 | 'k2', 'k3', 'k4', 'p1', 'p2']` (up to four radial coefficients and up to two 215 | tangential coefficients). By default, this is set to the empty dictionary `{}`, 216 | in which case undistortion is not run. 217 | 218 | ### Details of the inner workings of Dataset 219 | 220 | The public interface mimics the behavior of a standard machine learning pipeline 221 | dataset provider that can provide infinite batches of data to the 222 | training/testing pipelines without exposing any details of how the batches are 223 | loaded/created or how this is parallelized. Therefore, the initializer runs all 224 | setup, including data loading from disk using `_load_renderings`, and begins 225 | the thread using its parent start() method. After the initializer returns, the 226 | caller can request batches of data straight away. 227 | 228 | The internal `self._queue` is initialized as `queue.Queue(3)`, so the infinite 229 | loop in `run()` will block on the call `self._queue.put(self._next_fn())` once 230 | there are 3 elements. The main thread training job runs in a loop that pops 1 231 | element at a time off the front of the queue. The Dataset thread's `run()` loop 232 | will populate the queue with 3 elements, then wait until a batch has been 233 | removed and push one more onto the end. 234 | 235 | This repeats indefinitely until the main thread's training loop completes 236 | (typically hundreds of thousands of iterations), then the main thread will exit 237 | and the Dataset thread will automatically be killed since it is a daemon. 238 | 239 | 240 | ## Citation 241 | If you use this software package, please cite whichever constituent paper(s) 242 | you build upon, or feel free to cite this entire codebase as: 243 | 244 | ``` 245 | @misc{multinerf2022, 246 | title={MultiNeRF: A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF}, 247 | author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron}, 248 | year={2022}, 249 | } 250 | ``` 251 | --------------------------------------------------------------------------------